Sync to upstream/release/565

This commit is contained in:
Andy Friesen 2023-02-24 10:24:22 -08:00
parent 5c77305609
commit 1e7b23fbfc
62 changed files with 3500 additions and 2429 deletions

View file

@ -159,6 +159,20 @@ struct SetPropConstraint
TypeId propType;
};
// result ~ setIndexer subjectType indexType propType
//
// If the subject is a table or table-like thing that already has an indexer,
// unify its indexType and propType with those from this constraint.
//
// If the table is a free or unsealed table, we augment it with a new indexer.
struct SetIndexerConstraint
{
TypeId resultType;
TypeId subjectType;
TypeId indexType;
TypeId propType;
};
// if negation:
// result ~ if isSingleton D then ~D else unknown where D = discriminantType
// if not negation:
@ -170,9 +184,19 @@ struct SingletonOrTopTypeConstraint
bool negated;
};
// resultType ~ unpack sourceTypePack
//
// Similar to PackSubtypeConstraint, but with one important difference: If the
// sourcePack is blocked, this constraint blocks.
struct UnpackConstraint
{
TypePackId resultPack;
TypePackId sourcePack;
};
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, UnaryConstraint,
BinaryConstraint, IterableConstraint, NameConstraint, TypeAliasExpansionConstraint, FunctionCallConstraint, PrimitiveTypeConstraint,
HasPropConstraint, SetPropConstraint, SingletonOrTopTypeConstraint>;
HasPropConstraint, SetPropConstraint, SetIndexerConstraint, SingletonOrTopTypeConstraint, UnpackConstraint>;
struct Constraint
{

View file

@ -191,7 +191,7 @@ struct ConstraintGraphBuilder
Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
std::tuple<TypeId, TypeId, RefinementId> checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
TypePackId checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs);
std::vector<TypeId> checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs);
TypeId checkLValue(const ScopePtr& scope, AstExpr* expr);
@ -244,10 +244,31 @@ struct ConstraintGraphBuilder
**/
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments);
/**
* Creates generic types given a list of AST definitions, resolving default
* types as required.
* @param scope the scope that the generics should belong to.
* @param generics the AST generics to create types for.
* @param useCache whether to use the generic type cache for the given
* scope.
* @param addTypes whether to add the types to the scope's
* privateTypeBindings map.
**/
std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics(
const ScopePtr& scope, AstArray<AstGenericType> generics, bool useCache = false);
const ScopePtr& scope, AstArray<AstGenericType> generics, bool useCache = false, bool addTypes = true);
/**
* Creates generic type packs given a list of AST definitions, resolving
* default type packs as required.
* @param scope the scope that the generic packs should belong to.
* @param generics the AST generics to create type packs for.
* @param useCache whether to use the generic type pack cache for the given
* scope.
* @param addTypes whether to add the types to the scope's
* privateTypePackBindings map.
**/
std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks(
const ScopePtr& scope, AstArray<AstGenericTypePack> packs, bool useCache = false);
const ScopePtr& scope, AstArray<AstGenericTypePack> packs, bool useCache = false, bool addTypes = true);
Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack);

View file

@ -8,6 +8,7 @@
#include "Luau/Normalize.h"
#include "Luau/ToString.h"
#include "Luau/Type.h"
#include "Luau/TypeReduction.h"
#include "Luau/Variant.h"
#include <vector>
@ -19,7 +20,12 @@ struct DcrLogger;
// TypeId, TypePackId, or Constraint*. It is impossible to know which, but we
// never dereference this pointer.
using BlockedConstraintId = const void*;
using BlockedConstraintId = Variant<TypeId, TypePackId, const Constraint*>;
struct HashBlockedConstraintId
{
size_t operator()(const BlockedConstraintId& bci) const;
};
struct ModuleResolver;
@ -47,6 +53,7 @@ struct ConstraintSolver
NotNull<BuiltinTypes> builtinTypes;
InternalErrorReporter iceReporter;
NotNull<Normalizer> normalizer;
NotNull<TypeReduction> reducer;
// The entire set of constraints that the solver is trying to resolve.
std::vector<NotNull<Constraint>> constraints;
NotNull<Scope> rootScope;
@ -65,7 +72,7 @@ struct ConstraintSolver
// anything.
std::unordered_map<NotNull<const Constraint>, size_t> blockedConstraints;
// A mapping of type/pack pointers to the constraints they block.
std::unordered_map<BlockedConstraintId, std::vector<NotNull<const Constraint>>> blocked;
std::unordered_map<BlockedConstraintId, std::vector<NotNull<const Constraint>>, HashBlockedConstraintId> blocked;
// Memoized instantiations of type aliases.
DenseHashMap<InstantiationSignature, TypeId, HashInstantiationSignature> instantiatedAliases{{}};
@ -78,7 +85,8 @@ struct ConstraintSolver
DcrLogger* logger;
explicit ConstraintSolver(NotNull<Normalizer> normalizer, NotNull<Scope> rootScope, std::vector<NotNull<Constraint>> constraints,
ModuleName moduleName, NotNull<ModuleResolver> moduleResolver, std::vector<RequireCycle> requireCycles, DcrLogger* logger);
ModuleName moduleName, NotNull<TypeReduction> reducer, NotNull<ModuleResolver> moduleResolver, std::vector<RequireCycle> requireCycles,
DcrLogger* logger);
// Randomize the order in which to dispatch constraints
void randomize(unsigned seed);
@ -112,7 +120,9 @@ struct ConstraintSolver
bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const SetPropConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const SetIndexerConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const UnpackConstraint& c, NotNull<const Constraint> constraint);
// for a, ... in some_table do
// also handles __iter metamethod
@ -123,6 +133,7 @@ struct ConstraintSolver
TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
std::optional<TypeId> lookupTableProp(TypeId subjectType, const std::string& propName);
std::optional<TypeId> lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set<TypeId>& seen);
void block(NotNull<const Constraint> target, NotNull<const Constraint> constraint);
/**

View file

@ -4,6 +4,7 @@
#include "Luau/Constraint.h"
#include "Luau/NotNull.h"
#include "Luau/Scope.h"
#include "Luau/Module.h"
#include "Luau/ToString.h"
#include "Luau/Error.h"
#include "Luau/Variant.h"
@ -34,11 +35,26 @@ struct TypeBindingSnapshot
std::string typeString;
};
struct ExprTypesAtLocation
{
Location location;
TypeId ty;
std::optional<TypeId> expectedTy;
};
struct AnnotationTypesAtLocation
{
Location location;
TypeId resolvedTy;
};
struct ConstraintGenerationLog
{
std::string source;
std::unordered_map<std::string, Location> constraintLocations;
std::vector<ErrorSnapshot> errors;
std::vector<ExprTypesAtLocation> exprTypeLocations;
std::vector<AnnotationTypesAtLocation> annotationTypeLocations;
};
struct ScopeSnapshot
@ -49,16 +65,11 @@ struct ScopeSnapshot
std::vector<ScopeSnapshot> children;
};
enum class ConstraintBlockKind
{
TypeId,
TypePackId,
ConstraintId,
};
using ConstraintBlockTarget = Variant<TypeId, TypePackId, NotNull<const Constraint>>;
struct ConstraintBlock
{
ConstraintBlockKind kind;
ConstraintBlockTarget target;
std::string stringification;
};
@ -71,16 +82,18 @@ struct ConstraintSnapshot
struct BoundarySnapshot
{
std::unordered_map<std::string, ConstraintSnapshot> constraints;
DenseHashMap<const Constraint*, ConstraintSnapshot> unsolvedConstraints{nullptr};
ScopeSnapshot rootScope;
DenseHashMap<const void*, std::string> typeStrings{nullptr};
};
struct StepSnapshot
{
std::string currentConstraint;
const Constraint* currentConstraint;
bool forced;
std::unordered_map<std::string, ConstraintSnapshot> unsolvedConstraints;
DenseHashMap<const Constraint*, ConstraintSnapshot> unsolvedConstraints{nullptr};
ScopeSnapshot rootScope;
DenseHashMap<const void*, std::string> typeStrings{nullptr};
};
struct TypeSolveLog
@ -95,8 +108,6 @@ struct TypeCheckLog
std::vector<ErrorSnapshot> errors;
};
using ConstraintBlockTarget = Variant<TypeId, TypePackId, NotNull<const Constraint>>;
struct DcrLogger
{
std::string compileOutput();
@ -104,6 +115,7 @@ struct DcrLogger
void captureSource(std::string source);
void captureGenerationError(const TypeError& error);
void captureConstraintLocation(NotNull<const Constraint> constraint, Location location);
void captureGenerationModule(const ModulePtr& module);
void pushBlock(NotNull<const Constraint> constraint, TypeId block);
void pushBlock(NotNull<const Constraint> constraint, TypePackId block);
@ -126,9 +138,10 @@ private:
TypeSolveLog solveLog;
TypeCheckLog checkLog;
ToStringOptions opts;
ToStringOptions opts{true};
std::vector<ConstraintBlock> snapshotBlocks(NotNull<const Constraint> constraint);
void captureBoundaryState(BoundarySnapshot& target, const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints);
};
} // namespace Luau

View file

@ -52,7 +52,7 @@ struct Scope
std::optional<TypeId> lookup(Symbol sym) const;
std::optional<TypeId> lookup(DefId def) const;
std::optional<std::pair<TypeId, Scope*>> lookupEx(Symbol sym);
std::optional<std::pair<Binding*, Scope*>> lookupEx(Symbol sym);
std::optional<TypeFun> lookupType(const Name& name);
std::optional<TypeFun> lookupImportedType(const Name& moduleAlias, const Name& name);

View file

@ -37,6 +37,11 @@ struct Symbol
AstLocal* local;
AstName global;
explicit operator bool() const
{
return local != nullptr || global.value != nullptr;
}
bool operator==(const Symbol& rhs) const
{
if (local)

View file

@ -246,6 +246,18 @@ struct WithPredicate
{
T type;
PredicateVec predicates;
WithPredicate() = default;
explicit WithPredicate(T type)
: type(type)
{
}
WithPredicate(T type, PredicateVec predicates)
: type(type)
, predicates(std::move(predicates))
{
}
};
using MagicFunction = std::function<std::optional<WithPredicate<TypePackId>>(
@ -853,4 +865,15 @@ bool hasTag(TypeId ty, const std::string& tagName);
bool hasTag(const Property& prop, const std::string& tagName);
bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work.
/*
* Use this to change the kind of a particular type.
*
* LUAU_NOINLINE so that the calling frame doesn't have to pay the stack storage for the new variant.
*/
template<typename T, typename... Args>
LUAU_NOINLINE T* emplaceType(Type* ty, Args&&... args)
{
return &ty->ty.emplace<T>(std::forward<Args>(args)...);
}
} // namespace Luau

View file

@ -146,10 +146,12 @@ struct TypeChecker
WithPredicate<TypePackId> checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr);
WithPredicate<TypePackId> checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr);
WithPredicate<TypePackId> checkExprPackHelper2(
const ScopePtr& scope, const AstExprCall& expr, TypeId selfType, TypeId actualFunctionType, TypeId functionType, TypePackId retPack);
std::vector<std::optional<TypeId>> getExpectedTypesForCall(const std::vector<TypeId>& overloads, size_t argumentCount, bool selfCall);
std::optional<WithPredicate<TypePackId>> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack,
std::unique_ptr<WithPredicate<TypePackId>> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack,
TypePackId argPack, TypePack* args, const std::vector<Location>* argLocations, const WithPredicate<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<TypeId>& overloadsThatDont, std::vector<OverloadErrorEntry>& errors);
bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector<Location>& argLocations,

View file

@ -12,11 +12,36 @@ namespace Luau
namespace detail
{
template<typename T>
struct ReductionContext
struct ReductionEdge
{
T type = nullptr;
bool irreducible = false;
};
struct TypeReductionMemoization
{
TypeReductionMemoization() = default;
TypeReductionMemoization(const TypeReductionMemoization&) = delete;
TypeReductionMemoization& operator=(const TypeReductionMemoization&) = delete;
TypeReductionMemoization(TypeReductionMemoization&&) = default;
TypeReductionMemoization& operator=(TypeReductionMemoization&&) = default;
DenseHashMap<TypeId, ReductionEdge<TypeId>> types{nullptr};
DenseHashMap<TypePackId, ReductionEdge<TypePackId>> typePacks{nullptr};
bool isIrreducible(TypeId ty);
bool isIrreducible(TypePackId tp);
TypeId memoize(TypeId ty, TypeId reducedTy);
TypePackId memoize(TypePackId tp, TypePackId reducedTp);
// Reducing A into B may have a non-irreducible edge A to B for which B is not irreducible, which means B could be reduced into C.
// Because reduction should always be transitive, A should point to C if A points to B and B points to C.
std::optional<ReductionEdge<TypeId>> memoizedof(TypeId ty) const;
std::optional<ReductionEdge<TypePackId>> memoizedof(TypePackId tp) const;
};
} // namespace detail
struct TypeReductionOptions
@ -42,29 +67,19 @@ struct TypeReduction
std::optional<TypePackId> reduce(TypePackId tp);
std::optional<TypeFun> reduce(const TypeFun& fun);
/// Creating a child TypeReduction will allow the parent TypeReduction to share its memoization with the child TypeReductions.
/// This is safe as long as the parent's TypeArena continues to outlive both TypeReduction memoization.
TypeReduction fork(NotNull<TypeArena> arena, const TypeReductionOptions& opts = {}) const;
private:
const TypeReduction* parent = nullptr;
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes;
NotNull<struct InternalErrorReporter> handle;
TypeReductionOptions options;
DenseHashMap<TypeId, detail::ReductionContext<TypeId>> memoizedTypes{nullptr};
DenseHashMap<TypePackId, detail::ReductionContext<TypePackId>> memoizedTypePacks{nullptr};
TypeReductionOptions options;
detail::TypeReductionMemoization memoization;
// Computes an *estimated length* of the cartesian product of the given type.
size_t cartesianProductSize(TypeId ty) const;
bool hasExceededCartesianProductLimit(TypeId ty) const;
bool hasExceededCartesianProductLimit(TypePackId tp) const;
std::optional<TypeId> memoizedof(TypeId ty) const;
std::optional<TypePackId> memoizedof(TypePackId tp) const;
};
} // namespace Luau

View file

@ -67,6 +67,12 @@ struct Unifier
UnifierSharedState& sharedState;
// When the Unifier is forced to unify two blocked types (or packs), they
// get added to these vectors. The ConstraintSolver can use this to know
// when it is safe to reattempt dispatching a constraint.
std::vector<TypeId> blockedTypes;
std::vector<TypePackId> blockedTypePacks;
Unifier(
NotNull<Normalizer> normalizer, Mode mode, NotNull<Scope> scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr);

View file

@ -320,6 +320,9 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block)
prepopulateGlobalScope(scope, block);
visitBlockWithoutChildScope(scope, block);
if (FFlag::DebugLuauLogSolverToJson)
logger->captureGenerationModule(module);
}
void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block)
@ -357,13 +360,11 @@ void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope,
for (const auto& [name, gen] : createGenerics(defnScope, alias->generics, /* useCache */ true))
{
initialFun.typeParams.push_back(gen);
defnScope->privateTypeBindings[name] = TypeFun{gen.ty};
}
for (const auto& [name, genPack] : createGenericPacks(defnScope, alias->genericPacks, /* useCache */ true))
{
initialFun.typePackParams.push_back(genPack);
defnScope->privateTypePackBindings[name] = genPack.tp;
}
if (alias->exported)
@ -503,13 +504,13 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local)
if (j - i < packTypes.head.size())
varTypes[j] = packTypes.head[j - i];
else
varTypes[j] = freshType(scope);
varTypes[j] = arena->addType(BlockedType{});
}
}
std::vector<TypeId> tailValues{varTypes.begin() + i, varTypes.end()};
TypePackId tailPack = arena->addTypePack(std::move(tailValues));
addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack});
addConstraint(scope, local->location, UnpackConstraint{tailPack, exprPack});
}
}
}
@ -686,6 +687,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct
Checkpoint start = checkpoint(this);
FunctionSignature sig = checkFunctionSignature(scope, function->func);
std::unordered_set<Constraint*> excludeList;
if (AstExprLocal* localName = function->name->as<AstExprLocal>())
{
std::optional<TypeId> existingFunctionTy = scope->lookup(localName->local);
@ -716,9 +719,20 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct
}
else if (AstExprIndexName* indexName = function->name->as<AstExprIndexName>())
{
Checkpoint check1 = checkpoint(this);
TypeId lvalueType = checkLValue(scope, indexName);
Checkpoint check2 = checkpoint(this);
forEachConstraint(check1, check2, this, [&excludeList](const ConstraintPtr& c) {
excludeList.insert(c.get());
});
// TODO figure out how to populate the location field of the table Property.
addConstraint(scope, indexName->location, SubtypeConstraint{lvalueType, generalizedType});
if (get<FreeType>(lvalueType))
asMutable(lvalueType)->ty.emplace<BoundType>(generalizedType);
else
addConstraint(scope, indexName->location, SubtypeConstraint{lvalueType, generalizedType});
}
else if (AstExprError* err = function->name->as<AstExprError>())
{
@ -735,8 +749,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct
std::unique_ptr<Constraint> c =
std::make_unique<Constraint>(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature});
forEachConstraint(start, end, this, [&c](const ConstraintPtr& constraint) {
c->dependencies.push_back(NotNull{constraint.get()});
forEachConstraint(start, end, this, [&c, &excludeList](const ConstraintPtr& constraint) {
if (!excludeList.count(constraint.get()))
c->dependencies.push_back(NotNull{constraint.get()});
});
addConstraint(scope, std::move(c));
@ -763,16 +778,31 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block)
visitBlockWithoutChildScope(innerScope, block);
}
static void bindFreeType(TypeId a, TypeId b)
{
FreeType* af = getMutable<FreeType>(a);
FreeType* bf = getMutable<FreeType>(b);
LUAU_ASSERT(af || bf);
if (!bf)
asMutable(a)->ty.emplace<BoundType>(b);
else if (!af)
asMutable(b)->ty.emplace<BoundType>(a);
else if (subsumes(bf->scope, af->scope))
asMutable(a)->ty.emplace<BoundType>(b);
else if (subsumes(af->scope, bf->scope))
asMutable(b)->ty.emplace<BoundType>(a);
}
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign)
{
TypePackId varPackId = checkLValues(scope, assign->vars);
TypePack expectedPack = extendTypePack(*arena, builtinTypes, varPackId, assign->values.size);
std::vector<TypeId> varTypes = checkLValues(scope, assign->vars);
std::vector<std::optional<TypeId>> expectedTypes;
expectedTypes.reserve(expectedPack.head.size());
expectedTypes.reserve(varTypes.size());
for (TypeId ty : expectedPack.head)
for (TypeId ty : varTypes)
{
ty = follow(ty);
if (get<FreeType>(ty))
@ -781,9 +811,10 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign)
expectedTypes.push_back(ty);
}
TypePackId valuePack = checkPack(scope, assign->values, expectedTypes).tp;
TypePackId exprPack = checkPack(scope, assign->values, expectedTypes).tp;
TypePackId varPack = arena->addTypePack({varTypes});
addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId});
addConstraint(scope, assign->location, PackSubtypeConstraint{exprPack, varPack});
}
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* assign)
@ -865,11 +896,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alia
asMutable(aliasTy)->ty.emplace<BoundType>(ty);
std::vector<TypeId> typeParams;
for (auto tyParam : createGenerics(*defnScope, alias->generics, /* useCache */ true))
for (auto tyParam : createGenerics(*defnScope, alias->generics, /* useCache */ true, /* addTypes */ false))
typeParams.push_back(tyParam.second.ty);
std::vector<TypePackId> typePackParams;
for (auto tpParam : createGenericPacks(*defnScope, alias->genericPacks, /* useCache */ true))
for (auto tpParam : createGenericPacks(*defnScope, alias->genericPacks, /* useCache */ true, /* addTypes */ false))
typePackParams.push_back(tpParam.second.tp);
addConstraint(scope, alias->type->location,
@ -1010,7 +1041,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction
for (auto& [name, generic] : generics)
{
genericTys.push_back(generic.ty);
scope->privateTypeBindings[name] = TypeFun{generic.ty};
}
std::vector<TypePackId> genericTps;
@ -1018,7 +1048,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction
for (auto& [name, generic] : genericPacks)
{
genericTps.push_back(generic.tp);
scope->privateTypePackBindings[name] = generic.tp;
}
ScopePtr funScope = scope;
@ -1161,7 +1190,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa
TypePackId expectedArgPack = arena->freshTypePack(scope.get());
TypePackId expectedRetPack = arena->freshTypePack(scope.get());
TypeId expectedFunctionType = arena->addType(FunctionType{expectedArgPack, expectedRetPack});
TypeId expectedFunctionType = arena->addType(FunctionType{expectedArgPack, expectedRetPack, std::nullopt, call->self});
TypeId instantiatedFnType = arena->addType(BlockedType{});
addConstraint(scope, call->location, InstantiationConstraint{instantiatedFnType, fnType});
@ -1264,7 +1293,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa
// TODO: How do expectedTypes play into this? Do they?
TypePackId rets = arena->addTypePack(BlockedTypePack{});
TypePackId argPack = arena->addTypePack(TypePack{args, argTail});
FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets);
FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets, std::nullopt, call->self);
NotNull<Constraint> fcc = addConstraint(scope, call->func->location,
FunctionCallConstraint{
@ -1457,7 +1486,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* gl
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName)
{
TypeId obj = check(scope, indexName->expr).ty;
TypeId result = freshType(scope);
TypeId result = arena->addType(BlockedType{});
std::optional<DefId> def = dfg->getDef(indexName);
if (def)
@ -1468,13 +1497,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName*
scope->dcrRefinements[*def] = result;
}
TableType::Props props{{indexName->index.value, Property{result}}};
const std::optional<TableIndexer> indexer;
TableType ttv{std::move(props), indexer, TypeLevel{}, scope.get(), TableState::Free};
TypeId expectedTableType = arena->addType(std::move(ttv));
addConstraint(scope, indexName->expr->location, SubtypeConstraint{obj, expectedTableType});
addConstraint(scope, indexName->expr->location, HasPropConstraint{result, obj, indexName->index.value});
if (def)
return Inference{result, refinementArena.proposition(*def, builtinTypes->truthyType)};
@ -1589,6 +1612,8 @@ std::tuple<TypeId, TypeId, RefinementId> ConstraintGraphBuilder::checkBinary(
else if (typeguard->type == "number")
discriminantTy = builtinTypes->numberType;
else if (typeguard->type == "boolean")
discriminantTy = builtinTypes->booleanType;
else if (typeguard->type == "thread")
discriminantTy = builtinTypes->threadType;
else if (typeguard->type == "table")
discriminantTy = builtinTypes->tableType;
@ -1596,8 +1621,8 @@ std::tuple<TypeId, TypeId, RefinementId> ConstraintGraphBuilder::checkBinary(
discriminantTy = builtinTypes->functionType;
else if (typeguard->type == "userdata")
{
// For now, we don't really care about being accurate with userdata if the typeguard was using typeof
discriminantTy = builtinTypes->neverType; // TODO: replace with top class type
// For now, we don't really care about being accurate with userdata if the typeguard was using typeof.
discriminantTy = builtinTypes->classType;
}
else if (!typeguard->isTypeof && typeguard->type == "vector")
discriminantTy = builtinTypes->neverType; // TODO: figure out a way to deal with this quirky type
@ -1649,18 +1674,15 @@ std::tuple<TypeId, TypeId, RefinementId> ConstraintGraphBuilder::checkBinary(
}
}
TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs)
std::vector<TypeId> ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs)
{
std::vector<TypeId> types;
types.reserve(exprs.size);
for (size_t i = 0; i < exprs.size; ++i)
{
AstExpr* const expr = exprs.data[i];
for (AstExpr* expr : exprs)
types.push_back(checkLValue(scope, expr));
}
return arena->addTypePack(std::move(types));
return types;
}
/**
@ -1679,6 +1701,28 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr)
indexExpr->location, indexExpr->expr, syntheticIndex, constantString->location, indexExpr->expr->location.end, '.'};
return checkLValue(scope, &synthetic);
}
// An indexer is only interesting in an lvalue-ey way if it is at the
// tail of an expression.
//
// If the indexer is not at the tail, then we are not interested in
// augmenting the lhs data structure with a new indexer. Constraint
// generation can treat it as an ordinary lvalue.
//
// eg
//
// a.b.c[1] = 44 -- lvalue
// a.b[4].c = 2 -- rvalue
TypeId resultType = arena->addType(BlockedType{});
TypeId subjectType = check(scope, indexExpr->expr).ty;
TypeId indexType = check(scope, indexExpr->index).ty;
TypeId propType = arena->addType(BlockedType{});
addConstraint(scope, expr->location, SetIndexerConstraint{resultType, subjectType, indexType, propType});
module->astTypes[expr] = propType;
return propType;
}
else if (!expr->is<AstExprIndexName>())
return check(scope, expr).ty;
@ -1718,7 +1762,8 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr)
auto lookupResult = scope->lookupEx(sym);
if (!lookupResult)
return check(scope, expr).ty;
const auto [subjectType, symbolScope] = std::move(*lookupResult);
const auto [subjectBinding, symbolScope] = std::move(*lookupResult);
TypeId subjectType = subjectBinding->typeId;
TypeId propTy = freshType(scope);
@ -1739,14 +1784,17 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr)
module->astTypes[expr] = prevSegmentTy;
module->astTypes[e] = updatedType;
symbolScope->bindings[sym].typeId = updatedType;
std::optional<DefId> def = dfg->getDef(sym);
if (def)
if (!subjectType->persistent)
{
// This can fail if the user is erroneously trying to augment a builtin
// table like os or string.
symbolScope->dcrRefinements[*def] = updatedType;
symbolScope->bindings[sym].typeId = updatedType;
std::optional<DefId> def = dfg->getDef(sym);
if (def)
{
// This can fail if the user is erroneously trying to augment a builtin
// table like os or string.
symbolScope->dcrRefinements[*def] = updatedType;
}
}
return propTy;
@ -1904,13 +1952,11 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS
for (const auto& [name, g] : genericDefinitions)
{
genericTypes.push_back(g.ty);
signatureScope->privateTypeBindings[name] = TypeFun{g.ty};
}
for (const auto& [name, g] : genericPackDefinitions)
{
genericTypePacks.push_back(g.tp);
signatureScope->privateTypePackBindings[name] = g.tp;
}
// Local variable works around an odd gcc 11.3 warning: <anonymous> may be used uninitialized
@ -2023,15 +2069,14 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS
actualFunction.generics = std::move(genericTypes);
actualFunction.genericPacks = std::move(genericTypePacks);
actualFunction.argNames = std::move(argNames);
actualFunction.hasSelf = fn->self != nullptr;
TypeId actualFunctionType = arena->addType(std::move(actualFunction));
LUAU_ASSERT(actualFunctionType);
module->astTypes[fn] = actualFunctionType;
if (expectedType && get<FreeType>(*expectedType))
{
asMutable(*expectedType)->ty.emplace<BoundType>(actualFunctionType);
}
bindFreeType(*expectedType, actualFunctionType);
return {
/* signature */ actualFunctionType,
@ -2179,13 +2224,11 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b
for (const auto& [name, g] : genericDefinitions)
{
genericTypes.push_back(g.ty);
signatureScope->privateTypeBindings[name] = TypeFun{g.ty};
}
for (const auto& [name, g] : genericPackDefinitions)
{
genericTypePacks.push_back(g.tp);
signatureScope->privateTypePackBindings[name] = g.tp;
}
}
else
@ -2330,7 +2373,7 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, const
}
std::vector<std::pair<Name, GenericTypeDefinition>> ConstraintGraphBuilder::createGenerics(
const ScopePtr& scope, AstArray<AstGenericType> generics, bool useCache)
const ScopePtr& scope, AstArray<AstGenericType> generics, bool useCache, bool addTypes)
{
std::vector<std::pair<Name, GenericTypeDefinition>> result;
for (const auto& generic : generics)
@ -2350,6 +2393,9 @@ std::vector<std::pair<Name, GenericTypeDefinition>> ConstraintGraphBuilder::crea
if (generic.defaultValue)
defaultTy = resolveType(scope, generic.defaultValue, /* inTypeArguments */ false);
if (addTypes)
scope->privateTypeBindings[generic.name.value] = TypeFun{genericTy};
result.push_back({generic.name.value, GenericTypeDefinition{genericTy, defaultTy}});
}
@ -2357,7 +2403,7 @@ std::vector<std::pair<Name, GenericTypeDefinition>> ConstraintGraphBuilder::crea
}
std::vector<std::pair<Name, GenericTypePackDefinition>> ConstraintGraphBuilder::createGenericPacks(
const ScopePtr& scope, AstArray<AstGenericTypePack> generics, bool useCache)
const ScopePtr& scope, AstArray<AstGenericTypePack> generics, bool useCache, bool addTypes)
{
std::vector<std::pair<Name, GenericTypePackDefinition>> result;
for (const auto& generic : generics)
@ -2378,6 +2424,9 @@ std::vector<std::pair<Name, GenericTypePackDefinition>> ConstraintGraphBuilder::
if (generic.defaultValue)
defaultTy = resolveTypePack(scope, generic.defaultValue, /* inTypeArguments */ false);
if (addTypes)
scope->privateTypePackBindings[generic.name.value] = genericTy;
result.push_back({generic.name.value, GenericTypePackDefinition{genericTy, defaultTy}});
}
@ -2394,11 +2443,9 @@ Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location lo
if (auto f = first(tp))
return Inference{*f, refinement};
TypeId typeResult = freshType(scope);
TypePack onePack{{typeResult}, freshTypePack(scope)};
TypePackId oneTypePack = arena->addTypePack(std::move(onePack));
addConstraint(scope, location, PackSubtypeConstraint{tp, oneTypePack});
TypeId typeResult = arena->addType(BlockedType{});
TypePackId resultPack = arena->addTypePack({typeResult}, arena->freshTypePack(scope.get()));
addConstraint(scope, location, UnpackConstraint{resultPack, tp});
return Inference{typeResult, refinement};
}

View file

@ -22,6 +22,22 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false);
namespace Luau
{
size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const
{
size_t result = 0;
if (const TypeId* ty = get_if<TypeId>(&bci))
result = std::hash<TypeId>()(*ty);
else if (const TypePackId* tp = get_if<TypePackId>(&bci))
result = std::hash<TypePackId>()(*tp);
else if (Constraint const* const* c = get_if<const Constraint*>(&bci))
result = std::hash<const Constraint*>()(*c);
else
LUAU_ASSERT(!"Should be unreachable");
return result;
}
[[maybe_unused]] static void dumpBindings(NotNull<Scope> scope, ToStringOptions& opts)
{
for (const auto& [k, v] : scope->bindings)
@ -221,10 +237,12 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts)
}
ConstraintSolver::ConstraintSolver(NotNull<Normalizer> normalizer, NotNull<Scope> rootScope, std::vector<NotNull<Constraint>> constraints,
ModuleName moduleName, NotNull<ModuleResolver> moduleResolver, std::vector<RequireCycle> requireCycles, DcrLogger* logger)
ModuleName moduleName, NotNull<TypeReduction> reducer, NotNull<ModuleResolver> moduleResolver, std::vector<RequireCycle> requireCycles,
DcrLogger* logger)
: arena(normalizer->arena)
, builtinTypes(normalizer->builtinTypes)
, normalizer(normalizer)
, reducer(reducer)
, constraints(std::move(constraints))
, rootScope(rootScope)
, currentModuleName(std::move(moduleName))
@ -326,6 +344,27 @@ void ConstraintSolver::run()
if (force)
printf("Force ");
printf("Dispatched\n\t%s\n", saveMe.c_str());
if (force)
{
printf("Blocked on:\n");
for (const auto& [bci, cv] : blocked)
{
if (end(cv) == std::find(begin(cv), end(cv), c))
continue;
if (auto bty = get_if<TypeId>(&bci))
printf("\tType %s\n", toString(*bty, opts).c_str());
else if (auto btp = get_if<TypePackId>(&bci))
printf("\tPack %s\n", toString(*btp, opts).c_str());
else if (auto cc = get_if<const Constraint*>(&bci))
printf("\tCons %s\n", toString(**cc, opts).c_str());
else
LUAU_ASSERT(!"Unreachable??");
}
}
dump(this, opts);
}
}
@ -411,8 +450,12 @@ bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool fo
success = tryDispatch(*hpc, constraint);
else if (auto spc = get<SetPropConstraint>(*constraint))
success = tryDispatch(*spc, constraint, force);
else if (auto spc = get<SetIndexerConstraint>(*constraint))
success = tryDispatch(*spc, constraint, force);
else if (auto sottc = get<SingletonOrTopTypeConstraint>(*constraint))
success = tryDispatch(*sottc, constraint);
else if (auto uc = get<UnpackConstraint>(*constraint))
success = tryDispatch(*uc, constraint);
else
LUAU_ASSERT(false);
@ -424,26 +467,46 @@ bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool fo
bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull<const Constraint> constraint, bool force)
{
if (!recursiveBlock(c.subType, constraint))
return false;
if (!recursiveBlock(c.superType, constraint))
return false;
if (isBlocked(c.subType))
return block(c.subType, constraint);
else if (isBlocked(c.superType))
return block(c.superType, constraint);
unify(c.subType, c.superType, constraint->scope);
Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant};
u.useScopes = true;
u.tryUnify(c.subType, c.superType);
if (!u.blockedTypes.empty() || !u.blockedTypePacks.empty())
{
for (TypeId bt : u.blockedTypes)
block(bt, constraint);
for (TypePackId btp : u.blockedTypePacks)
block(btp, constraint);
return false;
}
if (!u.errors.empty())
{
TypeId errorType = errorRecoveryType();
u.tryUnify(c.subType, errorType);
u.tryUnify(c.superType, errorType);
}
const auto [changedTypes, changedPacks] = u.log.getChanges();
u.log.commit();
unblock(changedTypes);
unblock(changedPacks);
// unify(c.subType, c.superType, constraint->scope);
return true;
}
bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint, bool force)
{
if (!recursiveBlock(c.subPack, constraint) || !recursiveBlock(c.superPack, constraint))
return false;
if (isBlocked(c.subPack))
return block(c.subPack, constraint);
else if (isBlocked(c.superPack))
@ -1183,8 +1246,26 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
TypeId instantiatedTy = arena->addType(BlockedType{});
TypeId inferredTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result});
auto ic = pushConstraint(constraint->scope, constraint->location, InstantiationConstraint{instantiatedTy, fn});
auto sc = pushConstraint(constraint->scope, constraint->location, SubtypeConstraint{instantiatedTy, inferredTy});
auto pushConstraintGreedy = [this, constraint](ConstraintV cv) -> Constraint* {
std::unique_ptr<Constraint> c = std::make_unique<Constraint>(constraint->scope, constraint->location, std::move(cv));
NotNull<Constraint> borrow{c.get()};
bool ok = tryDispatch(borrow, false);
if (ok)
return nullptr;
solverConstraints.push_back(std::move(c));
unsolvedConstraints.push_back(borrow);
return borrow;
};
// HACK: We don't want other constraints to act on the free type pack
// created above until after these two constraints are solved, so we try to
// dispatch them directly.
auto ic = pushConstraintGreedy(InstantiationConstraint{instantiatedTy, fn});
auto sc = pushConstraintGreedy(SubtypeConstraint{instantiatedTy, inferredTy});
// Anything that is blocked on this constraint must also be blocked on our
// synthesized constraints.
@ -1193,8 +1274,10 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
{
for (const auto& blockedConstraint : blockedIt->second)
{
block(ic, blockedConstraint);
block(sc, blockedConstraint);
if (ic)
block(NotNull{ic}, blockedConstraint);
if (sc)
block(NotNull{sc}, blockedConstraint);
}
}
@ -1230,6 +1313,8 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull<const Con
return true;
}
subjectType = reducer->reduce(subjectType).value_or(subjectType);
std::optional<TypeId> resultType = lookupTableProp(subjectType, c.prop);
if (!resultType)
{
@ -1360,11 +1445,18 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull<const Con
if (existingPropType)
{
unify(c.propType, *existingPropType, constraint->scope);
if (!isBlocked(c.propType))
unify(c.propType, *existingPropType, constraint->scope);
bind(c.resultType, c.subjectType);
return true;
}
if (get<AnyType>(subjectType) || get<ErrorType>(subjectType) || get<NeverType>(subjectType))
{
bind(c.resultType, subjectType);
return true;
}
if (get<FreeType>(subjectType))
{
TypeId ty = arena->freshType(constraint->scope);
@ -1381,21 +1473,27 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull<const Con
LUAU_ASSERT(ty);
bind(subjectType, ty);
bind(c.resultType, ty);
if (follow(c.resultType) != follow(ty))
bind(c.resultType, ty);
return true;
}
else if (auto ttv = getMutable<TableType>(subjectType))
{
if (ttv->state == TableState::Free)
{
LUAU_ASSERT(!subjectType->persistent);
ttv->props[c.path[0]] = Property{c.propType};
bind(c.resultType, c.subjectType);
return true;
}
else if (ttv->state == TableState::Unsealed)
{
LUAU_ASSERT(!subjectType->persistent);
std::optional<TypeId> augmented = updateTheTableType(NotNull{arena}, subjectType, c.path, c.propType);
bind(c.resultType, augmented.value_or(subjectType));
bind(subjectType, c.resultType);
return true;
}
else
@ -1411,16 +1509,62 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull<const Con
bind(c.resultType, subjectType);
return true;
}
else if (get<AnyType>(subjectType) || get<ErrorType>(subjectType) || get<NeverType>(subjectType))
{
bind(c.resultType, subjectType);
return true;
}
LUAU_ASSERT(0);
return true;
}
bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNull<const Constraint> constraint, bool force)
{
TypeId subjectType = follow(c.subjectType);
if (isBlocked(subjectType))
return block(subjectType, constraint);
if (auto ft = get<FreeType>(subjectType))
{
Scope* scope = ft->scope;
TableType* tt = &asMutable(subjectType)->ty.emplace<TableType>(TableState::Free, TypeLevel{}, scope);
tt->indexer = TableIndexer{c.indexType, c.propType};
asMutable(c.resultType)->ty.emplace<BoundType>(subjectType);
asMutable(c.propType)->ty.emplace<FreeType>(scope);
unblock(c.propType);
unblock(c.resultType);
return true;
}
else if (auto tt = get<TableType>(subjectType))
{
if (tt->indexer)
{
// TODO This probably has to be invariant.
unify(c.indexType, tt->indexer->indexType, constraint->scope);
asMutable(c.propType)->ty.emplace<BoundType>(tt->indexer->indexResultType);
asMutable(c.resultType)->ty.emplace<BoundType>(subjectType);
unblock(c.propType);
unblock(c.resultType);
return true;
}
else if (tt->state == TableState::Free || tt->state == TableState::Unsealed)
{
auto mtt = getMutable<TableType>(subjectType);
mtt->indexer = TableIndexer{c.indexType, c.propType};
asMutable(c.propType)->ty.emplace<FreeType>(tt->scope);
asMutable(c.resultType)->ty.emplace<BoundType>(subjectType);
unblock(c.propType);
unblock(c.resultType);
return true;
}
// Do not augment sealed or generic tables that lack indexers
}
asMutable(c.propType)->ty.emplace<BoundType>(builtinTypes->errorRecoveryType());
asMutable(c.resultType)->ty.emplace<BoundType>(builtinTypes->errorRecoveryType());
unblock(c.propType);
unblock(c.resultType);
return true;
}
bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull<const Constraint> constraint)
{
if (isBlocked(c.discriminantType))
@ -1439,6 +1583,69 @@ bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNul
return true;
}
bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull<const Constraint> constraint)
{
TypePackId sourcePack = follow(c.sourcePack);
TypePackId resultPack = follow(c.resultPack);
if (isBlocked(sourcePack))
return block(sourcePack, constraint);
if (isBlocked(resultPack))
{
asMutable(resultPack)->ty.emplace<BoundTypePack>(sourcePack);
unblock(resultPack);
return true;
}
TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, size(resultPack));
auto destIter = begin(resultPack);
auto destEnd = end(resultPack);
size_t i = 0;
while (destIter != destEnd)
{
if (i >= srcPack.head.size())
break;
TypeId srcTy = follow(srcPack.head[i]);
if (isBlocked(*destIter))
{
if (follow(srcTy) == *destIter)
{
// Cyclic type dependency. (????)
asMutable(*destIter)->ty.emplace<FreeType>(constraint->scope);
}
else
asMutable(*destIter)->ty.emplace<BoundType>(srcTy);
unblock(*destIter);
}
else
unify(*destIter, srcTy, constraint->scope);
++destIter;
++i;
}
// We know that resultPack does not have a tail, but we don't know if
// sourcePack is long enough to fill every value. Replace every remaining
// result TypeId with the error recovery type.
while (destIter != destEnd)
{
if (isBlocked(*destIter))
{
asMutable(*destIter)->ty.emplace<BoundType>(builtinTypes->errorRecoveryType());
unblock(*destIter);
}
++destIter;
}
return true;
}
bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force)
{
auto block_ = [&](auto&& t) {
@ -1628,10 +1835,20 @@ bool ConstraintSolver::tryDispatchIterableFunction(
std::optional<TypeId> ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName)
{
std::unordered_set<TypeId> seen;
return lookupTableProp(subjectType, propName, seen);
}
std::optional<TypeId> ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set<TypeId>& seen)
{
if (!seen.insert(subjectType).second)
return std::nullopt;
auto collectParts = [&](auto&& unionOrIntersection) -> std::pair<std::optional<TypeId>, std::vector<TypeId>> {
std::optional<TypeId> blocked;
std::vector<TypeId> parts;
std::vector<TypeId> freeParts;
for (TypeId expectedPart : unionOrIntersection)
{
expectedPart = follow(expectedPart);
@ -1644,6 +1861,29 @@ std::optional<TypeId> ConstraintSolver::lookupTableProp(TypeId subjectType, cons
else if (ttv->indexer && maybeString(ttv->indexer->indexType))
parts.push_back(ttv->indexer->indexResultType);
}
else if (get<FreeType>(expectedPart))
{
freeParts.push_back(expectedPart);
}
}
// If the only thing resembling a match is a single fresh type, we can
// confidently tablify it. If other types match or if there are more
// than one free type, we can't do anything.
if (parts.empty() && 1 == freeParts.size())
{
TypeId freePart = freeParts.front();
const FreeType* ft = get<FreeType>(freePart);
LUAU_ASSERT(ft);
Scope* scope = ft->scope;
TableType* tt = &asMutable(freePart)->ty.emplace<TableType>();
tt->state = TableState::Free;
tt->scope = scope;
TypeId propType = arena->freshType(scope);
tt->props[propName] = Property{propType};
parts.push_back(propType);
}
return {blocked, parts};
@ -1651,12 +1891,75 @@ std::optional<TypeId> ConstraintSolver::lookupTableProp(TypeId subjectType, cons
std::optional<TypeId> resultType;
if (auto ttv = get<TableType>(subjectType))
if (get<AnyType>(subjectType) || get<NeverType>(subjectType))
{
return subjectType;
}
else if (auto ttv = getMutable<TableType>(subjectType))
{
if (auto prop = ttv->props.find(propName); prop != ttv->props.end())
resultType = prop->second.type;
else if (ttv->indexer && maybeString(ttv->indexer->indexType))
resultType = ttv->indexer->indexResultType;
else if (ttv->state == TableState::Free)
{
resultType = arena->addType(FreeType{ttv->scope});
ttv->props[propName] = Property{*resultType};
}
}
else if (auto mt = get<MetatableType>(subjectType))
{
if (auto p = lookupTableProp(mt->table, propName, seen))
return p;
TypeId mtt = follow(mt->metatable);
if (get<BlockedType>(mtt))
return mtt;
else if (auto metatable = get<TableType>(mtt))
{
auto indexProp = metatable->props.find("__index");
if (indexProp == metatable->props.end())
return std::nullopt;
// TODO: __index can be an overloaded function.
TypeId indexType = follow(indexProp->second.type);
if (auto ft = get<FunctionType>(indexType))
{
std::optional<TypeId> ret = first(ft->retTypes);
if (ret)
return *ret;
else
return std::nullopt;
}
return lookupTableProp(indexType, propName, seen);
}
}
else if (auto ct = get<ClassType>(subjectType))
{
while (ct)
{
if (auto prop = ct->props.find(propName); prop != ct->props.end())
return prop->second.type;
else if (ct->parent)
ct = get<ClassType>(follow(*ct->parent));
else
break;
}
}
else if (auto pt = get<PrimitiveType>(subjectType); pt && pt->metatable)
{
const TableType* metatable = get<TableType>(follow(*pt->metatable));
LUAU_ASSERT(metatable);
auto indexProp = metatable->props.find("__index");
if (indexProp == metatable->props.end())
return std::nullopt;
return lookupTableProp(indexProp->second.type, propName, seen);
}
else if (auto utv = get<UnionType>(subjectType))
{
@ -1704,7 +2007,7 @@ void ConstraintSolver::block(NotNull<const Constraint> target, NotNull<const Con
if (FFlag::DebugLuauLogSolver)
printf("block Constraint %s on\t%s\n", toString(*target, opts).c_str(), toString(*constraint, opts).c_str());
block_(target, constraint);
block_(target.get(), constraint);
}
bool ConstraintSolver::block(TypeId target, NotNull<const Constraint> constraint)
@ -1715,7 +2018,7 @@ bool ConstraintSolver::block(TypeId target, NotNull<const Constraint> constraint
if (FFlag::DebugLuauLogSolver)
printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str());
block_(target, constraint);
block_(follow(target), constraint);
return false;
}
@ -1802,7 +2105,7 @@ void ConstraintSolver::unblock(NotNull<const Constraint> progressed)
if (FFlag::DebugLuauLogSolverToJson)
logger->popBlock(progressed);
return unblock_(progressed);
return unblock_(progressed.get());
}
void ConstraintSolver::unblock(TypeId progressed)
@ -1810,7 +2113,10 @@ void ConstraintSolver::unblock(TypeId progressed)
if (FFlag::DebugLuauLogSolverToJson)
logger->popBlock(progressed);
return unblock_(progressed);
unblock_(progressed);
if (auto bt = get<BoundType>(progressed))
unblock(bt->boundTo);
}
void ConstraintSolver::unblock(TypePackId progressed)

View file

@ -9,17 +9,39 @@
namespace Luau
{
template<typename T>
static std::string toPointerId(const T* ptr)
{
return std::to_string(reinterpret_cast<size_t>(ptr));
}
static std::string toPointerId(NotNull<const Constraint> ptr)
{
return std::to_string(reinterpret_cast<size_t>(ptr.get()));
}
namespace Json
{
template<typename T>
void write(JsonEmitter& emitter, const T* ptr)
{
write(emitter, toPointerId(ptr));
}
void write(JsonEmitter& emitter, NotNull<const Constraint> ptr)
{
write(emitter, toPointerId(ptr));
}
void write(JsonEmitter& emitter, const Location& location)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("beginLine", location.begin.line);
o.writePair("beginColumn", location.begin.column);
o.writePair("endLine", location.end.line);
o.writePair("endColumn", location.end.column);
o.finish();
ArrayEmitter a = emitter.writeArray();
a.writeValue(location.begin.line);
a.writeValue(location.begin.column);
a.writeValue(location.end.line);
a.writeValue(location.end.column);
a.finish();
}
void write(JsonEmitter& emitter, const ErrorSnapshot& snapshot)
@ -47,24 +69,43 @@ void write(JsonEmitter& emitter, const TypeBindingSnapshot& snapshot)
o.finish();
}
template<typename K, typename V>
void write(JsonEmitter& emitter, const DenseHashMap<const K*, V>& map)
{
ObjectEmitter o = emitter.writeObject();
for (const auto& [k, v] : map)
o.writePair(toPointerId(k), v);
o.finish();
}
void write(JsonEmitter& emitter, const ExprTypesAtLocation& tys)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("location", tys.location);
o.writePair("ty", toPointerId(tys.ty));
if (tys.expectedTy)
o.writePair("expectedTy", toPointerId(*tys.expectedTy));
o.finish();
}
void write(JsonEmitter& emitter, const AnnotationTypesAtLocation& tys)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("location", tys.location);
o.writePair("resolvedTy", toPointerId(tys.resolvedTy));
o.finish();
}
void write(JsonEmitter& emitter, const ConstraintGenerationLog& log)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("source", log.source);
emitter.writeComma();
write(emitter, "constraintLocations");
emitter.writeRaw(":");
ObjectEmitter locationEmitter = emitter.writeObject();
for (const auto& [id, location] : log.constraintLocations)
{
locationEmitter.writePair(id, location);
}
locationEmitter.finish();
o.writePair("errors", log.errors);
o.writePair("exprTypeLocations", log.exprTypeLocations);
o.writePair("annotationTypeLocations", log.annotationTypeLocations);
o.finish();
}
@ -78,26 +119,34 @@ void write(JsonEmitter& emitter, const ScopeSnapshot& snapshot)
o.finish();
}
void write(JsonEmitter& emitter, const ConstraintBlockKind& kind)
{
switch (kind)
{
case ConstraintBlockKind::TypeId:
return write(emitter, "type");
case ConstraintBlockKind::TypePackId:
return write(emitter, "typePack");
case ConstraintBlockKind::ConstraintId:
return write(emitter, "constraint");
default:
LUAU_ASSERT(0);
}
}
void write(JsonEmitter& emitter, const ConstraintBlock& block)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("kind", block.kind);
o.writePair("stringification", block.stringification);
auto go = [&o](auto&& t) {
using T = std::decay_t<decltype(t)>;
o.writePair("id", toPointerId(t));
if constexpr (std::is_same_v<T, TypeId>)
{
o.writePair("kind", "type");
}
else if constexpr (std::is_same_v<T, TypePackId>)
{
o.writePair("kind", "typePack");
}
else if constexpr (std::is_same_v<T, NotNull<const Constraint>>)
{
o.writePair("kind", "constraint");
}
else
static_assert(always_false_v<T>, "non-exhaustive possibility switch");
};
visit(go, block.target);
o.finish();
}
@ -114,7 +163,8 @@ void write(JsonEmitter& emitter, const BoundarySnapshot& snapshot)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("rootScope", snapshot.rootScope);
o.writePair("constraints", snapshot.constraints);
o.writePair("unsolvedConstraints", snapshot.unsolvedConstraints);
o.writePair("typeStrings", snapshot.typeStrings);
o.finish();
}
@ -125,6 +175,7 @@ void write(JsonEmitter& emitter, const StepSnapshot& snapshot)
o.writePair("forced", snapshot.forced);
o.writePair("unsolvedConstraints", snapshot.unsolvedConstraints);
o.writePair("rootScope", snapshot.rootScope);
o.writePair("typeStrings", snapshot.typeStrings);
o.finish();
}
@ -146,11 +197,6 @@ void write(JsonEmitter& emitter, const TypeCheckLog& log)
} // namespace Json
static std::string toPointerId(NotNull<const Constraint> ptr)
{
return std::to_string(reinterpret_cast<size_t>(ptr.get()));
}
static ScopeSnapshot snapshotScope(const Scope* scope, ToStringOptions& opts)
{
std::unordered_map<Name, BindingSnapshot> bindings;
@ -230,6 +276,32 @@ void DcrLogger::captureSource(std::string source)
generationLog.source = std::move(source);
}
void DcrLogger::captureGenerationModule(const ModulePtr& module)
{
generationLog.exprTypeLocations.reserve(module->astTypes.size());
for (const auto& [expr, ty] : module->astTypes)
{
ExprTypesAtLocation tys;
tys.location = expr->location;
tys.ty = ty;
if (auto expectedTy = module->astExpectedTypes.find(expr))
tys.expectedTy = *expectedTy;
generationLog.exprTypeLocations.push_back(tys);
}
generationLog.annotationTypeLocations.reserve(module->astResolvedTypes.size());
for (const auto& [annot, ty] : module->astResolvedTypes)
{
AnnotationTypesAtLocation tys;
tys.location = annot->location;
tys.resolvedTy = ty;
generationLog.annotationTypeLocations.push_back(tys);
}
}
void DcrLogger::captureGenerationError(const TypeError& error)
{
std::string stringifiedError = toString(error);
@ -239,12 +311,6 @@ void DcrLogger::captureGenerationError(const TypeError& error)
});
}
void DcrLogger::captureConstraintLocation(NotNull<const Constraint> constraint, Location location)
{
std::string id = toPointerId(constraint);
generationLog.constraintLocations[id] = location;
}
void DcrLogger::pushBlock(NotNull<const Constraint> constraint, TypeId block)
{
constraintBlocks[constraint].push_back(block);
@ -284,44 +350,70 @@ void DcrLogger::popBlock(NotNull<const Constraint> block)
}
}
void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints)
static void snapshotTypeStrings(const std::vector<ExprTypesAtLocation>& interestedExprs,
const std::vector<AnnotationTypesAtLocation>& interestedAnnots, DenseHashMap<const void*, std::string>& map, ToStringOptions& opts)
{
solveLog.initialState.rootScope = snapshotScope(rootScope, opts);
solveLog.initialState.constraints.clear();
for (const ExprTypesAtLocation& tys : interestedExprs)
{
map[tys.ty] = toString(tys.ty, opts);
if (tys.expectedTy)
map[*tys.expectedTy] = toString(*tys.expectedTy, opts);
}
for (const AnnotationTypesAtLocation& tys : interestedAnnots)
{
map[tys.resolvedTy] = toString(tys.resolvedTy, opts);
}
}
void DcrLogger::captureBoundaryState(
BoundarySnapshot& target, const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints)
{
target.rootScope = snapshotScope(rootScope, opts);
target.unsolvedConstraints.clear();
for (NotNull<const Constraint> c : unsolvedConstraints)
{
std::string id = toPointerId(c);
solveLog.initialState.constraints[id] = {
target.unsolvedConstraints[c.get()] = {
toString(*c.get(), opts),
c->location,
snapshotBlocks(c),
};
}
snapshotTypeStrings(generationLog.exprTypeLocations, generationLog.annotationTypeLocations, target.typeStrings, opts);
}
void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints)
{
captureBoundaryState(solveLog.initialState, rootScope, unsolvedConstraints);
}
StepSnapshot DcrLogger::prepareStepSnapshot(
const Scope* rootScope, NotNull<const Constraint> current, bool force, const std::vector<NotNull<const Constraint>>& unsolvedConstraints)
{
ScopeSnapshot scopeSnapshot = snapshotScope(rootScope, opts);
std::string currentId = toPointerId(current);
std::unordered_map<std::string, ConstraintSnapshot> constraints;
DenseHashMap<const Constraint*, ConstraintSnapshot> constraints{nullptr};
for (NotNull<const Constraint> c : unsolvedConstraints)
{
std::string id = toPointerId(c);
constraints[id] = {
constraints[c.get()] = {
toString(*c.get(), opts),
c->location,
snapshotBlocks(c),
};
}
DenseHashMap<const void*, std::string> typeStrings{nullptr};
snapshotTypeStrings(generationLog.exprTypeLocations, generationLog.annotationTypeLocations, typeStrings, opts);
return StepSnapshot{
currentId,
current,
force,
constraints,
std::move(constraints),
scopeSnapshot,
std::move(typeStrings),
};
}
@ -332,18 +424,7 @@ void DcrLogger::commitStepSnapshot(StepSnapshot snapshot)
void DcrLogger::captureFinalSolverState(const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints)
{
solveLog.finalState.rootScope = snapshotScope(rootScope, opts);
solveLog.finalState.constraints.clear();
for (NotNull<const Constraint> c : unsolvedConstraints)
{
std::string id = toPointerId(c);
solveLog.finalState.constraints[id] = {
toString(*c.get(), opts),
c->location,
snapshotBlocks(c),
};
}
captureBoundaryState(solveLog.finalState, rootScope, unsolvedConstraints);
}
void DcrLogger::captureTypeCheckError(const TypeError& error)
@ -370,21 +451,21 @@ std::vector<ConstraintBlock> DcrLogger::snapshotBlocks(NotNull<const Constraint>
if (const TypeId* ty = get_if<TypeId>(&target))
{
snapshot.push_back({
ConstraintBlockKind::TypeId,
*ty,
toString(*ty, opts),
});
}
else if (const TypePackId* tp = get_if<TypePackId>(&target))
{
snapshot.push_back({
ConstraintBlockKind::TypePackId,
*tp,
toString(*tp, opts),
});
}
else if (const NotNull<const Constraint>* c = get_if<NotNull<const Constraint>>(&target))
{
snapshot.push_back({
ConstraintBlockKind::ConstraintId,
*c,
toString(*(c->get()), opts),
});
}

View file

@ -899,8 +899,8 @@ ModulePtr check(
cgb.visit(sourceModule.root);
result->errors = std::move(cgb.errors);
ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, moduleResolver,
requireCycles, logger.get()};
ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name,
NotNull{result->reduction.get()}, moduleResolver, requireCycles, logger.get()};
if (options.randomizeConstraintResolutionSeed)
cs.randomize(*options.randomizeConstraintResolutionSeed);

View file

@ -1441,6 +1441,8 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor
if (!unionNormals(here, *tn))
return false;
}
else if (get<BlockedType>(there))
LUAU_ASSERT(!"Internal error: Trying to normalize a BlockedType");
else
LUAU_ASSERT(!"Unreachable");

View file

@ -183,7 +183,7 @@ struct PureQuantifier : Substitution
else if (ttv->state == TableState::Generic)
seenGenericType = true;
return ttv->state == TableState::Unsealed || (ttv->state == TableState::Free && subsumes(scope, ttv->scope));
return (ttv->state == TableState::Unsealed || ttv->state == TableState::Free) && subsumes(scope, ttv->scope);
}
return false;

View file

@ -31,12 +31,12 @@ std::optional<TypeId> Scope::lookup(Symbol sym) const
{
auto r = const_cast<Scope*>(this)->lookupEx(sym);
if (r)
return r->first;
return r->first->typeId;
else
return std::nullopt;
}
std::optional<std::pair<TypeId, Scope*>> Scope::lookupEx(Symbol sym)
std::optional<std::pair<Binding*, Scope*>> Scope::lookupEx(Symbol sym)
{
Scope* s = this;
@ -44,7 +44,7 @@ std::optional<std::pair<TypeId, Scope*>> Scope::lookupEx(Symbol sym)
{
auto it = s->bindings.find(sym);
if (it != s->bindings.end())
return std::pair{it->second.typeId, s};
return std::pair{&it->second, s};
if (s->parent)
s = s->parent.get();

View file

@ -1533,6 +1533,10 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
const std::string pathStr = c.path.size() == 1 ? "\"" + c.path[0] + "\"" : "[\"" + join(c.path, "\", \"") + "\"]";
return tos(c.resultType) + " ~ setProp " + tos(c.subjectType) + ", " + pathStr + " " + tos(c.propType);
}
else if constexpr (std::is_same_v<T, SetIndexerConstraint>)
{
return tos(c.resultType) + " ~ setIndexer " + tos(c.subjectType) + " [ " + tos(c.indexType) + " ] " + tos(c.propType);
}
else if constexpr (std::is_same_v<T, SingletonOrTopTypeConstraint>)
{
std::string result = tos(c.resultType);
@ -1543,6 +1547,8 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
else
return result + " ~ if isSingleton D then D else unknown where D = " + discriminant;
}
else if constexpr (std::is_same_v<T, UnpackConstraint>)
return tos(c.resultPack) + " ~ unpack " + tos(c.sourcePack);
else
static_assert(always_false_v<T>, "Non-exhaustive constraint switch");
};

View file

@ -4,6 +4,7 @@
#include "Luau/Ast.h"
#include "Luau/AstQuery.h"
#include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/DcrLogger.h"
#include "Luau/Error.h"
#include "Luau/Instantiation.h"
@ -329,11 +330,12 @@ struct TypeChecker2
for (size_t i = 0; i < count; ++i)
{
AstExpr* value = i < local->values.size ? local->values.data[i] : nullptr;
const bool isPack = value && (value->is<AstExprCall>() || value->is<AstExprVarargs>());
if (value)
visit(value, RValue);
if (i != local->values.size - 1 || value)
if (i != local->values.size - 1 || !isPack)
{
AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr;
@ -351,16 +353,19 @@ struct TypeChecker2
visit(var->annotation);
}
}
else
else if (value)
{
LUAU_ASSERT(value);
TypePackId valuePack = lookupPack(value);
TypePack valueTypes;
if (i < local->vars.size)
valueTypes = extendTypePack(module->internalTypes, builtinTypes, valuePack, local->vars.size - i);
TypePackId valueTypes = lookupPack(value);
auto it = begin(valueTypes);
Location errorLocation;
for (size_t j = i; j < local->vars.size; ++j)
{
if (it == end(valueTypes))
if (j - i >= valueTypes.head.size())
{
errorLocation = local->vars.data[j]->location;
break;
}
@ -368,14 +373,28 @@ struct TypeChecker2
if (var->annotation)
{
TypeId varType = lookupAnnotation(var->annotation);
ErrorVec errors = tryUnify(stack.back(), value->location, *it, varType);
ErrorVec errors = tryUnify(stack.back(), value->location, valueTypes.head[j - i], varType);
if (!errors.empty())
reportErrors(std::move(errors));
visit(var->annotation);
}
}
++it;
if (valueTypes.head.size() < local->vars.size - i)
{
reportError(
CountMismatch{
// We subtract 1 here because the final AST
// expression is not worth one value. It is worth 0
// or more depending on valueTypes.head
local->values.size - 1 + valueTypes.head.size(),
std::nullopt,
local->vars.size,
local->values.data[local->values.size - 1]->is<AstExprCall>() ? CountMismatch::FunctionResult
: CountMismatch::ExprListResult,
},
errorLocation);
}
}
}
@ -810,6 +829,95 @@ struct TypeChecker2
// TODO!
}
ErrorVec visitOverload(AstExprCall* call, NotNull<const FunctionType> overloadFunctionType, const std::vector<Location>& argLocs,
TypePackId expectedArgTypes, TypePackId expectedRetType)
{
ErrorVec overloadErrors =
tryUnify(stack.back(), call->location, overloadFunctionType->retTypes, expectedRetType, CountMismatch::FunctionResult);
size_t argIndex = 0;
auto inferredArgIt = begin(overloadFunctionType->argTypes);
auto expectedArgIt = begin(expectedArgTypes);
while (inferredArgIt != end(overloadFunctionType->argTypes) && expectedArgIt != end(expectedArgTypes))
{
Location argLoc = (argIndex >= argLocs.size()) ? argLocs.back() : argLocs[argIndex];
ErrorVec argErrors = tryUnify(stack.back(), argLoc, *expectedArgIt, *inferredArgIt);
for (TypeError e : argErrors)
overloadErrors.emplace_back(e);
++argIndex;
++inferredArgIt;
++expectedArgIt;
}
// piggyback on the unifier for arity checking, but we can't do this for checking the actual arguments since the locations would be bad
ErrorVec argumentErrors = tryUnify(stack.back(), call->location, expectedArgTypes, overloadFunctionType->argTypes);
for (TypeError e : argumentErrors)
if (get<CountMismatch>(e) != nullptr)
overloadErrors.emplace_back(std::move(e));
return overloadErrors;
}
void reportOverloadResolutionErrors(AstExprCall* call, std::vector<TypeId> overloads, TypePackId expectedArgTypes,
const std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<std::pair<ErrorVec, const FunctionType*>> overloadsErrors)
{
if (overloads.size() == 1)
{
reportErrors(std::get<0>(overloadsErrors.front()));
return;
}
std::vector<TypeId> overloadTypes = overloadsThatMatchArgCount;
if (overloadsThatMatchArgCount.size() == 0)
{
reportError(GenericError{"No overload for function accepts " + std::to_string(size(expectedArgTypes)) + " arguments."}, call->location);
// If no overloads match argument count, just list all overloads.
overloadTypes = overloads;
}
else
{
// Report errors of the first argument-count-matching, but failing overload
TypeId overload = overloadsThatMatchArgCount[0];
// Remove the overload we are reporting errors about from the list of alternatives
overloadTypes.erase(std::remove(overloadTypes.begin(), overloadTypes.end(), overload), overloadTypes.end());
const FunctionType* ftv = get<FunctionType>(overload);
LUAU_ASSERT(ftv); // overload must be a function type here
auto error = std::find_if(overloadsErrors.begin(), overloadsErrors.end(), [ftv](const std::pair<ErrorVec, const FunctionType*>& e) {
return ftv == std::get<1>(e);
});
LUAU_ASSERT(error != overloadsErrors.end());
reportErrors(std::get<0>(*error));
// If only one overload matched, we don't need this error because we provided the previous errors.
if (overloadsThatMatchArgCount.size() == 1)
return;
}
std::string s;
for (size_t i = 0; i < overloadTypes.size(); ++i)
{
TypeId overload = follow(overloadTypes[i]);
if (i > 0)
s += "; ";
if (i > 0 && i == overloadTypes.size() - 1)
s += "and ";
s += toString(overload);
}
if (overloadsThatMatchArgCount.size() == 0)
reportError(ExtraInformation{"Available overloads: " + s}, call->func->location);
else
reportError(ExtraInformation{"Other overloads are also not viable: " + s}, call->func->location);
}
void visit(AstExprCall* call)
{
visit(call->func, RValue);
@ -865,6 +973,10 @@ struct TypeChecker2
return;
}
}
else if (auto itv = get<IntersectionType>(functionType))
{
// We do nothing here because we'll flatten the intersection later, but we don't want to report it as a non-function.
}
else if (auto utv = get<UnionType>(functionType))
{
// Sometimes it's okay to call a union of functions, but only if all of the functions are the same.
@ -930,48 +1042,105 @@ struct TypeChecker2
TypePackId expectedArgTypes = arena->addTypePack(args);
const FunctionType* inferredFunctionType = get<FunctionType>(testFunctionType);
LUAU_ASSERT(inferredFunctionType); // testFunctionType should always be a FunctionType here
std::vector<TypeId> overloads = flattenIntersection(testFunctionType);
std::vector<std::pair<ErrorVec, const FunctionType*>> overloadsErrors;
overloadsErrors.reserve(overloads.size());
size_t argIndex = 0;
auto inferredArgIt = begin(inferredFunctionType->argTypes);
auto expectedArgIt = begin(expectedArgTypes);
while (inferredArgIt != end(inferredFunctionType->argTypes) && expectedArgIt != end(expectedArgTypes))
std::vector<TypeId> overloadsThatMatchArgCount;
for (TypeId overload : overloads)
{
Location argLoc = (argIndex >= argLocs.size()) ? argLocs.back() : argLocs[argIndex];
reportErrors(tryUnify(stack.back(), argLoc, *expectedArgIt, *inferredArgIt));
overload = follow(overload);
++argIndex;
++inferredArgIt;
++expectedArgIt;
const FunctionType* overloadFn = get<FunctionType>(overload);
if (!overloadFn)
{
reportError(CannotCallNonFunction{overload}, call->func->location);
return;
}
else
{
// We may have to instantiate the overload in order for it to typecheck.
if (std::optional<TypeId> instantiatedFunctionType = instantiation.substitute(overload))
{
overloadFn = get<FunctionType>(*instantiatedFunctionType);
}
else
{
overloadsErrors.emplace_back(std::vector{TypeError{call->func->location, UnificationTooComplex{}}}, overloadFn);
return;
}
}
ErrorVec overloadErrors = visitOverload(call, NotNull{overloadFn}, argLocs, expectedArgTypes, expectedRetType);
if (overloadErrors.empty())
return;
bool argMismatch = false;
for (auto error : overloadErrors)
{
CountMismatch* cm = get<CountMismatch>(error);
if (!cm)
continue;
if (cm->context == CountMismatch::Arg)
{
argMismatch = true;
break;
}
}
if (!argMismatch)
overloadsThatMatchArgCount.push_back(overload);
overloadsErrors.emplace_back(std::move(overloadErrors), overloadFn);
}
// piggyback on the unifier for arity checking, but we can't do this for checking the actual arguments since the locations would be bad
ErrorVec errors = tryUnify(stack.back(), call->location, expectedArgTypes, inferredFunctionType->argTypes);
for (TypeError e : errors)
if (get<CountMismatch>(e) != nullptr)
reportError(std::move(e));
reportOverloadResolutionErrors(call, overloads, expectedArgTypes, overloadsThatMatchArgCount, overloadsErrors);
}
reportErrors(tryUnify(stack.back(), call->location, inferredFunctionType->retTypes, expectedRetType, CountMismatch::FunctionResult));
void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context)
{
visit(expr, RValue);
TypeId leftType = lookupType(expr);
const NormalizedType* norm = normalizer.normalize(leftType);
if (!norm)
reportError(NormalizationTooComplex{}, location);
checkIndexTypeFromType(leftType, *norm, propName, location, context);
}
void visit(AstExprIndexName* indexName, ValueContext context)
{
visit(indexName->expr, RValue);
TypeId leftType = lookupType(indexName->expr);
const NormalizedType* norm = normalizer.normalize(leftType);
if (!norm)
reportError(NormalizationTooComplex{}, indexName->indexLocation);
checkIndexTypeFromType(leftType, *norm, indexName->index.value, indexName->location, context);
visitExprName(indexName->expr, indexName->location, indexName->index.value, context);
}
void visit(AstExprIndexExpr* indexExpr, ValueContext context)
{
if (auto str = indexExpr->index->as<AstExprConstantString>())
{
const std::string stringValue(str->value.data, str->value.size);
visitExprName(indexExpr->expr, indexExpr->location, stringValue, context);
return;
}
// TODO!
visit(indexExpr->expr, LValue);
visit(indexExpr->index, RValue);
NotNull<Scope> scope = stack.back();
TypeId exprType = lookupType(indexExpr->expr);
TypeId indexType = lookupType(indexExpr->index);
if (auto tt = get<TableType>(exprType))
{
if (tt->indexer)
reportErrors(tryUnify(scope, indexExpr->index->location, indexType, tt->indexer->indexType));
else
reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location);
}
}
void visit(AstExprFunction* fn)
@ -1879,8 +2048,17 @@ struct TypeChecker2
ty = *mtIndex;
}
if (getTableType(ty))
return bool(findTablePropertyRespectingMeta(builtinTypes, module->errors, ty, prop, location));
if (auto tt = getTableType(ty))
{
if (findTablePropertyRespectingMeta(builtinTypes, module->errors, ty, prop, location))
return true;
else if (tt->indexer && isPrim(tt->indexer->indexResultType, PrimitiveType::String))
return tt->indexer->indexResultType;
else
return false;
}
else if (const ClassType* cls = get<ClassType>(ty))
return bool(lookupClassProp(cls, prop));
else if (const UnionType* utv = get<UnionType>(ty))

View file

@ -1759,7 +1759,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit)
{
reportErrorCodeTooComplex(expr.location);
return {errorRecoveryType(scope)};
return WithPredicate{errorRecoveryType(scope)};
}
WithPredicate<TypeId> result;
@ -1767,23 +1767,23 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
if (auto a = expr.as<AstExprGroup>())
result = checkExpr(scope, *a->expr, expectedType);
else if (expr.is<AstExprConstantNil>())
result = {nilType};
result = WithPredicate{nilType};
else if (const AstExprConstantBool* bexpr = expr.as<AstExprConstantBool>())
{
if (forceSingleton || (expectedType && maybeSingleton(*expectedType)))
result = {singletonType(bexpr->value)};
result = WithPredicate{singletonType(bexpr->value)};
else
result = {booleanType};
result = WithPredicate{booleanType};
}
else if (const AstExprConstantString* sexpr = expr.as<AstExprConstantString>())
{
if (forceSingleton || (expectedType && maybeSingleton(*expectedType)))
result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))};
result = WithPredicate{singletonType(std::string(sexpr->value.data, sexpr->value.size))};
else
result = {stringType};
result = WithPredicate{stringType};
}
else if (expr.is<AstExprConstantNumber>())
result = {numberType};
result = WithPredicate{numberType};
else if (auto a = expr.as<AstExprLocal>())
result = checkExpr(scope, *a);
else if (auto a = expr.as<AstExprGlobal>())
@ -1837,7 +1837,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
// TODO: tempting to ice here, but this breaks very often because our toposort doesn't enforce this constraint
// ice("AstExprLocal exists but no binding definition for it?", expr.location);
reportError(TypeError{expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}});
return {errorRecoveryType(scope)};
return WithPredicate{errorRecoveryType(scope)};
}
WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr)
@ -1849,7 +1849,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}};
reportError(TypeError{expr.location, UnknownSymbol{expr.name.value, UnknownSymbol::Binding}});
return {errorRecoveryType(scope)};
return WithPredicate{errorRecoveryType(scope)};
}
WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr)
@ -1859,26 +1859,26 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
if (get<TypePack>(varargPack))
{
if (std::optional<TypeId> ty = first(varargPack))
return {*ty};
return WithPredicate{*ty};
return {nilType};
return WithPredicate{nilType};
}
else if (get<FreeTypePack>(varargPack))
{
TypeId head = freshType(scope);
TypePackId tail = freshTypePack(scope);
*asMutable(varargPack) = TypePack{{head}, tail};
return {head};
return WithPredicate{head};
}
if (get<ErrorType>(varargPack))
return {errorRecoveryType(scope)};
return WithPredicate{errorRecoveryType(scope)};
else if (auto vtp = get<VariadicTypePack>(varargPack))
return {vtp->ty};
return WithPredicate{vtp->ty};
else if (get<Unifiable::Generic>(varargPack))
{
// TODO: Better error?
reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"});
return {errorRecoveryType(scope)};
return WithPredicate{errorRecoveryType(scope)};
}
else
ice("Unknown TypePack type in checkExpr(AstExprVarargs)!");
@ -1929,9 +1929,9 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
lhsType = stripFromNilAndReport(lhsType, expr.expr->location);
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, lhsType, name, expr.location, /* addErrors= */ true))
return {*ty};
return WithPredicate{*ty};
return {errorRecoveryType(scope)};
return WithPredicate{errorRecoveryType(scope)};
}
std::optional<TypeId> TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors)
@ -2138,7 +2138,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
if (std::optional<TypeId> refiTy = resolveLValue(scope, *lvalue))
return {*refiTy, {TruthyPredicate{std::move(*lvalue), expr.location}}};
return {ty};
return WithPredicate{ty};
}
WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional<TypeId> expectedType)
@ -2147,7 +2147,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
checkFunctionBody(funScope, funTy, expr);
return {quantify(funScope, funTy, expr.location)};
return WithPredicate{quantify(funScope, funTy, expr.location)};
}
TypeId TypeChecker::checkExprTable(
@ -2252,7 +2252,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit)
{
reportErrorCodeTooComplex(expr.location);
return {errorRecoveryType(scope)};
return WithPredicate{errorRecoveryType(scope)};
}
std::vector<std::pair<TypeId, TypeId>> fieldTypes(expr.items.size);
@ -2339,7 +2339,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
expectedIndexResultType = fieldTypes[i].second;
}
return {checkExprTable(scope, expr, fieldTypes, expectedType)};
return WithPredicate{checkExprTable(scope, expr, fieldTypes, expectedType)};
}
WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUnary& expr)
@ -2356,7 +2356,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
const bool operandIsAny = get<AnyType>(operandType) || get<ErrorType>(operandType) || get<NeverType>(operandType);
if (operandIsAny)
return {operandType};
return WithPredicate{operandType};
if (typeCouldHaveMetatable(operandType))
{
@ -2377,16 +2377,16 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
if (!state.errors.empty())
retType = errorRecoveryType(retType);
return {retType};
return WithPredicate{retType};
}
reportError(expr.location,
GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())});
return {errorRecoveryType(scope)};
return WithPredicate{errorRecoveryType(scope)};
}
reportErrors(tryUnify(operandType, numberType, scope, expr.location));
return {numberType};
return WithPredicate{numberType};
}
case AstExprUnary::Len:
{
@ -2396,7 +2396,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
// # operator is guaranteed to return number
if (get<AnyType>(operandType) || get<ErrorType>(operandType) || get<NeverType>(operandType))
return {numberType};
return WithPredicate{numberType};
DenseHashSet<TypeId> seen{nullptr};
@ -2420,7 +2420,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
if (!hasLength(operandType, seen, &recursionCount))
reportError(TypeError{expr.location, NotATable{operandType}});
return {numberType};
return WithPredicate{numberType};
}
default:
ice("Unknown AstExprUnary " + std::to_string(int(expr.op)));
@ -3014,7 +3014,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
WithPredicate<TypeId> rhs = checkExpr(scope, *expr.right);
// Intentionally discarding predicates with other operators.
return {checkBinaryOperation(scope, expr, lhs.type, rhs.type, lhs.predicates)};
return WithPredicate{checkBinaryOperation(scope, expr, lhs.type, rhs.type, lhs.predicates)};
}
}
@ -3045,7 +3045,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
// any type errors that may arise from it are going to be useless.
currentModule->errors.resize(oldSize);
return {errorRecoveryType(scope)};
return WithPredicate{errorRecoveryType(scope)};
}
WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional<TypeId> expectedType)
@ -3061,12 +3061,12 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
WithPredicate<TypeId> falseType = checkExpr(falseScope, *expr.falseExpr, expectedType);
if (falseType.type == trueType.type)
return {trueType.type};
return WithPredicate{trueType.type};
std::vector<TypeId> types = reduceUnion({trueType.type, falseType.type});
if (types.empty())
return {neverType};
return {types.size() == 1 ? types[0] : addType(UnionType{std::move(types)})};
return WithPredicate{neverType};
return WithPredicate{types.size() == 1 ? types[0] : addType(UnionType{std::move(types)})};
}
WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprInterpString& expr)
@ -3074,7 +3074,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
for (AstExpr* expr : expr.expressions)
checkExpr(scope, *expr);
return {stringType};
return WithPredicate{stringType};
}
TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx)
@ -3704,7 +3704,7 @@ WithPredicate<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, cons
{
WithPredicate<TypePackId> result = checkExprPackHelper(scope, expr);
if (containsNever(result.type))
return {uninhabitableTypePack};
return WithPredicate{uninhabitableTypePack};
return result;
}
@ -3715,14 +3715,14 @@ WithPredicate<TypePackId> TypeChecker::checkExprPackHelper(const ScopePtr& scope
else if (expr.is<AstExprVarargs>())
{
if (!scope->varargPack)
return {errorRecoveryTypePack(scope)};
return WithPredicate{errorRecoveryTypePack(scope)};
return {*scope->varargPack};
return WithPredicate{*scope->varargPack};
}
else
{
TypeId type = checkExpr(scope, expr).type;
return {addTypePack({type})};
return WithPredicate{addTypePack({type})};
}
}
@ -3994,71 +3994,77 @@ WithPredicate<TypePackId> TypeChecker::checkExprPackHelper(const ScopePtr& scope
{
retPack = freshTypePack(free->level);
TypePackId freshArgPack = freshTypePack(free->level);
asMutable(actualFunctionType)->ty.emplace<FunctionType>(free->level, freshArgPack, retPack);
emplaceType<FunctionType>(asMutable(actualFunctionType), free->level, freshArgPack, retPack);
}
else
retPack = freshTypePack(scope->level);
// checkExpr will log the pre-instantiated type of the function.
// That's not nearly as interesting as the instantiated type, which will include details about how
// generic functions are being instantiated for this particular callsite.
currentModule->astOriginalCallTypes[expr.func] = follow(functionType);
currentModule->astTypes[expr.func] = actualFunctionType;
// We break this function up into a lambda here to limit our stack footprint.
// The vectors used by this function aren't allocated until the lambda is actually called.
auto the_rest = [&]() -> WithPredicate<TypePackId> {
// checkExpr will log the pre-instantiated type of the function.
// That's not nearly as interesting as the instantiated type, which will include details about how
// generic functions are being instantiated for this particular callsite.
currentModule->astOriginalCallTypes[expr.func] = follow(functionType);
currentModule->astTypes[expr.func] = actualFunctionType;
std::vector<TypeId> overloads = flattenIntersection(actualFunctionType);
std::vector<TypeId> overloads = flattenIntersection(actualFunctionType);
std::vector<std::optional<TypeId>> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self);
std::vector<std::optional<TypeId>> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self);
WithPredicate<TypePackId> argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes);
TypePackId argPack = argListResult.type;
WithPredicate<TypePackId> argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes);
TypePackId argPack = argListResult.type;
if (get<Unifiable::Error>(argPack))
return {errorRecoveryTypePack(scope)};
if (get<Unifiable::Error>(argPack))
return WithPredicate{errorRecoveryTypePack(scope)};
TypePack* args = nullptr;
if (expr.self)
{
argPack = addTypePack(TypePack{{selfType}, argPack});
argListResult.type = argPack;
}
args = getMutable<TypePack>(argPack);
LUAU_ASSERT(args);
TypePack* args = nullptr;
if (expr.self)
{
argPack = addTypePack(TypePack{{selfType}, argPack});
argListResult.type = argPack;
}
args = getMutable<TypePack>(argPack);
LUAU_ASSERT(args);
std::vector<Location> argLocations;
argLocations.reserve(expr.args.size + 1);
if (expr.self)
argLocations.push_back(expr.func->as<AstExprIndexName>()->expr->location);
for (AstExpr* arg : expr.args)
argLocations.push_back(arg->location);
std::vector<Location> argLocations;
argLocations.reserve(expr.args.size + 1);
if (expr.self)
argLocations.push_back(expr.func->as<AstExprIndexName>()->expr->location);
for (AstExpr* arg : expr.args)
argLocations.push_back(arg->location);
std::vector<OverloadErrorEntry> errors; // errors encountered for each overload
std::vector<OverloadErrorEntry> errors; // errors encountered for each overload
std::vector<TypeId> overloadsThatMatchArgCount;
std::vector<TypeId> overloadsThatDont;
std::vector<TypeId> overloadsThatMatchArgCount;
std::vector<TypeId> overloadsThatDont;
for (TypeId fn : overloads)
{
fn = follow(fn);
for (TypeId fn : overloads)
{
fn = follow(fn);
if (auto ret = checkCallOverload(
scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors))
return *ret;
}
if (auto ret = checkCallOverload(
scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors))
return *ret;
}
if (handleSelfCallMismatch(scope, expr, args, argLocations, errors))
return {retPack};
if (handleSelfCallMismatch(scope, expr, args, argLocations, errors))
return WithPredicate{retPack};
reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors);
reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors);
const FunctionType* overload = nullptr;
if (!overloadsThatMatchArgCount.empty())
overload = get<FunctionType>(overloadsThatMatchArgCount[0]);
if (!overload && !overloadsThatDont.empty())
overload = get<FunctionType>(overloadsThatDont[0]);
if (overload)
return {errorRecoveryTypePack(overload->retTypes)};
const FunctionType* overload = nullptr;
if (!overloadsThatMatchArgCount.empty())
overload = get<FunctionType>(overloadsThatMatchArgCount[0]);
if (!overload && !overloadsThatDont.empty())
overload = get<FunctionType>(overloadsThatDont[0]);
if (overload)
return WithPredicate{errorRecoveryTypePack(overload->retTypes)};
return {errorRecoveryTypePack(retPack)};
return WithPredicate{errorRecoveryTypePack(retPack)};
};
return the_rest();
}
std::vector<std::optional<TypeId>> TypeChecker::getExpectedTypesForCall(const std::vector<TypeId>& overloads, size_t argumentCount, bool selfCall)
@ -4119,8 +4125,13 @@ std::vector<std::optional<TypeId>> TypeChecker::getExpectedTypesForCall(const st
return expectedTypes;
}
std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack,
TypePackId argPack, TypePack* args, const std::vector<Location>* argLocations, const WithPredicate<TypePackId>& argListResult,
/*
* Note: We return a std::unique_ptr here rather than an optional to manage our stack consumption.
* If this was an optional, callers would have to pay the stack cost for the result. This is problematic
* for functions that need to support recursion up to 600 levels deep.
*/
std::unique_ptr<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn,
TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector<Location>* argLocations, const WithPredicate<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<TypeId>& overloadsThatDont, std::vector<OverloadErrorEntry>& errors)
{
LUAU_ASSERT(argLocations);
@ -4130,16 +4141,16 @@ std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const Sc
if (get<AnyType>(fn))
{
unify(anyTypePack, argPack, scope, expr.location);
return {{anyTypePack}};
return std::make_unique<WithPredicate<TypePackId>>(anyTypePack);
}
if (get<ErrorType>(fn))
{
return {{errorRecoveryTypePack(scope)}};
return std::make_unique<WithPredicate<TypePackId>>(errorRecoveryTypePack(scope));
}
if (get<NeverType>(fn))
return {{uninhabitableTypePack}};
return std::make_unique<WithPredicate<TypePackId>>(uninhabitableTypePack);
if (auto ftv = get<FreeType>(fn))
{
@ -4152,7 +4163,7 @@ std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const Sc
options.isFunctionCall = true;
unify(r, fn, scope, expr.location, options);
return {{retPack}};
return std::make_unique<WithPredicate<TypePackId>>(retPack);
}
std::vector<Location> metaArgLocations;
@ -4191,7 +4202,7 @@ std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const Sc
{
reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}});
unify(errorRecoveryTypePack(scope), retPack, scope, expr.func->location);
return {{errorRecoveryTypePack(retPack)}};
return std::make_unique<WithPredicate<TypePackId>>(errorRecoveryTypePack(retPack));
}
// When this function type has magic functions and did return something, we select that overload instead.
@ -4200,7 +4211,7 @@ std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const Sc
{
// TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458
if (std::optional<WithPredicate<TypePackId>> ret = ftv->magicFunction(*this, scope, expr, argListResult))
return *ret;
return std::make_unique<WithPredicate<TypePackId>>(std::move(*ret));
}
Unifier state = mkUnifier(scope, expr.location);
@ -4209,7 +4220,7 @@ std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const Sc
checkArgumentList(scope, *expr.func, state, retPack, ftv->retTypes, /*argLocations*/ {});
if (!state.errors.empty())
{
return {};
return nullptr;
}
checkArgumentList(scope, *expr.func, state, argPack, ftv->argTypes, *argLocations);
@ -4244,10 +4255,10 @@ std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const Sc
currentModule->astOverloadResolvedTypes[&expr] = fn;
// We select this overload
return {{retPack}};
return std::make_unique<WithPredicate<TypePackId>>(retPack);
}
return {};
return nullptr;
}
bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector<Location>& argLocations,
@ -4404,7 +4415,7 @@ WithPredicate<TypePackId> TypeChecker::checkExprList(const ScopePtr& scope, cons
};
if (exprs.size == 0)
return {pack};
return WithPredicate{pack};
TypePack* tp = getMutable<TypePack>(pack);
@ -4484,7 +4495,7 @@ WithPredicate<TypePackId> TypeChecker::checkExprList(const ScopePtr& scope, cons
log.commit();
if (uninhabitable)
return {uninhabitableTypePack};
return WithPredicate{uninhabitableTypePack};
return {pack, predicates};
}

View file

@ -16,11 +16,167 @@ LUAU_FASTFLAGVARIABLE(DebugLuauDontReduceTypes, false)
namespace Luau
{
namespace detail
{
bool TypeReductionMemoization::isIrreducible(TypeId ty)
{
ty = follow(ty);
// Only does shallow check, the TypeReducer itself already does deep traversal.
if (auto edge = types.find(ty); edge && edge->irreducible)
return true;
else if (get<FreeType>(ty) || get<BlockedType>(ty) || get<PendingExpansionType>(ty))
return false;
else if (auto tt = get<TableType>(ty); tt && (tt->state == TableState::Free || tt->state == TableState::Unsealed))
return false;
else
return true;
}
bool TypeReductionMemoization::isIrreducible(TypePackId tp)
{
tp = follow(tp);
// Only does shallow check, the TypeReducer itself already does deep traversal.
if (auto edge = typePacks.find(tp); edge && edge->irreducible)
return true;
else if (get<FreeTypePack>(tp) || get<BlockedTypePack>(tp))
return false;
else if (auto vtp = get<VariadicTypePack>(tp))
return isIrreducible(vtp->ty);
else
return true;
}
TypeId TypeReductionMemoization::memoize(TypeId ty, TypeId reducedTy)
{
ty = follow(ty);
reducedTy = follow(reducedTy);
// The irreducibility of this [`reducedTy`] depends on whether its contents are themselves irreducible.
// We don't need to recurse much further than that, because we already record the irreducibility from
// the bottom up.
bool irreducible = isIrreducible(reducedTy);
if (auto it = get<IntersectionType>(reducedTy))
{
for (TypeId part : it)
irreducible &= isIrreducible(part);
}
else if (auto ut = get<UnionType>(reducedTy))
{
for (TypeId option : ut)
irreducible &= isIrreducible(option);
}
else if (auto tt = get<TableType>(reducedTy))
{
for (auto& [k, p] : tt->props)
irreducible &= isIrreducible(p.type);
if (tt->indexer)
{
irreducible &= isIrreducible(tt->indexer->indexType);
irreducible &= isIrreducible(tt->indexer->indexResultType);
}
for (auto ta : tt->instantiatedTypeParams)
irreducible &= isIrreducible(ta);
for (auto tpa : tt->instantiatedTypePackParams)
irreducible &= isIrreducible(tpa);
}
else if (auto mt = get<MetatableType>(reducedTy))
{
irreducible &= isIrreducible(mt->table);
irreducible &= isIrreducible(mt->metatable);
}
else if (auto ft = get<FunctionType>(reducedTy))
{
irreducible &= isIrreducible(ft->argTypes);
irreducible &= isIrreducible(ft->retTypes);
}
else if (auto nt = get<NegationType>(reducedTy))
irreducible &= isIrreducible(nt->ty);
types[ty] = {reducedTy, irreducible};
types[reducedTy] = {reducedTy, irreducible};
return reducedTy;
}
TypePackId TypeReductionMemoization::memoize(TypePackId tp, TypePackId reducedTp)
{
tp = follow(tp);
reducedTp = follow(reducedTp);
bool irreducible = isIrreducible(reducedTp);
TypePackIterator it = begin(tp);
while (it != end(tp))
{
irreducible &= isIrreducible(*it);
++it;
}
if (it.tail())
irreducible &= isIrreducible(*it.tail());
typePacks[tp] = {reducedTp, irreducible};
typePacks[reducedTp] = {reducedTp, irreducible};
return reducedTp;
}
std::optional<ReductionEdge<TypeId>> TypeReductionMemoization::memoizedof(TypeId ty) const
{
auto fetchContext = [this](TypeId ty) -> std::optional<ReductionEdge<TypeId>> {
if (auto edge = types.find(ty))
return *edge;
else
return std::nullopt;
};
TypeId currentTy = ty;
std::optional<ReductionEdge<TypeId>> lastEdge;
while (auto edge = fetchContext(currentTy))
{
lastEdge = edge;
if (edge->irreducible)
return edge;
else if (edge->type == currentTy)
return edge;
else
currentTy = edge->type;
}
return lastEdge;
}
std::optional<ReductionEdge<TypePackId>> TypeReductionMemoization::memoizedof(TypePackId tp) const
{
auto fetchContext = [this](TypePackId tp) -> std::optional<ReductionEdge<TypePackId>> {
if (auto edge = typePacks.find(tp))
return *edge;
else
return std::nullopt;
};
TypePackId currentTp = tp;
std::optional<ReductionEdge<TypePackId>> lastEdge;
while (auto edge = fetchContext(currentTp))
{
lastEdge = edge;
if (edge->irreducible)
return edge;
else if (edge->type == currentTp)
return edge;
else
currentTp = edge->type;
}
return lastEdge;
}
} // namespace detail
namespace
{
using detail::ReductionContext;
template<typename A, typename B, typename Thing>
std::pair<const A*, const B*> get2(const Thing& one, const Thing& two)
{
@ -34,9 +190,7 @@ struct TypeReducer
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes;
NotNull<InternalErrorReporter> handle;
DenseHashMap<TypeId, ReductionContext<TypeId>>* memoizedTypes;
DenseHashMap<TypePackId, ReductionContext<TypePackId>>* memoizedTypePacks;
NotNull<detail::TypeReductionMemoization> memoization;
DenseHashSet<const void*>* cyclics;
int depth = 0;
@ -50,12 +204,6 @@ struct TypeReducer
TypeId functionType(TypeId ty);
TypeId negationType(TypeId ty);
bool isIrreducible(TypeId ty);
bool isIrreducible(TypePackId tp);
TypeId memoize(TypeId ty, TypeId reducedTy);
TypePackId memoize(TypePackId tp, TypePackId reducedTp);
using BinaryFold = std::optional<TypeId> (TypeReducer::*)(TypeId, TypeId);
using UnaryFold = TypeId (TypeReducer::*)(TypeId);
@ -64,12 +212,15 @@ struct TypeReducer
{
ty = follow(ty);
if (auto ctx = memoizedTypes->find(ty))
return {ctx->type, getMutable<T>(ctx->type)};
if (auto edge = memoization->memoizedof(ty))
return {edge->type, getMutable<T>(edge->type)};
// We specifically do not want to use [`detail::TypeReductionMemoization::memoize`] because that will
// potentially consider these copiedTy to be reducible, but we need this to resolve cyclic references
// without attempting to recursively reduce it, causing copies of copies of copies of...
TypeId copiedTy = arena->addType(*t);
(*memoizedTypes)[ty] = {copiedTy, true};
(*memoizedTypes)[copiedTy] = {copiedTy, true};
memoization->types[ty] = {copiedTy, true};
memoization->types[copiedTy] = {copiedTy, true};
return {copiedTy, getMutable<T>(copiedTy)};
}
@ -175,8 +326,13 @@ TypeId TypeReducer::reduce(TypeId ty)
{
ty = follow(ty);
if (auto ctx = memoizedTypes->find(ty); ctx && ctx->irreducible)
return ctx->type;
if (auto edge = memoization->memoizedof(ty))
{
if (edge->irreducible)
return edge->type;
else
ty = edge->type;
}
else if (cyclics->contains(ty))
return ty;
@ -196,15 +352,20 @@ TypeId TypeReducer::reduce(TypeId ty)
else
result = ty;
return memoize(ty, result);
return memoization->memoize(ty, result);
}
TypePackId TypeReducer::reduce(TypePackId tp)
{
tp = follow(tp);
if (auto ctx = memoizedTypePacks->find(tp); ctx && ctx->irreducible)
return ctx->type;
if (auto edge = memoization->memoizedof(tp))
{
if (edge->irreducible)
return edge->type;
else
tp = edge->type;
}
else if (cyclics->contains(tp))
return tp;
@ -237,11 +398,11 @@ TypePackId TypeReducer::reduce(TypePackId tp)
}
if (!didReduce)
return memoize(tp, tp);
return memoization->memoize(tp, tp);
else if (head.empty() && tail)
return memoize(tp, *tail);
return memoization->memoize(tp, *tail);
else
return memoize(tp, arena->addTypePack(TypePack{std::move(head), tail}));
return memoization->memoize(tp, arena->addTypePack(TypePack{std::move(head), tail}));
}
std::optional<TypeId> TypeReducer::intersectionType(TypeId left, TypeId right)
@ -832,111 +993,6 @@ TypeId TypeReducer::negationType(TypeId ty)
return ty; // for all T except the ones handled above, ~T ~ ~T
}
bool TypeReducer::isIrreducible(TypeId ty)
{
ty = follow(ty);
// Only does shallow check, the TypeReducer itself already does deep traversal.
if (auto ctx = memoizedTypes->find(ty); ctx && ctx->irreducible)
return true;
else if (get<FreeType>(ty) || get<BlockedType>(ty) || get<PendingExpansionType>(ty))
return false;
else if (auto tt = get<TableType>(ty); tt && (tt->state == TableState::Free || tt->state == TableState::Unsealed))
return false;
else
return true;
}
bool TypeReducer::isIrreducible(TypePackId tp)
{
tp = follow(tp);
// Only does shallow check, the TypeReducer itself already does deep traversal.
if (auto ctx = memoizedTypePacks->find(tp); ctx && ctx->irreducible)
return true;
else if (get<FreeTypePack>(tp) || get<BlockedTypePack>(tp))
return false;
else if (auto vtp = get<VariadicTypePack>(tp))
return isIrreducible(vtp->ty);
else
return true;
}
TypeId TypeReducer::memoize(TypeId ty, TypeId reducedTy)
{
ty = follow(ty);
reducedTy = follow(reducedTy);
// The irreducibility of this [`reducedTy`] depends on whether its contents are themselves irreducible.
// We don't need to recurse much further than that, because we already record the irreducibility from
// the bottom up.
bool irreducible = isIrreducible(reducedTy);
if (auto it = get<IntersectionType>(reducedTy))
{
for (TypeId part : it)
irreducible &= isIrreducible(part);
}
else if (auto ut = get<UnionType>(reducedTy))
{
for (TypeId option : ut)
irreducible &= isIrreducible(option);
}
else if (auto tt = get<TableType>(reducedTy))
{
for (auto& [k, p] : tt->props)
irreducible &= isIrreducible(p.type);
if (tt->indexer)
{
irreducible &= isIrreducible(tt->indexer->indexType);
irreducible &= isIrreducible(tt->indexer->indexResultType);
}
for (auto ta : tt->instantiatedTypeParams)
irreducible &= isIrreducible(ta);
for (auto tpa : tt->instantiatedTypePackParams)
irreducible &= isIrreducible(tpa);
}
else if (auto mt = get<MetatableType>(reducedTy))
{
irreducible &= isIrreducible(mt->table);
irreducible &= isIrreducible(mt->metatable);
}
else if (auto ft = get<FunctionType>(reducedTy))
{
irreducible &= isIrreducible(ft->argTypes);
irreducible &= isIrreducible(ft->retTypes);
}
else if (auto nt = get<NegationType>(reducedTy))
irreducible &= isIrreducible(nt->ty);
(*memoizedTypes)[ty] = {reducedTy, irreducible};
(*memoizedTypes)[reducedTy] = {reducedTy, irreducible};
return reducedTy;
}
TypePackId TypeReducer::memoize(TypePackId tp, TypePackId reducedTp)
{
tp = follow(tp);
reducedTp = follow(reducedTp);
bool irreducible = isIrreducible(reducedTp);
TypePackIterator it = begin(tp);
while (it != end(tp))
{
irreducible &= isIrreducible(*it);
++it;
}
if (it.tail())
irreducible &= isIrreducible(*it.tail());
(*memoizedTypePacks)[tp] = {reducedTp, irreducible};
(*memoizedTypePacks)[reducedTp] = {reducedTp, irreducible};
return reducedTp;
}
struct MarkCycles : TypeVisitor
{
DenseHashSet<const void*> cyclics{nullptr};
@ -961,7 +1017,6 @@ struct MarkCycles : TypeVisitor
return !cyclics.find(follow(tp));
}
};
} // namespace
TypeReduction::TypeReduction(
@ -981,8 +1036,13 @@ std::optional<TypeId> TypeReduction::reduce(TypeId ty)
return ty;
else if (!options.allowTypeReductionsFromOtherArenas && ty->owningArena != arena)
return ty;
else if (auto memoized = memoizedof(ty))
return *memoized;
else if (auto edge = memoization.memoizedof(ty))
{
if (edge->irreducible)
return edge->type;
else
ty = edge->type;
}
else if (hasExceededCartesianProductLimit(ty))
return std::nullopt;
@ -991,7 +1051,7 @@ std::optional<TypeId> TypeReduction::reduce(TypeId ty)
MarkCycles finder;
finder.traverse(ty);
TypeReducer reducer{arena, builtinTypes, handle, &memoizedTypes, &memoizedTypePacks, &finder.cyclics};
TypeReducer reducer{arena, builtinTypes, handle, NotNull{&memoization}, &finder.cyclics};
return reducer.reduce(ty);
}
catch (const RecursionLimitException&)
@ -1008,8 +1068,13 @@ std::optional<TypePackId> TypeReduction::reduce(TypePackId tp)
return tp;
else if (!options.allowTypeReductionsFromOtherArenas && tp->owningArena != arena)
return tp;
else if (auto memoized = memoizedof(tp))
return *memoized;
else if (auto edge = memoization.memoizedof(tp))
{
if (edge->irreducible)
return edge->type;
else
tp = edge->type;
}
else if (hasExceededCartesianProductLimit(tp))
return std::nullopt;
@ -1018,7 +1083,7 @@ std::optional<TypePackId> TypeReduction::reduce(TypePackId tp)
MarkCycles finder;
finder.traverse(tp);
TypeReducer reducer{arena, builtinTypes, handle, &memoizedTypes, &memoizedTypePacks, &finder.cyclics};
TypeReducer reducer{arena, builtinTypes, handle, NotNull{&memoization}, &finder.cyclics};
return reducer.reduce(tp);
}
catch (const RecursionLimitException&)
@ -1039,13 +1104,6 @@ std::optional<TypeFun> TypeReduction::reduce(const TypeFun& fun)
return std::nullopt;
}
TypeReduction TypeReduction::fork(NotNull<TypeArena> arena, const TypeReductionOptions& opts) const
{
TypeReduction child{arena, builtinTypes, handle, opts};
child.parent = this;
return child;
}
size_t TypeReduction::cartesianProductSize(TypeId ty) const
{
ty = follow(ty);
@ -1093,24 +1151,4 @@ bool TypeReduction::hasExceededCartesianProductLimit(TypePackId tp) const
return false;
}
std::optional<TypeId> TypeReduction::memoizedof(TypeId ty) const
{
if (auto ctx = memoizedTypes.find(ty); ctx && ctx->irreducible)
return ctx->type;
else if (parent)
return parent->memoizedof(ty);
else
return std::nullopt;
}
std::optional<TypePackId> TypeReduction::memoizedof(TypePackId tp) const
{
if (auto ctx = memoizedTypePacks.find(tp); ctx && ctx->irreducible)
return ctx->type;
else if (parent)
return parent->memoizedof(tp);
else
return std::nullopt;
}
} // namespace Luau

View file

@ -520,7 +520,12 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
size_t errorCount = errors.size();
if (const UnionType* subUnion = log.getMutable<UnionType>(subTy))
if (log.getMutable<BlockedType>(subTy) && log.getMutable<BlockedType>(superTy))
{
blockedTypes.push_back(subTy);
blockedTypes.push_back(superTy);
}
else if (const UnionType* subUnion = log.getMutable<UnionType>(subTy))
{
tryUnifyUnionWithType(subTy, subUnion, superTy);
}

View file

@ -42,6 +42,7 @@ struct IrBuilder
IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c);
IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d);
IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e);
IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f);
IrOp block(IrBlockKind kind); // Requested kind can be ignored if we are in an outlined sequence
IrOp blockAtInst(uint32_t index);
@ -57,6 +58,8 @@ struct IrBuilder
IrFunction function;
uint32_t activeBlockIdx = ~0u;
std::vector<uint32_t> instIndexToBlock; // Block index at the bytecode instruction
};

View file

@ -5,6 +5,7 @@
#include "Luau/RegisterX64.h"
#include "Luau/RegisterA64.h"
#include <optional>
#include <vector>
#include <stdint.h>
@ -186,6 +187,16 @@ enum class IrCmd : uint8_t
// A: int
INT_TO_NUM,
// Adjust stack top (L->top) to point at 'B' TValues *after* the specified register
// This is used to return muliple values
// A: Rn
// B: int (offset)
ADJUST_STACK_TO_REG,
// Restore stack top (L->top) to point to the function stack top (L->ci->top)
// This is used to recover after calling a variadic function
ADJUST_STACK_TO_TOP,
// Fallback functions
// Perform an arithmetic operation on TValues of any type
@ -329,7 +340,7 @@ enum class IrCmd : uint8_t
// Call specified function
// A: unsigned int (bytecode instruction index)
// B: Rn (function, followed by arguments)
// C: int (argument count or -1 to preserve all arguments up to stack top)
// C: int (argument count or -1 to use all arguments up to stack top)
// D: int (result count or -1 to preserve all results and adjust stack top)
// Note: return values are placed starting from Rn specified in 'B'
LOP_CALL,
@ -337,13 +348,13 @@ enum class IrCmd : uint8_t
// Return specified values from the function
// A: unsigned int (bytecode instruction index)
// B: Rn (value start)
// B: int (result count or -1 to return all values up to stack top)
// C: int (result count or -1 to return all values up to stack top)
LOP_RETURN,
// Perform a fast call of a built-in function
// A: unsigned int (bytecode instruction index)
// B: Rn (argument start)
// C: int (argument count or -1 preserve all arguments up to stack top)
// C: int (argument count or -1 use all arguments up to stack top)
// D: block (fallback)
// Note: return values are placed starting from Rn specified in 'B'
LOP_FASTCALL,
@ -560,6 +571,7 @@ struct IrInst
IrOp c;
IrOp d;
IrOp e;
IrOp f;
uint32_t lastUse = 0;
uint16_t useCount = 0;
@ -584,9 +596,10 @@ struct IrBlock
uint16_t useCount = 0;
// Start points to an instruction index in a stream
// End is implicit
// 'start' and 'finish' define an inclusive range of instructions which belong to this block inside the function
// When block has been constructed, 'finish' always points to the first and only terminating instruction
uint32_t start = ~0u;
uint32_t finish = ~0u;
Label label;
};
@ -633,6 +646,19 @@ struct IrFunction
return value.valueTag;
}
std::optional<uint8_t> asTagOp(IrOp op)
{
if (op.kind != IrOpKind::Constant)
return std::nullopt;
IrConst& value = constOp(op);
if (value.kind != IrConstKind::Tag)
return std::nullopt;
return value.valueTag;
}
bool boolOp(IrOp op)
{
IrConst& value = constOp(op);
@ -641,6 +667,19 @@ struct IrFunction
return value.valueBool;
}
std::optional<bool> asBoolOp(IrOp op)
{
if (op.kind != IrOpKind::Constant)
return std::nullopt;
IrConst& value = constOp(op);
if (value.kind != IrConstKind::Bool)
return std::nullopt;
return value.valueBool;
}
int intOp(IrOp op)
{
IrConst& value = constOp(op);
@ -649,6 +688,19 @@ struct IrFunction
return value.valueInt;
}
std::optional<int> asIntOp(IrOp op)
{
if (op.kind != IrOpKind::Constant)
return std::nullopt;
IrConst& value = constOp(op);
if (value.kind != IrConstKind::Int)
return std::nullopt;
return value.valueInt;
}
unsigned uintOp(IrOp op)
{
IrConst& value = constOp(op);
@ -657,6 +709,19 @@ struct IrFunction
return value.valueUint;
}
std::optional<unsigned> asUintOp(IrOp op)
{
if (op.kind != IrOpKind::Constant)
return std::nullopt;
IrConst& value = constOp(op);
if (value.kind != IrConstKind::Uint)
return std::nullopt;
return value.valueUint;
}
double doubleOp(IrOp op)
{
IrConst& value = constOp(op);
@ -665,11 +730,31 @@ struct IrFunction
return value.valueDouble;
}
std::optional<double> asDoubleOp(IrOp op)
{
if (op.kind != IrOpKind::Constant)
return std::nullopt;
IrConst& value = constOp(op);
if (value.kind != IrConstKind::Double)
return std::nullopt;
return value.valueDouble;
}
IrCondition conditionOp(IrOp op)
{
LUAU_ASSERT(op.kind == IrOpKind::Condition);
return IrCondition(op.index);
}
uint32_t getBlockIndex(const IrBlock& block)
{
// Can only be called with blocks from our vector
LUAU_ASSERT(&block >= blocks.data() && &block <= blocks.data() + blocks.size());
return uint32_t(&block - blocks.data());
}
};
} // namespace CodeGen

View file

@ -162,6 +162,8 @@ inline bool isPseudo(IrCmd cmd)
return cmd == IrCmd::NOP || cmd == IrCmd::SUBSTITUTE;
}
bool isGCO(uint8_t tag);
// Remove a single instruction
void kill(IrFunction& function, IrInst& inst);
@ -179,7 +181,7 @@ void replace(IrFunction& function, IrOp& original, IrOp replacement);
// Replace a single instruction
// Target instruction index instead of reference is used to handle introduction of a new block terminator
void replace(IrFunction& function, uint32_t instIdx, IrInst replacement);
void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst replacement);
// Replace instruction with a different value (using IrCmd::SUBSTITUTE)
void substitute(IrFunction& function, IrInst& inst, IrOp replacement);
@ -188,10 +190,13 @@ void substitute(IrFunction& function, IrInst& inst, IrOp replacement);
void applySubstitutions(IrFunction& function, IrOp& op);
void applySubstitutions(IrFunction& function, IrInst& inst);
// Compare numbers using IR condition value
bool compare(double a, double b, IrCondition cond);
// Perform constant folding on instruction at index
// For most instructions, successful folding results in a IrCmd::SUBSTITUTE
// But it can also be successful on conditional control-flow, replacing it with an unconditional IrCmd::JUMP
void foldConstants(IrBuilder& build, IrFunction& function, uint32_t instIdx);
void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint32_t instIdx);
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,16 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/IrData.h"
namespace Luau
{
namespace CodeGen
{
struct IrBuilder;
void constPropInBlockChains(IrBuilder& build);
} // namespace CodeGen
} // namespace Luau

View file

@ -7,6 +7,7 @@
#include "Luau/CodeBlockUnwind.h"
#include "Luau/IrAnalysis.h"
#include "Luau/IrBuilder.h"
#include "Luau/OptimizeConstProp.h"
#include "Luau/OptimizeFinalX64.h"
#include "Luau/UnwindBuilder.h"
#include "Luau/UnwindBuilderDwarf2.h"
@ -31,7 +32,7 @@
#endif
#endif
LUAU_FASTFLAGVARIABLE(DebugUseOldCodegen, false)
LUAU_FASTFLAGVARIABLE(DebugCodegenNoOpt, false)
namespace Luau
{
@ -40,12 +41,6 @@ namespace CodeGen
constexpr uint32_t kFunctionAlignment = 32;
struct InstructionOutline
{
int pcpos;
int length;
};
static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers)
{
if (build.logText)
@ -64,346 +59,6 @@ static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers)
emitContinueCallInVm(build);
}
static int emitInst(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, LuauOpcode op, const Instruction* pc, int i,
Label* labelarr, Label& next, Label& fallback)
{
int skip = 0;
switch (op)
{
case LOP_NOP:
break;
case LOP_LOADNIL:
emitInstLoadNil(build, pc);
break;
case LOP_LOADB:
emitInstLoadB(build, pc, i, labelarr);
break;
case LOP_LOADN:
emitInstLoadN(build, pc);
break;
case LOP_LOADK:
emitInstLoadK(build, pc);
break;
case LOP_LOADKX:
emitInstLoadKX(build, pc);
break;
case LOP_MOVE:
emitInstMove(build, pc);
break;
case LOP_GETGLOBAL:
emitInstGetGlobal(build, pc, i, fallback);
break;
case LOP_SETGLOBAL:
emitInstSetGlobal(build, pc, i, next, fallback);
break;
case LOP_NAMECALL:
emitInstNameCall(build, pc, i, proto->k, next, fallback);
break;
case LOP_CALL:
emitInstCall(build, helpers, pc, i);
break;
case LOP_RETURN:
emitInstReturn(build, helpers, pc, i);
break;
case LOP_GETTABLE:
emitInstGetTable(build, pc, fallback);
break;
case LOP_SETTABLE:
emitInstSetTable(build, pc, next, fallback);
break;
case LOP_GETTABLEKS:
emitInstGetTableKS(build, pc, i, fallback);
break;
case LOP_SETTABLEKS:
emitInstSetTableKS(build, pc, i, next, fallback);
break;
case LOP_GETTABLEN:
emitInstGetTableN(build, pc, fallback);
break;
case LOP_SETTABLEN:
emitInstSetTableN(build, pc, next, fallback);
break;
case LOP_JUMP:
emitInstJump(build, pc, i, labelarr);
break;
case LOP_JUMPBACK:
emitInstJumpBack(build, pc, i, labelarr);
break;
case LOP_JUMPIF:
emitInstJumpIf(build, pc, i, labelarr, /* not_ */ false);
break;
case LOP_JUMPIFNOT:
emitInstJumpIf(build, pc, i, labelarr, /* not_ */ true);
break;
case LOP_JUMPIFEQ:
emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ false, fallback);
break;
case LOP_JUMPIFLE:
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::LessEqual, fallback);
break;
case LOP_JUMPIFLT:
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::Less, fallback);
break;
case LOP_JUMPIFNOTEQ:
emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ true, fallback);
break;
case LOP_JUMPIFNOTLE:
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::NotLessEqual, fallback);
break;
case LOP_JUMPIFNOTLT:
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::NotLess, fallback);
break;
case LOP_JUMPX:
emitInstJumpX(build, pc, i, labelarr);
break;
case LOP_JUMPXEQKNIL:
emitInstJumpxEqNil(build, pc, i, labelarr);
break;
case LOP_JUMPXEQKB:
emitInstJumpxEqB(build, pc, i, labelarr);
break;
case LOP_JUMPXEQKN:
emitInstJumpxEqN(build, pc, proto->k, i, labelarr);
break;
case LOP_JUMPXEQKS:
emitInstJumpxEqS(build, pc, i, labelarr);
break;
case LOP_ADD:
emitInstBinary(build, pc, TM_ADD, fallback);
break;
case LOP_SUB:
emitInstBinary(build, pc, TM_SUB, fallback);
break;
case LOP_MUL:
emitInstBinary(build, pc, TM_MUL, fallback);
break;
case LOP_DIV:
emitInstBinary(build, pc, TM_DIV, fallback);
break;
case LOP_MOD:
emitInstBinary(build, pc, TM_MOD, fallback);
break;
case LOP_POW:
emitInstBinary(build, pc, TM_POW, fallback);
break;
case LOP_ADDK:
emitInstBinaryK(build, pc, TM_ADD, fallback);
break;
case LOP_SUBK:
emitInstBinaryK(build, pc, TM_SUB, fallback);
break;
case LOP_MULK:
emitInstBinaryK(build, pc, TM_MUL, fallback);
break;
case LOP_DIVK:
emitInstBinaryK(build, pc, TM_DIV, fallback);
break;
case LOP_MODK:
emitInstBinaryK(build, pc, TM_MOD, fallback);
break;
case LOP_POWK:
emitInstPowK(build, pc, proto->k, fallback);
break;
case LOP_NOT:
emitInstNot(build, pc);
break;
case LOP_MINUS:
emitInstMinus(build, pc, fallback);
break;
case LOP_LENGTH:
emitInstLength(build, pc, fallback);
break;
case LOP_NEWTABLE:
emitInstNewTable(build, pc, i, next);
break;
case LOP_DUPTABLE:
emitInstDupTable(build, pc, i, next);
break;
case LOP_SETLIST:
emitInstSetList(build, pc, next);
break;
case LOP_GETUPVAL:
emitInstGetUpval(build, pc);
break;
case LOP_SETUPVAL:
emitInstSetUpval(build, pc, next);
break;
case LOP_CLOSEUPVALS:
emitInstCloseUpvals(build, pc, next);
break;
case LOP_FASTCALL:
// We want to lower next instruction at skip+2, but this instruction is only 1 long, so we need to add 1
skip = emitInstFastCall(build, pc, i, next) + 1;
break;
case LOP_FASTCALL1:
// We want to lower next instruction at skip+2, but this instruction is only 1 long, so we need to add 1
skip = emitInstFastCall1(build, pc, i, next) + 1;
break;
case LOP_FASTCALL2:
skip = emitInstFastCall2(build, pc, i, next);
break;
case LOP_FASTCALL2K:
skip = emitInstFastCall2K(build, pc, i, next);
break;
case LOP_FORNPREP:
emitInstForNPrep(build, pc, i, next, labelarr[i + 1 + LUAU_INSN_D(*pc)]);
break;
case LOP_FORNLOOP:
emitInstForNLoop(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)], next);
break;
case LOP_FORGLOOP:
emitinstForGLoop(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)], next, fallback);
break;
case LOP_FORGPREP_NEXT:
emitInstForGPrepNext(build, pc, labelarr[i + 1 + LUAU_INSN_D(*pc)], fallback);
break;
case LOP_FORGPREP_INEXT:
emitInstForGPrepInext(build, pc, labelarr[i + 1 + LUAU_INSN_D(*pc)], fallback);
break;
case LOP_AND:
emitInstAnd(build, pc);
break;
case LOP_ANDK:
emitInstAndK(build, pc);
break;
case LOP_OR:
emitInstOr(build, pc);
break;
case LOP_ORK:
emitInstOrK(build, pc);
break;
case LOP_GETIMPORT:
emitInstGetImport(build, pc, fallback);
break;
case LOP_CONCAT:
emitInstConcat(build, pc, i, next);
break;
case LOP_COVERAGE:
emitInstCoverage(build, i);
break;
default:
emitFallback(build, data, op, i);
break;
}
return skip;
}
static void emitInstFallback(AssemblyBuilderX64& build, NativeState& data, LuauOpcode op, const Instruction* pc, int i, Label* labelarr)
{
switch (op)
{
case LOP_GETIMPORT:
emitSetSavedPc(build, i + 1);
emitInstGetImportFallback(build, LUAU_INSN_A(*pc), pc[1]);
break;
case LOP_GETTABLE:
emitInstGetTableFallback(build, pc, i);
break;
case LOP_SETTABLE:
emitInstSetTableFallback(build, pc, i);
break;
case LOP_GETTABLEN:
emitInstGetTableNFallback(build, pc, i);
break;
case LOP_SETTABLEN:
emitInstSetTableNFallback(build, pc, i);
break;
case LOP_NAMECALL:
// TODO: fast-paths that we've handled can be removed from the fallback
emitFallback(build, data, op, i);
break;
case LOP_JUMPIFEQ:
emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ false);
break;
case LOP_JUMPIFLE:
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::LessEqual);
break;
case LOP_JUMPIFLT:
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::Less);
break;
case LOP_JUMPIFNOTEQ:
emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ true);
break;
case LOP_JUMPIFNOTLE:
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::NotLessEqual);
break;
case LOP_JUMPIFNOTLT:
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::NotLess);
break;
case LOP_ADD:
emitInstBinaryFallback(build, pc, i, TM_ADD);
break;
case LOP_SUB:
emitInstBinaryFallback(build, pc, i, TM_SUB);
break;
case LOP_MUL:
emitInstBinaryFallback(build, pc, i, TM_MUL);
break;
case LOP_DIV:
emitInstBinaryFallback(build, pc, i, TM_DIV);
break;
case LOP_MOD:
emitInstBinaryFallback(build, pc, i, TM_MOD);
break;
case LOP_POW:
emitInstBinaryFallback(build, pc, i, TM_POW);
break;
case LOP_ADDK:
emitInstBinaryKFallback(build, pc, i, TM_ADD);
break;
case LOP_SUBK:
emitInstBinaryKFallback(build, pc, i, TM_SUB);
break;
case LOP_MULK:
emitInstBinaryKFallback(build, pc, i, TM_MUL);
break;
case LOP_DIVK:
emitInstBinaryKFallback(build, pc, i, TM_DIV);
break;
case LOP_MODK:
emitInstBinaryKFallback(build, pc, i, TM_MOD);
break;
case LOP_POWK:
emitInstBinaryKFallback(build, pc, i, TM_POW);
break;
case LOP_MINUS:
emitInstMinusFallback(build, pc, i);
break;
case LOP_LENGTH:
emitInstLengthFallback(build, pc, i);
break;
case LOP_FORGLOOP:
emitinstForGLoopFallback(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)]);
break;
case LOP_FORGPREP_NEXT:
case LOP_FORGPREP_INEXT:
emitInstForGPrepXnextFallback(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)]);
break;
case LOP_GETGLOBAL:
// TODO: luaV_gettable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
case LOP_SETGLOBAL:
// TODO: luaV_settable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
case LOP_GETTABLEKS:
// Full fallback required for LOP_GETTABLEKS because 'luaV_gettable' doesn't handle builtin vector field access
// It is also required to perform cached slot update
// TODO: extra fast-paths could be lowered before the full fallback
emitFallback(build, data, op, i);
break;
case LOP_SETTABLEKS:
// TODO: luaV_settable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
default:
LUAU_ASSERT(!"Expected fallback for instruction");
}
}
static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options)
{
NativeProto* result = new NativeProto();
@ -423,153 +78,32 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat
build.logAppend("\n");
}
if (!FFlag::DebugUseOldCodegen)
{
build.align(kFunctionAlignment, AlignmentDataX64::Ud2);
Label start = build.setLabel();
IrBuilder builder;
builder.buildFunctionIr(proto);
optimizeMemoryOperandsX64(builder.function);
IrLoweringX64 lowering(build, helpers, data, proto, builder.function);
lowering.lower(options);
result->instTargets = new uintptr_t[proto->sizecode];
for (int i = 0; i < proto->sizecode; i++)
{
auto [irLocation, asmLocation] = builder.function.bcMapping[i];
result->instTargets[i] = irLocation == ~0u ? 0 : asmLocation - start.location;
}
result->location = start.location;
if (build.logText)
build.logAppend("\n");
return result;
}
std::vector<Label> instLabels;
instLabels.resize(proto->sizecode);
std::vector<Label> instFallbacks;
instFallbacks.resize(proto->sizecode);
std::vector<InstructionOutline> instOutlines;
instOutlines.reserve(64);
build.align(kFunctionAlignment, AlignmentDataX64::Ud2);
Label start = build.setLabel();
for (int i = 0; i < proto->sizecode;)
IrBuilder builder;
builder.buildFunctionIr(proto);
if (!FFlag::DebugCodegenNoOpt)
{
const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
int nexti = i + getOpLength(op);
LUAU_ASSERT(nexti <= proto->sizecode);
build.setLabel(instLabels[i]);
if (options.annotator)
options.annotator(options.annotatorContext, build.text, proto->bytecodeid, i);
Label& next = nexti < proto->sizecode ? instLabels[nexti] : start; // Last instruction can't use 'next' label
int skip = emitInst(build, data, helpers, proto, op, pc, i, instLabels.data(), next, instFallbacks[i]);
if (skip != 0)
instOutlines.push_back({nexti, skip});
i = nexti + skip;
LUAU_ASSERT(i <= proto->sizecode);
constPropInBlockChains(builder);
}
size_t textSize = build.text.size();
uint32_t codeSize = build.getCodeSize();
optimizeMemoryOperandsX64(builder.function);
if (options.annotator && options.includeOutlinedCode)
build.logAppend("; outlined instructions\n");
IrLoweringX64 lowering(build, helpers, data, proto, builder.function);
for (auto [pcpos, length] : instOutlines)
{
int i = pcpos;
while (i < pcpos + length)
{
const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
int nexti = i + getOpLength(op);
LUAU_ASSERT(nexti <= proto->sizecode);
build.setLabel(instLabels[i]);
if (options.annotator && options.includeOutlinedCode)
options.annotator(options.annotatorContext, build.text, proto->bytecodeid, i);
Label& next = nexti < proto->sizecode ? instLabels[nexti] : start; // Last instruction can't use 'next' label
int skip = emitInst(build, data, helpers, proto, op, pc, i, instLabels.data(), next, instFallbacks[i]);
LUAU_ASSERT(skip == 0);
i = nexti;
}
if (i < proto->sizecode)
build.jmp(instLabels[i]);
}
if (options.annotator && options.includeOutlinedCode)
build.logAppend("; outlined code\n");
for (int i = 0, instid = 0; i < proto->sizecode; ++instid)
{
const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
int nexti = i + getOpLength(op);
LUAU_ASSERT(nexti <= proto->sizecode);
if (instFallbacks[i].id == 0)
{
i = nexti;
continue;
}
if (options.annotator && options.includeOutlinedCode)
options.annotator(options.annotatorContext, build.text, proto->bytecodeid, instid);
build.setLabel(instFallbacks[i]);
emitInstFallback(build, data, op, pc, i, instLabels.data());
// Jump back to the next instruction handler
if (nexti < proto->sizecode)
build.jmp(instLabels[nexti]);
i = nexti;
}
// Truncate assembly output if we don't care for outlined code part
if (!options.includeOutlinedCode)
{
build.text.resize(textSize);
build.logAppend("; skipping %u bytes of outlined code\n", build.getCodeSize() - codeSize);
}
lowering.lower(options);
result->instTargets = new uintptr_t[proto->sizecode];
for (int i = 0; i < proto->sizecode; i++)
result->instTargets[i] = instLabels[i].location - start.location;
{
auto [irLocation, asmLocation] = builder.function.bcMapping[i];
result->instTargets[i] = irLocation == ~0u ? 0 : asmLocation - start.location;
}
result->location = start.location;

View file

@ -5,33 +5,18 @@
#include "Luau/Bytecode.h"
#include "EmitCommonX64.h"
#include "IrTranslateBuiltins.h" // Used temporarily for shared definition of BuiltinImplResult
#include "NativeState.h"
#include "lstate.h"
// TODO: LBF_MATH_FREXP and LBF_MATH_MODF can work for 1 result case if second store is removed
namespace Luau
{
namespace CodeGen
{
BuiltinImplResult emitBuiltinAssert(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback)
{
if (nparams < 1 || nresults != 0)
return {BuiltinImplType::None, -1};
if (build.logText)
build.logAppend("; inlined LBF_ASSERT\n");
Label skip;
jumpIfFalsy(build, arg, fallback, skip);
// TODO: use of 'skip' causes a jump to a jump instruction that skips the fallback - can be optimized
build.setLabel(skip);
return {BuiltinImplType::UsesFallback, 0};
}
BuiltinImplResult emitBuiltinMathFloor(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback)
{
if (nparams < 1 || nresults > 1)
@ -620,7 +605,8 @@ BuiltinImplResult emitBuiltin(AssemblyBuilderX64& build, int bfid, int nparams,
switch (bfid)
{
case LBF_ASSERT:
return emitBuiltinAssert(build, nparams, ra, arg, args, nresults, fallback);
// This builtin fast-path was already translated to IR
return {BuiltinImplType::None, -1};
case LBF_MATH_FLOOR:
return emitBuiltinMathFloor(build, nparams, ra, arg, args, nresults, fallback);
case LBF_MATH_CEIL:

View file

@ -9,18 +9,7 @@ namespace CodeGen
class AssemblyBuilderX64;
struct Label;
struct OperandX64;
enum class BuiltinImplType
{
None,
UsesFallback, // Uses fallback for unsupported cases
};
struct BuiltinImplResult
{
BuiltinImplType type;
int actualResultCount;
};
struct BuiltinImplResult;
BuiltinImplResult emitBuiltin(AssemblyBuilderX64& build, int bfid, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback);

View file

@ -151,14 +151,6 @@ inline void setLuauReg(AssemblyBuilderX64& build, RegisterX64 tmp, int ri, Opera
build.vmovups(luauReg(ri), tmp);
}
inline void setNodeValue(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 op, int ri)
{
LUAU_ASSERT(op.cat == CategoryX64::mem);
build.vmovups(tmp, luauReg(ri));
build.vmovups(op, tmp);
}
inline void jumpIfTagIs(AssemblyBuilderX64& build, int ri, lua_Type tag, Label& label)
{
build.cmp(luauRegTag(ri), tag);

View file

@ -7,6 +7,7 @@
#include "EmitBuiltinsX64.h"
#include "EmitCommonX64.h"
#include "NativeState.h"
#include "IrTranslateBuiltins.h" // Used temporarily until emitInstFastCallN is removed
#include "lobject.h"
#include "ltm.h"
@ -16,59 +17,6 @@ namespace Luau
namespace CodeGen
{
void emitInstLoadNil(AssemblyBuilderX64& build, const Instruction* pc)
{
int ra = LUAU_INSN_A(*pc);
build.mov(luauRegTag(ra), LUA_TNIL);
}
void emitInstLoadB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
build.mov(luauRegValue(ra), LUAU_INSN_B(*pc));
build.mov(luauRegTag(ra), LUA_TBOOLEAN);
if (int target = LUAU_INSN_C(*pc))
build.jmp(labelarr[pcpos + 1 + target]);
}
void emitInstLoadN(AssemblyBuilderX64& build, const Instruction* pc)
{
int ra = LUAU_INSN_A(*pc);
build.vmovsd(xmm0, build.f64(double(LUAU_INSN_D(*pc))));
build.vmovsd(luauRegValue(ra), xmm0);
build.mov(luauRegTag(ra), LUA_TNUMBER);
}
void emitInstLoadK(AssemblyBuilderX64& build, const Instruction* pc)
{
int ra = LUAU_INSN_A(*pc);
build.vmovups(xmm0, luauConstant(LUAU_INSN_D(*pc)));
build.vmovups(luauReg(ra), xmm0);
}
void emitInstLoadKX(AssemblyBuilderX64& build, const Instruction* pc)
{
int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1];
build.vmovups(xmm0, luauConstant(aux));
build.vmovups(luauReg(ra), xmm0);
}
void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
build.vmovups(xmm0, luauReg(rb));
build.vmovups(luauReg(ra), xmm0);
}
void emitInstNameCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, const TValue* k, Label& next, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
@ -429,363 +377,6 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins
build.jmp(qword[rdx + rax * 2]);
}
void emitInstJump(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
build.jmp(labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]);
}
void emitInstJumpBack(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
emitInterrupt(build, pcpos);
build.jmp(labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]);
}
void emitInstJumpIf(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_)
{
int ra = LUAU_INSN_A(*pc);
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
Label& exit = labelarr[pcpos + 1];
if (not_)
jumpIfFalsy(build, ra, target, exit);
else
jumpIfTruthy(build, ra, target, exit);
}
void emitInstJumpIfEq(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = pc[1];
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
Label& exit = labelarr[pcpos + 2];
build.mov(eax, luauRegTag(ra));
build.cmp(eax, luauRegTag(rb));
build.jcc(ConditionX64::NotEqual, not_ ? target : exit);
// fast-path: number
build.cmp(eax, LUA_TNUMBER);
build.jcc(ConditionX64::NotEqual, fallback);
jumpOnNumberCmp(build, xmm0, luauRegValue(ra), luauRegValue(rb), ConditionX64::NotEqual, not_ ? target : exit);
if (!not_)
build.jmp(target);
}
void emitInstJumpIfEqFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_)
{
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
emitSetSavedPc(build, pcpos + 1);
jumpOnAnyCmpFallback(build, LUAU_INSN_A(*pc), pc[1], not_ ? ConditionX64::NotEqual : ConditionX64::Equal, target);
}
void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = pc[1];
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
// fast-path: number
jumpIfTagIsNot(build, ra, LUA_TNUMBER, fallback);
jumpIfTagIsNot(build, rb, LUA_TNUMBER, fallback);
jumpOnNumberCmp(build, xmm0, luauRegValue(ra), luauRegValue(rb), cond, target);
}
void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond)
{
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
emitSetSavedPc(build, pcpos + 1);
jumpOnAnyCmpFallback(build, LUAU_INSN_A(*pc), pc[1], cond, target);
}
void emitInstJumpX(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
emitInterrupt(build, pcpos);
build.jmp(labelarr[pcpos + 1 + LUAU_INSN_E(*pc)]);
}
void emitInstJumpxEqNil(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
bool not_ = (pc[1] & 0x80000000) != 0;
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
build.cmp(luauRegTag(ra), LUA_TNIL);
build.jcc(not_ ? ConditionX64::NotEqual : ConditionX64::Equal, target);
}
void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1];
bool not_ = (aux & 0x80000000) != 0;
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
Label& exit = labelarr[pcpos + 2];
jumpIfTagIsNot(build, ra, LUA_TBOOLEAN, not_ ? target : exit);
build.test(luauRegValueInt(ra), 1);
build.jcc((aux & 0x1) ^ not_ ? ConditionX64::NotZero : ConditionX64::Zero, target);
}
void emitInstJumpxEqN(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1];
bool not_ = (aux & 0x80000000) != 0;
TValue kv = k[aux & 0xffffff];
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
Label& exit = labelarr[pcpos + 2];
jumpIfTagIsNot(build, ra, LUA_TNUMBER, not_ ? target : exit);
if (not_)
{
jumpOnNumberCmp(build, xmm0, luauRegValue(ra), build.f64(kv.value.n), ConditionX64::NotEqual, target);
}
else
{
// Compact equality check requires two labels, so it's not supported in generic 'jumpOnNumberCmp'
build.vmovsd(xmm0, luauRegValue(ra));
build.vucomisd(xmm0, build.f64(kv.value.n));
build.jcc(ConditionX64::Parity, exit); // We first have to check PF=1 for NaN operands, because it also sets ZF=1
build.jcc(ConditionX64::Zero, target); // Now that NaN is out of the way, we can check ZF=1 for equality
}
}
void emitInstJumpxEqS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1];
bool not_ = (aux & 0x80000000) != 0;
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
Label& exit = labelarr[pcpos + 2];
jumpIfTagIsNot(build, ra, LUA_TSTRING, not_ ? target : exit);
build.mov(rax, luauRegValue(ra));
build.cmp(rax, luauConstantValue(aux & 0xffffff));
build.jcc(not_ ? ConditionX64::NotEqual : ConditionX64::Equal, target);
}
static void emitInstBinaryNumeric(AssemblyBuilderX64& build, int ra, int rb, int rc, OperandX64 opc, TMS tm, Label& fallback)
{
jumpIfTagIsNot(build, rb, LUA_TNUMBER, fallback);
if (rc != -1 && rc != rb)
jumpIfTagIsNot(build, rc, LUA_TNUMBER, fallback);
// fast-path: number
build.vmovsd(xmm0, luauRegValue(rb));
switch (tm)
{
case TM_ADD:
build.vaddsd(xmm0, xmm0, opc);
break;
case TM_SUB:
build.vsubsd(xmm0, xmm0, opc);
break;
case TM_MUL:
build.vmulsd(xmm0, xmm0, opc);
break;
case TM_DIV:
build.vdivsd(xmm0, xmm0, opc);
break;
case TM_MOD:
// This follows the implementation of 'luai_nummod' which is less precise than 'fmod' for better performance
build.vmovsd(xmm1, opc);
build.vdivsd(xmm2, xmm0, xmm1);
build.vroundsd(xmm2, xmm2, xmm2, RoundingModeX64::RoundToNegativeInfinity);
build.vmulsd(xmm1, xmm2, xmm1);
build.vsubsd(xmm0, xmm0, xmm1);
break;
case TM_POW:
build.vmovsd(xmm1, luauRegValue(rc));
build.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]);
break;
default:
LUAU_ASSERT(!"unsupported binary op");
}
build.vmovsd(luauRegValue(ra), xmm0);
if (ra != rb && ra != rc)
build.mov(luauRegTag(ra), LUA_TNUMBER);
}
void emitInstBinary(AssemblyBuilderX64& build, const Instruction* pc, TMS tm, Label& fallback)
{
emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), LUAU_INSN_C(*pc), luauRegValue(LUAU_INSN_C(*pc)), tm, fallback);
}
void emitInstBinaryFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm)
{
emitSetSavedPc(build, pcpos + 1);
callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_C(*pc)), tm);
}
void emitInstBinaryK(AssemblyBuilderX64& build, const Instruction* pc, TMS tm, Label& fallback)
{
emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), -1, luauConstantValue(LUAU_INSN_C(*pc)), tm, fallback);
}
void emitInstBinaryKFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm)
{
emitSetSavedPc(build, pcpos + 1);
callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauConstantAddress(LUAU_INSN_C(*pc)), tm);
}
void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
double kv = nvalue(&k[LUAU_INSN_C(*pc)]);
jumpIfTagIsNot(build, rb, LUA_TNUMBER, fallback);
// fast-path: number
build.vmovsd(xmm0, luauRegValue(rb));
// Specialize for a few constants, similar to how it's done in the VM
if (kv == 2.0)
{
build.vmulsd(xmm0, xmm0, xmm0);
}
else if (kv == 0.5)
{
build.vsqrtsd(xmm0, xmm0, xmm0);
}
else if (kv == 3.0)
{
build.vmulsd(xmm1, xmm0, xmm0);
build.vmulsd(xmm0, xmm0, xmm1);
}
else
{
build.vmovsd(xmm1, build.f64(kv));
build.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]);
}
build.vmovsd(luauRegValue(ra), xmm0);
if (ra != rb)
build.mov(luauRegTag(ra), LUA_TNUMBER);
}
void emitInstNot(AssemblyBuilderX64& build, const Instruction* pc)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
Label saveone, savezero, exit;
jumpIfFalsy(build, rb, saveone, savezero);
build.setLabel(savezero);
build.mov(luauRegValueInt(ra), 0);
build.jmp(exit);
build.setLabel(saveone);
build.mov(luauRegValueInt(ra), 1);
build.setLabel(exit);
build.mov(luauRegTag(ra), LUA_TBOOLEAN);
}
void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
jumpIfTagIsNot(build, rb, LUA_TNUMBER, fallback);
// fast-path: number
build.vxorpd(xmm0, xmm0, xmm0);
build.vsubsd(xmm0, xmm0, luauRegValue(rb));
build.vmovsd(luauRegValue(ra), xmm0);
if (ra != rb)
build.mov(luauRegTag(ra), LUA_TNUMBER);
}
void emitInstMinusFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
emitSetSavedPc(build, pcpos + 1);
callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_B(*pc)), TM_UNM);
}
void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
// fast-path: table without __len
build.mov(rArg1, luauRegValue(rb));
jumpIfMetatablePresent(build, rArg1, fallback);
// First argument (Table*) is already in rArg1
build.call(qword[rNativeContext + offsetof(NativeContext, luaH_getn)]);
build.vcvtsi2sd(xmm0, xmm0, eax);
build.vmovsd(luauRegValue(ra), xmm0);
build.mov(luauRegTag(ra), LUA_TNUMBER);
}
void emitInstLengthFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
emitSetSavedPc(build, pcpos + 1);
callLengthHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc));
}
void emitInstNewTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next)
{
int ra = LUAU_INSN_A(*pc);
int b = LUAU_INSN_B(*pc);
uint32_t aux = pc[1];
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.mov(dwordReg(rArg2), aux);
build.mov(dwordReg(rArg3), b == 0 ? 0 : 1 << (b - 1));
build.call(qword[rNativeContext + offsetof(NativeContext, luaH_new)]);
build.mov(luauRegValue(ra), rax);
build.mov(luauRegTag(ra), LUA_TTABLE);
callCheckGc(build, pcpos, /* savepc = */ false, next);
}
void emitInstDupTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next)
{
int ra = LUAU_INSN_A(*pc);
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.mov(rArg2, luauConstantValue(LUAU_INSN_D(*pc)));
build.call(qword[rNativeContext + offsetof(NativeContext, luaH_clone)]);
build.mov(luauRegValue(ra), rax);
build.mov(luauRegTag(ra), LUA_TTABLE);
callCheckGc(build, pcpos, /* savepc= */ false, next);
}
void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& next)
{
int ra = LUAU_INSN_A(*pc);
@ -890,68 +481,7 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& ne
callBarrierTableFast(build, table, next);
}
void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc)
{
int ra = LUAU_INSN_A(*pc);
int up = LUAU_INSN_B(*pc);
build.mov(rax, sClosure);
build.add(rax, offsetof(Closure, l.uprefs) + sizeof(TValue) * up);
// uprefs[] is either an actual value, or it points to UpVal object which has a pointer to value
Label skip;
// TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though
build.cmp(dword[rax + offsetof(TValue, tt)], LUA_TUPVAL);
build.jcc(ConditionX64::NotEqual, skip);
// UpVal.v points to the value (either on stack, or on heap inside each UpVal, but we can deref it unconditionally)
build.mov(rax, qword[rax + offsetof(TValue, value.gc)]);
build.mov(rax, qword[rax + offsetof(UpVal, v)]);
build.setLabel(skip);
build.vmovups(xmm0, xmmword[rax]);
build.vmovups(luauReg(ra), xmm0);
}
void emitInstSetUpval(AssemblyBuilderX64& build, const Instruction* pc, Label& next)
{
int ra = LUAU_INSN_A(*pc);
int up = LUAU_INSN_B(*pc);
RegisterX64 upval = rax;
RegisterX64 tmp = rcx;
build.mov(tmp, sClosure);
build.mov(upval, qword[tmp + offsetof(Closure, l.uprefs) + sizeof(TValue) * up + offsetof(TValue, value.gc)]);
build.mov(tmp, qword[upval + offsetof(UpVal, v)]);
build.vmovups(xmm0, luauReg(ra));
build.vmovups(xmmword[tmp], xmm0);
callBarrierObject(build, tmp, upval, ra, next);
}
void emitInstCloseUpvals(AssemblyBuilderX64& build, const Instruction* pc, Label& next)
{
int ra = LUAU_INSN_A(*pc);
// L->openupval != 0
build.mov(rax, qword[rState + offsetof(lua_State, openupval)]);
build.test(rax, rax);
build.jcc(ConditionX64::Zero, next);
// ra <= L->openuval->v
build.lea(rcx, addr[rBase + ra * sizeof(TValue)]);
build.cmp(rcx, qword[rax + offsetof(UpVal, v)]);
build.jcc(ConditionX64::Above, next);
build.mov(rArg2, rcx);
build.mov(rArg1, rState);
build.call(qword[rNativeContext + offsetof(NativeContext, luaF_close)]);
}
static int emitInstFastCallN(
static void emitInstFastCallN(
AssemblyBuilderX64& build, const Instruction* pc, bool customParams, int customParamCount, OperandX64 customArgs, int pcpos, Label& fallback)
{
int bfid = LUAU_INSN_A(*pc);
@ -966,8 +496,6 @@ static int emitInstFastCallN(
int arg = customParams ? LUAU_INSN_B(*pc) : ra + 1;
OperandX64 args = customParams ? customArgs : luauRegAddress(ra + 2);
jumpIfUnsafeEnv(build, rax, fallback);
BuiltinImplResult br = emitBuiltin(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback);
if (br.type == BuiltinImplType::UsesFallback)
@ -986,7 +514,7 @@ static int emitInstFastCallN(
build.mov(qword[rState + offsetof(lua_State, top)], rax);
}
return skip; // Return fallback instruction sequence length
return;
}
// TODO: we can skip saving pc for some well-behaved builtins which we didn't inline
@ -1052,108 +580,29 @@ static int emitInstFastCallN(
build.mov(rax, qword[rax + offsetof(CallInfo, top)]);
build.mov(qword[rState + offsetof(lua_State, top)], rax);
}
return skip; // Return fallback instruction sequence length
}
int emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
void emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
{
return emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 1, /* customArgs */ 0, pcpos, fallback);
}
int emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
void emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
{
return emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauRegAddress(pc[1]), pcpos, fallback);
}
int emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
void emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
{
return emitInstFastCallN(
build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauConstantAddress(pc[1]), pcpos, fallback);
}
int emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
void emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
{
return emitInstFastCallN(build, pc, /* customParams */ false, /* customParamCount */ 0, /* customArgs */ 0, pcpos, fallback);
}
void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopStart, Label& loopExit)
{
int ra = LUAU_INSN_A(*pc);
Label tryConvert;
jumpIfTagIsNot(build, ra + 0, LUA_TNUMBER, tryConvert);
jumpIfTagIsNot(build, ra + 1, LUA_TNUMBER, tryConvert);
jumpIfTagIsNot(build, ra + 2, LUA_TNUMBER, tryConvert);
// After successful conversion of arguments to number, we return here
Label retry = build.setLabel();
RegisterX64 limit = xmm0;
RegisterX64 step = xmm1;
RegisterX64 idx = xmm2;
RegisterX64 zero = xmm3;
build.vxorpd(zero, xmm0, xmm0);
build.vmovsd(limit, luauRegValue(ra + 0));
build.vmovsd(step, luauRegValue(ra + 1));
build.vmovsd(idx, luauRegValue(ra + 2));
Label reverse;
// step <= 0
jumpOnNumberCmp(build, noreg, step, zero, ConditionX64::LessEqual, reverse);
// TODO: target branches can probably be arranged better, but we need tests for NaN behavior preservation
// false: idx <= limit
jumpOnNumberCmp(build, noreg, idx, limit, ConditionX64::LessEqual, loopStart);
build.jmp(loopExit);
// true: limit <= idx
build.setLabel(reverse);
jumpOnNumberCmp(build, noreg, limit, idx, ConditionX64::LessEqual, loopStart);
build.jmp(loopExit);
// TOOD: place at the end of the function
build.setLabel(tryConvert);
emitSetSavedPc(build, pcpos + 1);
callPrepareForN(build, ra + 0, ra + 1, ra + 2);
build.jmp(retry);
}
void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat, Label& loopExit)
{
emitInterrupt(build, pcpos);
int ra = LUAU_INSN_A(*pc);
RegisterX64 limit = xmm0;
RegisterX64 step = xmm1;
RegisterX64 idx = xmm2;
RegisterX64 zero = xmm3;
build.vxorpd(zero, xmm0, xmm0);
build.vmovsd(limit, luauRegValue(ra + 0));
build.vmovsd(step, luauRegValue(ra + 1));
build.vmovsd(idx, luauRegValue(ra + 2));
build.vaddsd(idx, idx, step);
build.vmovsd(luauRegValue(ra + 2), idx);
Label reverse;
// step <= 0
jumpOnNumberCmp(build, noreg, step, zero, ConditionX64::LessEqual, reverse);
// false: idx <= limit
jumpOnNumberCmp(build, noreg, idx, limit, ConditionX64::LessEqual, loopRepeat);
build.jmp(loopExit);
// true: limit <= idx
build.setLabel(reverse);
jumpOnNumberCmp(build, noreg, limit, idx, ConditionX64::LessEqual, loopRepeat);
}
void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat, Label& loopExit, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
@ -1248,46 +697,6 @@ void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc,
build.jcc(ConditionX64::NotZero, loopRepeat);
}
void emitInstForGPrepNext(AssemblyBuilderX64& build, const Instruction* pc, Label& target, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
// fast-path: pairs/next
jumpIfUnsafeEnv(build, rax, fallback);
jumpIfTagIsNot(build, ra + 1, LUA_TTABLE, fallback);
jumpIfTagIsNot(build, ra + 2, LUA_TNIL, fallback);
build.mov(luauRegTag(ra), LUA_TNIL);
// setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0)));
build.mov(luauRegValue(ra + 2), 0);
build.mov(luauRegTag(ra + 2), LUA_TLIGHTUSERDATA);
build.jmp(target);
}
void emitInstForGPrepInext(AssemblyBuilderX64& build, const Instruction* pc, Label& target, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
// fast-path: ipairs/inext
jumpIfUnsafeEnv(build, rax, fallback);
jumpIfTagIsNot(build, ra + 1, LUA_TTABLE, fallback);
jumpIfTagIsNot(build, ra + 2, LUA_TNUMBER, fallback);
build.vxorpd(xmm0, xmm0, xmm0);
build.vmovsd(xmm1, luauRegValue(ra + 2));
jumpOnNumberCmp(build, noreg, xmm0, xmm1, ConditionX64::NotEqual, fallback);
build.mov(luauRegTag(ra), LUA_TNIL);
// setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0)));
build.mov(luauRegValue(ra + 2), 0);
build.mov(luauRegTag(ra + 2), LUA_TLIGHTUSERDATA);
build.jmp(target);
}
void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& target)
{
int ra = LUAU_INSN_A(*pc);
@ -1375,169 +784,6 @@ void emitInstOrK(AssemblyBuilderX64& build, const Instruction* pc)
emitInstOrX(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauConstant(LUAU_INSN_C(*pc)));
}
void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int c = LUAU_INSN_C(*pc);
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
RegisterX64 table = rcx;
build.mov(table, luauRegValue(rb));
// unsigned(c) < unsigned(h->sizearray)
build.cmp(dword[table + offsetof(Table, sizearray)], c);
build.jcc(ConditionX64::BelowEqual, fallback);
jumpIfMetatablePresent(build, table, fallback);
build.mov(rax, qword[table + offsetof(Table, array)]);
setLuauReg(build, xmm0, ra, xmmword[rax + c * sizeof(TValue)]);
}
void emitInstGetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
emitSetSavedPc(build, pcpos + 1);
TValue n;
setnvalue(&n, LUAU_INSN_C(*pc) + 1);
callGetTable(build, LUAU_INSN_B(*pc), build.bytes(&n, sizeof(n)), LUAU_INSN_A(*pc));
}
void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, Label& next, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int c = LUAU_INSN_C(*pc);
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
RegisterX64 table = rcx;
build.mov(table, luauRegValue(rb));
// unsigned(c) < unsigned(h->sizearray)
build.cmp(dword[table + offsetof(Table, sizearray)], c);
build.jcc(ConditionX64::BelowEqual, fallback);
jumpIfMetatablePresent(build, table, fallback);
jumpIfTableIsReadOnly(build, table, fallback);
// setobj2t(L, &h->array[c], ra);
build.mov(rax, qword[table + offsetof(Table, array)]);
build.vmovups(xmm0, luauReg(ra));
build.vmovups(xmmword[rax + c * sizeof(TValue)], xmm0);
callBarrierTable(build, rax, table, ra, next);
}
void emitInstSetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
emitSetSavedPc(build, pcpos + 1);
TValue n;
setnvalue(&n, LUAU_INSN_C(*pc) + 1);
callSetTable(build, LUAU_INSN_B(*pc), build.bytes(&n, sizeof(n)), LUAU_INSN_A(*pc));
}
void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int rc = LUAU_INSN_C(*pc);
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
jumpIfTagIsNot(build, rc, LUA_TNUMBER, fallback);
// fast-path: table with a number index
RegisterX64 table = rcx;
build.mov(table, luauRegValue(rb));
RegisterX64 intIndex = eax;
RegisterX64 fpIndex = xmm0;
build.vmovsd(fpIndex, luauRegValue(rc));
convertNumberToIndexOrJump(build, xmm1, fpIndex, intIndex, fallback);
// index - 1
build.dec(intIndex);
// unsigned(index - 1) < unsigned(h->sizearray)
build.cmp(dword[table + offsetof(Table, sizearray)], intIndex);
build.jcc(ConditionX64::BelowEqual, fallback);
jumpIfMetatablePresent(build, table, fallback);
// setobj2s(L, ra, &h->array[unsigned(index - 1)]);
build.mov(rdx, qword[table + offsetof(Table, array)]);
build.shl(intIndex, kTValueSizeLog2);
setLuauReg(build, xmm0, ra, xmmword[rdx + rax]);
}
void emitInstGetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
emitSetSavedPc(build, pcpos + 1);
callGetTable(build, LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_C(*pc)), LUAU_INSN_A(*pc));
}
void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, Label& next, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int rc = LUAU_INSN_C(*pc);
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
jumpIfTagIsNot(build, rc, LUA_TNUMBER, fallback);
// fast-path: table with a number index
RegisterX64 table = rcx;
build.mov(table, luauRegValue(rb));
RegisterX64 intIndex = eax;
RegisterX64 fpIndex = xmm0;
build.vmovsd(fpIndex, luauRegValue(rc));
convertNumberToIndexOrJump(build, xmm1, fpIndex, intIndex, fallback);
// index - 1
build.dec(intIndex);
// unsigned(index - 1) < unsigned(h->sizearray)
build.cmp(dword[table + offsetof(Table, sizearray)], intIndex);
build.jcc(ConditionX64::BelowEqual, fallback);
jumpIfMetatablePresent(build, table, fallback);
jumpIfTableIsReadOnly(build, table, fallback);
// setobj2t(L, &h->array[unsigned(index - 1)], ra);
build.mov(rdx, qword[table + offsetof(Table, array)]);
build.shl(intIndex, kTValueSizeLog2);
build.vmovups(xmm0, luauReg(ra));
build.vmovups(xmmword[rdx + rax], xmm0);
callBarrierTable(build, rdx, table, ra, next);
}
void emitInstSetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
emitSetSavedPc(build, pcpos + 1);
callSetTable(build, LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_C(*pc)), LUAU_INSN_A(*pc));
}
void emitInstGetImport(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int k = LUAU_INSN_D(*pc);
jumpIfUnsafeEnv(build, rax, fallback);
// note: if import failed, k[] is nil; we could check this during codegen, but we instead use runtime fallback
// this allows us to handle ahead-of-time codegen smoothly when an import fails to resolve at runtime
build.cmp(luauConstantTag(k), LUA_TNIL);
build.jcc(ConditionX64::Equal, fallback);
build.vmovups(xmm0, luauConstant(k));
build.vmovups(luauReg(ra), xmm0);
}
void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux)
{
build.mov(rax, sClosure);
@ -1567,105 +813,6 @@ void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux)
build.mov(qword[rState + offsetof(lua_State, top)], rax);
}
void emitInstGetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
uint32_t aux = pc[1];
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
RegisterX64 table = rcx;
build.mov(table, luauRegValue(rb));
RegisterX64 node = rdx;
getTableNodeAtCachedSlot(build, rax, node, table, pcpos);
jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback);
setLuauReg(build, xmm0, ra, luauNodeValue(node));
}
void emitInstSetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
uint32_t aux = pc[1];
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
RegisterX64 table = rcx;
build.mov(table, luauRegValue(rb));
// fast-path: set value at the expected slot
RegisterX64 node = rdx;
getTableNodeAtCachedSlot(build, rax, node, table, pcpos);
jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback);
jumpIfTableIsReadOnly(build, table, fallback);
setNodeValue(build, xmm0, luauNodeValue(node), ra);
callBarrierTable(build, rax, table, ra, next);
}
void emitInstGetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1];
RegisterX64 table = rcx;
build.mov(rax, sClosure);
build.mov(table, qword[rax + offsetof(Closure, env)]);
RegisterX64 node = rdx;
getTableNodeAtCachedSlot(build, rax, node, table, pcpos);
jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback);
setLuauReg(build, xmm0, ra, luauNodeValue(node));
}
void emitInstSetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1];
RegisterX64 table = rcx;
build.mov(rax, sClosure);
build.mov(table, qword[rax + offsetof(Closure, env)]);
RegisterX64 node = rdx;
getTableNodeAtCachedSlot(build, rax, node, table, pcpos);
jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback);
jumpIfTableIsReadOnly(build, table, fallback);
setNodeValue(build, xmm0, luauNodeValue(node), ra);
callBarrierTable(build, rax, table, ra, next);
}
void emitInstConcat(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int rc = LUAU_INSN_C(*pc);
emitSetSavedPc(build, pcpos + 1);
// luaV_concat(L, c - b + 1, c)
build.mov(rArg1, rState);
build.mov(dwordReg(rArg2), rc - rb + 1);
build.mov(dwordReg(rArg3), rc);
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_concat)]);
emitUpdateBase(build);
// setobj2s(L, ra, base + b)
build.vmovups(xmm0, luauReg(rb));
build.vmovups(luauReg(ra), xmm0);
callCheckGc(build, pcpos, /* savepc= */ false, next);
}
void emitInstCoverage(AssemblyBuilderX64& build, int pcpos)
{
build.mov(rcx, sCode);

View file

@ -14,78 +14,25 @@ namespace CodeGen
{
class AssemblyBuilderX64;
enum class ConditionX64 : uint8_t;
struct Label;
struct ModuleHelpers;
struct NativeState;
void emitInstLoadNil(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstLoadN(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadKX(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstNameCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, const TValue* k, Label& next, Label& fallback);
void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos);
void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos);
void emitInstJump(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpBack(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpIf(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_);
void emitInstJumpIfEq(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_, Label& fallback);
void emitInstJumpIfEqFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_);
void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond, Label& fallback);
void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond);
void emitInstJumpX(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpxEqNil(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpxEqN(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr);
void emitInstJumpxEqS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstBinary(AssemblyBuilderX64& build, const Instruction* pc, TMS tm, Label& fallback);
void emitInstBinaryFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm);
void emitInstBinaryK(AssemblyBuilderX64& build, const Instruction* pc, TMS tm, Label& fallback);
void emitInstBinaryKFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm);
void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, Label& fallback);
void emitInstNot(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback);
void emitInstMinusFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback);
void emitInstLengthFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstNewTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next);
void emitInstDupTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next);
void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& next);
void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstSetUpval(AssemblyBuilderX64& build, const Instruction* pc, Label& next);
void emitInstCloseUpvals(AssemblyBuilderX64& build, const Instruction* pc, Label& next);
int emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
int emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
int emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
int emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopStart, Label& loopExit);
void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat, Label& loopExit);
void emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat, Label& loopExit, Label& fallback);
void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat);
void emitInstForGPrepNext(AssemblyBuilderX64& build, const Instruction* pc, Label& target, Label& fallback);
void emitInstForGPrepInext(AssemblyBuilderX64& build, const Instruction* pc, Label& target, Label& fallback);
void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& target);
void emitInstAnd(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstAndK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstOr(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstOrK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback);
void emitInstGetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, Label& next, Label& fallback);
void emitInstSetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback);
void emitInstGetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, Label& next, Label& fallback);
void emitInstSetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstGetImport(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback);
void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux);
void emitInstGetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstSetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next, Label& fallback);
void emitInstGetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstSetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next, Label& fallback);
void emitInstConcat(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next);
void emitInstCoverage(AssemblyBuilderX64& build, int pcpos);
} // namespace CodeGen

View file

@ -44,6 +44,7 @@ void updateUseCounts(IrFunction& function)
checkOp(inst.c);
checkOp(inst.d);
checkOp(inst.e);
checkOp(inst.f);
}
}
@ -68,6 +69,7 @@ void updateLastUseLocations(IrFunction& function)
checkOp(inst.c);
checkOp(inst.d);
checkOp(inst.e);
checkOp(inst.f);
}
}

View file

@ -269,17 +269,9 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
case LOP_FASTCALL:
{
int skip = LUAU_INSN_C(*pc);
IrOp fallback = block(IrBlockKind::Fallback);
IrOp next = blockAtInst(i + skip + 2);
Instruction call = pc[skip + 1];
LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
inst(IrCmd::LOP_FASTCALL, constUint(i), vmReg(LUAU_INSN_A(call)), constInt(LUAU_INSN_B(call) - 1), fallback);
inst(IrCmd::JUMP, next);
beginBlock(fallback);
translateFastCallN(*this, pc, i, false, 0, {}, next, IrCmd::LOP_FASTCALL);
activeFastcallFallback = true;
fastcallFallbackReturn = next;
@ -288,17 +280,9 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
case LOP_FASTCALL1:
{
int skip = LUAU_INSN_C(*pc);
IrOp fallback = block(IrBlockKind::Fallback);
IrOp next = blockAtInst(i + skip + 2);
Instruction call = pc[skip + 1];
LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
inst(IrCmd::LOP_FASTCALL1, constUint(i), vmReg(LUAU_INSN_A(call)), vmReg(LUAU_INSN_B(*pc)), fallback);
inst(IrCmd::JUMP, next);
beginBlock(fallback);
translateFastCallN(*this, pc, i, true, 1, constBool(false), next, IrCmd::LOP_FASTCALL1);
activeFastcallFallback = true;
fastcallFallbackReturn = next;
@ -307,17 +291,9 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
case LOP_FASTCALL2:
{
int skip = LUAU_INSN_C(*pc);
IrOp fallback = block(IrBlockKind::Fallback);
IrOp next = blockAtInst(i + skip + 2);
Instruction call = pc[skip + 1];
LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
inst(IrCmd::LOP_FASTCALL2, constUint(i), vmReg(LUAU_INSN_A(call)), vmReg(LUAU_INSN_B(*pc)), vmReg(pc[1]), fallback);
inst(IrCmd::JUMP, next);
beginBlock(fallback);
translateFastCallN(*this, pc, i, true, 2, vmReg(pc[1]), next, IrCmd::LOP_FASTCALL2);
activeFastcallFallback = true;
fastcallFallbackReturn = next;
@ -326,17 +302,9 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
case LOP_FASTCALL2K:
{
int skip = LUAU_INSN_C(*pc);
IrOp fallback = block(IrBlockKind::Fallback);
IrOp next = blockAtInst(i + skip + 2);
Instruction call = pc[skip + 1];
LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
inst(IrCmd::LOP_FASTCALL2K, constUint(i), vmReg(LUAU_INSN_A(call)), vmReg(LUAU_INSN_B(*pc)), vmConst(pc[1]), fallback);
inst(IrCmd::JUMP, next);
beginBlock(fallback);
translateFastCallN(*this, pc, i, true, 2, vmConst(pc[1]), next, IrCmd::LOP_FASTCALL2K);
activeFastcallFallback = true;
fastcallFallbackReturn = next;
@ -449,6 +417,7 @@ bool IrBuilder::isInternalBlock(IrOp block)
void IrBuilder::beginBlock(IrOp block)
{
IrBlock& target = function.blocks[block.index];
activeBlockIdx = block.index;
LUAU_ASSERT(target.start == ~0u || target.start == uint32_t(function.instructions.size()));
@ -511,36 +480,46 @@ IrOp IrBuilder::cond(IrCondition cond)
IrOp IrBuilder::inst(IrCmd cmd)
{
return inst(cmd, {}, {}, {}, {}, {});
return inst(cmd, {}, {}, {}, {}, {}, {});
}
IrOp IrBuilder::inst(IrCmd cmd, IrOp a)
{
return inst(cmd, a, {}, {}, {}, {});
return inst(cmd, a, {}, {}, {}, {}, {});
}
IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b)
{
return inst(cmd, a, b, {}, {}, {});
return inst(cmd, a, b, {}, {}, {}, {});
}
IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c)
{
return inst(cmd, a, b, c, {}, {});
return inst(cmd, a, b, c, {}, {}, {});
}
IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d)
{
return inst(cmd, a, b, c, d, {});
return inst(cmd, a, b, c, d, {}, {});
}
IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e)
{
return inst(cmd, a, b, c, d, e, {});
}
IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f)
{
uint32_t index = uint32_t(function.instructions.size());
function.instructions.push_back({cmd, a, b, c, d, e});
function.instructions.push_back({cmd, a, b, c, d, e, f});
LUAU_ASSERT(!inTerminatedBlock);
if (isBlockTerminator(cmd))
{
function.blocks[activeBlockIdx].finish = index;
inTerminatedBlock = true;
}
return {IrOpKind::Inst, index};
}

View file

@ -148,6 +148,10 @@ const char* getCmdName(IrCmd cmd)
return "NUM_TO_INDEX";
case IrCmd::INT_TO_NUM:
return "INT_TO_NUM";
case IrCmd::ADJUST_STACK_TO_REG:
return "ADJUST_STACK_TO_REG";
case IrCmd::ADJUST_STACK_TO_TOP:
return "ADJUST_STACK_TO_TOP";
case IrCmd::DO_ARITH:
return "DO_ARITH";
case IrCmd::DO_LEN:
@ -280,35 +284,20 @@ void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index)
ctx.result.append(getCmdName(inst.cmd));
if (inst.a.kind != IrOpKind::None)
{
append(ctx.result, " ");
toString(ctx, inst.a);
}
auto checkOp = [&ctx](IrOp op, const char* sep) {
if (op.kind != IrOpKind::None)
{
ctx.result.append(sep);
toString(ctx, op);
}
};
if (inst.b.kind != IrOpKind::None)
{
append(ctx.result, ", ");
toString(ctx, inst.b);
}
if (inst.c.kind != IrOpKind::None)
{
append(ctx.result, ", ");
toString(ctx, inst.c);
}
if (inst.d.kind != IrOpKind::None)
{
append(ctx.result, ", ");
toString(ctx, inst.d);
}
if (inst.e.kind != IrOpKind::None)
{
append(ctx.result, ", ");
toString(ctx, inst.e);
}
checkOp(inst.a, " ");
checkOp(inst.b, ", ");
checkOp(inst.c, ", ");
checkOp(inst.d, ", ");
checkOp(inst.e, ", ");
checkOp(inst.f, ", ");
}
void toString(IrToStringContext& ctx, const IrBlock& block, uint32_t index)
@ -421,7 +410,7 @@ std::string toString(IrFunction& function, bool includeDetails)
}
// To allow dumping blocks that are still being constructed, we can't rely on terminator and need a bounds check
for (uint32_t index = block.start; index < uint32_t(function.instructions.size()); index++)
for (uint32_t index = block.start; index <= block.finish && index < uint32_t(function.instructions.size()); index++)
{
IrInst& inst = function.instructions[index];
@ -440,13 +429,9 @@ std::string toString(IrFunction& function, bool includeDetails)
toString(ctx, inst, index);
ctx.result.append("\n");
}
if (isBlockTerminator(inst.cmd))
{
append(ctx.result, "\n");
break;
}
}
append(ctx.result, "\n");
}
return result;

View file

@ -7,6 +7,7 @@
#include "Luau/IrDump.h"
#include "Luau/IrUtils.h"
#include "EmitBuiltinsX64.h"
#include "EmitCommonX64.h"
#include "EmitInstructionX64.h"
#include "NativeState.h"
@ -20,18 +21,14 @@ namespace Luau
namespace CodeGen
{
static const RegisterX64 kGprAllocOrder[] = {rax, rdx, rcx, rbx, rsi, rdi, r8, r9, r10, r11};
IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function)
: build(build)
, helpers(helpers)
, data(data)
, proto(proto)
, function(function)
, regs(function)
{
freeGprMap.fill(true);
freeXmmMap.fill(true);
// In order to allocate registers during lowering, we need to know where instruction results are last used
updateLastUseLocations(function);
}
@ -95,11 +92,13 @@ void IrLoweringX64::lower(AssemblyOptions options)
uint32_t blockIndex = sortedBlocks[i];
IrBlock& block = function.blocks[blockIndex];
LUAU_ASSERT(block.start != ~0u);
if (block.kind == IrBlockKind::Dead)
continue;
LUAU_ASSERT(block.start != ~0u);
LUAU_ASSERT(block.finish != ~0u);
// If we want to skip fallback code IR/asm, we'll record when those blocks start once we see them
if (block.kind == IrBlockKind::Fallback && !seenFallback)
{
@ -116,7 +115,7 @@ void IrLoweringX64::lower(AssemblyOptions options)
build.setLabel(block.label);
for (uint32_t index = block.start; true; index++)
for (uint32_t index = block.start; index <= block.finish; index++)
{
LUAU_ASSERT(index < function.instructions.size());
@ -151,15 +150,11 @@ void IrLoweringX64::lower(AssemblyOptions options)
lowerInst(inst, index, next);
freeLastUseRegs(inst, index);
if (isBlockTerminator(inst.cmd))
{
if (options.includeIr)
build.logAppend("#\n");
break;
}
regs.freeLastUseRegs(inst, index);
}
if (options.includeIr)
build.logAppend("#\n");
}
if (outputEnabled && !options.includeOutlinedCode && seenFallback)
@ -183,7 +178,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
switch (inst.cmd)
{
case IrCmd::LOAD_TAG:
inst.regX64 = allocGprReg(SizeX64::dword);
inst.regX64 = regs.allocGprReg(SizeX64::dword);
if (inst.a.kind == IrOpKind::VmReg)
build.mov(inst.regX64, luauRegTag(inst.a.index));
@ -197,7 +192,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
LUAU_ASSERT(!"Unsupported instruction form");
break;
case IrCmd::LOAD_POINTER:
inst.regX64 = allocGprReg(SizeX64::qword);
inst.regX64 = regs.allocGprReg(SizeX64::qword);
if (inst.a.kind == IrOpKind::VmReg)
build.mov(inst.regX64, luauRegValue(inst.a.index));
@ -207,7 +202,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
LUAU_ASSERT(!"Unsupported instruction form");
break;
case IrCmd::LOAD_DOUBLE:
inst.regX64 = allocXmmReg();
inst.regX64 = regs.allocXmmReg();
if (inst.a.kind == IrOpKind::VmReg)
build.vmovsd(inst.regX64, luauRegValue(inst.a.index));
@ -219,12 +214,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
case IrCmd::LOAD_INT:
LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg);
inst.regX64 = allocGprReg(SizeX64::dword);
inst.regX64 = regs.allocGprReg(SizeX64::dword);
build.mov(inst.regX64, luauRegValueInt(inst.a.index));
break;
case IrCmd::LOAD_TVALUE:
inst.regX64 = allocXmmReg();
inst.regX64 = regs.allocXmmReg();
if (inst.a.kind == IrOpKind::VmReg)
build.vmovups(inst.regX64, luauReg(inst.a.index));
@ -236,12 +231,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
LUAU_ASSERT(!"Unsupported instruction form");
break;
case IrCmd::LOAD_NODE_VALUE_TV:
inst.regX64 = allocXmmReg();
inst.regX64 = regs.allocXmmReg();
build.vmovups(inst.regX64, luauNodeValue(regOp(inst.a)));
break;
case IrCmd::LOAD_ENV:
inst.regX64 = allocGprReg(SizeX64::qword);
inst.regX64 = regs.allocGprReg(SizeX64::qword);
build.mov(inst.regX64, sClosure);
build.mov(inst.regX64, qword[inst.regX64 + offsetof(Closure, env)]);
@ -249,7 +244,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
case IrCmd::GET_ARR_ADDR:
if (inst.b.kind == IrOpKind::Inst)
{
inst.regX64 = allocGprRegOrReuse(SizeX64::qword, index, {inst.b});
inst.regX64 = regs.allocGprRegOrReuse(SizeX64::qword, index, {inst.b});
if (dwordReg(inst.regX64) != regOp(inst.b))
build.mov(dwordReg(inst.regX64), regOp(inst.b));
@ -259,7 +254,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
}
else if (inst.b.kind == IrOpKind::Constant)
{
inst.regX64 = allocGprRegOrReuse(SizeX64::qword, index, {inst.a});
inst.regX64 = regs.allocGprRegOrReuse(SizeX64::qword, index, {inst.a});
build.mov(inst.regX64, qword[regOp(inst.a) + offsetof(Table, array)]);
@ -273,9 +268,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
break;
case IrCmd::GET_SLOT_NODE_ADDR:
{
inst.regX64 = allocGprReg(SizeX64::qword);
inst.regX64 = regs.allocGprReg(SizeX64::qword);
ScopedReg tmp{*this, SizeX64::qword};
ScopedRegX64 tmp{regs, SizeX64::qword};
getTableNodeAtCachedSlot(build, tmp.reg, inst.regX64, regOp(inst.a), uintOp(inst.b));
break;
@ -298,7 +293,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
if (inst.b.kind == IrOpKind::Constant)
{
ScopedReg tmp{*this, SizeX64::xmmword};
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, build.f64(doubleOp(inst.b)));
build.vmovsd(luauRegValue(inst.a.index), tmp.reg);
@ -336,7 +331,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.vmovups(luauNodeValue(regOp(inst.a)), regOp(inst.b));
break;
case IrCmd::ADD_INT:
inst.regX64 = allocGprRegOrReuse(SizeX64::dword, index, {inst.a});
inst.regX64 = regs.allocGprRegOrReuse(SizeX64::dword, index, {inst.a});
if (inst.regX64 == regOp(inst.a) && intOp(inst.b) == 1)
build.inc(inst.regX64);
@ -346,7 +341,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.lea(inst.regX64, addr[regOp(inst.a) + intOp(inst.b)]);
break;
case IrCmd::SUB_INT:
inst.regX64 = allocGprRegOrReuse(SizeX64::dword, index, {inst.a});
inst.regX64 = regs.allocGprRegOrReuse(SizeX64::dword, index, {inst.a});
if (inst.regX64 == regOp(inst.a) && intOp(inst.b) == 1)
build.dec(inst.regX64);
@ -356,34 +351,74 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.lea(inst.regX64, addr[regOp(inst.a) - intOp(inst.b)]);
break;
case IrCmd::ADD_NUM:
inst.regX64 = allocXmmRegOrReuse(index, {inst.a, inst.b});
inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b});
build.vaddsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b));
if (inst.a.kind == IrOpKind::Constant)
{
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, memRegDoubleOp(inst.a));
build.vaddsd(inst.regX64, tmp.reg, memRegDoubleOp(inst.b));
}
else
{
build.vaddsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b));
}
break;
case IrCmd::SUB_NUM:
inst.regX64 = allocXmmRegOrReuse(index, {inst.a, inst.b});
inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b});
build.vsubsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b));
if (inst.a.kind == IrOpKind::Constant)
{
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, memRegDoubleOp(inst.a));
build.vsubsd(inst.regX64, tmp.reg, memRegDoubleOp(inst.b));
}
else
{
build.vsubsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b));
}
break;
case IrCmd::MUL_NUM:
inst.regX64 = allocXmmRegOrReuse(index, {inst.a, inst.b});
inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b});
build.vmulsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b));
if (inst.a.kind == IrOpKind::Constant)
{
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, memRegDoubleOp(inst.a));
build.vmulsd(inst.regX64, tmp.reg, memRegDoubleOp(inst.b));
}
else
{
build.vmulsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b));
}
break;
case IrCmd::DIV_NUM:
inst.regX64 = allocXmmRegOrReuse(index, {inst.a, inst.b});
inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b});
build.vdivsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b));
if (inst.a.kind == IrOpKind::Constant)
{
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, memRegDoubleOp(inst.a));
build.vdivsd(inst.regX64, tmp.reg, memRegDoubleOp(inst.b));
}
else
{
build.vdivsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b));
}
break;
case IrCmd::MOD_NUM:
{
inst.regX64 = allocXmmRegOrReuse(index, {inst.a, inst.b});
inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b});
RegisterX64 lhs = regOp(inst.a);
if (inst.b.kind == IrOpKind::Inst)
{
ScopedReg tmp{*this, SizeX64::xmmword};
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vdivsd(tmp.reg, lhs, memRegDoubleOp(inst.b));
build.vroundsd(tmp.reg, tmp.reg, tmp.reg, RoundingModeX64::RoundToNegativeInfinity);
@ -392,8 +427,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
}
else
{
ScopedReg tmp1{*this, SizeX64::xmmword};
ScopedReg tmp2{*this, SizeX64::xmmword};
ScopedRegX64 tmp1{regs, SizeX64::xmmword};
ScopedRegX64 tmp2{regs, SizeX64::xmmword};
build.vmovsd(tmp1.reg, memRegDoubleOp(inst.b));
build.vdivsd(tmp2.reg, lhs, tmp1.reg);
@ -405,9 +440,21 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
}
case IrCmd::POW_NUM:
{
inst.regX64 = allocXmmRegOrReuse(index, {inst.a, inst.b});
inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b});
RegisterX64 lhs = regOp(inst.a);
ScopedRegX64 tmp{regs, SizeX64::xmmword};
RegisterX64 lhs;
if (inst.a.kind == IrOpKind::Constant)
{
build.vmovsd(tmp.reg, memRegDoubleOp(inst.a));
lhs = tmp.reg;
}
else
{
lhs = regOp(inst.a);
}
if (inst.b.kind == IrOpKind::Inst)
{
@ -439,7 +486,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
}
else if (rhs == 3.0)
{
ScopedReg tmp{*this, SizeX64::xmmword};
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmulsd(tmp.reg, lhs, lhs);
build.vmulsd(inst.regX64, lhs, tmp.reg);
@ -472,7 +519,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
}
case IrCmd::UNM_NUM:
{
inst.regX64 = allocXmmRegOrReuse(index, {inst.a});
inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a});
RegisterX64 src = regOp(inst.a);
@ -491,15 +538,23 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
case IrCmd::NOT_ANY:
{
// TODO: if we have a single user which is a STORE_INT, we are missing the opportunity to write directly to target
inst.regX64 = allocGprRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
inst.regX64 = regs.allocGprRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
Label saveone, savezero, exit;
build.cmp(regOp(inst.a), LUA_TNIL);
build.jcc(ConditionX64::Equal, saveone);
if (inst.a.kind == IrOpKind::Constant)
{
// Other cases should've been constant folded
LUAU_ASSERT(tagOp(inst.a) == LUA_TBOOLEAN);
}
else
{
build.cmp(regOp(inst.a), LUA_TNIL);
build.jcc(ConditionX64::Equal, saveone);
build.cmp(regOp(inst.a), LUA_TBOOLEAN);
build.jcc(ConditionX64::NotEqual, savezero);
build.cmp(regOp(inst.a), LUA_TBOOLEAN);
build.jcc(ConditionX64::NotEqual, savezero);
}
build.cmp(regOp(inst.b), 0);
build.jcc(ConditionX64::Equal, saveone);
@ -566,7 +621,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
IrCondition cond = IrCondition(inst.c.index);
ScopedReg tmp{*this, SizeX64::xmmword};
ScopedRegX64 tmp{regs, SizeX64::xmmword};
// TODO: jumpOnNumberCmp should work on IrCondition directly
jumpOnNumberCmp(build, tmp.reg, memRegDoubleOp(inst.a), memRegDoubleOp(inst.b), getX64Condition(cond), labelOp(inst.d));
@ -586,14 +641,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
break;
}
case IrCmd::TABLE_LEN:
inst.regX64 = allocXmmReg();
inst.regX64 = regs.allocXmmReg();
build.mov(rArg1, regOp(inst.a));
build.call(qword[rNativeContext + offsetof(NativeContext, luaH_getn)]);
build.vcvtsi2sd(inst.regX64, inst.regX64, eax);
break;
case IrCmd::NEW_TABLE:
inst.regX64 = allocGprReg(SizeX64::qword);
inst.regX64 = regs.allocGprReg(SizeX64::qword);
build.mov(rArg1, rState);
build.mov(dwordReg(rArg2), uintOp(inst.a));
@ -604,7 +659,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.mov(inst.regX64, rax);
break;
case IrCmd::DUP_TABLE:
inst.regX64 = allocGprReg(SizeX64::qword);
inst.regX64 = regs.allocGprReg(SizeX64::qword);
// Re-ordered to avoid register conflict
build.mov(rArg2, regOp(inst.a));
@ -616,18 +671,51 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
break;
case IrCmd::NUM_TO_INDEX:
{
inst.regX64 = allocGprReg(SizeX64::dword);
inst.regX64 = regs.allocGprReg(SizeX64::dword);
ScopedReg tmp{*this, SizeX64::xmmword};
ScopedRegX64 tmp{regs, SizeX64::xmmword};
convertNumberToIndexOrJump(build, tmp.reg, regOp(inst.a), inst.regX64, labelOp(inst.b));
break;
}
case IrCmd::INT_TO_NUM:
inst.regX64 = allocXmmReg();
inst.regX64 = regs.allocXmmReg();
build.vcvtsi2sd(inst.regX64, inst.regX64, regOp(inst.a));
break;
case IrCmd::ADJUST_STACK_TO_REG:
{
LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg);
if (inst.b.kind == IrOpKind::Constant)
{
ScopedRegX64 tmp{regs, SizeX64::qword};
build.lea(tmp.reg, addr[rBase + (inst.a.index + intOp(inst.b)) * sizeof(TValue)]);
build.mov(qword[rState + offsetof(lua_State, top)], tmp.reg);
}
else if (inst.b.kind == IrOpKind::Inst)
{
ScopedRegX64 tmp(regs, regs.allocGprRegOrReuse(SizeX64::dword, index, {inst.b}));
build.shl(qwordReg(tmp.reg), kTValueSizeLog2);
build.lea(qwordReg(tmp.reg), addr[rBase + qwordReg(tmp.reg) + inst.a.index * sizeof(TValue)]);
build.mov(qword[rState + offsetof(lua_State, top)], qwordReg(tmp.reg));
}
else
{
LUAU_ASSERT(!"Unsupported instruction form");
}
break;
}
case IrCmd::ADJUST_STACK_TO_TOP:
{
ScopedRegX64 tmp{regs, SizeX64::qword};
build.mov(tmp.reg, qword[rState + offsetof(lua_State, ci)]);
build.mov(tmp.reg, qword[tmp.reg + offsetof(CallInfo, top)]);
build.mov(qword[rState + offsetof(lua_State, top)], tmp.reg);
break;
}
case IrCmd::DO_ARITH:
LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg);
LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg);
@ -702,8 +790,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg);
LUAU_ASSERT(inst.b.kind == IrOpKind::VmUpvalue);
ScopedReg tmp1{*this, SizeX64::qword};
ScopedReg tmp2{*this, SizeX64::xmmword};
ScopedRegX64 tmp1{regs, SizeX64::qword};
ScopedRegX64 tmp2{regs, SizeX64::xmmword};
build.mov(tmp1.reg, sClosure);
build.add(tmp1.reg, offsetof(Closure, l.uprefs) + sizeof(TValue) * inst.b.index);
@ -730,9 +818,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg);
Label next;
ScopedReg tmp1{*this, SizeX64::qword};
ScopedReg tmp2{*this, SizeX64::qword};
ScopedReg tmp3{*this, SizeX64::xmmword};
ScopedRegX64 tmp1{regs, SizeX64::qword};
ScopedRegX64 tmp2{regs, SizeX64::qword};
ScopedRegX64 tmp3{regs, SizeX64::xmmword};
build.mov(tmp1.reg, sClosure);
build.mov(tmp2.reg, qword[tmp1.reg + offsetof(Closure, l.uprefs) + sizeof(TValue) * inst.a.index + offsetof(TValue, value.gc)]);
@ -758,6 +846,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
{
jumpIfTagIsNot(build, inst.a.index, lua_Type(tagOp(inst.b)), labelOp(inst.c));
}
else if (inst.a.kind == IrOpKind::VmConst)
{
build.cmp(luauConstantTag(inst.a.index), tagOp(inst.b));
build.jcc(ConditionX64::NotEqual, labelOp(inst.c));
}
else
{
LUAU_ASSERT(!"Unsupported instruction form");
@ -771,7 +864,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
break;
case IrCmd::CHECK_SAFE_ENV:
{
ScopedReg tmp{*this, SizeX64::qword};
ScopedRegX64 tmp{regs, SizeX64::qword};
jumpIfUnsafeEnv(build, tmp.reg, labelOp(inst.a));
break;
@ -790,7 +883,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
{
LUAU_ASSERT(inst.b.kind == IrOpKind::VmConst);
ScopedReg tmp{*this, SizeX64::qword};
ScopedRegX64 tmp{regs, SizeX64::qword};
jumpIfNodeKeyNotInExpectedSlot(build, tmp.reg, regOp(inst.a), luauConstantValue(inst.b.index), labelOp(inst.c));
break;
@ -810,7 +903,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg);
Label skip;
ScopedReg tmp{*this, SizeX64::qword};
ScopedRegX64 tmp{regs, SizeX64::qword};
callBarrierObject(build, tmp.reg, regOp(inst.a), inst.b.index, skip);
build.setLabel(skip);
@ -829,7 +922,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg);
Label skip;
ScopedReg tmp{*this, SizeX64::qword};
ScopedRegX64 tmp{regs, SizeX64::qword};
callBarrierTable(build, tmp.reg, regOp(inst.a), inst.b.index, skip);
build.setLabel(skip);
@ -838,8 +931,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
case IrCmd::SET_SAVEDPC:
{
// This is like emitSetSavedPc, but using register allocation instead of relying on rax/rdx
ScopedReg tmp1{*this, SizeX64::qword};
ScopedReg tmp2{*this, SizeX64::qword};
ScopedRegX64 tmp1{regs, SizeX64::qword};
ScopedRegX64 tmp2{regs, SizeX64::qword};
build.mov(tmp2.reg, sCode);
build.add(tmp2.reg, uintOp(inst.a) * sizeof(Instruction));
@ -852,8 +945,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg);
Label next;
ScopedReg tmp1{*this, SizeX64::qword};
ScopedReg tmp2{*this, SizeX64::qword};
ScopedRegX64 tmp1{regs, SizeX64::qword};
ScopedRegX64 tmp2{regs, SizeX64::qword};
// L->openupval != 0
build.mov(tmp1.reg, qword[rState + offsetof(lua_State, openupval)]);
@ -1131,124 +1224,6 @@ Label& IrLoweringX64::labelOp(IrOp op) const
return blockOp(op).label;
}
RegisterX64 IrLoweringX64::allocGprReg(SizeX64 preferredSize)
{
LUAU_ASSERT(
preferredSize == SizeX64::byte || preferredSize == SizeX64::word || preferredSize == SizeX64::dword || preferredSize == SizeX64::qword);
for (RegisterX64 reg : kGprAllocOrder)
{
if (freeGprMap[reg.index])
{
freeGprMap[reg.index] = false;
return RegisterX64{preferredSize, reg.index};
}
}
LUAU_ASSERT(!"Out of GPR registers to allocate");
return noreg;
}
RegisterX64 IrLoweringX64::allocXmmReg()
{
for (size_t i = 0; i < freeXmmMap.size(); ++i)
{
if (freeXmmMap[i])
{
freeXmmMap[i] = false;
return RegisterX64{SizeX64::xmmword, uint8_t(i)};
}
}
LUAU_ASSERT(!"Out of XMM registers to allocate");
return noreg;
}
RegisterX64 IrLoweringX64::allocGprRegOrReuse(SizeX64 preferredSize, uint32_t index, std::initializer_list<IrOp> oprefs)
{
for (IrOp op : oprefs)
{
if (op.kind != IrOpKind::Inst)
continue;
IrInst& source = function.instructions[op.index];
if (source.lastUse == index && !source.reusedReg)
{
LUAU_ASSERT(source.regX64.size != SizeX64::xmmword);
LUAU_ASSERT(source.regX64 != noreg);
source.reusedReg = true;
return RegisterX64{preferredSize, source.regX64.index};
}
}
return allocGprReg(preferredSize);
}
RegisterX64 IrLoweringX64::allocXmmRegOrReuse(uint32_t index, std::initializer_list<IrOp> oprefs)
{
for (IrOp op : oprefs)
{
if (op.kind != IrOpKind::Inst)
continue;
IrInst& source = function.instructions[op.index];
if (source.lastUse == index && !source.reusedReg)
{
LUAU_ASSERT(source.regX64.size == SizeX64::xmmword);
LUAU_ASSERT(source.regX64 != noreg);
source.reusedReg = true;
return source.regX64;
}
}
return allocXmmReg();
}
void IrLoweringX64::freeReg(RegisterX64 reg)
{
if (reg.size == SizeX64::xmmword)
{
LUAU_ASSERT(!freeXmmMap[reg.index]);
freeXmmMap[reg.index] = true;
}
else
{
LUAU_ASSERT(!freeGprMap[reg.index]);
freeGprMap[reg.index] = true;
}
}
void IrLoweringX64::freeLastUseReg(IrInst& target, uint32_t index)
{
if (target.lastUse == index && !target.reusedReg)
{
// Register might have already been freed if it had multiple uses inside a single instruction
if (target.regX64 == noreg)
return;
freeReg(target.regX64);
target.regX64 = noreg;
}
}
void IrLoweringX64::freeLastUseRegs(const IrInst& inst, uint32_t index)
{
auto checkOp = [this, index](IrOp op) {
if (op.kind == IrOpKind::Inst)
freeLastUseReg(function.instructions[op.index], index);
};
checkOp(inst.a);
checkOp(inst.b);
checkOp(inst.c);
checkOp(inst.d);
checkOp(inst.e);
}
ConditionX64 IrLoweringX64::getX64Condition(IrCondition cond) const
{
// TODO: this function will not be required when jumpOnNumberCmp starts accepting an IrCondition
@ -1282,27 +1257,5 @@ ConditionX64 IrLoweringX64::getX64Condition(IrCondition cond) const
return ConditionX64::Count;
}
IrLoweringX64::ScopedReg::ScopedReg(IrLoweringX64& owner, SizeX64 size)
: owner(owner)
{
if (size == SizeX64::xmmword)
reg = owner.allocXmmReg();
else
reg = owner.allocGprReg(size);
}
IrLoweringX64::ScopedReg::~ScopedReg()
{
if (reg != noreg)
owner.freeReg(reg);
}
void IrLoweringX64::ScopedReg::free()
{
LUAU_ASSERT(reg != noreg);
owner.freeReg(reg);
reg = noreg;
}
} // namespace CodeGen
} // namespace Luau

View file

@ -4,8 +4,8 @@
#include "Luau/AssemblyBuilderX64.h"
#include "Luau/IrData.h"
#include <array>
#include <initializer_list>
#include "IrRegAllocX64.h"
#include <vector>
struct Proto;
@ -46,33 +46,8 @@ struct IrLoweringX64
IrBlock& blockOp(IrOp op) const;
Label& labelOp(IrOp op) const;
// Unscoped register allocation
RegisterX64 allocGprReg(SizeX64 preferredSize);
RegisterX64 allocXmmReg();
RegisterX64 allocGprRegOrReuse(SizeX64 preferredSize, uint32_t index, std::initializer_list<IrOp> oprefs);
RegisterX64 allocXmmRegOrReuse(uint32_t index, std::initializer_list<IrOp> oprefs);
void freeReg(RegisterX64 reg);
void freeLastUseReg(IrInst& target, uint32_t index);
void freeLastUseRegs(const IrInst& inst, uint32_t index);
ConditionX64 getX64Condition(IrCondition cond) const;
struct ScopedReg
{
ScopedReg(IrLoweringX64& owner, SizeX64 size);
~ScopedReg();
ScopedReg(const ScopedReg&) = delete;
ScopedReg& operator=(const ScopedReg&) = delete;
void free();
IrLoweringX64& owner;
RegisterX64 reg;
};
AssemblyBuilderX64& build;
ModuleHelpers& helpers;
NativeState& data;
@ -80,8 +55,7 @@ struct IrLoweringX64
IrFunction& function;
std::array<bool, 16> freeGprMap;
std::array<bool, 16> freeXmmMap;
IrRegAllocX64 regs;
};
} // namespace CodeGen

View file

@ -0,0 +1,181 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "IrRegAllocX64.h"
#include "Luau/CodeGen.h"
#include "Luau/DenseHash.h"
#include "Luau/IrAnalysis.h"
#include "Luau/IrDump.h"
#include "Luau/IrUtils.h"
#include "EmitCommonX64.h"
#include "EmitInstructionX64.h"
#include "NativeState.h"
#include "lstate.h"
#include <algorithm>
namespace Luau
{
namespace CodeGen
{
static const RegisterX64 kGprAllocOrder[] = {rax, rdx, rcx, rbx, rsi, rdi, r8, r9, r10, r11};
IrRegAllocX64::IrRegAllocX64(IrFunction& function)
: function(function)
{
freeGprMap.fill(true);
freeXmmMap.fill(true);
}
RegisterX64 IrRegAllocX64::allocGprReg(SizeX64 preferredSize)
{
LUAU_ASSERT(
preferredSize == SizeX64::byte || preferredSize == SizeX64::word || preferredSize == SizeX64::dword || preferredSize == SizeX64::qword);
for (RegisterX64 reg : kGprAllocOrder)
{
if (freeGprMap[reg.index])
{
freeGprMap[reg.index] = false;
return RegisterX64{preferredSize, reg.index};
}
}
LUAU_ASSERT(!"Out of GPR registers to allocate");
return noreg;
}
RegisterX64 IrRegAllocX64::allocXmmReg()
{
for (size_t i = 0; i < freeXmmMap.size(); ++i)
{
if (freeXmmMap[i])
{
freeXmmMap[i] = false;
return RegisterX64{SizeX64::xmmword, uint8_t(i)};
}
}
LUAU_ASSERT(!"Out of XMM registers to allocate");
return noreg;
}
RegisterX64 IrRegAllocX64::allocGprRegOrReuse(SizeX64 preferredSize, uint32_t index, std::initializer_list<IrOp> oprefs)
{
for (IrOp op : oprefs)
{
if (op.kind != IrOpKind::Inst)
continue;
IrInst& source = function.instructions[op.index];
if (source.lastUse == index && !source.reusedReg)
{
LUAU_ASSERT(source.regX64.size != SizeX64::xmmword);
LUAU_ASSERT(source.regX64 != noreg);
source.reusedReg = true;
return RegisterX64{preferredSize, source.regX64.index};
}
}
return allocGprReg(preferredSize);
}
RegisterX64 IrRegAllocX64::allocXmmRegOrReuse(uint32_t index, std::initializer_list<IrOp> oprefs)
{
for (IrOp op : oprefs)
{
if (op.kind != IrOpKind::Inst)
continue;
IrInst& source = function.instructions[op.index];
if (source.lastUse == index && !source.reusedReg)
{
LUAU_ASSERT(source.regX64.size == SizeX64::xmmword);
LUAU_ASSERT(source.regX64 != noreg);
source.reusedReg = true;
return source.regX64;
}
}
return allocXmmReg();
}
void IrRegAllocX64::freeReg(RegisterX64 reg)
{
if (reg.size == SizeX64::xmmword)
{
LUAU_ASSERT(!freeXmmMap[reg.index]);
freeXmmMap[reg.index] = true;
}
else
{
LUAU_ASSERT(!freeGprMap[reg.index]);
freeGprMap[reg.index] = true;
}
}
void IrRegAllocX64::freeLastUseReg(IrInst& target, uint32_t index)
{
if (target.lastUse == index && !target.reusedReg)
{
// Register might have already been freed if it had multiple uses inside a single instruction
if (target.regX64 == noreg)
return;
freeReg(target.regX64);
target.regX64 = noreg;
}
}
void IrRegAllocX64::freeLastUseRegs(const IrInst& inst, uint32_t index)
{
auto checkOp = [this, index](IrOp op) {
if (op.kind == IrOpKind::Inst)
freeLastUseReg(function.instructions[op.index], index);
};
checkOp(inst.a);
checkOp(inst.b);
checkOp(inst.c);
checkOp(inst.d);
checkOp(inst.e);
checkOp(inst.f);
}
ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner, SizeX64 size)
: owner(owner)
{
if (size == SizeX64::xmmword)
reg = owner.allocXmmReg();
else
reg = owner.allocGprReg(size);
}
ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner, RegisterX64 reg)
: owner(owner)
, reg(reg)
{
}
ScopedRegX64::~ScopedRegX64()
{
if (reg != noreg)
owner.freeReg(reg);
}
void ScopedRegX64::free()
{
LUAU_ASSERT(reg != noreg);
owner.freeReg(reg);
reg = noreg;
}
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,51 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/IrData.h"
#include "Luau/RegisterX64.h"
#include <array>
#include <initializer_list>
namespace Luau
{
namespace CodeGen
{
struct IrRegAllocX64
{
IrRegAllocX64(IrFunction& function);
RegisterX64 allocGprReg(SizeX64 preferredSize);
RegisterX64 allocXmmReg();
RegisterX64 allocGprRegOrReuse(SizeX64 preferredSize, uint32_t index, std::initializer_list<IrOp> oprefs);
RegisterX64 allocXmmRegOrReuse(uint32_t index, std::initializer_list<IrOp> oprefs);
void freeReg(RegisterX64 reg);
void freeLastUseReg(IrInst& target, uint32_t index);
void freeLastUseRegs(const IrInst& inst, uint32_t index);
IrFunction& function;
std::array<bool, 16> freeGprMap;
std::array<bool, 16> freeXmmMap;
};
struct ScopedRegX64
{
ScopedRegX64(IrRegAllocX64& owner, SizeX64 size);
ScopedRegX64(IrRegAllocX64& owner, RegisterX64 reg);
~ScopedRegX64();
ScopedRegX64(const ScopedRegX64&) = delete;
ScopedRegX64& operator=(const ScopedRegX64&) = delete;
void free();
IrRegAllocX64& owner;
RegisterX64 reg;
};
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,40 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "IrTranslateBuiltins.h"
#include "Luau/Bytecode.h"
#include "Luau/IrBuilder.h"
#include "lstate.h"
namespace Luau
{
namespace CodeGen
{
BuiltinImplResult translateBuiltinAssert(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback)
{
if (nparams < 1 || nresults != 0)
return {BuiltinImplType::None, -1};
IrOp cont = build.block(IrBlockKind::Internal);
// TODO: maybe adding a guard like CHECK_TRUTHY can be useful
build.inst(IrCmd::JUMP_IF_FALSY, build.vmReg(arg), fallback, cont);
build.beginBlock(cont);
return {BuiltinImplType::UsesFallback, 0};
}
BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback)
{
switch (bfid)
{
case LBF_ASSERT:
return translateBuiltinAssert(build, nparams, ra, arg, args, nresults, fallback);
default:
return {BuiltinImplType::None, -1};
}
}
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,27 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
namespace Luau
{
namespace CodeGen
{
struct IrBuilder;
struct IrOp;
enum class BuiltinImplType
{
None,
UsesFallback, // Uses fallback for unsupported cases
};
struct BuiltinImplResult
{
BuiltinImplType type;
int actualResultCount;
};
BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback);
} // namespace CodeGen
} // namespace Luau

View file

@ -6,6 +6,7 @@
#include "Luau/IrUtils.h"
#include "CustomExecUtils.h"
#include "IrTranslateBuiltins.h"
#include "lobject.h"
#include "ltm.h"
@ -68,7 +69,6 @@ void translateInstLoadK(IrBuilder& build, const Instruction* pc)
{
int ra = LUAU_INSN_A(*pc);
// TODO: per-component loads and stores might be preferable
IrOp load = build.inst(IrCmd::LOAD_TVALUE, build.vmConst(LUAU_INSN_D(*pc)));
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load);
}
@ -78,7 +78,6 @@ void translateInstLoadKX(IrBuilder& build, const Instruction* pc)
int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1];
// TODO: per-component loads and stores might be preferable
IrOp load = build.inst(IrCmd::LOAD_TVALUE, build.vmConst(aux));
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load);
}
@ -88,7 +87,6 @@ void translateInstMove(IrBuilder& build, const Instruction* pc)
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
// TODO: per-component loads and stores might be preferable
IrOp load = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(rb));
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load);
}
@ -146,10 +144,11 @@ void translateInstJumpIfEq(IrBuilder& build, const Instruction* pc, int pcpos, b
build.inst(IrCmd::JUMP_CMP_NUM, va, vb, build.cond(IrCondition::NotEqual), not_ ? target : next, not_ ? next : target);
FallbackStreamScope scope(build, fallback, next);
build.beginBlock(fallback);
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::JUMP_CMP_ANY, build.vmReg(ra), build.vmReg(rb), build.cond(not_ ? IrCondition::NotEqual : IrCondition::Equal), target, next);
build.beginBlock(next);
}
void translateInstJumpIfCond(IrBuilder& build, const Instruction* pc, int pcpos, IrCondition cond)
@ -173,10 +172,11 @@ void translateInstJumpIfCond(IrBuilder& build, const Instruction* pc, int pcpos,
build.inst(IrCmd::JUMP_CMP_NUM, va, vb, build.cond(cond), target, next);
FallbackStreamScope scope(build, fallback, next);
build.beginBlock(fallback);
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::JUMP_CMP_ANY, build.vmReg(ra), build.vmReg(rb), build.cond(cond), target, next);
build.beginBlock(next);
}
void translateInstJumpX(IrBuilder& build, const Instruction* pc, int pcpos)
@ -479,6 +479,61 @@ void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc)
build.inst(IrCmd::CLOSE_UPVALS, build.vmReg(ra));
}
void translateFastCallN(
IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp next, IrCmd fallbackCmd)
{
int bfid = LUAU_INSN_A(*pc);
int skip = LUAU_INSN_C(*pc);
IrOp fallback = build.block(IrBlockKind::Fallback);
Instruction call = pc[skip + 1];
LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
int ra = LUAU_INSN_A(call);
int nparams = customParams ? customParamCount : LUAU_INSN_B(call) - 1;
int nresults = LUAU_INSN_C(call) - 1;
int arg = customParams ? LUAU_INSN_B(*pc) : ra + 1;
IrOp args = customParams ? customArgs : build.vmReg(ra + 2);
build.inst(IrCmd::CHECK_SAFE_ENV, fallback);
BuiltinImplResult br = translateBuiltin(build, LuauBuiltinFunction(bfid), ra, arg, args, nparams, nresults, fallback);
if (br.type == BuiltinImplType::UsesFallback)
{
if (nresults == LUA_MULTRET)
build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(ra), build.constInt(br.actualResultCount));
else if (nparams == LUA_MULTRET)
build.inst(IrCmd::ADJUST_STACK_TO_TOP);
}
else
{
switch (fallbackCmd)
{
case IrCmd::LOP_FASTCALL:
build.inst(IrCmd::LOP_FASTCALL, build.constUint(pcpos), build.vmReg(ra), build.constInt(nparams), fallback);
break;
case IrCmd::LOP_FASTCALL1:
build.inst(IrCmd::LOP_FASTCALL1, build.constUint(pcpos), build.vmReg(ra), build.vmReg(arg), fallback);
break;
case IrCmd::LOP_FASTCALL2:
build.inst(IrCmd::LOP_FASTCALL2, build.constUint(pcpos), build.vmReg(ra), build.vmReg(arg), build.vmReg(pc[1]), fallback);
break;
case IrCmd::LOP_FASTCALL2K:
build.inst(IrCmd::LOP_FASTCALL2K, build.constUint(pcpos), build.vmReg(ra), build.vmReg(arg), build.vmConst(pc[1]), fallback);
break;
default:
LUAU_ASSERT(!"unexpected command");
}
}
build.inst(IrCmd::JUMP, next);
// this will be filled with IR corresponding to instructions after FASTCALL until skip+1
build.beginBlock(fallback);
}
void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos)
{
int ra = LUAU_INSN_A(*pc);
@ -589,7 +644,6 @@ void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpo
build.inst(IrCmd::JUMP, target);
// FallbackStreamScope not used here because this instruction doesn't fallthrough to next instruction
build.beginBlock(fallback);
build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), target);
}
@ -622,7 +676,6 @@ void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcp
build.inst(IrCmd::JUMP, target);
// FallbackStreamScope not used here because this instruction doesn't fallthrough to next instruction
build.beginBlock(fallback);
build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), target);
}
@ -700,7 +753,6 @@ void translateInstGetTableN(IrBuilder& build, const Instruction* pc, int pcpos)
IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, build.constInt(c));
// TODO: per-component loads and stores might be preferable
IrOp arrElTval = build.inst(IrCmd::LOAD_TVALUE, arrEl);
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), arrElTval);
@ -731,7 +783,6 @@ void translateInstSetTableN(IrBuilder& build, const Instruction* pc, int pcpos)
IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, build.constInt(c));
// TODO: per-component loads and stores might be preferable
IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra));
build.inst(IrCmd::STORE_TVALUE, arrEl, tva);
@ -771,7 +822,6 @@ void translateInstGetTable(IrBuilder& build, const Instruction* pc, int pcpos)
IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, index);
// TODO: per-component loads and stores might be preferable
IrOp arrElTval = build.inst(IrCmd::LOAD_TVALUE, arrEl);
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), arrElTval);
@ -810,7 +860,6 @@ void translateInstSetTable(IrBuilder& build, const Instruction* pc, int pcpos)
IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, index);
// TODO: per-component loads and stores might be preferable
IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra));
build.inst(IrCmd::STORE_TVALUE, arrEl, tva);
@ -842,7 +891,6 @@ void translateInstGetImport(IrBuilder& build, const Instruction* pc, int pcpos)
build.beginBlock(fastPath);
// TODO: per-component loads and stores might be preferable
IrOp tvk = build.inst(IrCmd::LOAD_TVALUE, build.vmConst(k));
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), tvk);
@ -871,7 +919,6 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos)
build.inst(IrCmd::CHECK_SLOT_MATCH, addrSlotEl, build.vmConst(aux), fallback);
// TODO: per-component loads and stores might be preferable
IrOp tvn = build.inst(IrCmd::LOAD_NODE_VALUE_TV, addrSlotEl);
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), tvn);
@ -900,7 +947,6 @@ void translateInstSetTableKS(IrBuilder& build, const Instruction* pc, int pcpos)
build.inst(IrCmd::CHECK_SLOT_MATCH, addrSlotEl, build.vmConst(aux), fallback);
build.inst(IrCmd::CHECK_READONLY, vb, fallback);
// TODO: per-component loads and stores might be preferable
IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra));
build.inst(IrCmd::STORE_NODE_VALUE_TV, addrSlotEl, tva);
@ -925,7 +971,6 @@ void translateInstGetGlobal(IrBuilder& build, const Instruction* pc, int pcpos)
build.inst(IrCmd::CHECK_SLOT_MATCH, addrSlotEl, build.vmConst(aux), fallback);
// TODO: per-component loads and stores might be preferable
IrOp tvn = build.inst(IrCmd::LOAD_NODE_VALUE_TV, addrSlotEl);
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), tvn);
@ -949,7 +994,6 @@ void translateInstSetGlobal(IrBuilder& build, const Instruction* pc, int pcpos)
build.inst(IrCmd::CHECK_SLOT_MATCH, addrSlotEl, build.vmConst(aux), fallback);
build.inst(IrCmd::CHECK_READONLY, env, fallback);
// TODO: per-component loads and stores might be preferable
IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra));
build.inst(IrCmd::STORE_NODE_VALUE_TV, addrSlotEl, tva);
@ -971,7 +1015,6 @@ void translateInstConcat(IrBuilder& build, const Instruction* pc, int pcpos)
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::CONCAT, build.vmReg(rb), build.constUint(rc - rb + 1));
// TODO: per-component loads and stores might be preferable
IrOp tvb = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(rb));
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), tvb);

View file

@ -15,6 +15,7 @@ namespace CodeGen
enum class IrCondition : uint8_t;
struct IrOp;
struct IrBuilder;
enum class IrCmd : uint8_t;
void translateInstLoadNil(IrBuilder& build, const Instruction* pc);
void translateInstLoadB(IrBuilder& build, const Instruction* pc, int pcpos);
@ -42,6 +43,8 @@ void translateInstDupTable(IrBuilder& build, const Instruction* pc, int pcpos);
void translateInstGetUpval(IrBuilder& build, const Instruction* pc, int pcpos);
void translateInstSetUpval(IrBuilder& build, const Instruction* pc, int pcpos);
void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc);
void translateFastCallN(
IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp next, IrCmd fallbackCmd);
void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos);
void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos);
void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpos);

View file

@ -14,20 +14,6 @@ namespace Luau
namespace CodeGen
{
static uint32_t getBlockEnd(IrFunction& function, uint32_t start)
{
LUAU_ASSERT(start < function.instructions.size());
uint32_t end = start;
// Find previous block terminator
while (!isBlockTerminator(function.instructions[end].cmd))
end++;
LUAU_ASSERT(end < function.instructions.size());
return end;
}
void addUse(IrFunction& function, IrOp op)
{
if (op.kind == IrOpKind::Inst)
@ -44,6 +30,12 @@ void removeUse(IrFunction& function, IrOp op)
removeUse(function, function.blocks[op.index]);
}
bool isGCO(uint8_t tag)
{
// mirrors iscollectable(o) from VM/lobject.h
return tag >= LUA_TSTRING;
}
void kill(IrFunction& function, IrInst& inst)
{
LUAU_ASSERT(inst.useCount == 0);
@ -55,12 +47,14 @@ void kill(IrFunction& function, IrInst& inst)
removeUse(function, inst.c);
removeUse(function, inst.d);
removeUse(function, inst.e);
removeUse(function, inst.f);
inst.a = {};
inst.b = {};
inst.c = {};
inst.d = {};
inst.e = {};
inst.f = {};
}
void kill(IrFunction& function, uint32_t start, uint32_t end)
@ -84,10 +78,9 @@ void kill(IrFunction& function, IrBlock& block)
block.kind = IrBlockKind::Dead;
uint32_t start = block.start;
uint32_t end = getBlockEnd(function, start);
kill(function, start, end);
kill(function, block.start, block.finish);
block.start = ~0u;
block.finish = ~0u;
}
void removeUse(IrFunction& function, IrInst& inst)
@ -117,7 +110,7 @@ void replace(IrFunction& function, IrOp& original, IrOp replacement)
original = replacement;
}
void replace(IrFunction& function, uint32_t instIdx, IrInst replacement)
void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst replacement)
{
IrInst& inst = function.instructions[instIdx];
@ -127,19 +120,18 @@ void replace(IrFunction& function, uint32_t instIdx, IrInst replacement)
addUse(function, replacement.c);
addUse(function, replacement.d);
addUse(function, replacement.e);
addUse(function, replacement.f);
// If we introduced an earlier terminating instruction, all following instructions become dead
if (!isBlockTerminator(inst.cmd) && isBlockTerminator(replacement.cmd))
{
uint32_t start = instIdx + 1;
// Block has has to be fully constructed before replacement is performed
LUAU_ASSERT(block.finish != ~0u);
LUAU_ASSERT(instIdx + 1 <= block.finish);
// If we are in the process of constructing a block, replacement might happen at the last instruction
if (start < function.instructions.size())
{
uint32_t end = getBlockEnd(function, start);
kill(function, instIdx + 1, block.finish);
kill(function, start, end);
}
block.finish = instIdx;
}
removeUse(function, inst.a);
@ -147,6 +139,7 @@ void replace(IrFunction& function, uint32_t instIdx, IrInst replacement)
removeUse(function, inst.c);
removeUse(function, inst.d);
removeUse(function, inst.e);
removeUse(function, inst.f);
inst = replacement;
}
@ -162,12 +155,14 @@ void substitute(IrFunction& function, IrInst& inst, IrOp replacement)
removeUse(function, inst.c);
removeUse(function, inst.d);
removeUse(function, inst.e);
removeUse(function, inst.f);
inst.a = replacement;
inst.b = {};
inst.c = {};
inst.d = {};
inst.e = {};
inst.f = {};
}
void applySubstitutions(IrFunction& function, IrOp& op)
@ -203,9 +198,10 @@ void applySubstitutions(IrFunction& function, IrInst& inst)
applySubstitutions(function, inst.c);
applySubstitutions(function, inst.d);
applySubstitutions(function, inst.e);
applySubstitutions(function, inst.f);
}
static bool compare(double a, double b, IrCondition cond)
bool compare(double a, double b, IrCondition cond)
{
switch (cond)
{
@ -236,7 +232,7 @@ static bool compare(double a, double b, IrCondition cond)
return false;
}
void foldConstants(IrBuilder& build, IrFunction& function, uint32_t index)
void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint32_t index)
{
IrInst& inst = function.instructions[index];
@ -311,27 +307,27 @@ void foldConstants(IrBuilder& build, IrFunction& function, uint32_t index)
if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant)
{
if (function.tagOp(inst.a) == function.tagOp(inst.b))
replace(function, index, {IrCmd::JUMP, inst.c});
replace(function, block, index, {IrCmd::JUMP, inst.c});
else
replace(function, index, {IrCmd::JUMP, inst.d});
replace(function, block, index, {IrCmd::JUMP, inst.d});
}
break;
case IrCmd::JUMP_EQ_INT:
if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant)
{
if (function.intOp(inst.a) == function.intOp(inst.b))
replace(function, index, {IrCmd::JUMP, inst.c});
replace(function, block, index, {IrCmd::JUMP, inst.c});
else
replace(function, index, {IrCmd::JUMP, inst.d});
replace(function, block, index, {IrCmd::JUMP, inst.d});
}
break;
case IrCmd::JUMP_CMP_NUM:
if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant)
{
if (compare(function.doubleOp(inst.a), function.doubleOp(inst.b), function.conditionOp(inst.c)))
replace(function, index, {IrCmd::JUMP, inst.d});
replace(function, block, index, {IrCmd::JUMP, inst.d});
else
replace(function, index, {IrCmd::JUMP, inst.e});
replace(function, block, index, {IrCmd::JUMP, inst.e});
}
break;
case IrCmd::NUM_TO_INDEX:
@ -347,11 +343,11 @@ void foldConstants(IrBuilder& build, IrFunction& function, uint32_t index)
if (double(arrIndex) == value)
substitute(function, inst, build.constInt(arrIndex));
else
replace(function, index, {IrCmd::JUMP, inst.b});
replace(function, block, index, {IrCmd::JUMP, inst.b});
}
else
{
replace(function, index, {IrCmd::JUMP, inst.b});
replace(function, block, index, {IrCmd::JUMP, inst.b});
}
}
break;
@ -365,7 +361,7 @@ void foldConstants(IrBuilder& build, IrFunction& function, uint32_t index)
if (function.tagOp(inst.a) == function.tagOp(inst.b))
kill(function, inst);
else
replace(function, index, {IrCmd::JUMP, inst.c}); // Shows a conflict in assumptions on this path
replace(function, block, index, {IrCmd::JUMP, inst.c}); // Shows a conflict in assumptions on this path
}
break;
default:

View file

@ -0,0 +1,565 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/OptimizeConstProp.h"
#include "Luau/DenseHash.h"
#include "Luau/IrBuilder.h"
#include "Luau/IrUtils.h"
#include "lua.h"
namespace Luau
{
namespace CodeGen
{
// Data we know about the register value
struct RegisterInfo
{
uint8_t tag = 0xff;
IrOp value;
// Used to quickly invalidate links between SSA values and register memory
// It's a bit imprecise where value and tag both always invalidate together
uint32_t version = 0;
bool knownNotReadonly = false;
bool knownNoMetatable = false;
};
// Load instructions are linked to target register to carry knowledge about the target
// We track a register version at the point of the load so it's easy to break the link when register is updated
struct RegisterLink
{
uint8_t reg = 0;
uint32_t version = 0;
};
// Data we know about the current VM state
struct ConstPropState
{
uint8_t tryGetTag(IrOp op)
{
if (RegisterInfo* info = tryGetRegisterInfo(op))
return info->tag;
return 0xff;
}
void saveTag(IrOp op, uint8_t tag)
{
if (RegisterInfo* info = tryGetRegisterInfo(op))
info->tag = tag;
}
IrOp tryGetValue(IrOp op)
{
if (RegisterInfo* info = tryGetRegisterInfo(op))
return info->value;
return IrOp{IrOpKind::None, 0u};
}
void saveValue(IrOp op, IrOp value)
{
LUAU_ASSERT(value.kind == IrOpKind::Constant);
if (RegisterInfo* info = tryGetRegisterInfo(op))
info->value = value;
}
void invalidate(RegisterInfo& reg, bool invalidateTag, bool invalidateValue)
{
if (invalidateTag)
{
reg.tag = 0xff;
}
if (invalidateValue)
{
reg.value = {};
reg.knownNotReadonly = false;
reg.knownNoMetatable = false;
}
reg.version++;
}
void invalidateTag(IrOp regOp)
{
LUAU_ASSERT(regOp.kind == IrOpKind::VmReg);
invalidate(regs[regOp.index], /* invalidateTag */ true, /* invalidateValue */ false);
}
void invalidateValue(IrOp regOp)
{
LUAU_ASSERT(regOp.kind == IrOpKind::VmReg);
invalidate(regs[regOp.index], /* invalidateTag */ false, /* invalidateValue */ true);
}
void invalidate(IrOp regOp)
{
LUAU_ASSERT(regOp.kind == IrOpKind::VmReg);
invalidate(regs[regOp.index], /* invalidateTag */ true, /* invalidateValue */ true);
}
void invalidateRegistersFrom(uint32_t firstReg)
{
for (int i = int(firstReg); i <= maxReg; ++i)
invalidate(regs[i], /* invalidateTag */ true, /* invalidateValue */ true);
maxReg = int(firstReg) - 1;
}
void invalidateHeap()
{
for (int i = 0; i <= maxReg; ++i)
invalidateHeap(regs[i]);
}
void invalidateHeap(RegisterInfo& reg)
{
reg.knownNotReadonly = false;
reg.knownNoMetatable = false;
}
void invalidateAll()
{
// Invalidating registers also invalidates what we know about the heap (stored in RegisterInfo)
invalidateRegistersFrom(0u);
inSafeEnv = false;
}
void createRegLink(uint32_t instIdx, IrOp regOp)
{
LUAU_ASSERT(regOp.kind == IrOpKind::VmReg);
LUAU_ASSERT(!instLink.contains(instIdx));
instLink[instIdx] = RegisterLink{uint8_t(regOp.index), regs[regOp.index].version};
}
RegisterInfo* tryGetRegisterInfo(IrOp op)
{
if (op.kind == IrOpKind::VmReg)
{
maxReg = int(op.index) > maxReg ? int(op.index) : maxReg;
return &regs[op.index];
}
if (RegisterLink* link = tryGetRegLink(op))
{
maxReg = int(link->reg) > maxReg ? int(link->reg) : maxReg;
return &regs[link->reg];
}
return nullptr;
}
RegisterLink* tryGetRegLink(IrOp instOp)
{
if (instOp.kind != IrOpKind::Inst)
return nullptr;
if (RegisterLink* link = instLink.find(instOp.index))
{
// Check that the target register hasn't changed the value
if (link->version > regs[link->reg].version)
return nullptr;
return link;
}
return nullptr;
}
RegisterInfo regs[256];
// For range/full invalidations, we only want to visit a limited number of data that we have recorded
int maxReg = 0;
bool inSafeEnv = false;
bool checkedGc = false;
DenseHashMap<uint32_t, RegisterLink> instLink{~0u};
};
static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& function, IrBlock& block, IrInst& inst, uint32_t index)
{
switch (inst.cmd)
{
case IrCmd::LOAD_TAG:
if (uint8_t tag = state.tryGetTag(inst.a); tag != 0xff)
substitute(function, inst, build.constTag(tag));
else if (inst.a.kind == IrOpKind::VmReg)
state.createRegLink(index, inst.a);
break;
case IrCmd::LOAD_POINTER:
if (inst.a.kind == IrOpKind::VmReg)
state.createRegLink(index, inst.a);
break;
case IrCmd::LOAD_DOUBLE:
if (IrOp value = state.tryGetValue(inst.a); value.kind == IrOpKind::Constant)
substitute(function, inst, value);
else if (inst.a.kind == IrOpKind::VmReg)
state.createRegLink(index, inst.a);
break;
case IrCmd::LOAD_INT:
if (IrOp value = state.tryGetValue(inst.a); value.kind == IrOpKind::Constant)
substitute(function, inst, value);
else if (inst.a.kind == IrOpKind::VmReg)
state.createRegLink(index, inst.a);
break;
case IrCmd::LOAD_TVALUE:
if (inst.a.kind == IrOpKind::VmReg)
state.createRegLink(index, inst.a);
break;
case IrCmd::STORE_TAG:
if (inst.a.kind == IrOpKind::VmReg)
{
if (inst.b.kind == IrOpKind::Constant)
{
uint8_t value = function.tagOp(inst.b);
if (state.tryGetTag(inst.a) == value)
kill(function, inst);
else
state.saveTag(inst.a, value);
}
else
{
state.invalidateTag(inst.a);
}
}
break;
case IrCmd::STORE_POINTER:
if (inst.a.kind == IrOpKind::VmReg)
state.invalidateValue(inst.a);
break;
case IrCmd::STORE_DOUBLE:
if (inst.a.kind == IrOpKind::VmReg)
{
if (inst.b.kind == IrOpKind::Constant)
{
std::optional<double> oldValue = function.asDoubleOp(state.tryGetValue(inst.a));
double newValue = function.doubleOp(inst.b);
if (oldValue && *oldValue == newValue)
kill(function, inst);
else
state.saveValue(inst.a, inst.b);
}
else
{
state.invalidateValue(inst.a);
}
}
break;
case IrCmd::STORE_INT:
if (inst.a.kind == IrOpKind::VmReg)
{
if (inst.b.kind == IrOpKind::Constant)
{
std::optional<int> oldValue = function.asIntOp(state.tryGetValue(inst.a));
int newValue = function.intOp(inst.b);
if (oldValue && *oldValue == newValue)
kill(function, inst);
else
state.saveValue(inst.a, inst.b);
}
else
{
state.invalidateValue(inst.a);
}
}
break;
case IrCmd::STORE_TVALUE:
if (inst.a.kind == IrOpKind::VmReg)
{
state.invalidate(inst.a);
if (uint8_t tag = state.tryGetTag(inst.b); tag != 0xff)
state.saveTag(inst.a, tag);
if (IrOp value = state.tryGetValue(inst.b); value.kind != IrOpKind::None)
state.saveValue(inst.a, value);
}
break;
case IrCmd::JUMP_IF_TRUTHY:
if (uint8_t tag = state.tryGetTag(inst.a); tag != 0xff)
{
if (tag == LUA_TNIL)
replace(function, block, index, {IrCmd::JUMP, inst.c});
else if (tag != LUA_TBOOLEAN)
replace(function, block, index, {IrCmd::JUMP, inst.b});
}
break;
case IrCmd::JUMP_IF_FALSY:
if (uint8_t tag = state.tryGetTag(inst.a); tag != 0xff)
{
if (tag == LUA_TNIL)
replace(function, block, index, {IrCmd::JUMP, inst.b});
else if (tag != LUA_TBOOLEAN)
replace(function, block, index, {IrCmd::JUMP, inst.c});
}
break;
case IrCmd::JUMP_EQ_TAG:
{
uint8_t tagA = inst.a.kind == IrOpKind::Constant ? function.tagOp(inst.a) : state.tryGetTag(inst.a);
uint8_t tagB = inst.b.kind == IrOpKind::Constant ? function.tagOp(inst.b) : state.tryGetTag(inst.b);
if (tagA != 0xff && tagB != 0xff)
{
if (tagA == tagB)
replace(function, block, index, {IrCmd::JUMP, inst.c});
else
replace(function, block, index, {IrCmd::JUMP, inst.d});
}
break;
}
case IrCmd::JUMP_EQ_INT:
{
std::optional<int> valueA = function.asIntOp(inst.a.kind == IrOpKind::Constant ? inst.a : state.tryGetValue(inst.a));
std::optional<int> valueB = function.asIntOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b));
if (valueA && valueB)
{
if (*valueA == *valueB)
replace(function, block, index, {IrCmd::JUMP, inst.c});
else
replace(function, block, index, {IrCmd::JUMP, inst.d});
}
break;
}
case IrCmd::JUMP_CMP_NUM:
{
std::optional<double> valueA = function.asDoubleOp(inst.a.kind == IrOpKind::Constant ? inst.a : state.tryGetValue(inst.a));
std::optional<double> valueB = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b));
if (valueA && valueB)
{
if (compare(*valueA, *valueB, function.conditionOp(inst.c)))
replace(function, block, index, {IrCmd::JUMP, inst.d});
else
replace(function, block, index, {IrCmd::JUMP, inst.e});
}
break;
}
case IrCmd::GET_UPVALUE:
state.invalidate(inst.a);
break;
case IrCmd::CHECK_TAG:
{
uint8_t b = function.tagOp(inst.b);
if (uint8_t tag = state.tryGetTag(inst.a); tag != 0xff)
{
if (tag == b)
kill(function, inst);
else
replace(function, block, index, {IrCmd::JUMP, inst.c}); // Shows a conflict in assumptions on this path
}
else
{
state.saveTag(inst.a, b); // We can assume the tag value going forward
}
break;
}
case IrCmd::CHECK_READONLY:
if (RegisterInfo* info = state.tryGetRegisterInfo(inst.a))
{
if (info->knownNotReadonly)
kill(function, inst);
else
info->knownNotReadonly = true;
}
break;
case IrCmd::CHECK_NO_METATABLE:
if (RegisterInfo* info = state.tryGetRegisterInfo(inst.a))
{
if (info->knownNoMetatable)
kill(function, inst);
else
info->knownNoMetatable = true;
}
break;
case IrCmd::CHECK_SAFE_ENV:
if (state.inSafeEnv)
kill(function, inst);
else
state.inSafeEnv = true;
break;
case IrCmd::CHECK_GC:
// It is enough to perform a GC check once in a block
if (state.checkedGc)
kill(function, inst);
else
state.checkedGc = true;
break;
case IrCmd::BARRIER_OBJ:
case IrCmd::BARRIER_TABLE_FORWARD:
if (inst.b.kind == IrOpKind::VmReg)
{
if (uint8_t tag = state.tryGetTag(inst.b); tag != 0xff)
{
// If the written object is not collectable, barrier is not required
if (!isGCO(tag))
kill(function, inst);
}
}
break;
case IrCmd::LOP_FASTCALL:
case IrCmd::LOP_FASTCALL1:
case IrCmd::LOP_FASTCALL2:
case IrCmd::LOP_FASTCALL2K:
// TODO: classify fast call behaviors to avoid heap invalidation
state.invalidateHeap(); // Even a builtin method can change table properties
state.invalidateRegistersFrom(inst.b.index);
break;
case IrCmd::LOP_AND:
case IrCmd::LOP_ANDK:
case IrCmd::LOP_OR:
case IrCmd::LOP_ORK:
state.invalidate(inst.b);
break;
// These instructions don't have an effect on register/memory state we are tracking
case IrCmd::NOP:
case IrCmd::LOAD_NODE_VALUE_TV:
case IrCmd::LOAD_ENV:
case IrCmd::GET_ARR_ADDR:
case IrCmd::GET_SLOT_NODE_ADDR:
case IrCmd::STORE_NODE_VALUE_TV:
case IrCmd::ADD_INT:
case IrCmd::SUB_INT:
case IrCmd::ADD_NUM:
case IrCmd::SUB_NUM:
case IrCmd::MUL_NUM:
case IrCmd::DIV_NUM:
case IrCmd::MOD_NUM:
case IrCmd::POW_NUM:
case IrCmd::UNM_NUM:
case IrCmd::NOT_ANY:
case IrCmd::JUMP:
case IrCmd::JUMP_EQ_POINTER:
case IrCmd::TABLE_LEN:
case IrCmd::NEW_TABLE:
case IrCmd::DUP_TABLE:
case IrCmd::NUM_TO_INDEX:
case IrCmd::INT_TO_NUM:
case IrCmd::CHECK_ARRAY_SIZE:
case IrCmd::CHECK_SLOT_MATCH:
case IrCmd::BARRIER_TABLE_BACK:
case IrCmd::LOP_RETURN:
case IrCmd::LOP_COVERAGE:
case IrCmd::SET_UPVALUE:
case IrCmd::LOP_SETLIST: // We don't track table state that this can invalidate
case IrCmd::SET_SAVEDPC: // TODO: we may be able to remove some updates to PC
case IrCmd::CLOSE_UPVALS: // Doesn't change memory that we track
case IrCmd::CAPTURE:
case IrCmd::SUBSTITUTE:
case IrCmd::ADJUST_STACK_TO_REG: // Changes stack top, but not the values
case IrCmd::ADJUST_STACK_TO_TOP: // Changes stack top, but not the values
break;
// We don't model the following instructions, so we just clear all the knowledge we have built up
// Many of these call user functions that can change memory and captured registers
// Some of these might yield with similar effects
case IrCmd::JUMP_CMP_ANY:
case IrCmd::DO_ARITH:
case IrCmd::DO_LEN:
case IrCmd::GET_TABLE:
case IrCmd::SET_TABLE:
case IrCmd::GET_IMPORT:
case IrCmd::CONCAT:
case IrCmd::PREPARE_FORN:
case IrCmd::INTERRUPT: // TODO: it will be important to keep tag/value state, but we have to track register capture
case IrCmd::LOP_NAMECALL:
case IrCmd::LOP_CALL:
case IrCmd::LOP_FORGLOOP:
case IrCmd::LOP_FORGLOOP_FALLBACK:
case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK:
case IrCmd::FALLBACK_GETGLOBAL:
case IrCmd::FALLBACK_SETGLOBAL:
case IrCmd::FALLBACK_GETTABLEKS:
case IrCmd::FALLBACK_SETTABLEKS:
case IrCmd::FALLBACK_NAMECALL:
case IrCmd::FALLBACK_PREPVARARGS:
case IrCmd::FALLBACK_GETVARARGS:
case IrCmd::FALLBACK_NEWCLOSURE:
case IrCmd::FALLBACK_DUPCLOSURE:
case IrCmd::FALLBACK_FORGPREP:
// TODO: this is very conservative, some of there instructions can be tracked better
// TODO: non-captured register tags and values should not be cleared here
state.invalidateAll();
break;
}
}
static void constPropInBlock(IrBuilder& build, IrBlock& block, ConstPropState& state)
{
IrFunction& function = build.function;
for (uint32_t index = block.start; index <= block.finish; index++)
{
LUAU_ASSERT(index < function.instructions.size());
IrInst& inst = function.instructions[index];
applySubstitutions(function, inst);
foldConstants(build, function, block, index);
constPropInInst(state, build, function, block, inst, index);
}
}
static void constPropInBlockChain(IrBuilder& build, std::vector<uint8_t>& visited, IrBlock* block)
{
IrFunction& function = build.function;
ConstPropState state;
while (block)
{
uint32_t blockIdx = function.getBlockIndex(*block);
LUAU_ASSERT(!visited[blockIdx]);
visited[blockIdx] = true;
constPropInBlock(build, *block, state);
IrInst& termInst = function.instructions[block->finish];
IrBlock* nextBlock = nullptr;
// Unconditional jump into a block with a single user (current block) allows us to continue optimization
// with the information we have gathered so far (unless we have already visited that block earlier)
if (termInst.cmd == IrCmd::JUMP)
{
IrBlock& target = function.blockOp(termInst.a);
if (target.useCount == 1 && !visited[function.getBlockIndex(target)] && target.kind != IrBlockKind::Fallback)
nextBlock = &target;
}
block = nextBlock;
}
}
void constPropInBlockChains(IrBuilder& build)
{
IrFunction& function = build.function;
std::vector<uint8_t> visited(function.blocks.size(), false);
for (IrBlock& block : function.blocks)
{
if (block.kind == IrBlockKind::Fallback || block.kind == IrBlockKind::Dead)
continue;
if (visited[function.getBlockIndex(block)])
continue;
constPropInBlockChain(build, visited, &block);
}
}
} // namespace CodeGen
} // namespace Luau

View file

@ -17,7 +17,7 @@ static void optimizeMemoryOperandsX64(IrFunction& function, IrBlock& block)
{
LUAU_ASSERT(block.kind != IrBlockKind::Dead);
for (uint32_t index = block.start; true; index++)
for (uint32_t index = block.start; index <= block.finish; index++)
{
LUAU_ASSERT(index < function.instructions.size());
IrInst& inst = function.instructions[index];
@ -90,9 +90,6 @@ static void optimizeMemoryOperandsX64(IrFunction& function, IrBlock& block)
default:
break;
}
if (isBlockTerminator(inst.cmd))
break;
}
}

View file

@ -42,7 +42,7 @@
// Note: due to limitations of the versioning scheme, some bytecode blobs that carry version 2 are using features from version 3. Starting from version 3, version should be sufficient to indicate bytecode compatibility.
//
// Version 1: Baseline version for the open-source release. Supported until 0.521.
// Version 2: Adds Proto::linedefined. Currently supported.
// Version 2: Adds Proto::linedefined. Supported until 0.544.
// Version 3: Adds FORGPREP/JUMPXEQK* and enhances AUX encoding for FORGLOOP. Removes FORGLOOP_NEXT/INEXT and JUMPIFEQK/JUMPIFNOTEQK. Currently supported.
// Bytecode opcode, part of the instruction header

View file

@ -25,6 +25,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25)
LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300)
LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5)
LUAU_FASTFLAGVARIABLE(LuauCompileTerminateBC, false)
namespace Luau
{
@ -132,14 +134,18 @@ struct Compiler
return uint8_t(upvals.size() - 1);
}
bool allPathsEndWithReturn(AstStat* node)
// true iff all execution paths through node subtree result in return/break/continue
// note: because this function doesn't visit loop nodes, it (correctly) only detects break/continue that refer to the outer control flow
bool alwaysTerminates(AstStat* node)
{
if (AstStatBlock* stat = node->as<AstStatBlock>())
return stat->body.size > 0 && allPathsEndWithReturn(stat->body.data[stat->body.size - 1]);
return stat->body.size > 0 && alwaysTerminates(stat->body.data[stat->body.size - 1]);
else if (node->is<AstStatReturn>())
return true;
else if (FFlag::LuauCompileTerminateBC && (node->is<AstStatBreak>() || node->is<AstStatContinue>()))
return true;
else if (AstStatIf* stat = node->as<AstStatIf>())
return stat->elsebody && allPathsEndWithReturn(stat->thenbody) && allPathsEndWithReturn(stat->elsebody);
return stat->elsebody && alwaysTerminates(stat->thenbody) && alwaysTerminates(stat->elsebody);
else
return false;
}
@ -213,7 +219,7 @@ struct Compiler
// valid function bytecode must always end with RETURN
// we elide this if we're guaranteed to hit a RETURN statement regardless of the control flow
if (!allPathsEndWithReturn(stat))
if (!alwaysTerminates(stat))
{
setDebugLineEnd(stat);
closeLocals(0);
@ -257,7 +263,7 @@ struct Compiler
f.costModel = modelCost(func->body, func->args.data, func->args.size, builtins);
// track functions that only ever return a single value so that we can convert multret calls to fixedret calls
if (allPathsEndWithReturn(func->body))
if (alwaysTerminates(func->body))
{
ReturnVisitor returnVisitor(this);
stat->visit(&returnVisitor);
@ -640,7 +646,7 @@ struct Compiler
}
// for the fallthrough path we need to ensure we clear out target registers
if (!usedFallthrough && !allPathsEndWithReturn(func->body))
if (!usedFallthrough && !alwaysTerminates(func->body))
{
for (size_t i = 0; i < targetCount; ++i)
bytecode.emitABC(LOP_LOADNIL, uint8_t(target + i), 0, 0);
@ -2435,9 +2441,9 @@ struct Compiler
if (stat->elsebody && elseJump.size() > 0)
{
// we don't need to skip past "else" body if "then" ends with return
// we don't need to skip past "else" body if "then" ends with return/break/continue
// this is important because, if "else" also ends with return, we may *not* have any statement to skip to!
if (allPathsEndWithReturn(stat->thenbody))
if (alwaysTerminates(stat->thenbody))
{
size_t elseLabel = bytecode.emitLabel();

View file

@ -70,6 +70,7 @@ target_sources(Luau.CodeGen PRIVATE
CodeGen/include/Luau/IrUtils.h
CodeGen/include/Luau/Label.h
CodeGen/include/Luau/OperandX64.h
CodeGen/include/Luau/OptimizeConstProp.h
CodeGen/include/Luau/OptimizeFinalX64.h
CodeGen/include/Luau/RegisterA64.h
CodeGen/include/Luau/RegisterX64.h
@ -92,9 +93,12 @@ target_sources(Luau.CodeGen PRIVATE
CodeGen/src/IrBuilder.cpp
CodeGen/src/IrDump.cpp
CodeGen/src/IrLoweringX64.cpp
CodeGen/src/IrRegAllocX64.cpp
CodeGen/src/IrTranslateBuiltins.cpp
CodeGen/src/IrTranslation.cpp
CodeGen/src/IrUtils.cpp
CodeGen/src/NativeState.cpp
CodeGen/src/OptimizeConstProp.cpp
CodeGen/src/OptimizeFinalX64.cpp
CodeGen/src/UnwindBuilderDwarf2.cpp
CodeGen/src/UnwindBuilderWin.cpp
@ -109,6 +113,8 @@ target_sources(Luau.CodeGen PRIVATE
CodeGen/src/Fallbacks.h
CodeGen/src/FallbacksProlog.h
CodeGen/src/IrLoweringX64.h
CodeGen/src/IrRegAllocX64.h
CodeGen/src/IrTranslateBuiltins.h
CodeGen/src/IrTranslation.h
CodeGen/src/NativeState.h
)

View file

@ -1677,6 +1677,8 @@ RETURN R0 0
TEST_CASE("LoopBreak")
{
ScopedFastFlag sff("LuauCompileTerminateBC", true);
// default codegen: compile breaks as unconditional jumps
CHECK_EQ("\n" + compileFunction0("while true do if math.random() < 0.5 then break else end end"), R"(
L0: GETIMPORT R0 2 [math.random]
@ -1684,7 +1686,6 @@ CALL R0 0 1
LOADK R1 K3 [0.5]
JUMPIFNOTLT R0 R1 L1
RETURN R0 0
JUMP L1
L1: JUMPBACK L0
RETURN R0 0
)");
@ -1702,6 +1703,8 @@ L1: RETURN R0 0
TEST_CASE("LoopContinue")
{
ScopedFastFlag sff("LuauCompileTerminateBC", true);
// default codegen: compile continue as unconditional jumps
CHECK_EQ("\n" + compileFunction0("repeat if math.random() < 0.5 then continue else end break until false error()"), R"(
L0: GETIMPORT R0 2 [math.random]
@ -1710,7 +1713,6 @@ LOADK R1 K3 [0.5]
JUMPIFNOTLT R0 R1 L2
JUMP L1
JUMP L2
JUMP L2
L1: JUMPBACK L0
L2: GETIMPORT R0 5 [error]
CALL R0 0 0
@ -6808,4 +6810,59 @@ RETURN R0 0
)");
}
TEST_CASE("ElideJumpAfterIf")
{
ScopedFastFlag sff("LuauCompileTerminateBC", true);
// break refers to outer loop => we can elide unconditional branches
CHECK_EQ("\n" + compileFunction0(R"(
local foo, bar = ...
repeat
if foo then break
elseif bar then break
end
print(1234)
until foo == bar
)"),
R"(
GETVARARGS R0 2
L0: JUMPIFNOT R0 L1
RETURN R0 0
L1: JUMPIF R1 L2
GETIMPORT R2 1 [print]
LOADN R3 1234
CALL R2 1 0
JUMPIFEQ R0 R1 L2
JUMPBACK L0
L2: RETURN R0 0
)");
// break refers to inner loop => branches remain
CHECK_EQ("\n" + compileFunction0(R"(
local foo, bar = ...
repeat
if foo then while true do break end
elseif bar then while true do break end
end
print(1234)
until foo == bar
)"),
R"(
GETVARARGS R0 2
L0: JUMPIFNOT R0 L1
JUMP L2
JUMPBACK L2
JUMP L2
L1: JUMPIFNOT R1 L2
JUMP L2
JUMPBACK L2
L2: GETIMPORT R2 1 [print]
LOADN R3 1234
CALL R2 1 0
JUMPIFEQ R0 R1 L3
JUMPBACK L0
L3: RETURN R0 0
)");
}
TEST_SUITE_END();

View file

@ -31,7 +31,8 @@ void ConstraintGraphBuilderFixture::generateConstraints(const std::string& code)
void ConstraintGraphBuilderFixture::solve(const std::string& code)
{
generateConstraints(code);
ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger};
ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull{mainModule->reduction.get()},
NotNull(&moduleResolver), {}, &logger};
cs.run();
}

View file

@ -3,6 +3,7 @@
#include "Luau/IrAnalysis.h"
#include "Luau/IrDump.h"
#include "Luau/IrUtils.h"
#include "Luau/OptimizeConstProp.h"
#include "Luau/OptimizeFinalX64.h"
#include "doctest.h"
@ -16,12 +17,18 @@ class IrBuilderFixture
public:
void constantFold()
{
for (size_t i = 0; i < build.function.instructions.size(); i++)
for (IrBlock& block : build.function.blocks)
{
IrInst& inst = build.function.instructions[i];
if (block.kind == IrBlockKind::Dead)
continue;
applySubstitutions(build.function, inst);
foldConstants(build, build.function, uint32_t(i));
for (size_t i = block.start; i <= block.finish; i++)
{
IrInst& inst = build.function.instructions[i];
applySubstitutions(build.function, inst);
foldConstants(build, build.function, block, uint32_t(i));
}
}
}
@ -96,7 +103,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptCheckTag")
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
updateUseCounts(build.function);
optimizeMemoryOperandsX64(build.function);
@ -109,7 +116,7 @@ bb_0:
LOP_RETURN 0u
bb_fallback_1:
LOP_RETURN 0u
LOP_RETURN 1u
)");
}
@ -147,7 +154,6 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag1")
IrOp opA = build.inst(IrCmd::LOAD_TAG, build.vmReg(1));
IrOp opB = build.inst(IrCmd::LOAD_TAG, build.vmReg(2));
build.inst(IrCmd::JUMP_EQ_TAG, opA, opB, trueBlock, falseBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(trueBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
@ -184,7 +190,6 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag2")
IrOp opB = build.inst(IrCmd::LOAD_TAG, build.vmReg(2));
build.inst(IrCmd::STORE_TAG, build.vmReg(6), opA);
build.inst(IrCmd::JUMP_EQ_TAG, opA, opB, trueBlock, falseBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(trueBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
@ -223,7 +228,6 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag3")
IrOp arrElem = build.inst(IrCmd::GET_ARR_ADDR, table, build.constInt(0));
IrOp opA = build.inst(IrCmd::LOAD_TAG, arrElem);
build.inst(IrCmd::JUMP_EQ_TAG, opA, build.constTag(0), trueBlock, falseBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(trueBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
@ -261,7 +265,6 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptJumpCmpNum")
IrOp opA = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1));
IrOp opB = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2));
build.inst(IrCmd::JUMP_CMP_NUM, opA, opB, trueBlock, falseBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(trueBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
@ -466,14 +469,14 @@ bb_3:
TEST_CASE_FIXTURE(IrBuilderFixture, "ControlFlowCmpNum")
{
IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0));
auto compareFold = [this](IrOp lhs, IrOp rhs, IrCondition cond, bool result) {
IrOp instOp;
IrInst instExpected;
withTwoBlocks([&](IrOp a, IrOp b) {
instOp = build.inst(IrCmd::JUMP_CMP_NUM, lhs, rhs, build.cond(cond), a, b);
IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0));
instOp = build.inst(
IrCmd::JUMP_CMP_NUM, lhs.kind == IrOpKind::None ? nan : lhs, rhs.kind == IrOpKind::None ? nan : rhs, build.cond(cond), a, b);
instExpected = IrInst{IrCmd::JUMP, result ? a : b};
});
@ -482,6 +485,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ControlFlowCmpNum")
checkEq(instOp, instExpected);
};
IrOp nan; // Empty operand is used to signal a placement of a 'nan'
compareFold(build.constDouble(1), build.constDouble(1), IrCondition::Equal, true);
compareFold(build.constDouble(1), build.constDouble(2), IrCondition::Equal, false);
compareFold(nan, nan, IrCondition::Equal, false);
@ -532,3 +537,661 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ControlFlowCmpNum")
}
TEST_SUITE_END();
TEST_SUITE_BEGIN("ConstantPropagation");
TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTagsAndValues")
{
IrOp block = build.block(IrBlockKind::Internal);
build.beginBlock(block);
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber));
build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(10));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(0.5));
// We know constants from those loads
build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.inst(IrCmd::LOAD_TAG, build.vmReg(0)));
build.inst(IrCmd::STORE_INT, build.vmReg(4), build.inst(IrCmd::LOAD_INT, build.vmReg(1)));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(5), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)));
// We know that these overrides have no effect
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber));
build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(10));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(0.5));
// But we can invalidate them with unknown values
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.inst(IrCmd::LOAD_TAG, build.vmReg(6)));
build.inst(IrCmd::STORE_INT, build.vmReg(1), build.inst(IrCmd::LOAD_INT, build.vmReg(7)));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(8)));
// So now the constant stores have to be made
build.inst(IrCmd::STORE_TAG, build.vmReg(9), build.inst(IrCmd::LOAD_TAG, build.vmReg(0)));
build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::LOAD_INT, build.vmReg(1)));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(11), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)));
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_TAG R0, tnumber
STORE_INT R1, 10i
STORE_DOUBLE R2, 0.5
STORE_TAG R3, tnumber
STORE_INT R4, 10i
STORE_DOUBLE R5, 0.5
%12 = LOAD_TAG R6
STORE_TAG R0, %12
%14 = LOAD_INT R7
STORE_INT R1, %14
%16 = LOAD_DOUBLE R8
STORE_DOUBLE R2, %16
%18 = LOAD_TAG R0
STORE_TAG R9, %18
%20 = LOAD_INT R1
STORE_INT R10, %20
%22 = LOAD_DOUBLE R2
STORE_DOUBLE R11, %22
LOP_RETURN 0u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "PropagateThroughTvalue")
{
IrOp block = build.block(IrBlockKind::Internal);
build.beginBlock(block);
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5));
IrOp tv = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0));
build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), tv);
// We know constants from those loads
build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.inst(IrCmd::LOAD_TAG, build.vmReg(1)));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(3), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)));
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_TAG R0, tnumber
STORE_DOUBLE R0, 0.5
%2 = LOAD_TVALUE R0
STORE_TVALUE R1, %2
STORE_TAG R3, tnumber
STORE_DOUBLE R3, 0.5
LOP_RETURN 0u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "SkipCheckTag")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp fallback = build.block(IrBlockKind::Fallback);
build.beginBlock(block);
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber));
build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_TAG R0, tnumber
LOP_RETURN 0u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "SkipOncePerBlockChecks")
{
IrOp block = build.block(IrBlockKind::Internal);
build.beginBlock(block);
build.inst(IrCmd::CHECK_SAFE_ENV);
build.inst(IrCmd::CHECK_SAFE_ENV);
build.inst(IrCmd::CHECK_GC);
build.inst(IrCmd::CHECK_GC);
build.inst(IrCmd::DO_LEN, build.vmReg(1), build.vmReg(2)); // Can make env unsafe
build.inst(IrCmd::CHECK_SAFE_ENV);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
CHECK_SAFE_ENV
CHECK_GC
DO_LEN R1, R2
CHECK_SAFE_ENV
LOP_RETURN 0u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTableState")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp fallback = build.block(IrBlockKind::Fallback);
build.beginBlock(block);
IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0));
build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback);
build.inst(IrCmd::CHECK_READONLY, table, fallback);
build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback);
build.inst(IrCmd::CHECK_READONLY, table, fallback);
build.inst(IrCmd::DO_LEN, build.vmReg(1), build.vmReg(2)); // Can access all heap memory
build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback);
build.inst(IrCmd::CHECK_READONLY, table, fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
%0 = LOAD_POINTER R0
CHECK_NO_METATABLE %0, bb_fallback_1
CHECK_READONLY %0, bb_fallback_1
DO_LEN R1, R2
CHECK_NO_METATABLE %0, bb_fallback_1
CHECK_READONLY %0, bb_fallback_1
LOP_RETURN 0u
bb_fallback_1:
LOP_RETURN 1u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "SkipUselessBarriers")
{
IrOp block = build.block(IrBlockKind::Internal);
build.beginBlock(block);
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber));
IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1));
build.inst(IrCmd::BARRIER_TABLE_FORWARD, table, build.vmReg(0));
IrOp something = build.inst(IrCmd::LOAD_POINTER, build.vmReg(2));
build.inst(IrCmd::BARRIER_OBJ, something, build.vmReg(0));
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_TAG R0, tnumber
LOP_RETURN 0u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "ConcatInvalidation")
{
IrOp block = build.block(IrBlockKind::Internal);
build.beginBlock(block);
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber));
build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(10));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(0.5));
build.inst(IrCmd::CONCAT, build.vmReg(0), build.vmReg(3)); // Concat invalidates more than the target register
build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.inst(IrCmd::LOAD_TAG, build.vmReg(0)));
build.inst(IrCmd::STORE_INT, build.vmReg(4), build.inst(IrCmd::LOAD_INT, build.vmReg(1)));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(5), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)));
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_TAG R0, tnumber
STORE_INT R1, 10i
STORE_DOUBLE R2, 0.5
CONCAT R0, R3
%4 = LOAD_TAG R0
STORE_TAG R3, %4
%6 = LOAD_INT R1
STORE_INT R4, %6
%8 = LOAD_DOUBLE R2
STORE_DOUBLE R5, %8
LOP_RETURN 0u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinFastcallsMayInvalidateMemory")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp fallback = build.block(IrBlockKind::Fallback);
build.beginBlock(block);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5));
IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0));
build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback);
build.inst(IrCmd::CHECK_READONLY, table, fallback);
build.inst(IrCmd::LOP_FASTCALL1, build.constUint(0), build.vmReg(1), build.vmReg(2), fallback);
build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback);
build.inst(IrCmd::CHECK_READONLY, table, fallback);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0))); // At least R0 wasn't touched
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_DOUBLE R0, 0.5
%1 = LOAD_POINTER R0
CHECK_NO_METATABLE %1, bb_fallback_1
CHECK_READONLY %1, bb_fallback_1
LOP_FASTCALL1 0u, R1, R2, bb_fallback_1
CHECK_NO_METATABLE %1, bb_fallback_1
CHECK_READONLY %1, bb_fallback_1
STORE_DOUBLE R1, 0.5
LOP_RETURN 0u
bb_fallback_1:
LOP_RETURN 1u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "RedundantStoreCheckConstantType")
{
IrOp block = build.block(IrBlockKind::Internal);
build.beginBlock(block);
build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(10));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5));
build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(10));
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_INT R0, 10i
STORE_DOUBLE R0, 0.5
STORE_INT R0, 10i
LOP_RETURN 0u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "RedundantStoreCheckConstantType")
{
IrOp block = build.block(IrBlockKind::Internal);
build.beginBlock(block);
build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(10));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5));
build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(10));
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_INT R0, 10i
STORE_DOUBLE R0, 0.5
STORE_INT R0, 10i
LOP_RETURN 0u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagation")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp fallback = build.block(IrBlockKind::Fallback);
build.beginBlock(block);
IrOp unknown = build.inst(IrCmd::LOAD_TAG, build.vmReg(0));
build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback);
build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
%0 = LOAD_TAG R0
CHECK_TAG %0, tnumber, bb_fallback_1
LOP_RETURN 0u
bb_fallback_1:
LOP_RETURN 1u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagationConflicting")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp fallback = build.block(IrBlockKind::Fallback);
build.beginBlock(block);
IrOp unknown = build.inst(IrCmd::LOAD_TAG, build.vmReg(0));
build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback);
build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnil), fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
%0 = LOAD_TAG R0
CHECK_TAG %0, tnumber, bb_fallback_1
JUMP bb_fallback_1
bb_fallback_1:
LOP_RETURN 1u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "TruthyTestRemoval")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp trueBlock = build.block(IrBlockKind::Internal);
IrOp falseBlock = build.block(IrBlockKind::Internal);
IrOp fallback = build.block(IrBlockKind::Fallback);
build.beginBlock(block);
IrOp unknown = build.inst(IrCmd::LOAD_TAG, build.vmReg(1));
build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback);
build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(1), trueBlock, falseBlock);
build.beginBlock(trueBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
build.beginBlock(falseBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(2));
build.beginBlock(fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(3));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
%0 = LOAD_TAG R1
CHECK_TAG %0, tnumber, bb_fallback_3
JUMP bb_1
bb_1:
LOP_RETURN 1u
bb_fallback_3:
LOP_RETURN 3u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "FalsyTestRemoval")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp trueBlock = build.block(IrBlockKind::Internal);
IrOp falseBlock = build.block(IrBlockKind::Internal);
IrOp fallback = build.block(IrBlockKind::Fallback);
build.beginBlock(block);
IrOp unknown = build.inst(IrCmd::LOAD_TAG, build.vmReg(1));
build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback);
build.inst(IrCmd::JUMP_IF_FALSY, build.vmReg(1), trueBlock, falseBlock);
build.beginBlock(trueBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
build.beginBlock(falseBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(2));
build.beginBlock(fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(3));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
%0 = LOAD_TAG R1
CHECK_TAG %0, tnumber, bb_fallback_3
JUMP bb_2
bb_2:
LOP_RETURN 2u
bb_fallback_3:
LOP_RETURN 3u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "TagEqRemoval")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp trueBlock = build.block(IrBlockKind::Internal);
IrOp falseBlock = build.block(IrBlockKind::Internal);
build.beginBlock(block);
IrOp tag = build.inst(IrCmd::LOAD_TAG, build.vmReg(1));
build.inst(IrCmd::CHECK_TAG, tag, build.constTag(tboolean));
build.inst(IrCmd::JUMP_EQ_TAG, tag, build.constTag(tnumber), trueBlock, falseBlock);
build.beginBlock(trueBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
build.beginBlock(falseBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(2));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
%0 = LOAD_TAG R1
CHECK_TAG %0, tboolean
JUMP bb_2
bb_2:
LOP_RETURN 2u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "IntEqRemoval")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp trueBlock = build.block(IrBlockKind::Internal);
IrOp falseBlock = build.block(IrBlockKind::Internal);
build.beginBlock(block);
IrOp value = build.inst(IrCmd::LOAD_INT, build.vmReg(1));
build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(5));
build.inst(IrCmd::JUMP_EQ_INT, value, build.constInt(5), trueBlock, falseBlock);
build.beginBlock(trueBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
build.beginBlock(falseBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(2));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_INT R1, 5i
JUMP bb_1
bb_1:
LOP_RETURN 1u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "NumCmpRemoval")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp trueBlock = build.block(IrBlockKind::Internal);
IrOp falseBlock = build.block(IrBlockKind::Internal);
build.beginBlock(block);
IrOp value = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(4.0));
build.inst(IrCmd::JUMP_CMP_NUM, value, build.constDouble(8.0), build.cond(IrCondition::Greater), trueBlock, falseBlock);
build.beginBlock(trueBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
build.beginBlock(falseBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(2));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_DOUBLE R1, 4
JUMP bb_2
bb_2:
LOP_RETURN 2u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "DataFlowsThroughDirectJumpToUniqueSuccessor")
{
IrOp block1 = build.block(IrBlockKind::Internal);
IrOp block2 = build.block(IrBlockKind::Internal);
build.beginBlock(block1);
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber));
build.inst(IrCmd::JUMP, block2);
build.beginBlock(block2);
build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.inst(IrCmd::LOAD_TAG, build.vmReg(0)));
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_TAG R0, tnumber
JUMP bb_1
bb_1:
STORE_TAG R1, tnumber
LOP_RETURN 1u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "DataDoesNotFlowThroughDirectJumpToNonUniqueSuccessor")
{
IrOp block1 = build.block(IrBlockKind::Internal);
IrOp block2 = build.block(IrBlockKind::Internal);
IrOp block3 = build.block(IrBlockKind::Internal);
build.beginBlock(block1);
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber));
build.inst(IrCmd::JUMP, block2);
build.beginBlock(block2);
build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.inst(IrCmd::LOAD_TAG, build.vmReg(0)));
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
build.beginBlock(block3);
build.inst(IrCmd::JUMP, block2);
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_TAG R0, tnumber
JUMP bb_1
bb_1:
%2 = LOAD_TAG R0
STORE_TAG R1, %2
LOP_RETURN 1u
bb_2:
JUMP bb_1
)");
}
TEST_SUITE_END();

View file

@ -71,6 +71,14 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive")
TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table")
{
// Under DCR, we don't seal the outer occurrance of the table `Cyclic` which
// breaks this test. I'm not sure if that behaviour change is important or
// not, but it's tangental to the core purpose of this test.
ScopedFastFlag sff[] = {
{"DebugLuauDeferredConstraintResolution", false},
};
CheckResult result = check(R"(
local Cyclic = {}
function Cyclic.get()
@ -85,13 +93,13 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table")
* Assert that the return type of get() is the same as the outer table.
*/
TypeId counterType = requireType("Cyclic");
TypeId ty = requireType("Cyclic");
TypeArena dest;
CloneState cloneState;
TypeId counterCopy = clone(counterType, dest, cloneState);
TypeId cloneTy = clone(ty, dest, cloneState);
TableType* ttv = getMutable<TableType>(counterCopy);
TableType* ttv = getMutable<TableType>(cloneTy);
REQUIRE(ttv != nullptr);
CHECK_EQ(std::optional<std::string>{"Cyclic"}, ttv->syntheticName);
@ -105,11 +113,42 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table")
std::optional<TypeId> methodReturnType = first(ftv->retTypes);
REQUIRE(methodReturnType);
CHECK_EQ(methodReturnType, counterCopy);
CHECK_MESSAGE(methodReturnType == cloneTy, toString(methodType, {true}) << " should be pointer identical to " << toString(cloneTy, {true}));
CHECK_EQ(2, dest.typePacks.size()); // one for the function args, and another for its return type
CHECK_EQ(2, dest.types.size()); // One table and one function
}
TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table_2")
{
TypeArena src;
TypeId tableTy = src.addType(TableType{});
TableType* tt = getMutable<TableType>(tableTy);
REQUIRE(tt);
TypeId methodTy = src.addType(FunctionType{src.addTypePack({}), src.addTypePack({tableTy})});
tt->props["get"].type = methodTy;
TypeArena dest;
CloneState cloneState;
TypeId cloneTy = clone(tableTy, dest, cloneState);
TableType* ctt = getMutable<TableType>(cloneTy);
REQUIRE(ctt);
TypeId clonedMethodType = ctt->props["get"].type;
REQUIRE(clonedMethodType);
const FunctionType* cmf = get<FunctionType>(clonedMethodType);
REQUIRE(cmf);
std::optional<TypeId> cloneMethodReturnType = first(cmf->retTypes);
REQUIRE(bool(cloneMethodReturnType));
CHECK(*cloneMethodReturnType == cloneTy);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena")
{
CheckResult result = check(R"(

View file

@ -194,7 +194,7 @@ TEST_CASE_FIXTURE(Fixture, "inline_table_props_are_also_any")
CHECK_EQ(*typeChecker.anyType, *ttv->props["one"].type);
CHECK_EQ(*typeChecker.anyType, *ttv->props["two"].type);
CHECK_MESSAGE(get<FunctionType>(ttv->props["three"].type), "Should be a function: " << *ttv->props["three"].type);
CHECK_MESSAGE(get<FunctionType>(follow(ttv->props["three"].type)), "Should be a function: " << *ttv->props["three"].type);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_iterator_variables_are_any")

View file

@ -786,7 +786,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param")
TypeId parentTy = requireType("foo");
auto ttv = get<TableType>(follow(parentTy));
auto ftv = get<FunctionType>(ttv->props.at("method").type);
auto ftv = get<FunctionType>(follow(ttv->props.at("method").type));
CHECK_EQ("foo:method<a>(self: a, arg: string): ()", toStringNamedFunction("foo:method", *ftv));
}
@ -803,12 +803,16 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param")
end
)");
TypeId parentTy = requireType("foo");
auto ttv = get<TableType>(follow(parentTy));
auto ftv = get<FunctionType>(ttv->props.at("method").type);
ToStringOptions opts;
opts.hideFunctionSelfArgument = true;
TypeId parentTy = requireType("foo");
auto ttv = get<TableType>(follow(parentTy));
REQUIRE_MESSAGE(ttv, "Expected a table but got " << toString(parentTy, opts));
TypeId methodTy = follow(ttv->props.at("method").type);
auto ftv = get<FunctionType>(methodTy);
REQUIRE_MESSAGE(ftv, "Expected a function but got " << toString(methodTy, opts));
CHECK_EQ("foo:method<a>(arg: string): ()", toStringNamedFunction("foo:method", *ftv, opts));
}

View file

@ -855,16 +855,8 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni
type FutureIntersection = A & B
)");
if (FFlag::DebugLuauDeferredConstraintResolution)
{
// To be quite honest, I don't know exactly why DCR fixes this.
LUAU_REQUIRE_NO_ERRORS(result);
}
else
{
// TODO: shared self causes this test to break in bizarre ways.
LUAU_REQUIRE_ERRORS(result);
}
// TODO: shared self causes this test to break in bizarre ways.
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok")

View file

@ -660,7 +660,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "higher_order_function_4")
)");
LUAU_REQUIRE_NO_ERRORS(result);
dumpErrors(result);
/*
* mergesort takes two arguments: an array of some type T and a function that takes two Ts.
@ -1424,9 +1423,11 @@ end
TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_non_self_sealed_overwrite")
{
CheckResult result = check(R"(
function string.len(): number
return 1
end
function string.len(): number
return 1
end
local s = string
)");
LUAU_REQUIRE_NO_ERRORS(result);
@ -1434,11 +1435,11 @@ end
// if 'string' library property was replaced with an internal module type, it will be freed and the next check will crash
frontend.clear();
result = check(R"(
print(string.len('hello'))
CheckResult result2 = check(R"(
print(string.len('hello'))
)");
LUAU_REQUIRE_NO_ERRORS(result);
LUAU_REQUIRE_NO_ERRORS(result2);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_non_self_sealed_overwrite_2")

View file

@ -1404,6 +1404,40 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknowns")
}
}
TEST_CASE_FIXTURE(BuiltinsFixture, "refine_boolean")
{
CheckResult result = check(R"(
local function f(x: number | boolean)
if typeof(x) == "boolean" then
local foo = x
else
local foo = x
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("boolean", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("number", toString(requireTypeAtPosition({5, 28})));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "refine_thread")
{
CheckResult result = check(R"(
local function f(x: number | thread)
if typeof(x) == "thread" then
local foo = x
else
local foo = x
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("thread", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("number", toString(requireTypeAtPosition({5, 28})));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_nil")
{
CheckResult result = check(R"(

View file

@ -347,8 +347,8 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification_3")
const TableType* arg0Table = get<TableType>(follow(arg0));
REQUIRE(arg0Table != nullptr);
REQUIRE(arg0Table->props.find("bar") != arg0Table->props.end());
REQUIRE(arg0Table->props.find("baz") != arg0Table->props.end());
CHECK(arg0Table->props.count("bar"));
CHECK(arg0Table->props.count("baz"));
}
TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1")
@ -2482,12 +2482,18 @@ TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_indexer")
TEST_CASE_FIXTURE(Fixture, "wrong_assign_does_hit_indexer")
{
CheckResult result = check("local a = {} a[0] = 7 a[0] = 't'");
CheckResult result = check(R"(
local a = {}
a[0] = 7
a[0] = 't'
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{
typeChecker.numberType,
typeChecker.stringType,
}}));
CHECK((Location{Position{3, 15}, Position{3, 18}}) == result.errors[0].location);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK(tm->wantedType == typeChecker.numberType);
CHECK(tm->givenType == typeChecker.stringType);
}
TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer")
@ -2673,7 +2679,10 @@ TEST_CASE_FIXTURE(Fixture, "inferring_crazy_table_should_also_be_quick")
)");
ModulePtr module = getMainModule();
CHECK_GE(100, module->internalTypes.types.size());
if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_GE(500, module->internalTypes.types.size());
else
CHECK_GE(100, module->internalTypes.types.size());
}
TEST_CASE_FIXTURE(Fixture, "MixedPropertiesAndIndexers")

View file

@ -1,5 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Common.h"
#include "Luau/Scope.h"
#include "Luau/Symbol.h"
#include "Luau/TypeInfer.h"
#include "Luau/Type.h"
@ -9,6 +11,8 @@
using namespace Luau;
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
struct TryUnifyFixture : Fixture
{
TypeArena arena;
@ -254,7 +258,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cli_41095_concat_log_in_sealed_table_unifica
LUAU_REQUIRE_ERROR_COUNT(2, result);
CHECK_EQ(toString(result.errors[0]), "No overload for function accepts 0 arguments.");
CHECK_EQ(toString(result.errors[1]), "Available overloads: ({a}, a) -> (); and ({a}, number, a) -> ()");
if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ(toString(result.errors[1]), "Available overloads: <V>({V}, V) -> (); and <V>({V}, number, V) -> ()");
else
CHECK_EQ(toString(result.errors[1]), "Available overloads: ({a}, a) -> (); and ({a}, number, a) -> ()");
}
TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly")

View file

@ -230,7 +230,7 @@ TEST_CASE_FIXTURE(Fixture, "assign_to_subscript_which_is_never")
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "assign_to_subscript_which_is_never")
TEST_CASE_FIXTURE(Fixture, "for_loop_over_never")
{
CheckResult result = check(R"(
for i, v in (5 :: never) do

View file

@ -5,23 +5,16 @@ AstQuery.last_argument_function_call_type
AstQuery::getDocumentationSymbolAtPosition.overloaded_class_method
AstQuery::getDocumentationSymbolAtPosition.overloaded_fn
AstQuery::getDocumentationSymbolAtPosition.table_overloaded_function_prop
AutocompleteTest.autocomplete_first_function_arg_expected_type
AutocompleteTest.autocomplete_oop_implicit_self
AutocompleteTest.autocomplete_string_singleton_equality
AutocompleteTest.do_compatible_self_calls
AutocompleteTest.do_wrong_compatible_self_calls
AutocompleteTest.type_correct_expected_return_type_suggestion
AutocompleteTest.type_correct_suggestion_for_overloads
BuiltinTests.aliased_string_format
BuiltinTests.assert_removes_falsy_types
BuiltinTests.assert_removes_falsy_types2
BuiltinTests.assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type
BuiltinTests.assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy
BuiltinTests.bad_select_should_not_crash
BuiltinTests.coroutine_wrap_anything_goes
BuiltinTests.debug_info_is_crazy
BuiltinTests.debug_traceback_is_crazy
BuiltinTests.dont_add_definitions_to_persistent_types
BuiltinTests.find_capture_types3
BuiltinTests.gmatch_definition
BuiltinTests.match_capture_types
BuiltinTests.match_capture_types2
@ -34,16 +27,14 @@ BuiltinTests.sort_with_bad_predicate
BuiltinTests.string_format_as_method
BuiltinTests.string_format_correctly_ordered_types
BuiltinTests.string_format_report_all_type_errors_at_correct_positions
BuiltinTests.string_format_tostring_specifier_type_constraint
BuiltinTests.string_format_use_correct_argument2
BuiltinTests.table_freeze_is_generic
BuiltinTests.table_insert_correctly_infers_type_of_array_2_args_overload
BuiltinTests.table_insert_correctly_infers_type_of_array_3_args_overload
BuiltinTests.table_pack
BuiltinTests.table_pack_reduce
BuiltinTests.table_pack_variadic
DefinitionTests.class_definition_overload_metamethods
DefinitionTests.class_definition_string_props
DefinitionTests.definition_file_classes
FrontendTest.environments
FrontendTest.nocheck_cycle_used_by_checked
GenericsTests.apply_type_function_nested_generics2
@ -52,6 +43,7 @@ GenericsTests.bound_tables_do_not_clone_original_fields
GenericsTests.check_mutual_generic_functions
GenericsTests.correctly_instantiate_polymorphic_member_functions
GenericsTests.do_not_infer_generic_functions
GenericsTests.dont_unify_bound_types
GenericsTests.generic_argument_count_too_few
GenericsTests.generic_argument_count_too_many
GenericsTests.generic_functions_should_be_memory_safe
@ -62,16 +54,13 @@ GenericsTests.infer_generic_function_function_argument_3
GenericsTests.infer_generic_function_function_argument_overloaded
GenericsTests.infer_generic_lib_function_function_argument
GenericsTests.instantiated_function_argument_names
GenericsTests.instantiation_sharing_types
GenericsTests.no_stack_overflow_from_quantifying
GenericsTests.self_recursive_instantiated_param
IntersectionTypes.select_correct_union_fn
IntersectionTypes.should_still_pick_an_overload_whose_arguments_are_unions
IntersectionTypes.overload_is_not_a_function
IntersectionTypes.table_intersection_write_sealed
IntersectionTypes.table_intersection_write_sealed_indirect
IntersectionTypes.table_write_sealed_indirect
ModuleTests.clone_self_property
ModuleTests.deepClone_cyclic_table
NonstrictModeTests.for_in_iterator_variables_are_any
NonstrictModeTests.function_parameters_are_any
NonstrictModeTests.inconsistent_module_return_types_are_ok
@ -85,7 +74,6 @@ NonstrictModeTests.offer_a_hint_if_you_use_a_dot_instead_of_a_colon
NonstrictModeTests.parameters_having_type_any_are_optional
NonstrictModeTests.table_dot_insert_and_recursive_calls
NonstrictModeTests.table_props_are_any
Normalize.cyclic_table_normalizes_sensibly
ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal
ProvisionalTests.bail_early_if_unification_is_too_complicated
ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack
@ -93,31 +81,28 @@ ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean
ProvisionalTests.free_options_cannot_be_unified_together
ProvisionalTests.generic_type_leak_to_module_interface_variadic
ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns
ProvisionalTests.pcall_returns_at_least_two_value_but_function_returns_nothing
ProvisionalTests.setmetatable_constrains_free_type_into_free_table
ProvisionalTests.specialization_binds_with_prototypes_too_early
ProvisionalTests.table_insert_with_a_singleton_argument
ProvisionalTests.typeguard_inference_incomplete
ProvisionalTests.weirditer_should_not_loop_forever
RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string
RefinementTest.discriminate_tag
RefinementTest.discriminate_from_isa_of_x
RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil
RefinementTest.narrow_property_of_a_bounded_variable
RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true
RefinementTest.refine_a_param_that_got_resolved_during_constraint_solving_stage
RefinementTest.refine_param_of_type_folder_or_part_without_using_typeof
RefinementTest.refine_unknowns
RefinementTest.type_guard_can_filter_for_intersection_of_tables
RefinementTest.type_narrow_for_all_the_userdata
RefinementTest.type_narrow_to_vector
RefinementTest.typeguard_cast_free_table_to_vector
RefinementTest.typeguard_in_assert_position
RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table
RefinementTest.x_is_not_instance_or_else_not_part
RuntimeLimits.typescript_port_of_Result_type
TableTests.a_free_shape_can_turn_into_a_scalar_directly
TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible
TableTests.accidentally_checked_prop_in_opposite_branch
TableTests.any_when_indexing_into_an_unsealed_table_with_no_indexer_in_nonstrict_mode
TableTests.call_method
TableTests.casting_tables_with_props_into_table_with_indexer3
TableTests.casting_tables_with_props_into_table_with_indexer4
TableTests.checked_prop_too_early
@ -135,7 +120,6 @@ TableTests.explicitly_typed_table_with_indexer
TableTests.found_like_key_in_table_function_call
TableTests.found_like_key_in_table_property_access
TableTests.found_multiple_like_keys
TableTests.function_calls_produces_sealed_table_given_unsealed_table
TableTests.fuzz_table_unify_instantiated_table
TableTests.generic_table_instantiation_potential_regression
TableTests.give_up_after_one_metatable_index_look_up
@ -144,21 +128,16 @@ TableTests.indexing_from_a_table_should_prefer_properties_when_possible
TableTests.inequality_operators_imply_exactly_matching_types
TableTests.infer_array_2
TableTests.inferred_return_type_of_free_table
TableTests.inferring_crazy_table_should_also_be_quick
TableTests.instantiate_table_cloning_3
TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound
TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound
TableTests.leaking_bad_metatable_errors
TableTests.less_exponential_blowup_please
TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred
TableTests.mixed_tables_with_implicit_numbered_keys
TableTests.nil_assign_doesnt_hit_indexer
TableTests.nil_assign_doesnt_hit_no_indexer
TableTests.okay_to_add_property_to_unsealed_tables_by_function_call
TableTests.ok_to_set_nil_even_on_non_lvalue_base_expr
TableTests.only_ascribe_synthetic_names_at_module_scope
TableTests.oop_indexer_works
TableTests.oop_polymorphic
TableTests.open_table_unification_2
TableTests.quantify_even_that_table_was_never_exported_at_all
TableTests.quantify_metatables_of_metatables_of_table
TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table
@ -169,32 +148,21 @@ TableTests.shared_selfs
TableTests.shared_selfs_from_free_param
TableTests.shared_selfs_through_metatables
TableTests.table_call_metamethod_basic
TableTests.table_indexing_error_location
TableTests.table_insert_should_cope_with_optional_properties_in_nonstrict
TableTests.table_insert_should_cope_with_optional_properties_in_strict
TableTests.table_param_row_polymorphism_3
TableTests.table_simple_call
TableTests.table_subtyping_with_extra_props_dont_report_multiple_errors
TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors
TableTests.table_unification_4
TableTests.unifying_tables_shouldnt_uaf2
TableTests.used_colon_instead_of_dot
TableTests.used_dot_instead_of_colon
ToString.exhaustive_toString_of_cyclic_table
TableTests.when_augmenting_an_unsealed_table_with_an_indexer_apply_the_correct_scope_to_the_indexer_type
ToString.named_metatable_toStringNamedFunction
ToString.toStringDetailed2
ToString.toStringErrorPack
ToString.toStringNamedFunction_generic_pack
ToString.toStringNamedFunction_hide_self_param
ToString.toStringNamedFunction_include_self_param
ToString.toStringNamedFunction_map
TryUnifyTests.cli_41095_concat_log_in_sealed_table_unification
TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType
TryUnifyTests.result_of_failed_typepack_unification_is_constrained
TryUnifyTests.typepack_unification_should_trim_free_tails
TryUnifyTests.variadics_should_use_reversed_properly
TypeAliases.cannot_create_cyclic_type_with_unknown_module
TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any
TypeAliases.generic_param_remap
TypeAliases.mismatched_generic_type_param
TypeAliases.mutually_recursive_types_restriction_not_ok_1
@ -218,11 +186,9 @@ TypeInfer.it_is_ok_to_have_inconsistent_number_of_return_values_in_nonstrict
TypeInfer.no_stack_overflow_from_isoptional
TypeInfer.no_stack_overflow_from_isoptional2
TypeInfer.tc_after_error_recovery_no_replacement_name_in_error
TypeInfer.tc_if_else_expressions_expected_type_3
TypeInfer.type_infer_recursion_limit_no_ice
TypeInfer.type_infer_recursion_limit_normalizer
TypeInferAnyError.for_in_loop_iterator_is_any2
TypeInferClasses.can_read_prop_of_base_class_using_string
TypeInferClasses.class_type_mismatch_with_name_conflict
TypeInferClasses.classes_without_overloaded_operators_cannot_be_added
TypeInferClasses.higher_order_function_arguments_are_contravariant
@ -232,6 +198,7 @@ TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_propert
TypeInferClasses.warn_when_prop_almost_matches
TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types
TypeInferFunctions.cannot_hoist_interior_defns_into_signature
TypeInferFunctions.check_function_before_lambda_that_uses_it
TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists
TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site
TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict
@ -243,10 +210,7 @@ TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer
TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict
TypeInferFunctions.improved_function_arg_mismatch_errors
TypeInferFunctions.infer_anonymous_function_arguments
TypeInferFunctions.infer_return_type_from_selected_overload
TypeInferFunctions.infer_that_function_does_not_return_a_table
TypeInferFunctions.list_all_overloads_if_no_overload_takes_given_argument_count
TypeInferFunctions.list_only_alternative_overloads_that_match_argument_count
TypeInferFunctions.luau_subtyping_is_np_hard
TypeInferFunctions.no_lossy_function_type
TypeInferFunctions.occurs_check_failure_in_function_return_type
@ -273,13 +237,11 @@ TypeInferModules.do_not_modify_imported_types_5
TypeInferModules.module_type_conflict
TypeInferModules.module_type_conflict_instantiated
TypeInferModules.type_error_of_unknown_qualified_type
TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_another_overload_works
TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory
TypeInferOOP.methods_are_topologically_sorted
TypeInferOOP.object_constructor_can_refer_to_method_of_self
TypeInferOperators.CallAndOrOfFunctions
TypeInferOperators.CallOrOfFunctions
TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable
TypeInferOperators.cannot_indirectly_compare_types_that_do_not_have_a_metatable
TypeInferOperators.cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators
TypeInferOperators.cli_38355_recursive_union
@ -303,25 +265,21 @@ TypeInferUnknownNever.math_operators_and_never
TypePackTests.detect_cyclic_typepacks2
TypePackTests.pack_tail_unification_check
TypePackTests.type_alias_backwards_compatible
TypePackTests.type_alias_default_export
TypePackTests.type_alias_default_mixed_self
TypePackTests.type_alias_default_type_chained
TypePackTests.type_alias_default_type_errors
TypePackTests.type_alias_default_type_pack_self_chained_tp
TypePackTests.type_alias_default_type_pack_self_tp
TypePackTests.type_alias_default_type_self
TypePackTests.type_alias_defaults_confusing_types
TypePackTests.type_alias_defaults_recursive_type
TypePackTests.type_alias_type_pack_multi
TypePackTests.type_alias_type_pack_variadic
TypePackTests.type_alias_type_packs_errors
TypePackTests.type_alias_type_packs_nested
TypePackTests.unify_variadic_tails_in_arguments
TypePackTests.unify_variadic_tails_in_arguments_free
TypePackTests.variadic_packs
TypeSingletons.function_call_with_singletons
TypeSingletons.function_call_with_singletons_mismatch
TypeSingletons.indexing_on_union_of_string_singletons
TypeSingletons.no_widening_from_callsites
TypeSingletons.overloaded_function_call_with_singletons
TypeSingletons.overloaded_function_call_with_singletons_mismatch
TypeSingletons.return_type_of_f_is_not_widened