Merge branch 'upstream' into merge

This commit is contained in:
Andy Friesen 2024-06-07 10:20:14 -07:00
commit 40e03164f7
68 changed files with 2712 additions and 687 deletions

View file

@ -57,7 +57,7 @@ struct GeneralizationConstraint
struct IterableConstraint
{
TypePackId iterator;
TypePackId variables;
std::vector<TypeId> variables;
const AstNode* nextAstFragment;
DenseHashMap<const AstNode*, TypeId>* astForInNextTypes;
@ -192,13 +192,7 @@ struct HasIndexerConstraint
TypeId indexType;
};
struct AssignConstraint
{
TypeId lhsType;
TypeId rhsType;
};
// assign lhsType propName rhsType
// assignProp lhsType propName rhsType
//
// Assign a value of type rhsType into the named property of lhsType.
@ -212,6 +206,12 @@ struct AssignPropConstraint
/// populate astTypes during constraint resolution. Nothing should ever
/// block on it.
TypeId propType;
// When we generate constraints, we increment the remaining prop count on
// the table if we are able. This flag informs the solver as to whether or
// not it should in turn decrement the prop count when this constraint is
// dispatched.
bool decrementPropCount = false;
};
struct AssignIndexConstraint
@ -226,13 +226,13 @@ struct AssignIndexConstraint
TypeId propType;
};
// resultType ~ unpack sourceTypePack
// resultTypes ~ unpack sourceTypePack
//
// Similar to PackSubtypeConstraint, but with one important difference: If the
// sourcePack is blocked, this constraint blocks.
struct UnpackConstraint
{
TypePackId resultPack;
std::vector<TypeId> resultPack;
TypePackId sourcePack;
};
@ -254,7 +254,7 @@ struct ReducePackConstraint
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, IterableConstraint, NameConstraint,
TypeAliasExpansionConstraint, FunctionCallConstraint, FunctionCheckConstraint, PrimitiveTypeConstraint, HasPropConstraint, HasIndexerConstraint,
AssignConstraint, AssignPropConstraint, AssignIndexConstraint, UnpackConstraint, ReduceConstraint, ReducePackConstraint, EqualityConstraint>;
AssignPropConstraint, AssignIndexConstraint, UnpackConstraint, ReduceConstraint, ReducePackConstraint, EqualityConstraint>;
struct Constraint
{

View file

@ -118,6 +118,8 @@ struct ConstraintGenerator
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope;
std::vector<RequireCycle> requireCycles;
DenseHashMap<TypeId, std::vector<TypeId>> localTypes{nullptr};
DcrLogger* logger;
ConstraintGenerator(ModulePtr module, NotNull<Normalizer> normalizer, NotNull<ModuleResolver> moduleResolver, NotNull<BuiltinTypes> builtinTypes,
@ -354,6 +356,8 @@ private:
*/
void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program);
bool recordPropertyAssignment(TypeId ty);
// Record the fact that a particular local has a particular type in at least
// one of its states.
void recordInferredBinding(AstLocal* local, TypeId ty);

View file

@ -142,7 +142,6 @@ struct ConstraintSolver
std::pair<bool, std::optional<TypeId>> tryDispatchSetIndexer(
NotNull<const Constraint> constraint, TypeId subjectType, TypeId indexType, TypeId propType, bool expandFreeTypeBounds);
bool tryDispatch(const AssignConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const AssignPropConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const AssignIndexConstraint& c, NotNull<const Constraint> constraint);
@ -158,8 +157,7 @@ struct ConstraintSolver
bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
// for a, ... in next_function, t, ... do
bool tryDispatchIterableFunction(
TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType,
const std::string& propName, ValueContext context, bool inConditional = false, bool suppressSimplification = false);
@ -168,14 +166,18 @@ struct ConstraintSolver
/**
* Generate constraints to unpack the types of srcTypes and assign each
* value to the corresponding LocalType in destTypes.
* value to the corresponding BlockedType in destTypes.
*
* @param destTypes A finite TypePack comprised of LocalTypes.
* This function also overwrites the owners of each BlockedType. This is
* okay because this function is only used to decompose IterableConstraint
* into an UnpackConstraint.
*
* @param destTypes A vector of types comprised of BlockedTypes.
* @param srcTypes A TypePack that represents rvalues to be assigned.
* @returns The underlying UnpackConstraint. There's a bit of code in
* iteration that needs to pass blocks on to this constraint.
*/
NotNull<const Constraint> unpackAndAssign(TypePackId destTypes, TypePackId srcTypes, NotNull<const Constraint> constraint);
NotNull<const Constraint> unpackAndAssign(const std::vector<TypeId> destTypes, TypePackId srcTypes, NotNull<const Constraint> constraint);
void block(NotNull<const Constraint> target, NotNull<const Constraint> constraint);
/**

View file

@ -86,24 +86,6 @@ struct FreeType
TypeId upperBound = nullptr;
};
/** A type that tracks the domain of a local variable.
*
* We consider each local's domain to be the union of all types assigned to it.
* We accomplish this with LocalType. Each time we dispatch an assignment to a
* local, we accumulate this union and decrement blockCount.
*
* When blockCount reaches 0, we can consider the LocalType to be "fully baked"
* and replace it with the union we've built.
*/
struct LocalType
{
TypeId domain;
int blockCount = 0;
// Used for debugging
std::string name;
};
struct GenericType
{
// By default, generics are global, with a synthetic name
@ -148,6 +130,7 @@ struct BlockedType
Constraint* getOwner() const;
void setOwner(Constraint* newOwner);
void replaceOwner(Constraint* newOwner);
private:
// The constraint that is intended to unblock this type. Other constraints
@ -471,6 +454,11 @@ struct TableType
// Methods of this table that have an untyped self will use the same shared self type.
std::optional<TypeId> selfTy;
// We track the number of as-yet-unadded properties to unsealed tables.
// Some constraints will use this information to decide whether or not they
// are able to dispatch.
size_t remainingProps = 0;
};
// Represents a metatable attached to a table type. Somewhat analogous to a bound type.
@ -669,9 +657,9 @@ struct NegationType
using ErrorType = Unifiable::Error;
using TypeVariant = Unifiable::Variant<TypeId, FreeType, LocalType, GenericType, PrimitiveType, BlockedType, PendingExpansionType, SingletonType,
FunctionType, TableType, MetatableType, ClassType, AnyType, UnionType, IntersectionType, LazyType, UnknownType, NeverType, NegationType,
TypeFamilyInstanceType>;
using TypeVariant =
Unifiable::Variant<TypeId, FreeType, GenericType, PrimitiveType, BlockedType, PendingExpansionType, SingletonType, FunctionType, TableType,
MetatableType, ClassType, AnyType, UnionType, IntersectionType, LazyType, UnknownType, NeverType, NegationType, TypeFamilyInstanceType>;
struct Type final
{

View file

@ -69,7 +69,6 @@ struct Unifier2
*/
bool unify(TypeId subTy, TypeId superTy);
bool unifyFreeWithType(TypeId subTy, TypeId superTy);
bool unify(const LocalType* subTy, TypeId superFn);
bool unify(TypeId subTy, const FunctionType* superFn);
bool unify(const UnionType* subUnion, TypeId superTy);
bool unify(TypeId subTy, const UnionType* superUnion);

View file

@ -100,10 +100,6 @@ struct GenericTypeVisitor
{
return visit(ty);
}
virtual bool visit(TypeId ty, const LocalType& ftv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const GenericType& gtv)
{
return visit(ty);
@ -248,11 +244,6 @@ struct GenericTypeVisitor
else
visit(ty, *ftv);
}
else if (auto lt = get<LocalType>(ty))
{
if (visit(ty, *lt))
traverse(lt->domain);
}
else if (auto gtv = get<GenericType>(ty))
visit(ty, *gtv);
else if (auto etv = get<ErrorType>(ty))

View file

@ -271,11 +271,6 @@ private:
t->upperBound = shallowClone(t->upperBound);
}
void cloneChildren(LocalType* t)
{
t->domain = shallowClone(t->domain);
}
void cloneChildren(GenericType* t)
{
// TOOD: clone upper bounds.

View file

@ -81,7 +81,8 @@ DenseHashSet<TypeId> Constraint::getMaybeMutatedFreeTypes() const
}
else if (auto itc = get<IterableConstraint>(*this))
{
rci.traverse(itc->variables);
for (TypeId ty : itc->variables)
rci.traverse(ty);
// `IterableConstraints` should not mutate `iterator`.
}
else if (auto nc = get<NameConstraint>(*this))
@ -106,11 +107,6 @@ DenseHashSet<TypeId> Constraint::getMaybeMutatedFreeTypes() const
rci.traverse(hic->resultType);
// `HasIndexerConstraint` should not mutate `subjectType` or `indexType`.
}
else if (auto ac = get<AssignConstraint>(*this))
{
rci.traverse(ac->lhsType);
rci.traverse(ac->rhsType);
}
else if (auto apc = get<AssignPropConstraint>(*this))
{
rci.traverse(apc->lhsType);
@ -124,7 +120,8 @@ DenseHashSet<TypeId> Constraint::getMaybeMutatedFreeTypes() const
}
else if (auto uc = get<UnpackConstraint>(*this))
{
rci.traverse(uc->resultPack);
for (TypeId ty : uc->resultPack)
rci.traverse(ty);
// `UnpackConstraint` should not mutate `sourcePack`.
}
else if (auto rpc = get<ReducePackConstraint>(*this))

View file

@ -28,6 +28,7 @@
LUAU_FASTINT(LuauCheckRecursionLimit);
LUAU_FASTFLAG(DebugLuauLogSolverToJson);
LUAU_FASTFLAG(DebugLuauMagicTypes);
LUAU_FASTFLAG(LuauAttributeSyntax);
namespace Luau
{
@ -246,6 +247,17 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block)
if (logger)
logger->captureGenerationModule(module);
for (const auto& [ty, domain] : localTypes)
{
// FIXME: This isn't the most efficient thing.
TypeId domainTy = builtinTypes->neverType;
for (TypeId d : domain)
domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result;
LUAU_ASSERT(get<BlockedType>(ty));
asMutable(ty)->ty.emplace<BoundType>(domainTy);
}
}
TypeId ConstraintGenerator::freshType(const ScopePtr& scope)
@ -310,7 +322,8 @@ std::optional<TypeId> ConstraintGenerator::lookup(const ScopePtr& scope, Locatio
std::optional<TypeId> ty = lookup(scope, location, operand, /*prototype*/ false);
if (!ty)
{
ty = arena->addType(LocalType{builtinTypes->neverType});
ty = arena->addType(BlockedType{});
localTypes[*ty] = {};
rootScope->lvalueTypes[operand] = *ty;
}
@ -703,7 +716,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat
{
const Location location = local->location;
TypeId assignee = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, local->name.value});
TypeId assignee = arena->addType(BlockedType{});
localTypes[assignee] = {};
assignees.push_back(assignee);
@ -740,7 +754,12 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat
if (hasAnnotation)
{
for (size_t i = 0; i < statLocal->vars.size; ++i)
addConstraint(scope, statLocal->location, AssignConstraint{assignees[i], annotatedTypes[i]});
{
LUAU_ASSERT(get<BlockedType>(assignees[i]));
std::vector<TypeId>* localDomain = localTypes.find(assignees[i]);
LUAU_ASSERT(localDomain);
localDomain->push_back(annotatedTypes[i]);
}
TypePackId annotatedPack = arena->addTypePack(std::move(annotatedTypes));
addConstraint(scope, statLocal->location, PackSubtypeConstraint{rvaluePack, annotatedPack});
@ -750,15 +769,30 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat
std::vector<TypeId> valueTypes;
valueTypes.reserve(statLocal->vars.size);
for (size_t i = 0; i < statLocal->vars.size; ++i)
valueTypes.push_back(arena->addType(BlockedType{}));
auto [head, tail] = flatten(rvaluePack);
auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{arena->addTypePack(valueTypes), rvaluePack});
if (head.size() >= statLocal->vars.size)
{
for (size_t i = 0; i < statLocal->vars.size; ++i)
valueTypes.push_back(head[i]);
}
else
{
for (size_t i = 0; i < statLocal->vars.size; ++i)
valueTypes.push_back(arena->addType(BlockedType{}));
auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{valueTypes, rvaluePack});
for (TypeId t: valueTypes)
getMutable<BlockedType>(t)->setOwner(uc);
}
for (size_t i = 0; i < statLocal->vars.size; ++i)
{
getMutable<BlockedType>(valueTypes[i])->setOwner(uc);
addConstraint(scope, statLocal->location, AssignConstraint{assignees[i], valueTypes[i]});
LUAU_ASSERT(get<BlockedType>(assignees[i]));
std::vector<TypeId>* localDomain = localTypes.find(assignees[i]);
LUAU_ASSERT(localDomain);
localDomain->push_back(valueTypes[i]);
}
}
@ -860,25 +894,34 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forI
for (AstLocal* var : forIn->vars)
{
TypeId assignee = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, var->name.value});
TypeId assignee = arena->addType(BlockedType{});
variableTypes.push_back(assignee);
TypeId loopVar = arena->addType(BlockedType{});
localTypes[loopVar].push_back(assignee);
if (var->annotation)
{
TypeId annotationTy = resolveType(loopScope, var->annotation, /*inTypeArguments*/ false);
loopScope->bindings[var] = Binding{annotationTy, var->location};
addConstraint(scope, var->location, SubtypeConstraint{assignee, annotationTy});
addConstraint(scope, var->location, SubtypeConstraint{loopVar, annotationTy});
}
else
loopScope->bindings[var] = Binding{assignee, var->location};
loopScope->bindings[var] = Binding{loopVar, var->location};
DefId def = dfg->getDef(var);
loopScope->lvalueTypes[def] = assignee;
loopScope->lvalueTypes[def] = loopVar;
}
TypePackId variablePack = arena->addTypePack(std::move(variableTypes));
auto iterable = addConstraint(
loopScope, getLocation(forIn->values), IterableConstraint{iterator, variablePack, forIn->values.data[0], &module->astForInNextTypes});
loopScope, getLocation(forIn->values), IterableConstraint{iterator, variableTypes, forIn->values.data[0], &module->astForInNextTypes});
for (TypeId var: variableTypes)
{
auto bt = getMutable<BlockedType>(var);
LUAU_ASSERT(bt);
bt->setOwner(iterable);
}
Checkpoint start = checkpoint(this);
visit(loopScope, forIn->body);
@ -1105,14 +1148,31 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatAssign* ass
std::vector<TypeId> valueTypes;
valueTypes.reserve(assign->vars.size);
for (size_t i = 0; i < assign->vars.size; ++i)
valueTypes.push_back(arena->addType(BlockedType{}));
auto [head, tail] = flatten(resultPack);
if (head.size() >= assign->vars.size)
{
// If the resultPack is definitely long enough for each variable, we can
// skip the UnpackConstraint and use the result types directly.
auto uc = addConstraint(scope, assign->location, UnpackConstraint{arena->addTypePack(valueTypes), resultPack});
for (size_t i = 0; i < assign->vars.size; ++i)
valueTypes.push_back(head[i]);
}
else
{
// We're not sure how many types are produced by the right-side
// expressions. We'll use an UnpackConstraint to defer this until
// later.
for (size_t i = 0; i < assign->vars.size; ++i)
valueTypes.push_back(arena->addType(BlockedType{}));
auto uc = addConstraint(scope, assign->location, UnpackConstraint{valueTypes, resultPack});
for (TypeId t: valueTypes)
getMutable<BlockedType>(t)->setOwner(uc);
}
for (size_t i = 0; i < assign->vars.size; ++i)
{
getMutable<BlockedType>(valueTypes[i])->setOwner(uc);
visitLValue(scope, assign->vars.data[i], valueTypes[i]);
}
@ -1393,7 +1453,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareFunc
TypePackId retPack = resolveTypePack(funScope, global->retTypes, /* inTypeArguments */ false);
TypeId fnType = arena->addType(FunctionType{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack});
FunctionType* ftv = getMutable<FunctionType>(fnType);
ftv->isCheckedFunction = global->checkedFunction;
ftv->isCheckedFunction = FFlag::LuauAttributeSyntax ? global->isCheckedFunction() : false;
ftv->argNames.reserve(global->paramNames.size);
for (const auto& el : global->paramNames)
@ -1599,9 +1659,8 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall*
mt = arena->addType(BlockedType{});
unpackedTypes.emplace_back(mt);
TypePackId mtPack = arena->addTypePack(std::move(unpackedTypes));
auto c = addConstraint(scope, call->location, UnpackConstraint{mtPack, *argTail});
auto c = addConstraint(scope, call->location, UnpackConstraint{unpackedTypes, *argTail});
getMutable<BlockedType>(mt)->setOwner(c);
if (auto b = getMutable<BlockedType>(target); b && b->getOwner() == nullptr)
b->setOwner(c);
@ -1842,7 +1901,37 @@ Inference ConstraintGenerator::checkIndexName(
const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation)
{
TypeId obj = check(scope, indexee).ty;
TypeId result = arena->addType(BlockedType{});
TypeId result = nullptr;
// We optimize away the HasProp constraint in simple cases so that we can
// reason about updates to unsealed tables more accurately.
const TableType* tt = getTableType(obj);
// This is a little bit iffy but I *believe* it is okay because, if the
// local's domain is going to be extended at all, it will be someplace after
// the current lexical position within the script.
if (!tt)
{
if (auto localDomain = localTypes.find(obj); localDomain && 1 == localDomain->size())
tt = getTableType(localDomain->front());
}
if (tt)
{
auto it = tt->props.find(index);
if (it != tt->props.end() && it->second.readTy.has_value())
result = *it->second.readTy;
}
if (!result)
{
result = arena->addType(BlockedType{});
auto c = addConstraint(
scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)});
getMutable<BlockedType>(result)->setOwner(c);
}
if (key)
{
@ -1852,10 +1941,6 @@ Inference ConstraintGenerator::checkIndexName(
scope->rvalueRefinements[key->def] = result;
}
auto c =
addConstraint(scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)});
getMutable<BlockedType>(result)->setOwner(c);
if (key)
return Inference{result, refinementArena.proposition(key, builtinTypes->truthyType)};
else
@ -2242,18 +2327,14 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local
if (ty)
{
if (auto lt = getMutable<LocalType>(*ty))
++lt->blockCount;
else if (auto ut = getMutable<UnionType>(*ty))
{
for (TypeId optTy : ut->options)
if (auto lt = getMutable<LocalType>(optTy))
++lt->blockCount;
}
std::vector<TypeId>* localDomain = localTypes.find(*ty);
if (localDomain)
localDomain->push_back(rhsType);
}
else
{
ty = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, local->local->name.value});
ty = arena->addType(BlockedType{});
localTypes[*ty].push_back(rhsType);
if (annotatedTy)
{
@ -2277,7 +2358,9 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local
if (annotatedTy)
addConstraint(scope, local->location, SubtypeConstraint{rhsType, *annotatedTy});
addConstraint(scope, local->location, AssignConstraint{*ty, rhsType});
if (auto localDomain = localTypes.find(*ty))
localDomain->push_back(rhsType);
}
void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType)
@ -2289,7 +2372,6 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* glob
rootScope->lvalueTypes[def] = rhsType;
addConstraint(scope, global->location, SubtypeConstraint{rhsType, *annotatedTy});
addConstraint(scope, global->location, AssignConstraint{*annotatedTy, rhsType});
}
}
@ -2298,7 +2380,10 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexName* e
TypeId lhsTy = check(scope, expr->expr).ty;
TypeId propTy = arena->addType(BlockedType{});
module->astTypes[expr] = propTy;
addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, propTy});
bool incremented = recordPropertyAssignment(lhsTy);
addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, propTy, incremented});
}
void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* expr, TypeId rhsType)
@ -2310,7 +2395,10 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* e
module->astTypes[expr] = propTy;
module->astTypes[expr->index] = builtinTypes->stringType; // FIXME? Singleton strings exist.
std::string propName{constantString->value.data, constantString->value.size};
addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, propTy});
bool incremented = recordPropertyAssignment(lhsTy);
addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, propTy, incremented});
return;
}
@ -2775,7 +2863,7 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool
// TODO: FunctionType needs a pointer to the scope so that we know
// how to quantify/instantiate it.
FunctionType ftv{TypeLevel{}, scope.get(), {}, {}, argTypes, returnTypes};
ftv.isCheckedFunction = fn->checkedFunction;
ftv.isCheckedFunction = FFlag::LuauAttributeSyntax ? fn->isCheckedFunction() : false;
// This replicates the behavior of the appropriate FunctionType
// constructors.
@ -2977,8 +3065,7 @@ Inference ConstraintGenerator::flattenPack(const ScopePtr& scope, Location locat
return Inference{*f, refinement};
TypeId typeResult = arena->addType(BlockedType{});
TypePackId resultPack = arena->addTypePack({typeResult}, arena->freshTypePack(scope.get()));
auto c = addConstraint(scope, location, UnpackConstraint{resultPack, tp});
auto c = addConstraint(scope, location, UnpackConstraint{{typeResult}, tp});
getMutable<BlockedType>(typeResult)->setOwner(c);
return Inference{typeResult, refinement};
@ -3075,6 +3162,46 @@ void ConstraintGenerator::prepopulateGlobalScope(const ScopePtr& globalScope, As
program->visit(&gp);
}
bool ConstraintGenerator::recordPropertyAssignment(TypeId ty)
{
DenseHashSet<TypeId> seen{nullptr};
VecDeque<TypeId> queue;
queue.push_back(ty);
bool incremented = false;
while (!queue.empty())
{
const TypeId t = follow(queue.front());
queue.pop_front();
if (seen.find(t))
continue;
seen.insert(t);
if (auto tt = getMutable<TableType>(t); tt && tt->state == TableState::Unsealed)
{
tt->remainingProps += 1;
incremented = true;
}
else if (auto mt = get<MetatableType>(t))
queue.push_back(mt->table);
else if (auto localDomain = localTypes.find(t))
{
for (TypeId domainTy : *localDomain)
queue.push_back(domainTy);
}
else if (auto ut = get<UnionType>(t))
{
for (TypeId part : ut)
queue.push_back(part);
}
}
return incremented;
}
void ConstraintGenerator::recordInferredBinding(AstLocal* local, TypeId ty)
{
if (InferredBinding* ib = inferredBindings.find(local))

View file

@ -532,8 +532,6 @@ bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool fo
success = tryDispatch(*hpc, constraint);
else if (auto spc = get<HasIndexerConstraint>(*constraint))
success = tryDispatch(*spc, constraint);
else if (auto uc = get<AssignConstraint>(*constraint))
success = tryDispatch(*uc, constraint);
else if (auto uc = get<AssignPropConstraint>(*constraint))
success = tryDispatch(*uc, constraint);
else if (auto uc = get<AssignIndexConstraint>(*constraint))
@ -686,7 +684,8 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull<const Co
if (0 == iterator.head.size())
{
unify(constraint, builtinTypes->anyTypePack, c.variables);
for (TypeId ty : c.variables)
unify(constraint, builtinTypes->errorRecoveryType(), ty);
return true;
}
@ -696,21 +695,35 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull<const Co
{
TypeId keyTy = freshType(arena, builtinTypes, constraint->scope);
TypeId valueTy = freshType(arena, builtinTypes, constraint->scope);
TypeId tableTy = arena->addType(TableType{TableState::Sealed, {}, constraint->scope});
getMutable<TableType>(tableTy)->indexer = TableIndexer{keyTy, valueTy};
TypeId tableTy = arena->addType(TableType{
TableType::Props{},
TableIndexer{keyTy, valueTy},
TypeLevel{},
constraint->scope,
TableState::Free
});
pushConstraint(constraint->scope, constraint->location, SubtypeConstraint{nextTy, tableTy});
unify(constraint, nextTy, tableTy);
auto it = begin(c.variables);
auto endIt = end(c.variables);
if (it != endIt)
{
pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, keyTy});
bindBlockedType(*it, keyTy, keyTy, constraint);
++it;
}
if (it != endIt)
pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, valueTy});
{
bindBlockedType(*it, valueTy, valueTy, constraint);
++it;
}
while (it != endIt)
{
bindBlockedType(*it, builtinTypes->nilType, builtinTypes->nilType, constraint);
++it;
}
return true;
}
@ -721,11 +734,7 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull<const Co
if (iterator.head.size() >= 2)
tableTy = iterator.head[1];
TypeId firstIndexTy = builtinTypes->nilType;
if (iterator.head.size() >= 3)
firstIndexTy = iterator.head[2];
return tryDispatchIterableFunction(nextTy, tableTy, firstIndexTy, c, constraint, force);
return tryDispatchIterableFunction(nextTy, tableTy, c, constraint, force);
}
else
@ -1310,6 +1319,14 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull<const Con
if (isBlocked(subjectType) || get<PendingExpansionType>(subjectType) || get<TypeFamilyInstanceType>(subjectType))
return block(subjectType, constraint);
if (const TableType* subjectTable = getTableType(subjectType))
{
if (subjectTable->state == TableState::Unsealed && subjectTable->remainingProps > 0 && subjectTable->props.count(c.prop) == 0)
{
return block(subjectType, constraint);
}
}
auto [blocked, result] = lookupTableProp(constraint, subjectType, c.prop, c.context, c.inConditional, c.suppressSimplification);
if (!blocked.empty())
{
@ -1517,7 +1534,10 @@ bool ConstraintSolver::tryDispatch(const HasIndexerConstraint& c, NotNull<const
Set<TypeId> seen{nullptr};
return tryDispatchHasIndexer(recursionDepth, constraint, subjectType, indexType, c.resultType, seen);
bool ok = tryDispatchHasIndexer(recursionDepth, constraint, subjectType, indexType, c.resultType, seen);
if (ok)
unblock(c.resultType, constraint->location);
return ok;
}
std::pair<bool, std::optional<TypeId>> ConstraintSolver::tryDispatchSetIndexer(
@ -1596,46 +1616,6 @@ std::pair<bool, std::optional<TypeId>> ConstraintSolver::tryDispatchSetIndexer(
return {true, std::nullopt};
}
bool ConstraintSolver::tryDispatch(const AssignConstraint& c, NotNull<const Constraint> constraint)
{
const TypeId lhsTy = follow(c.lhsType);
const TypeId rhsTy = follow(c.rhsType);
if (!get<LocalType>(lhsTy) && isBlocked(lhsTy))
return block(lhsTy, constraint);
auto tryExpand = [&](TypeId ty) {
LocalType* lt = getMutable<LocalType>(ty);
if (!lt)
return;
lt->domain = simplifyUnion(builtinTypes, arena, lt->domain, rhsTy).result;
LUAU_ASSERT(lt->blockCount > 0);
--lt->blockCount;
if (0 == lt->blockCount)
{
shiftReferences(ty, lt->domain);
emplaceType<BoundType>(asMutable(ty), lt->domain);
}
};
if (auto ut = get<UnionType>(lhsTy))
{
// FIXME: I suspect there's a bug here where lhsTy is a union that contains no LocalTypes.
for (TypeId t : ut)
tryExpand(t);
}
else if (get<LocalType>(lhsTy))
tryExpand(lhsTy);
else
unify(constraint, rhsTy, lhsTy);
unblock(lhsTy, constraint->location);
return true;
}
bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull<const Constraint> constraint)
{
TypeId lhsType = follow(c.lhsType);
@ -1753,6 +1733,14 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull<const
{
emplaceType<BoundType>(asMutable(c.propType), rhsType);
lhsTable->props[propName] = Property::rw(rhsType);
if (lhsTable->state == TableState::Unsealed && c.decrementPropCount)
{
LUAU_ASSERT(lhsTable->remainingProps > 0);
lhsTable->remainingProps -= 1;
unblock(lhsType, constraint->location);
}
return true;
}
}
@ -1927,24 +1915,14 @@ bool ConstraintSolver::tryDispatchUnpack1(NotNull<const Constraint> constraint,
bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull<const Constraint> constraint)
{
TypePackId sourcePack = follow(c.sourcePack);
TypePackId resultPack = follow(c.resultPack);
if (isBlocked(sourcePack))
return block(sourcePack, constraint);
if (isBlocked(resultPack))
{
LUAU_ASSERT(canMutate(resultPack, constraint));
LUAU_ASSERT(resultPack != sourcePack);
emplaceTypePack<BoundTypePack>(asMutable(resultPack), sourcePack);
unblock(resultPack, constraint->location);
return true;
}
TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, c.resultPack.size());
TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, size(resultPack));
auto resultIter = begin(resultPack);
auto resultEnd = end(resultPack);
auto resultIter = begin(c.resultPack);
auto resultEnd = end(c.resultPack);
size_t i = 0;
while (resultIter != resultEnd)
@ -2080,18 +2058,22 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl
auto endIt = end(c.variables);
if (it != endIt)
{
pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, keyTy});
bindBlockedType(*it, keyTy, keyTy, constraint);
++it;
}
if (it != endIt)
pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, valueTy});
bindBlockedType(*it, valueTy, valueTy, constraint);
return true;
}
auto unpack = [&](TypeId ty) {
for (TypeId varTy : c.variables)
pushConstraint(constraint->scope, constraint->location, AssignConstraint{varTy, ty});
for (TypeId varTy: c.variables)
{
LUAU_ASSERT(get<BlockedType>(varTy));
LUAU_ASSERT(varTy != ty);
bindBlockedType(varTy, ty, ty, constraint);
}
};
if (get<AnyType>(iteratorTy))
@ -2129,27 +2111,18 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl
if (iteratorTable->indexer)
{
TypePackId expectedVariablePack = arena->addTypePack({iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType});
unify(constraint, c.variables, expectedVariablePack);
std::vector<TypeId> expectedVariables{iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType};
while (c.variables.size() >= expectedVariables.size())
expectedVariables.push_back(builtinTypes->errorRecoveryType());
auto [variableTys, variablesTail] = flatten(c.variables);
// the local types for the indexer _should_ be all set after unification
for (TypeId ty : variableTys)
for (size_t i = 0; i < c.variables.size(); ++i)
{
if (auto lt = getMutable<LocalType>(ty))
{
LUAU_ASSERT(lt->blockCount > 0);
--lt->blockCount;
LUAU_ASSERT(c.variables[i] != expectedVariables[i]);
LUAU_ASSERT(0 <= lt->blockCount);
unify(constraint, c.variables[i], expectedVariables[i]);
if (0 == lt->blockCount)
{
shiftReferences(ty, lt->domain);
emplaceType<BoundType>(asMutable(ty), lt->domain);
}
}
bindBlockedType(c.variables[i], expectedVariables[i], expectedVariables[i], constraint);
unblock(c.variables[i], constraint->location);
}
}
else
@ -2213,26 +2186,16 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl
else if (auto primitiveTy = get<PrimitiveType>(iteratorTy); primitiveTy && primitiveTy->type == PrimitiveType::Type::Table)
unpack(builtinTypes->unknownType);
else
{
unpack(builtinTypes->errorType);
}
return true;
}
bool ConstraintSolver::tryDispatchIterableFunction(
TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force)
TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force)
{
// We need to know whether or not this type is nil or not.
// If we don't know, block and reschedule ourselves.
firstIndexTy = follow(firstIndexTy);
if (get<FreeType>(firstIndexTy))
{
if (force)
LUAU_ASSERT(false);
else
block(firstIndexTy, constraint);
return false;
}
const FunctionType* nextFn = get<FunctionType>(nextTy);
// If this does not hold, we should've never called `tryDispatchIterableFunction` in the first place.
LUAU_ASSERT(nextFn);
@ -2267,27 +2230,18 @@ bool ConstraintSolver::tryDispatchIterableFunction(
return true;
}
NotNull<const Constraint> ConstraintSolver::unpackAndAssign(TypePackId destTypes, TypePackId srcTypes, NotNull<const Constraint> constraint)
NotNull<const Constraint> ConstraintSolver::unpackAndAssign(const std::vector<TypeId> destTypes, TypePackId srcTypes, NotNull<const Constraint> constraint)
{
std::vector<TypeId> unpackedTys;
for (TypeId _ty : destTypes)
auto c = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{destTypes, srcTypes});
for (TypeId t: destTypes)
{
(void) _ty;
unpackedTys.push_back(arena->addType(BlockedType{}));
BlockedType* bt = getMutable<BlockedType>(t);
LUAU_ASSERT(bt);
bt->replaceOwner(c);
}
TypePackId unpackedTp = arena->addTypePack(TypePack{unpackedTys});
auto unpackConstraint = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{unpackedTp, srcTypes});
size_t i = 0;
for (TypeId varTy : destTypes)
{
pushConstraint(constraint->scope, constraint->location, AssignConstraint{varTy, unpackedTys[i]});
getMutable<BlockedType>(unpackedTys[i])->setOwner(unpackConstraint);
++i;
}
return unpackConstraint;
return c;
}
std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType,
@ -2808,9 +2762,6 @@ bool ConstraintSolver::isBlocked(TypeId ty)
{
ty = follow(ty);
if (auto lt = get<LocalType>(ty))
return lt->blockCount > 0;
if (auto tfit = get<TypeFamilyInstanceType>(ty))
return uninhabitedTypeFamilies.contains(ty) == false;

View file

@ -2,6 +2,7 @@
#include "Luau/BuiltinDefinitions.h"
LUAU_FASTFLAGVARIABLE(LuauCheckedEmbeddedDefinitions2, false);
LUAU_FASTFLAG(LuauAttributeSyntax);
namespace Luau
{
@ -319,9 +320,9 @@ declare os: {
clock: () -> number,
}
declare function @checked require(target: any): any
@checked declare function require(target: any): any
declare function @checked getfenv(target: any): { [string]: any }
@checked declare function getfenv(target: any): { [string]: any }
declare _G: any
declare _VERSION: string
@ -363,7 +364,7 @@ declare function select<A...>(i: string | number, ...: A...): ...any
-- (nil, string).
declare function loadstring<A...>(src: string, chunkname: string?): (((A...) -> any)?, string?)
declare function @checked newproxy(mt: boolean?): any
@checked declare function newproxy(mt: boolean?): any
declare coroutine: {
create: <A..., R...>(f: (A...) -> R...) -> thread,
@ -451,7 +452,7 @@ std::string getBuiltinDefinitionSource()
std::string result = kBuiltinDefinitionLuaSrc;
// Annotates each non generic function as checked
if (FFlag::LuauCheckedEmbeddedDefinitions2)
if (FFlag::LuauCheckedEmbeddedDefinitions2 && FFlag::LuauAttributeSyntax)
result = kBuiltinDefinitionLuaSrcChecked;
return result;

View file

@ -1196,12 +1196,6 @@ struct InternalTypeFinder : TypeOnceVisitor
return false;
}
bool visit(TypeId, const LocalType&) override
{
LUAU_ASSERT(false);
return false;
}
bool visit(TypePackId, const BlockedTypePack&) override
{
LUAU_ASSERT(false);

View file

@ -1815,12 +1815,6 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
if (!isCacheable(there))
here.isCacheable = false;
}
else if (auto lt = get<LocalType>(there))
{
// FIXME? This is somewhat questionable.
// Maybe we should assert because this should never happen?
unionNormalWithTy(here, lt->domain, seenSetTypes, ignoreSmallerTyvars);
}
else if (get<FunctionType>(there))
unionFunctionsWithFunction(here.functions, there);
else if (get<TableType>(there) || get<MetatableType>(there))
@ -3095,7 +3089,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
return NormalizationResult::True;
}
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) ||
get<TypeFamilyInstanceType>(there) || get<LocalType>(there))
get<TypeFamilyInstanceType>(there))
{
NormalizedType thereNorm{builtinTypes};
NormalizedType topNorm{builtinTypes};
@ -3104,10 +3098,6 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
here.isCacheable = false;
return intersectNormals(here, thereNorm);
}
else if (auto lt = get<LocalType>(there))
{
return intersectNormalWithTy(here, lt->domain, seenSetTypes);
}
NormalizedTyvars tyvars = std::move(here.tyvars);

View file

@ -24,8 +24,6 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
// We decline to copy them.
if constexpr (std::is_same_v<T, FreeType>)
return ty;
else if constexpr (std::is_same_v<T, LocalType>)
return ty;
else if constexpr (std::is_same_v<T, BoundType>)
{
// This should never happen, but visit() cannot see it.

View file

@ -262,14 +262,6 @@ void StateDot::visitChildren(TypeId ty, int index)
visitChild(t.upperBound, index, "[upperBound]");
}
}
else if constexpr (std::is_same_v<T, LocalType>)
{
formatAppend(result, "LocalType");
finishNodeLabel(ty);
finishNode();
visitChild(t.domain, 1, "[domain]");
}
else if constexpr (std::is_same_v<T, AnyType>)
{
formatAppend(result, "AnyType %d", index);

View file

@ -100,16 +100,6 @@ struct FindCyclicTypes final : TypeVisitor
return false;
}
bool visit(TypeId ty, const LocalType& lt) override
{
if (!visited.insert(ty))
return false;
traverse(lt.domain);
return false;
}
bool visit(TypeId ty, const TableType& ttv) override
{
if (!visited.insert(ty))
@ -525,21 +515,6 @@ struct TypeStringifier
}
}
void operator()(TypeId ty, const LocalType& lt)
{
state.emit("l-");
state.emit(lt.name);
if (FInt::DebugLuauVerboseTypeNames >= 1)
{
state.emit("[");
state.emit(lt.blockCount);
state.emit("]");
}
state.emit("=[");
stringify(lt.domain);
state.emit("]");
}
void operator()(TypeId, const BoundType& btv)
{
stringify(btv.boundTo);
@ -1724,6 +1699,18 @@ std::string generateName(size_t i)
return n;
}
std::string toStringVector(const std::vector<TypeId>& types, ToStringOptions& opts)
{
std::string s;
for (TypeId ty : types)
{
if (!s.empty())
s += ", ";
s += toString(ty, opts);
}
return s;
}
std::string toString(const Constraint& constraint, ToStringOptions& opts)
{
auto go = [&opts](auto&& c) -> std::string {
@ -1754,7 +1741,7 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
else if constexpr (std::is_same_v<T, IterableConstraint>)
{
std::string iteratorStr = tos(c.iterator);
std::string variableStr = tos(c.variables);
std::string variableStr = toStringVector(c.variables, opts);
return variableStr + " ~ iterate " + iteratorStr;
}
@ -1791,14 +1778,12 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
{
return tos(c.resultType) + " ~ hasIndexer " + tos(c.subjectType) + " " + tos(c.indexType);
}
else if constexpr (std::is_same_v<T, AssignConstraint>)
return "assign " + tos(c.lhsType) + " " + tos(c.rhsType);
else if constexpr (std::is_same_v<T, AssignPropConstraint>)
return "assignProp " + tos(c.lhsType) + " " + c.propName + " " + tos(c.rhsType);
else if constexpr (std::is_same_v<T, AssignIndexConstraint>)
return "assignIndex " + tos(c.lhsType) + " " + tos(c.indexType) + " " + tos(c.rhsType);
else if constexpr (std::is_same_v<T, UnpackConstraint>)
return tos(c.resultPack) + " ~ ...unpack " + tos(c.sourcePack);
return toStringVector(c.resultPack, opts) + " ~ ...unpack " + tos(c.sourcePack);
else if constexpr (std::is_same_v<T, ReduceConstraint>)
return "reduce " + tos(c.ty);
else if constexpr (std::is_same_v<T, ReducePackConstraint>)

View file

@ -1182,11 +1182,11 @@ std::string toString(AstNode* node)
Printer printer(writer);
printer.writeTypes = true;
if (auto statNode = dynamic_cast<AstStat*>(node))
if (auto statNode = node->asStat())
printer.visualize(*statNode);
else if (auto exprNode = dynamic_cast<AstExpr*>(node))
else if (auto exprNode = node->asExpr())
printer.visualize(*exprNode);
else if (auto typeNode = dynamic_cast<AstType*>(node))
else if (auto typeNode = node->asType())
printer.visualizeTypeAnnotation(*typeNode);
return writer.str();

View file

@ -561,6 +561,11 @@ void BlockedType::setOwner(Constraint* newOwner)
owner = newOwner;
}
void BlockedType::replaceOwner(Constraint* newOwner)
{
owner = newOwner;
}
PendingExpansionType::PendingExpansionType(
std::optional<AstName> prefix, AstName name, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments)
: prefix(prefix)

View file

@ -338,10 +338,6 @@ public:
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("free"), std::nullopt, Location());
}
AstType* operator()(const LocalType& lt)
{
return Luau::visit(*this, lt.domain->ty);
}
AstType* operator()(const UnionType& uv)
{
AstArray<AstType*> unionTypes;

View file

@ -447,7 +447,7 @@ FamilyGraphReductionResult reduceFamilies(TypePackId entrypoint, Location locati
bool isPending(TypeId ty, ConstraintSolver* solver)
{
return is<BlockedType, PendingExpansionType, TypeFamilyInstanceType, LocalType>(ty) || (solver && solver->hasUnresolvedConstraints(ty));
return is<BlockedType, PendingExpansionType, TypeFamilyInstanceType>(ty) || (solver && solver->hasUnresolvedConstraints(ty));
}
template<typename F, typename... Args>
@ -567,7 +567,7 @@ TypeFamilyReductionResult<TypeId> lenFamilyFn(TypeId instance, const std::vector
// check to see if the operand type is resolved enough, and wait to reduce if not
// the use of `typeFromNormal` later necessitates blocking on local types.
if (isPending(operandTy, ctx->solver) || get<LocalType>(operandTy))
if (isPending(operandTy, ctx->solver))
return {std::nullopt, false, {operandTy}, {}};
// if the type is free but has only one remaining reference, we can generalize it to its upper bound here.
@ -1427,12 +1427,6 @@ struct FindRefinementBlockers : TypeOnceVisitor
return false;
}
bool visit(TypeId ty, const LocalType&) override
{
found.insert(ty);
return false;
}
bool visit(TypeId ty, const ClassType&) override
{
return false;

View file

@ -158,12 +158,6 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy)
if (subFree || superFree)
return true;
if (auto subLocal = getMutable<LocalType>(subTy))
{
subLocal->domain = mkUnion(subLocal->domain, superTy);
expandedFreeTypes[subTy].push_back(superTy);
}
auto subFn = get<FunctionType>(subTy);
auto superFn = get<FunctionType>(superTy);
if (subFn && superFn)

View file

@ -60,6 +60,8 @@ class AstStat;
class AstStatBlock;
class AstExpr;
class AstTypePack;
class AstAttr;
class AstExprTable;
struct AstLocal
{
@ -172,6 +174,10 @@ public:
{
return nullptr;
}
virtual AstAttr* asAttr()
{
return nullptr;
}
template<typename T>
bool is() const
@ -193,6 +199,28 @@ public:
Location location;
};
class AstAttr : public AstNode
{
public:
LUAU_RTTI(AstAttr)
enum Type
{
Checked,
};
AstAttr(const Location& location, Type type);
AstAttr* asAttr() override
{
return this;
}
void visit(AstVisitor* visitor) override;
Type type;
};
class AstExpr : public AstNode
{
public:
@ -384,13 +412,15 @@ class AstExprFunction : public AstExpr
public:
LUAU_RTTI(AstExprFunction)
AstExprFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
AstLocal* self, const AstArray<AstLocal*>& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth,
const AstName& debugname, const std::optional<AstTypeList>& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr,
AstExprFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, AstLocal* self, const AstArray<AstLocal*>& args, bool vararg,
const Location& varargLocation, AstStatBlock* body, size_t functionDepth, const AstName& debugname,
const std::optional<AstTypeList>& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr,
const std::optional<Location>& argLocation = std::nullopt);
void visit(AstVisitor* visitor) override;
AstArray<AstAttr*> attributes;
AstArray<AstGenericType> generics;
AstArray<AstGenericTypePack> genericPacks;
AstLocal* self;
@ -810,20 +840,22 @@ public:
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames,
const AstTypeList& retTypes);
AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames,
const AstTypeList& retTypes, bool checkedFunction);
AstStatDeclareFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstName& name,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames, const AstTypeList& retTypes);
void visit(AstVisitor* visitor) override;
bool isCheckedFunction() const;
AstArray<AstAttr*> attributes;
AstName name;
AstArray<AstGenericType> generics;
AstArray<AstGenericTypePack> genericPacks;
AstTypeList params;
AstArray<AstArgumentName> paramNames;
AstTypeList retTypes;
bool checkedFunction;
};
struct AstDeclaredClassProp
@ -936,17 +968,20 @@ public:
AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes);
AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes, bool checkedFunction);
AstTypeFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes);
void visit(AstVisitor* visitor) override;
bool isCheckedFunction() const;
AstArray<AstAttr*> attributes;
AstArray<AstGenericType> generics;
AstArray<AstGenericTypePack> genericPacks;
AstTypeList argTypes;
AstArray<std::optional<AstArgumentName>> argNames;
AstTypeList returnTypes;
bool checkedFunction;
};
class AstTypeTypeof : public AstType
@ -1105,6 +1140,11 @@ public:
return true;
}
virtual bool visit(class AstAttr* node)
{
return visit(static_cast<AstNode*>(node));
}
virtual bool visit(class AstExpr* node)
{
return visit(static_cast<AstNode*>(node));

View file

@ -87,6 +87,8 @@ struct Lexeme
Comment,
BlockComment,
Attribute,
BrokenString,
BrokenComment,
BrokenUnicode,
@ -115,14 +117,20 @@ struct Lexeme
ReservedTrue,
ReservedUntil,
ReservedWhile,
ReservedChecked,
Reserved_END
};
Type type;
Location location;
// Field declared here, before the union, to ensure that Lexeme size is 32 bytes.
private:
// length is used to extract a slice from the input buffer.
// This field is only valid for certain lexeme types which don't duplicate portions of input
// but instead store a pointer to a location in the input buffer and the length of lexeme.
unsigned int length;
public:
union
{
const char* data; // String, Number, Comment
@ -135,9 +143,13 @@ struct Lexeme
Lexeme(const Location& location, Type type, const char* data, size_t size);
Lexeme(const Location& location, Type type, const char* name);
unsigned int getLength() const;
std::string toString() const;
};
static_assert(sizeof(Lexeme) <= 32, "Size of `Lexeme` struct should be up to 32 bytes.");
class AstNameTable
{
public:

View file

@ -82,8 +82,8 @@ private:
// if exp then block {elseif exp then block} [else block] end |
// for Name `=' exp `,' exp [`,' exp] do block end |
// for namelist in explist do block end |
// function funcname funcbody |
// local function Name funcbody |
// [attributes] function funcname funcbody |
// [attributes] local function Name funcbody |
// local namelist [`=' explist]
// laststat ::= return [explist] | break
AstStat* parseStat();
@ -114,11 +114,25 @@ private:
AstExpr* parseFunctionName(Location start, bool& hasself, AstName& debugname);
// function funcname funcbody
AstStat* parseFunctionStat();
LUAU_FORCEINLINE AstStat* parseFunctionStat(const AstArray<AstAttr*>& attributes = {nullptr, 0});
std::pair<bool, AstAttr::Type> validateAttribute(const char* attributeName, const TempVector<AstAttr*>& attributes);
// attribute ::= '@' NAME
void parseAttribute(TempVector<AstAttr*>& attribute);
// attributes ::= {attribute}
AstArray<AstAttr*> parseAttributes();
// attributes local function Name funcbody
// attributes function funcname funcbody
// attributes `declare function' Name`(' [parlist] `)' [`:` Type]
// declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}'
AstStat* parseAttributeStat();
// local function Name funcbody |
// local namelist [`=' explist]
AstStat* parseLocal();
AstStat* parseLocal(const AstArray<AstAttr*>& attributes);
// return [explist]
AstStat* parseReturn();
@ -130,7 +144,7 @@ private:
// `declare global' Name: Type |
// `declare function' Name`(' [parlist] `)' [`:` Type]
AstStat* parseDeclaration(const Location& start);
AstStat* parseDeclaration(const Location& start, const AstArray<AstAttr*>& attributes);
// varlist `=' explist
AstStat* parseAssignment(AstExpr* initial);
@ -143,7 +157,7 @@ private:
// funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` Type]
// funcbody ::= funcbodyhead block end
std::pair<AstExprFunction*, AstLocal*> parseFunctionBody(
bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName);
bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray<AstAttr*>& attributes);
// explist ::= {exp `,'} exp
void parseExprList(TempVector<AstExpr*>& result);
@ -176,10 +190,10 @@ private:
AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional<Location> accessLocation);
AstTypeOrPack parseFunctionType(bool allowPack, bool isCheckedFunction = false);
AstType* parseFunctionTypeTail(const Lexeme& begin, AstArray<AstGenericType> generics, AstArray<AstGenericTypePack> genericPacks,
AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames, AstTypePack* varargAnnotation,
bool isCheckedFunction = false);
AstTypeOrPack parseFunctionType(bool allowPack, const AstArray<AstAttr*>& attributes);
AstType* parseFunctionTypeTail(const Lexeme& begin, const AstArray<AstAttr*>& attributes, AstArray<AstGenericType> generics,
AstArray<AstGenericTypePack> genericPacks, AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames,
AstTypePack* varargAnnotation);
AstType* parseTableType(bool inDeclarationContext = false);
AstTypeOrPack parseSimpleType(bool allowPack, bool inDeclarationContext = false);
@ -393,6 +407,7 @@ private:
std::vector<unsigned int> matchRecoveryStopOnToken;
std::vector<AstAttr*> scratchAttr;
std::vector<AstStat*> scratchStat;
std::vector<AstArray<char>> scratchString;
std::vector<AstExpr*> scratchExpr;

View file

@ -3,6 +3,7 @@
#include "Luau/Common.h"
LUAU_FASTFLAG(LuauAttributeSyntax);
namespace Luau
{
@ -16,6 +17,17 @@ static void visitTypeList(AstVisitor* visitor, const AstTypeList& list)
list.tailType->visit(visitor);
}
AstAttr::AstAttr(const Location& location, Type type)
: AstNode(ClassIndex(), location)
, type(type)
{
}
void AstAttr::visit(AstVisitor* visitor)
{
visitor->visit(this);
}
int gAstRttiIndex = 0;
AstExprGroup::AstExprGroup(const Location& location, AstExpr* expr)
@ -161,11 +173,12 @@ void AstExprIndexExpr::visit(AstVisitor* visitor)
}
}
AstExprFunction::AstExprFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
AstLocal* self, const AstArray<AstLocal*>& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth,
const AstName& debugname, const std::optional<AstTypeList>& returnAnnotation, AstTypePack* varargAnnotation,
const std::optional<Location>& argLocation)
AstExprFunction::AstExprFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, AstLocal* self, const AstArray<AstLocal*>& args, bool vararg, const Location& varargLocation,
AstStatBlock* body, size_t functionDepth, const AstName& debugname, const std::optional<AstTypeList>& returnAnnotation,
AstTypePack* varargAnnotation, const std::optional<Location>& argLocation)
: AstExpr(ClassIndex(), location)
, attributes(attributes)
, generics(generics)
, genericPacks(genericPacks)
, self(self)
@ -696,27 +709,27 @@ AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const A
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames,
const AstTypeList& retTypes)
: AstStat(ClassIndex(), location)
, attributes()
, name(name)
, generics(generics)
, genericPacks(genericPacks)
, params(params)
, paramNames(paramNames)
, retTypes(retTypes)
, checkedFunction(false)
{
}
AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames,
const AstTypeList& retTypes, bool checkedFunction)
AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstName& name,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames, const AstTypeList& retTypes)
: AstStat(ClassIndex(), location)
, attributes(attributes)
, name(name)
, generics(generics)
, genericPacks(genericPacks)
, params(params)
, paramNames(paramNames)
, retTypes(retTypes)
, checkedFunction(checkedFunction)
{
}
@ -729,6 +742,19 @@ void AstStatDeclareFunction::visit(AstVisitor* visitor)
}
}
bool AstStatDeclareFunction::isCheckedFunction() const
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
for (const AstAttr* attr : attributes)
{
if (attr->type == AstAttr::Type::Checked)
return true;
}
return false;
}
AstStatDeclareClass::AstStatDeclareClass(const Location& location, const AstName& name, std::optional<AstName> superName,
const AstArray<AstDeclaredClassProp>& props, AstTableIndexer* indexer)
: AstStat(ClassIndex(), location)
@ -820,25 +846,26 @@ void AstTypeTable::visit(AstVisitor* visitor)
AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes)
: AstType(ClassIndex(), location)
, attributes()
, generics(generics)
, genericPacks(genericPacks)
, argTypes(argTypes)
, argNames(argNames)
, returnTypes(returnTypes)
, checkedFunction(false)
{
LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size);
}
AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes, bool checkedFunction)
AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes)
: AstType(ClassIndex(), location)
, attributes(attributes)
, generics(generics)
, genericPacks(genericPacks)
, argTypes(argTypes)
, argNames(argNames)
, returnTypes(returnTypes)
, checkedFunction(checkedFunction)
{
LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size);
}
@ -852,6 +879,19 @@ void AstTypeFunction::visit(AstVisitor* visitor)
}
}
bool AstTypeFunction::isCheckedFunction() const
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
for (const AstAttr* attr : attributes)
{
if (attr->type == AstAttr::Type::Checked)
return true;
}
return false;
}
AstTypeTypeof::AstTypeTypeof(const Location& location, AstExpr* expr)
: AstType(ClassIndex(), location)
, expr(expr)

View file

@ -8,6 +8,7 @@
#include <limits.h>
LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false)
LUAU_FASTFLAGVARIABLE(LuauAttributeSyntax, false)
namespace Luau
{
@ -102,11 +103,19 @@ Lexeme::Lexeme(const Location& location, Type type, const char* name)
, length(0)
, name(name)
{
LUAU_ASSERT(type == Name || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END));
LUAU_ASSERT(type == Name || type == Attribute || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END));
}
unsigned int Lexeme::getLength() const
{
LUAU_ASSERT(type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd ||
type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment);
return length;
}
static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", "not", "or",
"repeat", "return", "then", "true", "until", "while", "@checked"};
"repeat", "return", "then", "true", "until", "while"};
std::string Lexeme::toString() const
{
@ -191,6 +200,10 @@ std::string Lexeme::toString() const
case Comment:
return "comment";
case Attribute:
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
return name ? format("'%s'", name) : "attribute";
case BrokenString:
return "malformed string";
@ -278,7 +291,7 @@ std::pair<AstName, Lexeme::Type> AstNameTable::getOrAddWithType(const char* name
nameData[length] = 0;
const_cast<Entry&>(entry).value = AstName(nameData);
const_cast<Entry&>(entry).type = Lexeme::Name;
const_cast<Entry&>(entry).type = (name[0] == '@' ? Lexeme::Attribute : Lexeme::Name);
return std::make_pair(entry.value, entry.type);
}
@ -994,14 +1007,11 @@ Lexeme Lexer::readNext()
}
case '@':
{
// We're trying to lex the token @checked
LUAU_ASSERT(peekch() == '@');
std::pair<AstName, Lexeme::Type> maybeChecked = readName();
if (maybeChecked.second != Lexeme::ReservedChecked)
return Lexeme(Location(start, position()), Lexeme::Error);
return Lexeme(Location(start, position()), maybeChecked.second, maybeChecked.first.value);
if (FFlag::LuauAttributeSyntax)
{
std::pair<AstName, Lexeme::Type> attribute = readName();
return Lexeme(Location(start, position()), Lexeme::Attribute, attribute.first.value);
}
}
default:
if (isDigit(peekch()))

View file

@ -17,11 +17,20 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
// flag so that we don't break production games by reverting syntax changes.
// See docs/SyntaxChanges.md for an explanation.
LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false)
LUAU_FASTFLAG(LuauAttributeSyntax)
LUAU_FASTFLAGVARIABLE(LuauLeadingBarAndAmpersand, false)
namespace Luau
{
struct AttributeEntry
{
const char* name;
AstAttr::Type type;
};
AttributeEntry kAttributeEntries[] = {{"@checked", AstAttr::Type::Checked}, {nullptr, AstAttr::Type::Checked}};
ParseError::ParseError(const Location& location, const std::string& message)
: location(location)
, message(message)
@ -280,7 +289,9 @@ AstStatBlock* Parser::parseBlockNoScope()
// for binding `=' exp `,' exp [`,' exp] do block end |
// for namelist in explist do block end |
// function funcname funcbody |
// attributes function funcname funcbody |
// local function Name funcbody |
// local attributes function Name funcbody |
// local namelist [`=' explist]
// laststat ::= return [explist] | break
AstStat* Parser::parseStat()
@ -299,13 +310,16 @@ AstStat* Parser::parseStat()
case Lexeme::ReservedRepeat:
return parseRepeat();
case Lexeme::ReservedFunction:
return parseFunctionStat();
return parseFunctionStat(AstArray<AstAttr*>({nullptr, 0}));
case Lexeme::ReservedLocal:
return parseLocal();
return parseLocal(AstArray<AstAttr*>({nullptr, 0}));
case Lexeme::ReservedReturn:
return parseReturn();
case Lexeme::ReservedBreak:
return parseBreak();
case Lexeme::Attribute:
if (FFlag::LuauAttributeSyntax)
return parseAttributeStat();
default:;
}
@ -343,7 +357,7 @@ AstStat* Parser::parseStat()
if (options.allowDeclarationSyntax)
{
if (ident == "declare")
return parseDeclaration(expr->location);
return parseDeclaration(expr->location, AstArray<AstAttr*>({nullptr, 0}));
}
// skip unexpected symbol if lexer couldn't advance at all (statements are parsed in a loop)
@ -652,7 +666,7 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug
}
// function funcname funcbody
AstStat* Parser::parseFunctionStat()
AstStat* Parser::parseFunctionStat(const AstArray<AstAttr*>& attributes)
{
Location start = lexer.current().location;
@ -665,16 +679,125 @@ AstStat* Parser::parseFunctionStat()
matchRecoveryStopOnToken[Lexeme::ReservedEnd]++;
AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr).first;
AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr, attributes).first;
matchRecoveryStopOnToken[Lexeme::ReservedEnd]--;
return allocator.alloc<AstStatFunction>(Location(start, body->location), expr, body);
}
std::pair<bool, AstAttr::Type> Parser::validateAttribute(const char* attributeName, const TempVector<AstAttr*>& attributes)
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
AstAttr::Type type;
// check if the attribute name is valid
bool found = false;
for (int i = 0; kAttributeEntries[i].name; ++i)
{
found = !strcmp(attributeName, kAttributeEntries[i].name);
if (found)
{
type = kAttributeEntries[i].type;
break;
}
}
if (!found)
{
if (strlen(attributeName) == 1)
report(lexer.current().location, "Attribute name is missing");
else
report(lexer.current().location, "Invalid attribute '%s'", attributeName);
}
else
{
// check that attribute is not duplicated
for (const AstAttr* attr : attributes)
{
if (attr->type == type)
{
report(lexer.current().location, "Cannot duplicate attribute '%s'", attributeName);
}
}
}
return {found, type};
}
// attribute ::= '@' NAME
void Parser::parseAttribute(TempVector<AstAttr*>& attributes)
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
LUAU_ASSERT(lexer.current().type == Lexeme::Type::Attribute);
Location loc = lexer.current().location;
const char* name = lexer.current().name;
const auto [found, type] = validateAttribute(name, attributes);
nextLexeme();
if (found)
attributes.push_back(allocator.alloc<AstAttr>(loc, type));
}
// attributes ::= {attribute}
AstArray<AstAttr*> Parser::parseAttributes()
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
Lexeme::Type type = lexer.current().type;
LUAU_ASSERT(type == Lexeme::Attribute);
TempVector<AstAttr*> attributes(scratchAttr);
while (lexer.current().type == Lexeme::Attribute)
parseAttribute(attributes);
return copy(attributes);
}
// attributes local function Name funcbody
// attributes function funcname funcbody
// attributes `declare function' Name`(' [parlist] `)' [`:` Type]
// declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}'
AstStat* Parser::parseAttributeStat()
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
AstArray<AstAttr*> attributes = Parser::parseAttributes();
Lexeme::Type type = lexer.current().type;
switch (type)
{
case Lexeme::Type::ReservedFunction:
return parseFunctionStat(attributes);
case Lexeme::Type::ReservedLocal:
return parseLocal(attributes);
case Lexeme::Type::Name:
if (options.allowDeclarationSyntax && !strcmp("declare", lexer.current().data))
{
AstExpr* expr = parsePrimaryExpr(/* asStatement= */ true);
return parseDeclaration(expr->location, attributes);
}
default:
return reportStatError(lexer.current().location, {}, {},
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got %s intead",
lexer.current().toString().c_str());
}
}
// local function Name funcbody |
// local bindinglist [`=' explist]
AstStat* Parser::parseLocal()
AstStat* Parser::parseLocal(const AstArray<AstAttr*>& attributes)
{
Location start = lexer.current().location;
@ -694,7 +817,7 @@ AstStat* Parser::parseLocal()
matchRecoveryStopOnToken[Lexeme::ReservedEnd]++;
auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name);
auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name, attributes);
matchRecoveryStopOnToken[Lexeme::ReservedEnd]--;
@ -704,6 +827,12 @@ AstStat* Parser::parseLocal()
}
else
{
if (FFlag::LuauAttributeSyntax && attributes.size != 0)
{
return reportStatError(lexer.current().location, {}, {}, "Expected 'function' after local declaration with attribute, but got %s intead",
lexer.current().toString().c_str());
}
matchRecoveryStopOnToken['=']++;
TempVector<Binding> names(scratchBinding);
@ -831,18 +960,17 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod()
return AstDeclaredClassProp{fnName.name, fnType, true};
}
AstStat* Parser::parseDeclaration(const Location& start)
AstStat* Parser::parseDeclaration(const Location& start, const AstArray<AstAttr*>& attributes)
{
// `declare` token is already parsed at this point
if (FFlag::LuauAttributeSyntax && (attributes.size != 0) && (lexer.current().type != Lexeme::ReservedFunction))
return reportStatError(lexer.current().location, {}, {}, "Expected a function type declaration after attribute, but got %s intead",
lexer.current().toString().c_str());
if (lexer.current().type == Lexeme::ReservedFunction)
{
nextLexeme();
bool checkedFunction = false;
if (lexer.current().type == Lexeme::ReservedChecked)
{
checkedFunction = true;
nextLexeme();
}
Name globalName = parseName("global function name");
auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false);
@ -880,8 +1008,8 @@ AstStat* Parser::parseDeclaration(const Location& start)
if (vararg && !varargAnnotation)
return reportStatError(Location(start, end), {}, {}, "All declaration parameters must be annotated");
return allocator.alloc<AstStatDeclareFunction>(Location(start, end), globalName.name, generics, genericPacks,
AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes, checkedFunction);
return allocator.alloc<AstStatDeclareFunction>(Location(start, end), attributes, globalName.name, generics, genericPacks,
AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes);
}
else if (AstName(lexer.current().name) == "class")
{
@ -1035,7 +1163,7 @@ std::pair<AstLocal*, AstArray<AstLocal*>> Parser::prepareFunctionArguments(const
// funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end
// parlist ::= bindinglist [`,' `...'] | `...'
std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody(
bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName)
bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray<AstAttr*>& attributes)
{
Location start = matchFunction.location;
@ -1087,7 +1215,7 @@ std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody(
bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction);
body->hasEnd = hasEnd;
return {allocator.alloc<AstExprFunction>(Location(start, end), generics, genericPacks, self, vars, vararg, varargLocation, body,
return {allocator.alloc<AstExprFunction>(Location(start, end), attributes, generics, genericPacks, self, vars, vararg, varargLocation, body,
functionStack.size(), debugname, typelist, varargAnnotation, argLocation),
funLocal};
}
@ -1296,7 +1424,7 @@ std::pair<Location, AstTypeList> Parser::parseReturnType()
return {location, AstTypeList{copy(result), varargAnnotation}};
}
AstType* tail = parseFunctionTypeTail(begin, {}, {}, copy(result), copy(resultNames), varargAnnotation);
AstType* tail = parseFunctionTypeTail(begin, {nullptr, 0}, {}, {}, copy(result), copy(resultNames), varargAnnotation);
return {Location{location, tail->location}, AstTypeList{copy(&tail, 1), varargAnnotation}};
}
@ -1435,7 +1563,7 @@ AstType* Parser::parseTableType(bool inDeclarationContext)
// ReturnType ::= Type | `(' TypeList `)'
// FunctionType ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType
AstTypeOrPack Parser::parseFunctionType(bool allowPack, bool isCheckedFunction)
AstTypeOrPack Parser::parseFunctionType(bool allowPack, const AstArray<AstAttr*>& attributes)
{
incrementRecursionCounter("type annotation");
@ -1483,11 +1611,12 @@ AstTypeOrPack Parser::parseFunctionType(bool allowPack, bool isCheckedFunction)
AstArray<std::optional<AstArgumentName>> paramNames = copy(names);
return {parseFunctionTypeTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation, isCheckedFunction), {}};
return {parseFunctionTypeTail(begin, attributes, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}};
}
AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray<AstGenericType> generics, AstArray<AstGenericTypePack> genericPacks,
AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames, AstTypePack* varargAnnotation, bool isCheckedFunction)
AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, const AstArray<AstAttr*>& attributes, AstArray<AstGenericType> generics,
AstArray<AstGenericTypePack> genericPacks, AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames,
AstTypePack* varargAnnotation)
{
incrementRecursionCounter("type annotation");
@ -1512,7 +1641,7 @@ AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray<AstGenericT
AstTypeList paramTypes = AstTypeList{params, varargAnnotation};
return allocator.alloc<AstTypeFunction>(
Location(begin.location, endLocation), generics, genericPacks, paramTypes, paramNames, returnTypeList, isCheckedFunction);
Location(begin.location, endLocation), attributes, generics, genericPacks, paramTypes, paramNames, returnTypeList);
}
// Type ::=
@ -1666,7 +1795,21 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext)
Location start = lexer.current().location;
if (lexer.current().type == Lexeme::ReservedNil)
AstArray<AstAttr*> attributes{nullptr, 0};
if (lexer.current().type == Lexeme::Attribute)
{
if (!inDeclarationContext || !FFlag::LuauAttributeSyntax)
{
return {reportTypeError(start, {}, "attributes are not allowed in declaration context")};
}
else
{
attributes = Parser::parseAttributes();
return parseFunctionType(allowPack, attributes);
}
}
else if (lexer.current().type == Lexeme::ReservedNil)
{
nextLexeme();
return {allocator.alloc<AstTypeReference>(start, std::nullopt, nameNil, std::nullopt, start), {}};
@ -1754,14 +1897,9 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext)
{
return {parseTableType(/* inDeclarationContext */ inDeclarationContext), {}};
}
else if (inDeclarationContext && lexer.current().type == Lexeme::ReservedChecked)
{
nextLexeme();
return parseFunctionType(allowPack, /* isCheckedFunction */ true);
}
else if (lexer.current().type == '(' || lexer.current().type == '<')
{
return parseFunctionType(allowPack);
return parseFunctionType(allowPack, AstArray<AstAttr*>({nullptr, 0}));
}
else if (lexer.current().type == Lexeme::ReservedFunction)
{
@ -2259,7 +2397,7 @@ AstExpr* Parser::parseSimpleExpr()
Lexeme matchFunction = lexer.current();
nextLexeme();
return parseFunctionBody(false, matchFunction, AstName(), nullptr).first;
return parseFunctionBody(false, matchFunction, AstName(), nullptr, AstArray<AstAttr*>({nullptr, 0})).first;
}
else if (lexer.current().type == Lexeme::Number)
{
@ -2689,7 +2827,7 @@ std::optional<AstArray<char>> Parser::parseCharArray()
LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString ||
lexer.current().type == Lexeme::InterpStringSimple);
scratchData.assign(lexer.current().data, lexer.current().length);
scratchData.assign(lexer.current().data, lexer.current().getLength());
if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple)
{
@ -2734,7 +2872,7 @@ AstExpr* Parser::parseInterpString()
endLocation = currentLexeme.location;
scratchData.assign(currentLexeme.data, currentLexeme.length);
scratchData.assign(currentLexeme.data, currentLexeme.getLength());
if (!Lexer::fixupQuotedString(scratchData))
{
@ -2807,7 +2945,7 @@ AstExpr* Parser::parseNumber()
{
Location start = lexer.current().location;
scratchData.assign(lexer.current().data, lexer.current().length);
scratchData.assign(lexer.current().data, lexer.current().getLength());
// Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al
if (scratchData.find('_') != std::string::npos)
@ -3162,11 +3300,11 @@ void Parser::nextLexeme()
return;
// Comments starting with ! are called "hot comments" and contain directives for type checking / linting / compiling
if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!')
if (lexeme.type == Lexeme::Comment && lexeme.getLength() && lexeme.data[0] == '!')
{
const char* text = lexeme.data;
unsigned int end = lexeme.length;
unsigned int end = lexeme.getLength();
while (end > 0 && isSpace(text[end - 1]))
--end;

View file

@ -73,12 +73,39 @@ struct CompilationResult
};
struct IrBuilder;
struct IrOp;
using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength);
using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos);
using HostVectorNamecallHandler = bool (*)(
IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos);
enum class HostMetamethod
{
Add,
Sub,
Mul,
Div,
Idiv,
Mod,
Pow,
Minus,
Equal,
LessThan,
LessEqual,
Length,
Concat,
};
using HostUserdataOperationBytecodeType = uint8_t (*)(uint8_t type, const char* member, size_t memberLength);
using HostUserdataMetamethodBytecodeType = uint8_t (*)(uint8_t lhsTy, uint8_t rhsTy, HostMetamethod method);
using HostUserdataAccessHandler = bool (*)(
IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos);
using HostUserdataMetamethodHandler = bool (*)(
IrBuilder& builder, uint8_t lhsTy, uint8_t rhsTy, int resultReg, IrOp lhs, IrOp rhs, HostMetamethod method, int pcpos);
using HostUserdataNamecallHandler = bool (*)(
IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos);
struct HostIrHooks
{
// Suggest result type of a vector field access
@ -97,6 +124,34 @@ struct HostIrHooks
// All other arguments can be of any type
// Guards should take a VM exit to 'pcpos'
HostVectorNamecallHandler vectorNamecall = nullptr;
// Suggest result type of a userdata field access
HostUserdataOperationBytecodeType userdataAccessBytecodeType = nullptr;
// Suggest result type of a metamethod call
HostUserdataMetamethodBytecodeType userdataMetamethodBytecodeType = nullptr;
// Suggest result type of a userdata namecall
HostUserdataOperationBytecodeType userdataNamecallBytecodeType = nullptr;
// Handle userdata value field access
// 'sourceReg' is guaranteed to be a userdata, but tag has to be checked
// Write to 'resultReg' might invalidate 'sourceReg'
// Guards should take a VM exit to 'pcpos'
HostUserdataAccessHandler userdataAccess = nullptr;
// Handle metamethod operation on a userdata value
// 'lhs' and 'rhs' operands can be VM registers of constants
// Operand types have to be checked and userdata operand tags have to be checked
// Write to 'resultReg' might invalidate source operands
// Guards should take a VM exit to 'pcpos'
HostUserdataMetamethodHandler userdataMetamethod = nullptr;
// Handle namecall performed on a userdata value
// 'sourceReg' (self argument) is guaranteed to be a userdata, but tag has to be checked
// All other arguments can be of any type
// Guards should take a VM exit to 'pcpos'
HostUserdataNamecallHandler userdataNamecall = nullptr;
};
struct CompilationOptions

View file

@ -290,6 +290,11 @@ enum class IrCmd : uint8_t
// C: block
TRY_CALL_FASTGETTM,
// Create new tagged userdata
// A: int (size)
// B: int (tag)
NEW_USERDATA,
// Convert integer into a double number
// A: int
INT_TO_NUM,
@ -460,6 +465,13 @@ enum class IrCmd : uint8_t
// When undef is specified instead of a block, execution is aborted on check failure
CHECK_BUFFER_LEN,
// Guard against userdata tag mismatch
// A: pointer (userdata)
// B: int (tag)
// C: block/vmexit/undef
// When undef is specified instead of a block, execution is aborted on check failure
CHECK_USERDATA_TAG,
// Special operations
// Check interrupt handler

View file

@ -11,6 +11,7 @@ namespace CodeGen
{
struct IrBuilder;
enum class HostMetamethod;
inline bool isJumpD(LuauOpcode op)
{
@ -129,6 +130,7 @@ inline bool isNonTerminatingJump(IrCmd cmd)
case IrCmd::CHECK_NODE_NO_NEXT:
case IrCmd::CHECK_NODE_VALUE:
case IrCmd::CHECK_BUFFER_LEN:
case IrCmd::CHECK_USERDATA_TAG:
return true;
default:
break;
@ -182,6 +184,7 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::DUP_TABLE:
case IrCmd::TRY_NUM_TO_INDEX:
case IrCmd::TRY_CALL_FASTGETTM:
case IrCmd::NEW_USERDATA:
case IrCmd::INT_TO_NUM:
case IrCmd::UINT_TO_NUM:
case IrCmd::NUM_TO_INT:
@ -245,6 +248,8 @@ bool isGCO(uint8_t tag);
bool isUserdataBytecodeType(uint8_t ty);
bool isCustomUserdataBytecodeType(uint8_t ty);
HostMetamethod tmToHostMetamethod(int tm);
// Manually add or remove use of an operand
void addUse(IrFunction& function, IrOp op);
void removeUse(IrFunction& function, IrOp op);

View file

@ -16,6 +16,7 @@ LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo loa
LUAU_FASTFLAGVARIABLE(LuauCodegenTypeInfo, false) // New analysis is flagged separately
LUAU_FASTFLAGVARIABLE(LuauCodegenAnalyzeHostVectorOps, false)
LUAU_FASTFLAGVARIABLE(LuauCodegenLoadTypeUpvalCheck, false)
LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOps, false)
namespace Luau
{
@ -546,6 +547,49 @@ static void applyBuiltinCall(int bfid, BytecodeTypes& types)
}
}
static HostMetamethod opcodeToHostMetamethod(LuauOpcode op)
{
switch (op)
{
case LOP_ADD:
return HostMetamethod::Add;
case LOP_SUB:
return HostMetamethod::Sub;
case LOP_MUL:
return HostMetamethod::Mul;
case LOP_DIV:
return HostMetamethod::Div;
case LOP_IDIV:
return HostMetamethod::Idiv;
case LOP_MOD:
return HostMetamethod::Mod;
case LOP_POW:
return HostMetamethod::Pow;
case LOP_ADDK:
return HostMetamethod::Add;
case LOP_SUBK:
return HostMetamethod::Sub;
case LOP_MULK:
return HostMetamethod::Mul;
case LOP_DIVK:
return HostMetamethod::Div;
case LOP_IDIVK:
return HostMetamethod::Idiv;
case LOP_MODK:
return HostMetamethod::Mod;
case LOP_POWK:
return HostMetamethod::Pow;
case LOP_SUBRK:
return HostMetamethod::Sub;
case LOP_DIVRK:
return HostMetamethod::Div;
default:
CODEGEN_ASSERT(!"opcode is not assigned to a host metamethod");
}
return HostMetamethod::Add;
}
void buildBytecodeBlocks(IrFunction& function, const std::vector<uint8_t>& jumpTargets)
{
Proto* proto = function.proto;
@ -760,22 +804,50 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
regTags[ra] = LBC_TYPE_ANY;
if (bcType.a == LBC_TYPE_VECTOR)
if (FFlag::LuauCodegenUserdataOps)
{
TString* str = gco2ts(function.proto->k[kc].value.gc);
const char* field = getstr(str);
if (str->len == 1)
if (bcType.a == LBC_TYPE_VECTOR)
{
// Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z"
char ch = field[0] | ' ';
if (str->len == 1)
{
// Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z"
char ch = field[0] | ' ';
if (ch == 'x' || ch == 'y' || ch == 'z')
regTags[ra] = LBC_TYPE_NUMBER;
if (ch == 'x' || ch == 'y' || ch == 'z')
regTags[ra] = LBC_TYPE_NUMBER;
}
if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType)
regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len);
}
else if (isCustomUserdataBytecodeType(bcType.a))
{
if (regTags[ra] == LBC_TYPE_ANY && hostHooks.userdataAccessBytecodeType)
regTags[ra] = hostHooks.userdataAccessBytecodeType(bcType.a, field, str->len);
}
}
else
{
if (bcType.a == LBC_TYPE_VECTOR)
{
TString* str = gco2ts(function.proto->k[kc].value.gc);
const char* field = getstr(str);
if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType)
regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len);
if (str->len == 1)
{
// Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z"
char ch = field[0] | ' ';
if (ch == 'x' || ch == 'y' || ch == 'z')
regTags[ra] = LBC_TYPE_NUMBER;
}
if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType)
regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len);
}
}
bcType.result = regTags[ra];
@ -812,6 +884,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra];
break;
@ -841,6 +916,11 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
}
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
{
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
}
bcType.result = regTags[ra];
break;
@ -859,6 +939,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra];
break;
@ -879,6 +962,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra];
break;
@ -908,6 +994,11 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
}
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
{
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
}
bcType.result = regTags[ra];
break;
@ -926,6 +1017,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra];
break;
@ -945,6 +1039,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra];
break;
@ -972,6 +1069,11 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
}
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
{
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
}
bcType.result = regTags[ra];
break;
@ -1000,6 +1102,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && isCustomUserdataBytecodeType(bcType.a))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, LBC_TYPE_ANY, HostMetamethod::Minus);
bcType.result = regTags[ra];
break;
@ -1140,12 +1244,25 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
bcType.result = LBC_TYPE_FUNCTION;
if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType)
if (FFlag::LuauCodegenUserdataOps)
{
TString* str = gco2ts(function.proto->k[kc].value.gc);
const char* field = getstr(str);
knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len));
if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType)
knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len));
else if (isCustomUserdataBytecodeType(bcType.a) && hostHooks.userdataNamecallBytecodeType)
knownNextCallResult = LuauBytecodeType(hostHooks.userdataNamecallBytecodeType(bcType.a, field, str->len));
}
else
{
if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType)
{
TString* str = gco2ts(function.proto->k[kc].value.gc);
const char* field = getstr(str);
knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len));
}
}
}
break;

View file

@ -258,39 +258,6 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde
return locations;
}
bool initHeaderFunctions(NativeState& data)
{
AssemblyBuilderA64 build(/* logText= */ false);
UnwindBuilder& unwind = *data.unwindBuilder.get();
unwind.startInfo(UnwindBuilder::A64);
EntryLocations entryLocations = buildEntryFunction(build, unwind);
build.finalize();
unwind.finishInfo();
CODEGEN_ASSERT(build.data.empty());
uint8_t* codeStart = nullptr;
if (!data.codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast<const uint8_t*>(build.code.data()),
int(build.code.size() * sizeof(build.code[0])), data.gateData, data.gateDataSize, codeStart))
{
CODEGEN_ASSERT(!"Failed to create entry function");
return false;
}
// Set the offset at the begining so that functions in new blocks will not overlay the locations
// specified by the unwind information of the entry function
unwind.setBeginOffset(build.getLabelOffset(entryLocations.prologueEnd));
data.context.gateEntry = codeStart + build.getLabelOffset(entryLocations.start);
data.context.gateExit = codeStart + build.getLabelOffset(entryLocations.epilogueStart);
return true;
}
bool initHeaderFunctions(BaseCodeGenContext& codeGenContext)
{
AssemblyBuilderA64 build(/* logText= */ false);

View file

@ -7,7 +7,6 @@ namespace CodeGen
{
class BaseCodeGenContext;
struct NativeState;
struct ModuleHelpers;
namespace A64
@ -15,7 +14,6 @@ namespace A64
class AssemblyBuilderA64;
bool initHeaderFunctions(NativeState& data);
bool initHeaderFunctions(BaseCodeGenContext& codeGenContext);
void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers);

View file

@ -14,8 +14,8 @@
LUAU_FASTFLAGVARIABLE(LuauCodegenCheckNullContext, false)
LUAU_FASTINT(LuauCodeGenBlockSize)
LUAU_FASTINT(LuauCodeGenMaxTotalSize)
LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024)
LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024)
namespace Luau
{

View file

@ -14,6 +14,7 @@
#include "lstate.h"
#include "lstring.h"
#include "ltable.h"
#include "ludata.h"
#include <string.h>
@ -219,6 +220,20 @@ void callEpilogC(lua_State* L, int nresults, int n)
L->top = (nresults == LUA_MULTRET) ? res : cip->top;
}
Udata* newUserdata(lua_State* L, size_t s, int tag)
{
Udata* u = luaU_newudata(L, s, tag);
if (Table* h = L->global->udatamt[tag])
{
u->metatable = h;
luaC_objbarrier(L, u, h);
}
return u;
}
// Extracted as-is from lvmexecute.cpp with the exception of control flow (reentry) and removed interrupts/savedpc
Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults)
{

View file

@ -17,6 +17,8 @@ void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc);
Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults);
void callEpilogC(lua_State* L, int nresults, int n);
Udata* newUserdata(lua_State* L, size_t s, int tag);
#define CALL_FALLBACK_YIELD 1
Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults);

View file

@ -186,39 +186,6 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde
return locations;
}
bool initHeaderFunctions(NativeState& data)
{
AssemblyBuilderX64 build(/* logText= */ false);
UnwindBuilder& unwind = *data.unwindBuilder.get();
unwind.startInfo(UnwindBuilder::X64);
EntryLocations entryLocations = buildEntryFunction(build, unwind);
build.finalize();
unwind.finishInfo();
CODEGEN_ASSERT(build.data.empty());
uint8_t* codeStart = nullptr;
if (!data.codeAllocator.allocate(
build.data.data(), int(build.data.size()), build.code.data(), int(build.code.size()), data.gateData, data.gateDataSize, codeStart))
{
CODEGEN_ASSERT(!"Failed to create entry function");
return false;
}
// Set the offset at the begining so that functions in new blocks will not overlay the locations
// specified by the unwind information of the entry function
unwind.setBeginOffset(build.getLabelOffset(entryLocations.prologueEnd));
data.context.gateEntry = codeStart + build.getLabelOffset(entryLocations.start);
data.context.gateExit = codeStart + build.getLabelOffset(entryLocations.epilogueStart);
return true;
}
bool initHeaderFunctions(BaseCodeGenContext& codeGenContext)
{
AssemblyBuilderX64 build(/* logText= */ false);

View file

@ -7,7 +7,6 @@ namespace CodeGen
{
class BaseCodeGenContext;
struct NativeState;
struct ModuleHelpers;
namespace X64
@ -15,7 +14,6 @@ namespace X64
class AssemblyBuilderX64;
bool initHeaderFunctions(NativeState& data);
bool initHeaderFunctions(BaseCodeGenContext& codeGenContext);
void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers);

View file

@ -22,8 +22,6 @@ namespace Luau
namespace CodeGen
{
struct NativeState;
namespace A64
{

View file

@ -26,7 +26,6 @@ namespace CodeGen
{
enum class IrCondition : uint8_t;
struct NativeState;
struct IrOp;
namespace X64

View file

@ -199,6 +199,8 @@ const char* getCmdName(IrCmd cmd)
return "TRY_NUM_TO_INDEX";
case IrCmd::TRY_CALL_FASTGETTM:
return "TRY_CALL_FASTGETTM";
case IrCmd::NEW_USERDATA:
return "NEW_USERDATA";
case IrCmd::INT_TO_NUM:
return "INT_TO_NUM";
case IrCmd::UINT_TO_NUM:
@ -257,6 +259,8 @@ const char* getCmdName(IrCmd cmd)
return "CHECK_NODE_VALUE";
case IrCmd::CHECK_BUFFER_LEN:
return "CHECK_BUFFER_LEN";
case IrCmd::CHECK_USERDATA_TAG:
return "CHECK_USERDATA_TAG";
case IrCmd::INTERRUPT:
return "INTERRUPT";
case IrCmd::CHECK_GC:

View file

@ -13,6 +13,9 @@
LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5)
LUAU_FASTFLAG(LuauCodegenSplitDoarith)
LUAU_FASTFLAG(LuauCodegenUserdataOps)
LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataAlloc, false)
LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOpsFixA64, false)
namespace Luau
{
@ -1083,6 +1086,19 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
inst.regA64 = regs.takeReg(x0, index);
break;
}
case IrCmd::NEW_USERDATA:
{
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc);
regs.spill(build, index);
build.mov(x0, rState);
build.mov(x1, intOp(inst.a));
build.mov(x2, intOp(inst.b));
build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, newUserdata)));
build.blr(x3);
inst.regA64 = regs.takeReg(x0, index);
break;
}
case IrCmd::INT_TO_NUM:
{
inst.regA64 = regs.allocReg(KindA64::d, index);
@ -1677,6 +1693,24 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
finalizeTargetLabel(inst.d, fresh);
break;
}
case IrCmd::CHECK_USERDATA_TAG:
{
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps);
Label fresh; // used when guard aborts execution or jumps to a VM exit
Label& fail = getTargetLabel(inst.c, fresh);
RegisterA64 temp = regs.allocTemp(KindA64::w);
build.ldrb(temp, mem(regOp(inst.a), offsetof(Udata, tag)));
if (FFlag::LuauCodegenUserdataOpsFixA64)
build.cmp(temp, intOp(inst.b));
else
build.cmp(temp, tagOp(inst.b));
build.b(ConditionA64::NotEqual, fail);
finalizeTargetLabel(inst.c, fresh);
break;
}
case IrCmd::INTERRUPT:
{
regs.spill(build, index);
@ -2308,7 +2342,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READI8:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldrsb(inst.regA64, addr);
break;
@ -2317,7 +2351,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READU8:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldrb(inst.regA64, addr);
break;
@ -2326,7 +2360,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_WRITEI8:
{
RegisterA64 temp = tempInt(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.strb(temp, addr);
break;
@ -2335,7 +2369,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READI16:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldrsh(inst.regA64, addr);
break;
@ -2344,7 +2378,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READU16:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldrh(inst.regA64, addr);
break;
@ -2353,7 +2387,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_WRITEI16:
{
RegisterA64 temp = tempInt(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.strh(temp, addr);
break;
@ -2362,7 +2396,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READI32:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldr(inst.regA64, addr);
break;
@ -2371,7 +2405,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_WRITEI32:
{
RegisterA64 temp = tempInt(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.str(temp, addr);
break;
@ -2381,7 +2415,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
inst.regA64 = regs.allocReg(KindA64::d, index);
RegisterA64 temp = castReg(KindA64::s, inst.regA64); // safe to alias a fresh register
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldr(temp, addr);
build.fcvt(inst.regA64, temp);
@ -2392,7 +2426,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
RegisterA64 temp1 = tempDouble(inst.c);
RegisterA64 temp2 = regs.allocTemp(KindA64::s);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.fcvt(temp2, temp1);
build.str(temp2, addr);
@ -2402,7 +2436,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READF64:
{
inst.regA64 = regs.allocReg(KindA64::d, index);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldr(inst.regA64, addr);
break;
@ -2411,7 +2445,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_WRITEF64:
{
RegisterA64 temp = tempDouble(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.str(temp, addr);
break;
@ -2639,32 +2673,68 @@ AddressA64 IrLoweringA64::tempAddr(IrOp op, int offset)
}
}
AddressA64 IrLoweringA64::tempAddrBuffer(IrOp bufferOp, IrOp indexOp)
AddressA64 IrLoweringA64::tempAddrBuffer(IrOp bufferOp, IrOp indexOp, uint8_t tag)
{
if (indexOp.kind == IrOpKind::Inst)
if (FFlag::LuauCodegenUserdataOps)
{
RegisterA64 temp = regs.allocTemp(KindA64::x);
build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw
return mem(temp, offsetof(Buffer, data));
}
else if (indexOp.kind == IrOpKind::Constant)
{
// Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled encoding
if (unsigned(intOp(indexOp)) + offsetof(Buffer, data) <= 255)
return mem(regOp(bufferOp), int(intOp(indexOp) + offsetof(Buffer, data)));
CODEGEN_ASSERT(tag == LUA_TUSERDATA || tag == LUA_TBUFFER);
int dataOffset = tag == LUA_TBUFFER ? offsetof(Buffer, data) : offsetof(Udata, data);
// indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset
if (intOp(indexOp) < 0)
return mem(regOp(bufferOp), offsetof(Buffer, data));
if (indexOp.kind == IrOpKind::Inst)
{
RegisterA64 temp = regs.allocTemp(KindA64::x);
build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw
return mem(temp, dataOffset);
}
else if (indexOp.kind == IrOpKind::Constant)
{
// Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled
// encoding
if (unsigned(intOp(indexOp)) + dataOffset <= 255)
return mem(regOp(bufferOp), int(intOp(indexOp) + dataOffset));
RegisterA64 temp = regs.allocTemp(KindA64::x);
emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp)));
return mem(temp, offsetof(Buffer, data));
// indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset
if (intOp(indexOp) < 0)
return mem(regOp(bufferOp), dataOffset);
RegisterA64 temp = regs.allocTemp(KindA64::x);
emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp)));
return mem(temp, dataOffset);
}
else
{
CODEGEN_ASSERT(!"Unsupported instruction form");
return noreg;
}
}
else
{
CODEGEN_ASSERT(!"Unsupported instruction form");
return noreg;
if (indexOp.kind == IrOpKind::Inst)
{
RegisterA64 temp = regs.allocTemp(KindA64::x);
build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw
return mem(temp, offsetof(Buffer, data));
}
else if (indexOp.kind == IrOpKind::Constant)
{
// Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled
// encoding
if (unsigned(intOp(indexOp)) + offsetof(Buffer, data) <= 255)
return mem(regOp(bufferOp), int(intOp(indexOp) + offsetof(Buffer, data)));
// indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset
if (intOp(indexOp) < 0)
return mem(regOp(bufferOp), offsetof(Buffer, data));
RegisterA64 temp = regs.allocTemp(KindA64::x);
emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp)));
return mem(temp, offsetof(Buffer, data));
}
else
{
CODEGEN_ASSERT(!"Unsupported instruction form");
return noreg;
}
}
}

View file

@ -44,7 +44,7 @@ struct IrLoweringA64
RegisterA64 tempInt(IrOp op);
RegisterA64 tempUint(IrOp op);
AddressA64 tempAddr(IrOp op, int offset);
AddressA64 tempAddrBuffer(IrOp bufferOp, IrOp indexOp);
AddressA64 tempAddrBuffer(IrOp bufferOp, IrOp indexOp, uint8_t tag);
// May emit restore instructions
RegisterA64 regOp(IrOp op);

View file

@ -15,6 +15,9 @@
#include "lstate.h"
#include "lgc.h"
LUAU_FASTFLAG(LuauCodegenUserdataOps)
LUAU_FASTFLAG(LuauCodegenUserdataAlloc)
namespace Luau
{
namespace CodeGen
@ -905,6 +908,18 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
inst.regX64 = regs.takeReg(rax, index);
break;
}
case IrCmd::NEW_USERDATA:
{
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc);
IrCallWrapperX64 callWrap(regs, build, index);
callWrap.addArgument(SizeX64::qword, rState);
callWrap.addArgument(SizeX64::qword, intOp(inst.a));
callWrap.addArgument(SizeX64::dword, intOp(inst.b));
callWrap.call(qword[rNativeContext + offsetof(NativeContext, newUserdata)]);
inst.regX64 = regs.takeReg(rax, index);
break;
}
case IrCmd::INT_TO_NUM:
inst.regX64 = regs.allocReg(SizeX64::xmmword, index);
@ -1350,6 +1365,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
}
break;
}
case IrCmd::CHECK_USERDATA_TAG:
{
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps);
build.cmp(byte[regOp(inst.a) + offsetof(Udata, tag)], intOp(inst.b));
jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.c, next);
break;
}
case IrCmd::INTERRUPT:
{
unsigned pcpos = uintOp(inst.a);
@ -1895,71 +1918,71 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READI8:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movsx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b)]);
build.movsx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break;
case IrCmd::BUFFER_READU8:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movzx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b)]);
build.movzx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break;
case IrCmd::BUFFER_WRITEI8:
{
OperandX64 value = inst.c.kind == IrOpKind::Inst ? byteReg(regOp(inst.c)) : OperandX64(int8_t(intOp(inst.c)));
build.mov(byte[bufferAddrOp(inst.a, inst.b)], value);
build.mov(byte[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value);
break;
}
case IrCmd::BUFFER_READI16:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movsx(inst.regX64, word[bufferAddrOp(inst.a, inst.b)]);
build.movsx(inst.regX64, word[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break;
case IrCmd::BUFFER_READU16:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movzx(inst.regX64, word[bufferAddrOp(inst.a, inst.b)]);
build.movzx(inst.regX64, word[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break;
case IrCmd::BUFFER_WRITEI16:
{
OperandX64 value = inst.c.kind == IrOpKind::Inst ? wordReg(regOp(inst.c)) : OperandX64(int16_t(intOp(inst.c)));
build.mov(word[bufferAddrOp(inst.a, inst.b)], value);
build.mov(word[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value);
break;
}
case IrCmd::BUFFER_READI32:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.mov(inst.regX64, dword[bufferAddrOp(inst.a, inst.b)]);
build.mov(inst.regX64, dword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break;
case IrCmd::BUFFER_WRITEI32:
{
OperandX64 value = inst.c.kind == IrOpKind::Inst ? regOp(inst.c) : OperandX64(intOp(inst.c));
build.mov(dword[bufferAddrOp(inst.a, inst.b)], value);
build.mov(dword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value);
break;
}
case IrCmd::BUFFER_READF32:
inst.regX64 = regs.allocReg(SizeX64::xmmword, index);
build.vcvtss2sd(inst.regX64, inst.regX64, dword[bufferAddrOp(inst.a, inst.b)]);
build.vcvtss2sd(inst.regX64, inst.regX64, dword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break;
case IrCmd::BUFFER_WRITEF32:
storeDoubleAsFloat(dword[bufferAddrOp(inst.a, inst.b)], inst.c);
storeDoubleAsFloat(dword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], inst.c);
break;
case IrCmd::BUFFER_READF64:
inst.regX64 = regs.allocReg(SizeX64::xmmword, index);
build.vmovsd(inst.regX64, qword[bufferAddrOp(inst.a, inst.b)]);
build.vmovsd(inst.regX64, qword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break;
case IrCmd::BUFFER_WRITEF64:
@ -1967,11 +1990,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, build.f64(doubleOp(inst.c)));
build.vmovsd(qword[bufferAddrOp(inst.a, inst.b)], tmp.reg);
build.vmovsd(qword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], tmp.reg);
}
else if (inst.c.kind == IrOpKind::Inst)
{
build.vmovsd(qword[bufferAddrOp(inst.a, inst.b)], regOp(inst.c));
build.vmovsd(qword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], regOp(inst.c));
}
else
{
@ -2190,12 +2213,25 @@ RegisterX64 IrLoweringX64::regOp(IrOp op)
return inst.regX64;
}
OperandX64 IrLoweringX64::bufferAddrOp(IrOp bufferOp, IrOp indexOp)
OperandX64 IrLoweringX64::bufferAddrOp(IrOp bufferOp, IrOp indexOp, uint8_t tag)
{
if (indexOp.kind == IrOpKind::Inst)
return regOp(bufferOp) + qwordReg(regOp(indexOp)) + offsetof(Buffer, data);
else if (indexOp.kind == IrOpKind::Constant)
return regOp(bufferOp) + intOp(indexOp) + offsetof(Buffer, data);
if (FFlag::LuauCodegenUserdataOps)
{
CODEGEN_ASSERT(tag == LUA_TUSERDATA || tag == LUA_TBUFFER);
int dataOffset = tag == LUA_TBUFFER ? offsetof(Buffer, data) : offsetof(Udata, data);
if (indexOp.kind == IrOpKind::Inst)
return regOp(bufferOp) + qwordReg(regOp(indexOp)) + dataOffset;
else if (indexOp.kind == IrOpKind::Constant)
return regOp(bufferOp) + intOp(indexOp) + dataOffset;
}
else
{
if (indexOp.kind == IrOpKind::Inst)
return regOp(bufferOp) + qwordReg(regOp(indexOp)) + offsetof(Buffer, data);
else if (indexOp.kind == IrOpKind::Constant)
return regOp(bufferOp) + intOp(indexOp) + offsetof(Buffer, data);
}
CODEGEN_ASSERT(!"Unsupported instruction form");
return noreg;

View file

@ -50,7 +50,7 @@ struct IrLoweringX64
OperandX64 memRegUintOp(IrOp op);
OperandX64 memRegTagOp(IrOp op);
RegisterX64 regOp(IrOp op);
OperandX64 bufferAddrOp(IrOp bufferOp, IrOp indexOp);
OperandX64 bufferAddrOp(IrOp bufferOp, IrOp indexOp, uint8_t tag);
RegisterX64 vecOp(IrOp op, ScopedRegX64& tmp);
IrConst constOp(IrOp op) const;

View file

@ -15,6 +15,7 @@
LUAU_FASTFLAGVARIABLE(LuauCodegenDirectUserdataFlow, false)
LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps)
LUAU_FASTFLAG(LuauCodegenUserdataOps)
namespace Luau
{
@ -444,6 +445,17 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc,
return;
}
if (FFlag::LuauCodegenUserdataOps && (isUserdataBytecodeType(bcTypes.a) || isUserdataBytecodeType(bcTypes.b)))
{
if (build.hostHooks.userdataMetamethod &&
build.hostHooks.userdataMetamethod(build, bcTypes.a, bcTypes.b, ra, opb, opc, tmToHostMetamethod(tm), pcpos))
return;
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::DO_ARITH, build.vmReg(ra), opb, opc, build.constInt(tm));
return;
}
IrOp fallback;
// fast-path: number
@ -585,6 +597,17 @@ void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos)
return;
}
if (FFlag::LuauCodegenUserdataOps && isUserdataBytecodeType(bcTypes.a))
{
if (build.hostHooks.userdataMetamethod &&
build.hostHooks.userdataMetamethod(build, bcTypes.a, bcTypes.b, ra, build.vmReg(rb), {}, tmToHostMetamethod(TM_UNM), pcpos))
return;
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::DO_ARITH, build.vmReg(ra), build.vmReg(rb), build.vmReg(rb), build.constInt(TM_UNM));
return;
}
IrOp fallback;
IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb));
@ -606,8 +629,17 @@ void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos)
FallbackStreamScope scope(build, fallback, next);
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(
IrCmd::DO_ARITH, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.constInt(TM_UNM));
if (FFlag::LuauCodegenUserdataOps)
{
build.inst(IrCmd::DO_ARITH, build.vmReg(ra), build.vmReg(rb), build.vmReg(rb), build.constInt(TM_UNM));
}
else
{
build.inst(
IrCmd::DO_ARITH, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.constInt(TM_UNM));
}
build.inst(IrCmd::JUMP, next);
}
}
@ -619,6 +651,17 @@ void translateInstLength(IrBuilder& build, const Instruction* pc, int pcpos)
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
if (FFlag::LuauCodegenUserdataOps && isUserdataBytecodeType(bcTypes.a))
{
if (build.hostHooks.userdataMetamethod &&
build.hostHooks.userdataMetamethod(build, bcTypes.a, bcTypes.b, ra, build.vmReg(rb), {}, tmToHostMetamethod(TM_LEN), pcpos))
return;
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::DO_LEN, build.vmReg(ra), build.vmReg(rb));
return;
}
IrOp fallback = build.block(IrBlockKind::Fallback);
IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb));
@ -638,7 +681,12 @@ void translateInstLength(IrBuilder& build, const Instruction* pc, int pcpos)
FallbackStreamScope scope(build, fallback, next);
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::DO_LEN, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)));
if (FFlag::LuauCodegenUserdataOps)
build.inst(IrCmd::DO_LEN, build.vmReg(ra), build.vmReg(rb));
else
build.inst(IrCmd::DO_LEN, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)));
build.inst(IrCmd::JUMP, next);
}
@ -1229,10 +1277,19 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos)
return;
}
if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA)
if (FFlag::LuauCodegenDirectUserdataFlow && (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA))
{
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TUSERDATA), build.vmExit(pcpos));
if (FFlag::LuauCodegenUserdataOps && build.hostHooks.userdataAccess)
{
TString* str = gco2ts(build.function.proto->k[aux].value.gc);
const char* field = getstr(str);
if (build.hostHooks.userdataAccess(build, bcTypes.a, field, str->len, ra, rb, pcpos))
return;
}
build.inst(IrCmd::FALLBACK_GETTABLEKS, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux));
return;
}
@ -1267,7 +1324,7 @@ void translateInstSetTableKS(IrBuilder& build, const Instruction* pc, int pcpos)
IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb));
if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA)
if (FFlag::LuauCodegenDirectUserdataFlow && (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA))
{
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TUSERDATA), build.vmExit(pcpos));
@ -1413,10 +1470,26 @@ bool translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos)
return false;
}
if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA)
if (FFlag::LuauCodegenDirectUserdataFlow && (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA))
{
build.loadAndCheckTag(build.vmReg(rb), LUA_TUSERDATA, build.vmExit(pcpos));
if (FFlag::LuauCodegenUserdataOps && build.hostHooks.userdataNamecall)
{
Instruction call = pc[2];
CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
int callra = LUAU_INSN_A(call);
int nparams = LUAU_INSN_B(call) - 1;
int nresults = LUAU_INSN_C(call) - 1;
TString* str = gco2ts(build.function.proto->k[aux].value.gc);
const char* field = getstr(str);
if (build.hostHooks.userdataNamecall(build, bcTypes.a, field, str->len, callra, rb, nparams, nresults, pcpos))
return true;
}
build.inst(IrCmd::FALLBACK_NAMECALL, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux));
return false;
}

View file

@ -99,6 +99,7 @@ IrValueKind getCmdValueKind(IrCmd cmd)
case IrCmd::TRY_NUM_TO_INDEX:
return IrValueKind::Int;
case IrCmd::TRY_CALL_FASTGETTM:
case IrCmd::NEW_USERDATA:
return IrValueKind::Pointer;
case IrCmd::INT_TO_NUM:
case IrCmd::UINT_TO_NUM:
@ -135,6 +136,7 @@ IrValueKind getCmdValueKind(IrCmd cmd)
case IrCmd::CHECK_NODE_NO_NEXT:
case IrCmd::CHECK_NODE_VALUE:
case IrCmd::CHECK_BUFFER_LEN:
case IrCmd::CHECK_USERDATA_TAG:
case IrCmd::INTERRUPT:
case IrCmd::CHECK_GC:
case IrCmd::BARRIER_OBJ:
@ -262,6 +264,44 @@ bool isCustomUserdataBytecodeType(uint8_t ty)
return ty >= LBC_TYPE_TAGGED_USERDATA_BASE && ty < LBC_TYPE_TAGGED_USERDATA_END;
}
HostMetamethod tmToHostMetamethod(int tm)
{
switch (TMS(tm))
{
case TM_ADD:
return HostMetamethod::Add;
case TM_SUB:
return HostMetamethod::Sub;
case TM_MUL:
return HostMetamethod::Mul;
case TM_DIV:
return HostMetamethod::Div;
case TM_IDIV:
return HostMetamethod::Idiv;
case TM_MOD:
return HostMetamethod::Mod;
case TM_POW:
return HostMetamethod::Pow;
case TM_UNM:
return HostMetamethod::Minus;
case TM_EQ:
return HostMetamethod::Equal;
case TM_LT:
return HostMetamethod::LessThan;
case TM_LE:
return HostMetamethod::LessEqual;
case TM_LEN:
return HostMetamethod::Length;
case TM_CONCAT:
return HostMetamethod::Concat;
default:
CODEGEN_ASSERT(!"invalid tag method for host");
break;
}
return HostMetamethod::Add;
}
void kill(IrFunction& function, IrInst& inst)
{
CODEGEN_ASSERT(inst.useCount == 0);

View file

@ -14,114 +14,13 @@
#include <math.h>
#include <string.h>
LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024)
LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024)
LUAU_FASTFLAG(LuauCodegenUserdataAlloc)
namespace Luau
{
namespace CodeGen
{
NativeState::NativeState()
: NativeState(nullptr, nullptr)
{
}
NativeState::NativeState(AllocationCallback* allocationCallback, void* allocationCallbackContext)
: codeAllocator{size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext}
{
}
NativeState::~NativeState() = default;
void initFunctions(NativeState& data)
{
static_assert(sizeof(data.context.luauF_table) == sizeof(luauF_table), "fastcall tables are not of the same length");
memcpy(data.context.luauF_table, luauF_table, sizeof(luauF_table));
data.context.luaV_lessthan = luaV_lessthan;
data.context.luaV_lessequal = luaV_lessequal;
data.context.luaV_equalval = luaV_equalval;
data.context.luaV_doarith = luaV_doarith;
data.context.luaV_doarithadd = luaV_doarithimpl<TM_ADD>;
data.context.luaV_doarithsub = luaV_doarithimpl<TM_SUB>;
data.context.luaV_doarithmul = luaV_doarithimpl<TM_MUL>;
data.context.luaV_doarithdiv = luaV_doarithimpl<TM_DIV>;
data.context.luaV_doarithidiv = luaV_doarithimpl<TM_IDIV>;
data.context.luaV_doarithmod = luaV_doarithimpl<TM_MOD>;
data.context.luaV_doarithpow = luaV_doarithimpl<TM_POW>;
data.context.luaV_doarithunm = luaV_doarithimpl<TM_UNM>;
data.context.luaV_dolen = luaV_dolen;
data.context.luaV_gettable = luaV_gettable;
data.context.luaV_settable = luaV_settable;
data.context.luaV_getimport = luaV_getimport;
data.context.luaV_concat = luaV_concat;
data.context.luaH_getn = luaH_getn;
data.context.luaH_new = luaH_new;
data.context.luaH_clone = luaH_clone;
data.context.luaH_resizearray = luaH_resizearray;
data.context.luaH_setnum = luaH_setnum;
data.context.luaC_barriertable = luaC_barriertable;
data.context.luaC_barrierf = luaC_barrierf;
data.context.luaC_barrierback = luaC_barrierback;
data.context.luaC_step = luaC_step;
data.context.luaF_close = luaF_close;
data.context.luaF_findupval = luaF_findupval;
data.context.luaF_newLclosure = luaF_newLclosure;
data.context.luaT_gettm = luaT_gettm;
data.context.luaT_objtypenamestr = luaT_objtypenamestr;
data.context.libm_exp = exp;
data.context.libm_pow = pow;
data.context.libm_fmod = fmod;
data.context.libm_log = log;
data.context.libm_log2 = log2;
data.context.libm_log10 = log10;
data.context.libm_ldexp = ldexp;
data.context.libm_round = round;
data.context.libm_frexp = frexp;
data.context.libm_modf = modf;
data.context.libm_asin = asin;
data.context.libm_sin = sin;
data.context.libm_sinh = sinh;
data.context.libm_acos = acos;
data.context.libm_cos = cos;
data.context.libm_cosh = cosh;
data.context.libm_atan = atan;
data.context.libm_atan2 = atan2;
data.context.libm_tan = tan;
data.context.libm_tanh = tanh;
data.context.forgLoopTableIter = forgLoopTableIter;
data.context.forgLoopNodeIter = forgLoopNodeIter;
data.context.forgLoopNonTableFallback = forgLoopNonTableFallback;
data.context.forgPrepXnextFallback = forgPrepXnextFallback;
data.context.callProlog = callProlog;
data.context.callEpilogC = callEpilogC;
data.context.callFallback = callFallback;
data.context.executeGETGLOBAL = executeGETGLOBAL;
data.context.executeSETGLOBAL = executeSETGLOBAL;
data.context.executeGETTABLEKS = executeGETTABLEKS;
data.context.executeSETTABLEKS = executeSETTABLEKS;
data.context.executeNAMECALL = executeNAMECALL;
data.context.executeFORGPREP = executeFORGPREP;
data.context.executeGETVARARGSMultRet = executeGETVARARGSMultRet;
data.context.executeGETVARARGSConst = executeGETVARARGSConst;
data.context.executeDUPCLOSURE = executeDUPCLOSURE;
data.context.executePREPVARARGS = executePREPVARARGS;
data.context.executeSETLIST = executeSETLIST;
}
void initFunctions(NativeContext& context)
{
static_assert(sizeof(context.luauF_table) == sizeof(luauF_table), "fastcall tables are not of the same length");
@ -194,6 +93,9 @@ void initFunctions(NativeContext& context)
context.callProlog = callProlog;
context.callEpilogC = callEpilogC;
if (FFlag::LuauCodegenUserdataAlloc)
context.newUserdata = newUserdata;
context.callFallback = callFallback;
context.executeGETGLOBAL = executeGETGLOBAL;

View file

@ -94,6 +94,7 @@ struct NativeContext
void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr;
Closure* (*callProlog)(lua_State* L, TValue* ra, StkId argtop, int nresults) = nullptr;
void (*callEpilogC)(lua_State* L, int nresults, int n) = nullptr;
Udata* (*newUserdata)(lua_State* L, size_t s, int tag) = nullptr;
Closure* (*callFallback)(lua_State* L, StkId ra, StkId argtop, int nresults) = nullptr;
@ -116,22 +117,6 @@ struct NativeContext
using GateFn = int (*)(lua_State*, Proto*, uintptr_t, NativeContext*);
struct NativeState
{
NativeState();
NativeState(AllocationCallback* allocationCallback, void* allocationCallbackContext);
~NativeState();
CodeAllocator codeAllocator;
std::unique_ptr<UnwindBuilder> unwindBuilder;
uint8_t* gateData = nullptr;
size_t gateDataSize = 0;
NativeContext context;
};
void initFunctions(NativeState& data);
void initFunctions(NativeContext& context);
} // namespace CodeGen

View file

@ -16,9 +16,12 @@
LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3)
LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64)
LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64)
LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false)
LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5)
LUAU_FASTFLAGVARIABLE(LuauCodegenFixSplitStoreConstMismatch, false)
LUAU_FASTFLAG(LuauCodegenUserdataOps)
LUAU_FASTFLAG(LuauCodegenUserdataAlloc)
namespace Luau
{
@ -200,6 +203,11 @@ struct ConstPropState
checkBufferLenCache.clear();
}
void invalidateUserdataData()
{
useradataTagCache.clear();
}
void invalidateHeap()
{
for (int i = 0; i <= maxReg; ++i)
@ -417,6 +425,9 @@ struct ConstPropState
invalidateValuePropagation();
invalidateHeapTableData();
invalidateHeapBufferData();
if (FFlag::LuauCodegenUserdataOps)
invalidateUserdataData();
}
IrFunction& function;
@ -446,6 +457,9 @@ struct ConstPropState
std::vector<uint32_t> checkArraySizeCache; // Additionally, fallback block argument might be different
std::vector<uint32_t> checkBufferLenCache; // Additionally, fallback block argument might be different
// Userdata tag cache can point to both NEW_USERDATA and CHECK_USERDATA_TAG instructions
std::vector<uint32_t> useradataTagCache; // Additionally, fallback block argument might be different
};
static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid, uint32_t firstReturnReg, int nresults)
@ -1061,6 +1075,37 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
state.checkBufferLenCache.push_back(index);
break;
}
case IrCmd::CHECK_USERDATA_TAG:
{
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps);
for (uint32_t prevIdx : state.useradataTagCache)
{
IrInst& prev = function.instructions[prevIdx];
if (prev.cmd == IrCmd::CHECK_USERDATA_TAG)
{
if (prev.a != inst.a || prev.b != inst.b)
continue;
}
else if (FFlag::LuauCodegenUserdataAlloc && prev.cmd == IrCmd::NEW_USERDATA)
{
if (inst.a.kind != IrOpKind::Inst || prevIdx != inst.a.index || prev.b != inst.b)
continue;
}
if (FFlag::DebugLuauAbortingChecks)
replace(function, inst.c, build.undef());
else
kill(function, inst);
return; // Break out from both the loop and the switch
}
if (int(state.useradataTagCache.size()) < FInt::LuauCodeGenReuseUdataTagLimit)
state.useradataTagCache.push_back(index);
break;
}
case IrCmd::BUFFER_READI8:
case IrCmd::BUFFER_READU8:
case IrCmd::BUFFER_WRITEI8:
@ -1228,6 +1273,12 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
break;
case IrCmd::TRY_CALL_FASTGETTM:
break;
case IrCmd::NEW_USERDATA:
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc);
if (int(state.useradataTagCache.size()) < FInt::LuauCodeGenReuseUdataTagLimit)
state.useradataTagCache.push_back(index);
break;
case IrCmd::INT_TO_NUM:
case IrCmd::UINT_TO_NUM:
state.substituteOrRecord(inst, index);
@ -1512,6 +1563,9 @@ static void constPropInBlockChain(IrBuilder& build, std::vector<uint8_t>& visite
state.invalidateHeapTableData();
state.invalidateHeapBufferData();
if (FFlag::LuauCodegenUserdataOps)
state.invalidateUserdataData();
// Blocks in a chain are guaranteed to follow each other
// We force that by giving all blocks the same sorting key, but consecutive chain keys
block->sortkey = startSortkey;

View file

@ -10,6 +10,7 @@
#include "lobject.h"
LUAU_FASTFLAGVARIABLE(LuauCodegenRemoveDeadStores5, false)
LUAU_FASTFLAG(LuauCodegenUserdataOps)
// TODO: optimization can be improved by knowing which registers are live in at each VM exit
@ -595,6 +596,11 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build,
case IrCmd::CHECK_BUFFER_LEN:
state.checkLiveIns(inst.d);
break;
case IrCmd::CHECK_USERDATA_TAG:
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps);
state.checkLiveIns(inst.c);
break;
case IrCmd::JUMP:
// Ideally, we would be able to remove stores to registers that are not live out from a block

View file

@ -4219,7 +4219,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c
for (AstExprFunction* expr : functions)
compiler.compileFunction(expr, 0);
AstExprFunction main(root->location, /*generics= */ AstArray<AstGenericType>(), /*genericPacks= */ AstArray<AstGenericTypePack>(),
AstExprFunction main(root->location, /*attributes=*/AstArray<AstAttr*>({nullptr, 0}), /*generics= */ AstArray<AstGenericType>(),
/*genericPacks= */ AstArray<AstGenericTypePack>(),
/* self= */ nullptr, AstArray<AstLocal*>(), /* vararg= */ true, /* varargLocation= */ Luau::Location(), root, /* functionDepth= */ 0,
/* debugname= */ AstName());
uint32_t mainid = compiler.compileFunction(&main, mainFlags);

View file

@ -195,7 +195,7 @@ static Error parseJson(const std::string& contents, Action action)
}
else if (lexer.current().type == Lexeme::QuotedString)
{
std::string value(lexer.current().data, lexer.current().length);
std::string value(lexer.current().data, lexer.current().getLength());
next(lexer);
if (Error err = action(keys, value))
@ -232,7 +232,7 @@ static Error parseJson(const std::string& contents, Action action)
}
else if (lexer.current().type == Lexeme::QuotedString)
{
std::string key(lexer.current().data, lexer.current().length);
std::string key(lexer.current().data, lexer.current().getLength());
next(lexer);
keys.push_back(key);
@ -250,7 +250,7 @@ static Error parseJson(const std::string& contents, Action action)
lexer.current().type == Lexeme::ReservedFalse)
{
std::string value = lexer.current().type == Lexeme::QuotedString
? std::string(lexer.current().data, lexer.current().length)
? std::string(lexer.current().data, lexer.current().getLength())
: (lexer.current().type == Lexeme::ReservedTrue ? "true" : "false");
next(lexer);

View file

@ -324,6 +324,10 @@ typedef void (*lua_Destructor)(lua_State* L, void* userdata);
LUA_API void lua_setuserdatadtor(lua_State* L, int tag, lua_Destructor dtor);
LUA_API lua_Destructor lua_getuserdatadtor(lua_State* L, int tag);
// alternative access for metatables already registered with luaL_newmetatable
LUA_API void lua_setuserdatametatable(lua_State* L, int tag, int idx);
LUA_API void lua_getuserdatametatable(lua_State* L, int tag);
LUA_API void lua_setlightuserdataname(lua_State* L, int tag, const char* name);
LUA_API const char* lua_getlightuserdataname(lua_State* L, int tag);

View file

@ -1427,6 +1427,33 @@ lua_Destructor lua_getuserdatadtor(lua_State* L, int tag)
return L->global->udatagc[tag];
}
void lua_setuserdatametatable(lua_State* L, int tag, int idx)
{
api_check(L, unsigned(tag) < LUA_UTAG_LIMIT);
api_check(L, !L->global->udatamt[tag]); // reassignment not supported
StkId o = index2addr(L, idx);
api_check(L, ttistable(o));
L->global->udatamt[tag] = hvalue(o);
L->top--;
}
void lua_getuserdatametatable(lua_State* L, int tag)
{
api_check(L, unsigned(tag) < LUA_UTAG_LIMIT);
luaC_threadbarrier(L);
if (Table* h = L->global->udatamt[tag])
{
sethvalue(L, L->top, h);
}
else
{
setnilvalue(L->top);
}
api_incr_top(L);
}
void lua_setlightuserdataname(lua_State* L, int tag, const char* name)
{
api_check(L, unsigned(tag) < LUA_LUTAG_LIMIT);

View file

@ -210,7 +210,10 @@ lua_State* lua_newstate(lua_Alloc f, void* ud)
for (i = 0; i < LUA_T_COUNT; i++)
g->mt[i] = NULL;
for (i = 0; i < LUA_UTAG_LIMIT; i++)
{
g->udatagc[i] = NULL;
g->udatamt[i] = NULL;
}
for (i = 0; i < LUA_LUTAG_LIMIT; i++)
g->lightuserdataname[i] = NULL;
for (i = 0; i < LUA_MEMORY_CATEGORIES; i++)

View file

@ -217,6 +217,7 @@ typedef struct global_State
lua_ExecutionCallbacks ecb;
void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory
Table* udatamt[LUA_LUTAG_LIMIT]; // metatables for tagged userdata
TString* lightuserdataname[LUA_LUTAG_LIMIT]; // names for tagged lightuserdata

View file

@ -342,6 +342,209 @@ void setupVectorHelpers(lua_State* L)
lua_pop(L, 1);
}
Vec2* lua_vec2_push(lua_State* L)
{
Vec2* data = (Vec2*)lua_newuserdatatagged(L, sizeof(Vec2), kTagVec2);
lua_getuserdatametatable(L, kTagVec2);
lua_setmetatable(L, -2);
return data;
}
Vec2* lua_vec2_get(lua_State* L, int idx)
{
Vec2* a = (Vec2*)lua_touserdatatagged(L, idx, kTagVec2);
if (a)
return a;
luaL_typeerror(L, idx, "vec2");
}
static int lua_vec2(lua_State* L)
{
double x = luaL_checknumber(L, 1);
double y = luaL_checknumber(L, 2);
Vec2* data = lua_vec2_push(L);
data->x = float(x);
data->y = float(y);
return 1;
}
static int lua_vec2_dot(lua_State* L)
{
Vec2* a = lua_vec2_get(L, 1);
Vec2* b = lua_vec2_get(L, 2);
lua_pushnumber(L, a->x * b->x + a->y * b->y);
return 1;
}
static int lua_vec2_min(lua_State* L)
{
Vec2* a = lua_vec2_get(L, 1);
Vec2* b = lua_vec2_get(L, 2);
Vec2* data = lua_vec2_push(L);
data->x = a->x < b->x ? a->x : b->x;
data->y = a->y < b->y ? a->y : b->y;
return 1;
}
static int lua_vec2_index(lua_State* L)
{
Vec2* v = lua_vec2_get(L, 1);
const char* name = luaL_checkstring(L, 2);
if (strcmp(name, "X") == 0)
{
lua_pushnumber(L, v->x);
return 1;
}
if (strcmp(name, "Y") == 0)
{
lua_pushnumber(L, v->y);
return 1;
}
if (strcmp(name, "Magnitude") == 0)
{
lua_pushnumber(L, sqrtf(v->x * v->x + v->y * v->y));
return 1;
}
if (strcmp(name, "Unit") == 0)
{
float invSqrt = 1.0f / sqrtf(v->x * v->x + v->y * v->y);
Vec2* data = lua_vec2_push(L);
data->x = v->x * invSqrt;
data->y = v->y * invSqrt;
return 1;
}
luaL_error(L, "%s is not a valid member of vector", name);
}
static int lua_vec2_namecall(lua_State* L)
{
if (const char* str = lua_namecallatom(L, nullptr))
{
if (strcmp(str, "Dot") == 0)
return lua_vec2_dot(L);
if (strcmp(str, "Min") == 0)
return lua_vec2_min(L);
}
luaL_error(L, "%s is not a valid method of vector", luaL_checkstring(L, 1));
}
void setupUserdataHelpers(lua_State* L)
{
// create metatable with all the metamethods
luaL_newmetatable(L, "vec2");
luaL_getmetatable(L, "vec2");
lua_pushvalue(L, -1);
lua_setuserdatametatable(L, kTagVec2, -1);
lua_pushcfunction(L, lua_vec2_index, nullptr);
lua_setfield(L, -2, "__index");
lua_pushcfunction(L, lua_vec2_namecall, nullptr);
lua_setfield(L, -2, "__namecall");
lua_pushcclosurek(
L,
[](lua_State* L) {
Vec2* a = lua_vec2_get(L, 1);
Vec2* b = lua_vec2_get(L, 2);
Vec2* data = lua_vec2_push(L);
data->x = a->x + b->x;
data->y = a->y + b->y;
return 1;
},
nullptr, 0, nullptr);
lua_setfield(L, -2, "__add");
lua_pushcclosurek(
L,
[](lua_State* L) {
Vec2* a = lua_vec2_get(L, 1);
Vec2* b = lua_vec2_get(L, 2);
Vec2* data = lua_vec2_push(L);
data->x = a->x - b->x;
data->y = a->y - b->y;
return 1;
},
nullptr, 0, nullptr);
lua_setfield(L, -2, "__sub");
lua_pushcclosurek(
L,
[](lua_State* L) {
Vec2* a = lua_vec2_get(L, 1);
Vec2* b = lua_vec2_get(L, 2);
Vec2* data = lua_vec2_push(L);
data->x = a->x * b->x;
data->y = a->y * b->y;
return 1;
},
nullptr, 0, nullptr);
lua_setfield(L, -2, "__mul");
lua_pushcclosurek(
L,
[](lua_State* L) {
Vec2* a = lua_vec2_get(L, 1);
Vec2* b = lua_vec2_get(L, 2);
Vec2* data = lua_vec2_push(L);
data->x = a->x / b->x;
data->y = a->y / b->y;
return 1;
},
nullptr, 0, nullptr);
lua_setfield(L, -2, "__div");
lua_pushcclosurek(
L,
[](lua_State* L) {
Vec2* a = lua_vec2_get(L, 1);
Vec2* data = lua_vec2_push(L);
data->x = -a->x;
data->y = -a->y;
return 1;
},
nullptr, 0, nullptr);
lua_setfield(L, -2, "__unm");
lua_setreadonly(L, -1, true);
// ctor
lua_pushcfunction(L, lua_vec2, "vec2");
lua_setglobal(L, "vec2");
lua_pop(L, 1);
}
static void setupNativeHelpers(lua_State* L)
{
lua_pushcclosurek(
@ -1828,16 +2031,36 @@ TEST_CASE("UserdataApi")
luaL_newmetatable(L, "udata2");
void* ud5 = lua_newuserdata(L, 0);
lua_getfield(L, LUA_REGISTRYINDEX, "udata1");
luaL_getmetatable(L, "udata1");
lua_setmetatable(L, -2);
void* ud6 = lua_newuserdata(L, 0);
lua_getfield(L, LUA_REGISTRYINDEX, "udata2");
luaL_getmetatable(L, "udata2");
lua_setmetatable(L, -2);
CHECK(luaL_checkudata(L, -2, "udata1") == ud5);
CHECK(luaL_checkudata(L, -1, "udata2") == ud6);
// tagged user data with fast metatable access
luaL_newmetatable(L, "udata3");
luaL_getmetatable(L, "udata3");
lua_setuserdatametatable(L, 50, -1);
luaL_newmetatable(L, "udata4");
luaL_getmetatable(L, "udata4");
lua_setuserdatametatable(L, 51, -1);
void* ud7 = lua_newuserdatatagged(L, 16, 50);
lua_getuserdatametatable(L, 50);
lua_setmetatable(L, -2);
void* ud8 = lua_newuserdatatagged(L, 16, 51);
lua_getuserdatametatable(L, 51);
lua_setmetatable(L, -2);
CHECK(luaL_checkudata(L, -2, "udata3") == ud7);
CHECK(luaL_checkudata(L, -1, "udata4") == ud8);
globalState.reset();
CHECK(dtorhits == 42);
@ -1911,7 +2134,6 @@ TEST_CASE("Iter")
}
const int kInt64Tag = 1;
static int gInt64MT = -1;
static int64_t getInt64(lua_State* L, int idx)
{
@ -1928,7 +2150,7 @@ static void pushInt64(lua_State* L, int64_t value)
{
void* p = lua_newuserdatatagged(L, sizeof(int64_t), kInt64Tag);
lua_getref(L, gInt64MT);
luaL_getmetatable(L, "int64");
lua_setmetatable(L, -2);
*static_cast<int64_t*>(p) = value;
@ -1938,8 +2160,7 @@ TEST_CASE("Userdata")
{
runConformance("userdata.lua", [](lua_State* L) {
// create metatable with all the metamethods
lua_newtable(L);
gInt64MT = lua_ref(L, -1);
luaL_newmetatable(L, "int64");
// __index
lua_pushcfunction(
@ -2164,6 +2385,86 @@ TEST_CASE("NativeTypeAnnotations")
});
}
TEST_CASE("NativeUserdata")
{
lua_CompileOptions copts = defaultOptions();
Luau::CodeGen::CompilationOptions nativeOpts = defaultCodegenOptions();
static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr};
copts.userdataTypes = kUserdataCompileTypes;
SUBCASE("NoIrHooks")
{
SUBCASE("O0")
{
copts.optimizationLevel = 0;
}
SUBCASE("O1")
{
copts.optimizationLevel = 1;
}
SUBCASE("O2")
{
copts.optimizationLevel = 2;
}
}
SUBCASE("IrHooks")
{
nativeOpts.hooks.vectorAccessBytecodeType = vectorAccessBytecodeType;
nativeOpts.hooks.vectorNamecallBytecodeType = vectorNamecallBytecodeType;
nativeOpts.hooks.vectorAccess = vectorAccess;
nativeOpts.hooks.vectorNamecall = vectorNamecall;
nativeOpts.hooks.userdataAccessBytecodeType = userdataAccessBytecodeType;
nativeOpts.hooks.userdataMetamethodBytecodeType = userdataMetamethodBytecodeType;
nativeOpts.hooks.userdataNamecallBytecodeType = userdataNamecallBytecodeType;
nativeOpts.hooks.userdataAccess = userdataAccess;
nativeOpts.hooks.userdataMetamethod = userdataMetamethod;
nativeOpts.hooks.userdataNamecall = userdataNamecall;
nativeOpts.userdataTypes = kUserdataRunTypes;
SUBCASE("O0")
{
copts.optimizationLevel = 0;
}
SUBCASE("O1")
{
copts.optimizationLevel = 1;
}
SUBCASE("O2")
{
copts.optimizationLevel = 2;
}
}
runConformance(
"native_userdata.lua",
[](lua_State* L) {
Luau::CodeGen::setUserdataRemapper(L, kUserdataRunTypes, [](void* context, const char* str, size_t len) -> uint8_t {
const char** types = (const char**)context;
uint8_t index = 0;
std::string_view sv{str, len};
for (; *types; ++types)
{
if (sv == *types)
return index;
index++;
}
return 0xff;
});
setupVectorHelpers(L);
setupUserdataHelpers(L);
},
nullptr, nullptr, &copts, false, &nativeOpts);
}
[[nodiscard]] static std::string makeHugeFunctionSource()
{
std::string source;

View file

@ -5,14 +5,44 @@
static const char* kUserdataRunTypes[] = {"extra", "color", "vec2", "mat3", nullptr};
constexpr uint8_t kUserdataExtra = 0;
constexpr uint8_t kUserdataColor = 1;
constexpr uint8_t kUserdataVec2 = 2;
constexpr uint8_t kUserdataMat3 = 3;
// Userdata tags can be different from userdata bytecode type indices
constexpr uint8_t kTagVec2 = 12;
struct Vec2
{
float x;
float y;
};
inline bool compareMemberName(const char* member, size_t memberLength, const char* str)
{
return memberLength == strlen(str) && strcmp(member, str) == 0;
}
inline uint8_t typeToUserdataIndex(uint8_t type)
{
// Underflow will push the type into a value that is not comparable to any kUserdata* constants
return type - LBC_TYPE_TAGGED_USERDATA_BASE;
}
inline uint8_t userdataIndexToType(uint8_t userdataIndex)
{
return LBC_TYPE_TAGGED_USERDATA_BASE + userdataIndex;
}
inline uint8_t vectorAccessBytecodeType(const char* member, size_t memberLength)
{
using namespace Luau::CodeGen;
if (memberLength == strlen("Magnitude") && strcmp(member, "Magnitude") == 0)
if (compareMemberName(member, memberLength, "Magnitude"))
return LBC_TYPE_NUMBER;
if (memberLength == strlen("Unit") && strcmp(member, "Unit") == 0)
if (compareMemberName(member, memberLength, "Unit"))
return LBC_TYPE_VECTOR;
return LBC_TYPE_ANY;
@ -22,7 +52,7 @@ inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, si
{
using namespace Luau::CodeGen;
if (memberLength == strlen("Magnitude") && strcmp(member, "Magnitude") == 0)
if (compareMemberName(member, memberLength, "Magnitude"))
{
IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4));
@ -42,7 +72,7 @@ inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, si
return true;
}
if (memberLength == strlen("Unit") && strcmp(member, "Unit") == 0)
if (compareMemberName(member, memberLength, "Unit"))
{
IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4));
@ -72,10 +102,10 @@ inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, si
inline uint8_t vectorNamecallBytecodeType(const char* member, size_t memberLength)
{
if (memberLength == strlen("Dot") && strcmp(member, "Dot") == 0)
if (compareMemberName(member, memberLength, "Dot"))
return LBC_TYPE_NUMBER;
if (memberLength == strlen("Cross") && strcmp(member, "Cross") == 0)
if (compareMemberName(member, memberLength, "Cross"))
return LBC_TYPE_VECTOR;
return LBC_TYPE_ANY;
@ -86,7 +116,7 @@ inline bool vectorNamecall(
{
using namespace Luau::CodeGen;
if (memberLength == strlen("Dot") && strcmp(member, "Dot") == 0 && params == 2 && results <= 1)
if (compareMemberName(member, memberLength, "Dot") && params == 2 && results <= 1)
{
build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos));
@ -114,7 +144,7 @@ inline bool vectorNamecall(
return true;
}
if (memberLength == strlen("Cross") && strcmp(member, "Cross") == 0 && params == 2 && results <= 1)
if (compareMemberName(member, memberLength, "Cross") && params == 2 && results <= 1)
{
build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos));
@ -151,3 +181,362 @@ inline bool vectorNamecall(
return false;
}
inline uint8_t userdataAccessBytecodeType(uint8_t type, const char* member, size_t memberLength)
{
switch (typeToUserdataIndex(type))
{
case kUserdataColor:
if (compareMemberName(member, memberLength, "R"))
return LBC_TYPE_NUMBER;
if (compareMemberName(member, memberLength, "G"))
return LBC_TYPE_NUMBER;
if (compareMemberName(member, memberLength, "B"))
return LBC_TYPE_NUMBER;
break;
case kUserdataVec2:
if (compareMemberName(member, memberLength, "X"))
return LBC_TYPE_NUMBER;
if (compareMemberName(member, memberLength, "Y"))
return LBC_TYPE_NUMBER;
if (compareMemberName(member, memberLength, "Magnitude"))
return LBC_TYPE_NUMBER;
if (compareMemberName(member, memberLength, "Unit"))
return userdataIndexToType(kUserdataVec2);
break;
case kUserdataMat3:
if (compareMemberName(member, memberLength, "Row1"))
return LBC_TYPE_VECTOR;
if (compareMemberName(member, memberLength, "Row2"))
return LBC_TYPE_VECTOR;
if (compareMemberName(member, memberLength, "Row3"))
return LBC_TYPE_VECTOR;
break;
}
return LBC_TYPE_ANY;
}
inline bool userdataAccess(
Luau::CodeGen::IrBuilder& build, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos)
{
using namespace Luau::CodeGen;
switch (typeToUserdataIndex(type))
{
case kUserdataColor:
break;
case kUserdataVec2:
if (compareMemberName(member, memberLength, "X"))
{
IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg));
build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp value = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), value);
build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER));
return true;
}
if (compareMemberName(member, memberLength, "Y"))
{
IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg));
build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp value = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), value);
build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER));
return true;
}
if (compareMemberName(member, memberLength, "Magnitude"))
{
IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg));
build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp x = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp y = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x);
IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y);
IrOp sum = build.inst(IrCmd::ADD_NUM, x2, y2);
IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), mag);
build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER));
return true;
}
if (compareMemberName(member, memberLength, "Unit"))
{
IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg));
build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp x = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp y = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x);
IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y);
IrOp sum = build.inst(IrCmd::ADD_NUM, x2, y2);
IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);
IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag);
IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv);
IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv);
build.inst(IrCmd::CHECK_GC);
IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), xr, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), yr, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar);
build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA));
return true;
}
break;
case kUserdataMat3:
break;
}
return false;
}
inline uint8_t userdataMetamethodBytecodeType(uint8_t lhsTy, uint8_t rhsTy, Luau::CodeGen::HostMetamethod method)
{
switch (method)
{
case Luau::CodeGen::HostMetamethod::Add:
case Luau::CodeGen::HostMetamethod::Sub:
case Luau::CodeGen::HostMetamethod::Mul:
case Luau::CodeGen::HostMetamethod::Div:
if (typeToUserdataIndex(lhsTy) == kUserdataVec2 || typeToUserdataIndex(rhsTy) == kUserdataVec2)
return userdataIndexToType(kUserdataVec2);
break;
case Luau::CodeGen::HostMetamethod::Minus:
if (typeToUserdataIndex(lhsTy) == kUserdataVec2)
return userdataIndexToType(kUserdataVec2);
break;
default:
break;
}
return LBC_TYPE_ANY;
}
inline bool userdataMetamethod(Luau::CodeGen::IrBuilder& build, uint8_t lhsTy, uint8_t rhsTy, int resultReg, Luau::CodeGen::IrOp lhs,
Luau::CodeGen::IrOp rhs, Luau::CodeGen::HostMetamethod method, int pcpos)
{
using namespace Luau::CodeGen;
switch (method)
{
case Luau::CodeGen::HostMetamethod::Add:
if (typeToUserdataIndex(lhsTy) == kUserdataVec2 && typeToUserdataIndex(rhsTy) == kUserdataVec2)
{
build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos));
build.loadAndCheckTag(rhs, LUA_TUSERDATA, build.vmExit(pcpos));
IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs);
build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, rhs);
build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp mx = build.inst(IrCmd::ADD_NUM, x1, x2);
IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp my = build.inst(IrCmd::ADD_NUM, y1, y2);
build.inst(IrCmd::CHECK_GC);
IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar);
build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA));
return true;
}
break;
case Luau::CodeGen::HostMetamethod::Mul:
if (typeToUserdataIndex(lhsTy) == kUserdataVec2 && typeToUserdataIndex(rhsTy) == kUserdataVec2)
{
build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos));
build.loadAndCheckTag(rhs, LUA_TUSERDATA, build.vmExit(pcpos));
IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs);
build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, rhs);
build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp mx = build.inst(IrCmd::MUL_NUM, x1, x2);
IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp my = build.inst(IrCmd::MUL_NUM, y1, y2);
build.inst(IrCmd::CHECK_GC);
IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar);
build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA));
return true;
}
break;
case Luau::CodeGen::HostMetamethod::Minus:
if (typeToUserdataIndex(lhsTy) == kUserdataVec2)
{
build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos));
IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs);
build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp x = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp y = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp mx = build.inst(IrCmd::UNM_NUM, x);
IrOp my = build.inst(IrCmd::UNM_NUM, y);
build.inst(IrCmd::CHECK_GC);
IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar);
build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA));
return true;
}
break;
default:
break;
}
return false;
}
inline uint8_t userdataNamecallBytecodeType(uint8_t type, const char* member, size_t memberLength)
{
switch (typeToUserdataIndex(type))
{
case kUserdataColor:
break;
case kUserdataVec2:
if (compareMemberName(member, memberLength, "Dot"))
return LBC_TYPE_NUMBER;
if (compareMemberName(member, memberLength, "Min"))
return userdataIndexToType(kUserdataVec2);
break;
case kUserdataMat3:
break;
}
return LBC_TYPE_ANY;
}
inline bool userdataNamecall(Luau::CodeGen::IrBuilder& build, uint8_t type, const char* member, size_t memberLength, int argResReg, int sourceReg,
int params, int results, int pcpos)
{
using namespace Luau::CodeGen;
switch (typeToUserdataIndex(type))
{
case kUserdataColor:
break;
case kUserdataVec2:
if (compareMemberName(member, memberLength, "Dot"))
{
IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg));
build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos));
build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TUSERDATA, build.vmExit(pcpos));
IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(argResReg + 2));
build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2);
IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2);
IrOp sum = build.inst(IrCmd::ADD_NUM, xx, yy);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(argResReg), sum);
build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TNUMBER));
// If the function is called in multi-return context, stack has to be adjusted
if (results == LUA_MULTRET)
build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1));
return true;
}
if (compareMemberName(member, memberLength, "Min"))
{
IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg));
build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos));
build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TUSERDATA, build.vmExit(pcpos));
IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(argResReg + 2));
build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp mx = build.inst(IrCmd::MIN_NUM, x1, x2);
IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp my = build.inst(IrCmd::MIN_NUM, y1, y2);
build.inst(IrCmd::CHECK_GC);
IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::STORE_POINTER, build.vmReg(argResReg), udatar);
build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TUSERDATA));
// If the function is called in multi-return context, stack has to be adjusted
if (results == LUA_MULTRET)
build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1));
return true;
}
break;
case kUserdataMat3:
break;
}
return false;
}

View file

@ -24,6 +24,8 @@ LUAU_FASTFLAG(LuauCompileTempTypeInfo)
LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps)
LUAU_FASTFLAG(LuauCompileUserdataInfo)
LUAU_FASTFLAG(LuauLoadUserdataInfo)
LUAU_FASTFLAG(LuauCodegenUserdataOps)
LUAU_FASTFLAG(LuauCodegenUserdataAlloc)
static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1)
{
@ -34,6 +36,13 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes =
options.compilationOptions.hooks.vectorAccess = vectorAccess;
options.compilationOptions.hooks.vectorNamecall = vectorNamecall;
options.compilationOptions.hooks.userdataAccessBytecodeType = userdataAccessBytecodeType;
options.compilationOptions.hooks.userdataMetamethodBytecodeType = userdataMetamethodBytecodeType;
options.compilationOptions.hooks.userdataNamecallBytecodeType = userdataNamecallBytecodeType;
options.compilationOptions.hooks.userdataAccess = userdataAccess;
options.compilationOptions.hooks.userdataMetamethod = userdataMetamethod;
options.compilationOptions.hooks.userdataNamecall = userdataNamecall;
// For IR, we don't care about assembly, but we want a stable target
options.target = Luau::CodeGen::AssemblyOptions::Target::X64_SystemV;
@ -1690,4 +1699,352 @@ end
)");
}
TEST_CASE("CustomUserdataPropertyAccess")
{
// This test requires runtime component to be present
if (!Luau::CodeGen::isSupported())
return;
ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true},
{FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true},
{FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(v: vec2)
return v.X + v.Y
end
)",
/* includeIrTypes */ true),
R"(
; function foo($arg0) line 2
; R0: vec2 [argument]
bb_0:
CHECK_TAG R0, tuserdata, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
%6 = LOAD_POINTER R0
CHECK_USERDATA_TAG %6, 12i, exit(0)
%8 = BUFFER_READF32 %6, 0i, tuserdata
%15 = BUFFER_READF32 %6, 4i, tuserdata
%24 = ADD_NUM %8, %15
STORE_DOUBLE R1, %24
STORE_TAG R1, tnumber
INTERRUPT 5u
RETURN R1, 1i
)");
}
TEST_CASE("CustomUserdataPropertyAccess2")
{
// This test requires runtime component to be present
if (!Luau::CodeGen::isSupported())
return;
ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true},
{FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true},
{FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(a: mat3)
return a.Row1 * a.Row2
end
)",
/* includeIrTypes */ true),
R"(
; function foo($arg0) line 2
; R0: mat3 [argument]
bb_0:
CHECK_TAG R0, tuserdata, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
FALLBACK_GETTABLEKS 0u, R2, R0, K0
FALLBACK_GETTABLEKS 2u, R3, R0, K1
CHECK_TAG R2, tvector, exit(4)
CHECK_TAG R3, tvector, exit(4)
%14 = LOAD_TVALUE R2
%15 = LOAD_TVALUE R3
%16 = MUL_VEC %14, %15
%17 = TAG_VECTOR %16
STORE_TVALUE R1, %17
INTERRUPT 5u
RETURN R1, 1i
)");
}
TEST_CASE("CustomUserdataNamecall1")
{
// This test requires runtime component to be present
if (!Luau::CodeGen::isSupported())
return;
ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true},
{FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true},
{FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true},
{FFlag::LuauCodegenUserdataOps, true}};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(a: vec2, b: vec2)
return a:Dot(b)
end
)",
/* includeIrTypes */ true),
R"(
; function foo($arg0, $arg1) line 2
; R0: vec2 [argument]
; R1: vec2 [argument]
bb_0:
CHECK_TAG R0, tuserdata, exit(entry)
CHECK_TAG R1, tuserdata, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
%6 = LOAD_TVALUE R1
STORE_TVALUE R4, %6
%10 = LOAD_POINTER R0
CHECK_USERDATA_TAG %10, 12i, exit(1)
%14 = LOAD_POINTER R4
CHECK_USERDATA_TAG %14, 12i, exit(1)
%16 = BUFFER_READF32 %10, 0i, tuserdata
%17 = BUFFER_READF32 %14, 0i, tuserdata
%18 = MUL_NUM %16, %17
%19 = BUFFER_READF32 %10, 4i, tuserdata
%20 = BUFFER_READF32 %14, 4i, tuserdata
%21 = MUL_NUM %19, %20
%22 = ADD_NUM %18, %21
STORE_DOUBLE R2, %22
STORE_TAG R2, tnumber
ADJUST_STACK_TO_REG R2, 1i
INTERRUPT 4u
RETURN R2, -1i
)");
}
TEST_CASE("CustomUserdataNamecall2")
{
// This test requires runtime component to be present
if (!Luau::CodeGen::isSupported())
return;
ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true},
{FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true},
{FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true},
{FFlag::LuauCodegenUserdataOps, true}, {FFlag::LuauCodegenUserdataAlloc, true}};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(a: vec2, b: vec2)
return a:Min(b)
end
)",
/* includeIrTypes */ true),
R"(
; function foo($arg0, $arg1) line 2
; R0: vec2 [argument]
; R1: vec2 [argument]
bb_0:
CHECK_TAG R0, tuserdata, exit(entry)
CHECK_TAG R1, tuserdata, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
%6 = LOAD_TVALUE R1
STORE_TVALUE R4, %6
%10 = LOAD_POINTER R0
CHECK_USERDATA_TAG %10, 12i, exit(1)
%14 = LOAD_POINTER R4
CHECK_USERDATA_TAG %14, 12i, exit(1)
%16 = BUFFER_READF32 %10, 0i, tuserdata
%17 = BUFFER_READF32 %14, 0i, tuserdata
%18 = MIN_NUM %16, %17
%19 = BUFFER_READF32 %10, 4i, tuserdata
%20 = BUFFER_READF32 %14, 4i, tuserdata
%21 = MIN_NUM %19, %20
CHECK_GC
%23 = NEW_USERDATA 8i, 12i
BUFFER_WRITEF32 %23, 0i, %18, tuserdata
BUFFER_WRITEF32 %23, 4i, %21, tuserdata
STORE_POINTER R2, %23
STORE_TAG R2, tuserdata
ADJUST_STACK_TO_REG R2, 1i
INTERRUPT 4u
RETURN R2, -1i
)");
}
TEST_CASE("CustomUserdataMetamethodDirectFlow")
{
// This test requires runtime component to be present
if (!Luau::CodeGen::isSupported())
return;
ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true},
{FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true},
{FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(a: mat3, b: mat3)
return a * b
end
)",
/* includeIrTypes */ true),
R"(
; function foo($arg0, $arg1) line 2
; R0: mat3 [argument]
; R1: mat3 [argument]
bb_0:
CHECK_TAG R0, tuserdata, exit(entry)
CHECK_TAG R1, tuserdata, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
SET_SAVEDPC 1u
DO_ARITH R2, R0, R1, 10i
INTERRUPT 1u
RETURN R2, 1i
)");
}
TEST_CASE("CustomUserdataMetamethodDirectFlow2")
{
// This test requires runtime component to be present
if (!Luau::CodeGen::isSupported())
return;
ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true},
{FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true},
{FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(a: mat3)
return -a
end
)",
/* includeIrTypes */ true),
R"(
; function foo($arg0) line 2
; R0: mat3 [argument]
bb_0:
CHECK_TAG R0, tuserdata, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
SET_SAVEDPC 1u
DO_ARITH R1, R0, R0, 15i
INTERRUPT 1u
RETURN R1, 1i
)");
}
TEST_CASE("CustomUserdataMetamethodDirectFlow3")
{
// This test requires runtime component to be present
if (!Luau::CodeGen::isSupported())
return;
ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true},
{FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true},
{FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(a: sequence)
return #a
end
)",
/* includeIrTypes */ true),
R"(
; function foo($arg0) line 2
; R0: userdata [argument]
bb_0:
CHECK_TAG R0, tuserdata, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
SET_SAVEDPC 1u
DO_LEN R1, R0
INTERRUPT 1u
RETURN R1, 1i
)");
}
TEST_CASE("CustomUserdataMetamethod")
{
// This test requires runtime component to be present
if (!Luau::CodeGen::isSupported())
return;
ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true},
{FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true},
{FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true},
{FFlag::LuauCodegenUserdataAlloc, true}};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(a: vec2, b: vec2, c: vec2)
return -c + a * b
end
)",
/* includeIrTypes */ true),
R"(
; function foo($arg0, $arg1, $arg2) line 2
; R0: vec2 [argument]
; R1: vec2 [argument]
; R2: vec2 [argument]
bb_0:
CHECK_TAG R0, tuserdata, exit(entry)
CHECK_TAG R1, tuserdata, exit(entry)
CHECK_TAG R2, tuserdata, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
%10 = LOAD_POINTER R2
CHECK_USERDATA_TAG %10, 12i, exit(0)
%12 = BUFFER_READF32 %10, 0i, tuserdata
%13 = BUFFER_READF32 %10, 4i, tuserdata
%14 = UNM_NUM %12
%15 = UNM_NUM %13
CHECK_GC
%17 = NEW_USERDATA 8i, 12i
BUFFER_WRITEF32 %17, 0i, %14, tuserdata
BUFFER_WRITEF32 %17, 4i, %15, tuserdata
STORE_POINTER R4, %17
STORE_TAG R4, tuserdata
%26 = LOAD_POINTER R0
CHECK_USERDATA_TAG %26, 12i, exit(1)
%28 = LOAD_POINTER R1
CHECK_USERDATA_TAG %28, 12i, exit(1)
%30 = BUFFER_READF32 %26, 0i, tuserdata
%31 = BUFFER_READF32 %28, 0i, tuserdata
%32 = MUL_NUM %30, %31
%33 = BUFFER_READF32 %26, 4i, tuserdata
%34 = BUFFER_READF32 %28, 4i, tuserdata
%35 = MUL_NUM %33, %34
%37 = NEW_USERDATA 8i, 12i
BUFFER_WRITEF32 %37, 0i, %32, tuserdata
BUFFER_WRITEF32 %37, 4i, %35, tuserdata
STORE_POINTER R5, %37
STORE_TAG R5, tuserdata
%50 = BUFFER_READF32 %17, 0i, tuserdata
%51 = BUFFER_READF32 %37, 0i, tuserdata
%52 = ADD_NUM %50, %51
%53 = BUFFER_READF32 %17, 4i, tuserdata
%54 = BUFFER_READF32 %37, 4i, tuserdata
%55 = ADD_NUM %53, %54
%57 = NEW_USERDATA 8i, 12i
BUFFER_WRITEF32 %57, 0i, %52, tuserdata
BUFFER_WRITEF32 %57, 4i, %55, tuserdata
STORE_POINTER R3, %57
STORE_TAG R3, tuserdata
INTERRUPT 3u
RETURN R3, 1i
)");
}
TEST_SUITE_END();

View file

@ -192,13 +192,13 @@ TEST_CASE("string_interpolation_double_brace")
auto brokenInterpBegin = lexer.next();
CHECK_EQ(brokenInterpBegin.type, Lexeme::BrokenInterpDoubleBrace);
CHECK_EQ(std::string(brokenInterpBegin.data, brokenInterpBegin.length), std::string("foo"));
CHECK_EQ(std::string(brokenInterpBegin.data, brokenInterpBegin.getLength()), std::string("foo"));
CHECK_EQ(lexer.next().type, Lexeme::Name);
auto interpEnd = lexer.next();
CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd);
CHECK_EQ(std::string(interpEnd.data, interpEnd.length), std::string("}bar"));
CHECK_EQ(std::string(interpEnd.data, interpEnd.getLength()), std::string("}bar"));
}
TEST_CASE("string_interpolation_double_but_unmatched_brace")

View file

@ -15,6 +15,8 @@
using namespace Luau;
LUAU_FASTFLAG(LuauAttributeSyntax);
#define NONSTRICT_REQUIRE_ERR_AT_POS(pos, result, idx) \
do \
{ \
@ -68,6 +70,7 @@ struct NonStrictTypeCheckerFixture : Fixture
{
ScopedFastFlag flags[] = {
{FFlag::DebugLuauDeferredConstraintResolution, true},
{FFlag::LuauAttributeSyntax, true},
};
LoadDefinitionFileResult res = loadDefinition(definitions);
LUAU_ASSERT(res.success);
@ -78,6 +81,7 @@ struct NonStrictTypeCheckerFixture : Fixture
{
ScopedFastFlag flags[] = {
{FFlag::DebugLuauDeferredConstraintResolution, true},
{FFlag::LuauAttributeSyntax, true},
};
LoadDefinitionFileResult res = loadDefinition(definitions);
LUAU_ASSERT(res.success);
@ -85,21 +89,21 @@ struct NonStrictTypeCheckerFixture : Fixture
}
std::string definitions = R"BUILTIN_SRC(
declare function @checked abs(n: number): number
declare function @checked lower(s: string): string
@checked declare function abs(n: number): number
@checked declare function lower(s: string): string
declare function cond() : boolean
declare function @checked contrived(n : Not<number>) : number
@checked declare function contrived(n : Not<number>) : number
-- interesting types of things that we would like to mark as checked
declare function @checked onlyNums(...: number) : number
declare function @checked mixedArgs(x: string, ...: number) : number
declare function @checked optionalArg(x: string?) : number
@checked declare function onlyNums(...: number) : number
@checked declare function mixedArgs(x: string, ...: number) : number
@checked declare function optionalArg(x: string?) : number
declare foo: {
bar: @checked (number) -> number,
}
declare function @checked optionalArgsAtTheEnd1(x: string, y: number?, z: number?) : number
declare function @checked optionalArgsAtTheEnd2(x: string, y: number?, z: string) : number
@checked declare function optionalArgsAtTheEnd1(x: string, y: number?, z: number?) : number
@checked declare function optionalArgsAtTheEnd2(x: string, y: number?, z: string) : number
type DateTypeArg = {
year: number,
@ -115,7 +119,7 @@ declare os : {
time: @checked (time: DateTypeArg?) -> number
}
declare function @checked require(target : any) : any
@checked declare function require(target : any) : any
)BUILTIN_SRC";
};
@ -558,6 +562,10 @@ local E = require(script.Parent.A)
TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "nonstrict_shouldnt_warn_on_valid_buffer_use")
{
ScopedFastFlag flags[] = {
{FFlag::LuauAttributeSyntax, true},
};
loadDefinition(R"(
declare buffer: {
create: @checked (size: number) -> buffer,

View file

@ -16,6 +16,7 @@ LUAU_FASTINT(LuauRecursionLimit);
LUAU_FASTINT(LuauTypeLengthLimit);
LUAU_FASTINT(LuauParseErrorLimit);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAG(LuauAttributeSyntax);
LUAU_FASTFLAG(LuauLeadingBarAndAmpersand);
namespace
@ -3051,9 +3052,10 @@ TEST_CASE_FIXTURE(Fixture, "parse_top_level_checked_fn")
{
ParseOptions opts;
opts.allowDeclarationSyntax = true;
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
std::string src = R"BUILTIN_SRC(
declare function @checked abs(n: number): number
@checked declare function abs(n: number): number
)BUILTIN_SRC";
ParseResult pr = tryParse(src, opts);
@ -3063,13 +3065,14 @@ declare function @checked abs(n: number): number
AstStat* root = *(pr.root->body.data);
auto func = root->as<AstStatDeclareFunction>();
LUAU_ASSERT(func);
LUAU_ASSERT(func->checkedFunction);
LUAU_ASSERT(func->isCheckedFunction());
}
TEST_CASE_FIXTURE(Fixture, "parse_declared_table_checked_member")
{
ParseOptions opts;
opts.allowDeclarationSyntax = true;
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
const std::string src = R"BUILTIN_SRC(
declare math : {
@ -3090,13 +3093,14 @@ TEST_CASE_FIXTURE(Fixture, "parse_declared_table_checked_member")
auto prop = *tbl->props.data;
auto func = prop.type->as<AstTypeFunction>();
LUAU_ASSERT(func);
LUAU_ASSERT(func->checkedFunction);
LUAU_ASSERT(func->isCheckedFunction());
}
TEST_CASE_FIXTURE(Fixture, "parse_checked_outside_decl_fails")
{
ParseOptions opts;
opts.allowDeclarationSyntax = true;
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
ParseResult pr = tryParse(R"(
local @checked = 3
@ -3110,10 +3114,11 @@ TEST_CASE_FIXTURE(Fixture, "parse_checked_in_and_out_of_decl_fails")
{
ParseOptions opts;
opts.allowDeclarationSyntax = true;
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
auto pr = tryParse(R"(
local @checked = 3
declare function @checked abs(n: number): number
@checked declare function abs(n: number): number
)",
opts);
LUAU_ASSERT(pr.errors.size() == 2);
@ -3125,9 +3130,10 @@ TEST_CASE_FIXTURE(Fixture, "parse_checked_as_function_name_fails")
{
ParseOptions opts;
opts.allowDeclarationSyntax = true;
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
auto pr = tryParse(R"(
function @checked(x: number) : number
@checked function(x: number) : number
end
)",
opts);
@ -3138,6 +3144,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_use_@_as_variable_name")
{
ParseOptions opts;
opts.allowDeclarationSyntax = true;
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
auto pr = tryParse(R"(
local @blah = 3
@ -3190,4 +3197,300 @@ TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed")
matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses.");
}
void checkAttribute(const AstAttr* attr, const AstAttr::Type type, const Location& location)
{
CHECK_EQ(attr->type, type);
CHECK_EQ(attr->location, location);
}
void checkFirstErrorForAttributes(const std::vector<ParseError>& errors, const size_t minSize, const Location& location, const std::string& message)
{
LUAU_ASSERT(minSize >= 1);
CHECK_GE(errors.size(), minSize);
CHECK_EQ(errors[0].getLocation(), location);
CHECK_EQ(errors[0].getMessage(), message);
}
TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_function_stat")
{
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
AstStatBlock* stat = parse(R"(
@checked
function hello(x, y)
return x + y
end)");
LUAU_ASSERT(stat != nullptr);
AstStatFunction* statFun = stat->body.data[0]->as<AstStatFunction>();
LUAU_ASSERT(statFun != nullptr);
AstArray<AstAttr*> attributes = statFun->func->attributes;
CHECK_EQ(attributes.size, 1);
checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 0), Position(1, 8)));
}
TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_local_function_stat")
{
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
AstStatBlock* stat = parse(R"(
@checked
local function hello(x, y)
return x + y
end)");
LUAU_ASSERT(stat != nullptr);
AstStatLocalFunction* statFun = stat->body.data[0]->as<AstStatLocalFunction>();
LUAU_ASSERT(statFun != nullptr);
AstArray<AstAttr*> attributes = statFun->func->attributes;
CHECK_EQ(attributes.size, 1);
checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 4), Position(1, 12)));
}
TEST_CASE_FIXTURE(Fixture, "empty_attribute_name_is_not_allowed")
{
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
ParseResult result = tryParse(R"(
@
function hello(x, y)
return x + y
end)");
checkFirstErrorForAttributes(result.errors, 1, Location(Position(1, 0), Position(1, 1)), "Attribute name is missing");
}
TEST_CASE_FIXTURE(Fixture, "dont_parse_attributes_on_non_function_stat")
{
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
ParseResult pr1 = tryParse(R"(
@checked
if a<0 then a = 0 end)");
checkFirstErrorForAttributes(pr1.errors, 1, Location(Position(2, 0), Position(2, 2)),
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'if' intead");
ParseResult pr2 = tryParse(R"(
local i = 1
@checked
while a[i] do
print(a[i])
i = i + 1
end)");
checkFirstErrorForAttributes(pr2.errors, 1, Location(Position(3, 0), Position(3, 5)),
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'while' intead");
ParseResult pr3 = tryParse(R"(
@checked
do
local a2 = 2*a
local d = sqrt(b^2 - 4*a*c)
x1 = (-b + d)/a2
x2 = (-b - d)/a2
end)");
checkFirstErrorForAttributes(pr3.errors, 1, Location(Position(2, 0), Position(2, 2)),
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'do' intead");
ParseResult pr4 = tryParse(R"(
@checked
for i=1,10 do print(i) end
)");
checkFirstErrorForAttributes(pr4.errors, 1, Location(Position(2, 0), Position(2, 3)),
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'for' intead");
ParseResult pr5 = tryParse(R"(
@checked
repeat
line = io.read()
until line ~= ""
)");
checkFirstErrorForAttributes(pr5.errors, 1, Location(Position(2, 0), Position(2, 6)),
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'repeat' intead");
ParseResult pr6 = tryParse(R"(
@checked
local x = 10
)");
checkFirstErrorForAttributes(
pr6.errors, 1, Location(Position(2, 6), Position(2, 7)), "Expected 'function' after local declaration with attribute, but got 'x' intead");
ParseResult pr7 = tryParse(R"(
local i = 1
while a[i] do
if a[i] == v then @checked break end
i = i + 1
end
)");
checkFirstErrorForAttributes(pr7.errors, 1, Location(Position(3, 31), Position(3, 36)),
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'break' intead");
ParseResult pr8 = tryParse(R"(
function foo1 () @checked return 'a' end
)");
checkFirstErrorForAttributes(pr8.errors, 1, Location(Position(1, 26), Position(1, 32)),
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'return' intead");
}
TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_function_type_declaration")
{
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
ParseOptions opts;
opts.allowDeclarationSyntax = true;
std::string src = R"(
@checked declare function abs(n: number): number
)";
ParseResult pr = tryParse(src, opts);
CHECK_EQ(pr.errors.size(), 0);
LUAU_ASSERT(pr.root->body.size == 1);
AstStat* root = *(pr.root->body.data);
auto func = root->as<AstStatDeclareFunction>();
LUAU_ASSERT(func != nullptr);
CHECK(func->isCheckedFunction());
AstArray<AstAttr*> attributes = func->attributes;
checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 0), Position(1, 8)));
}
TEST_CASE_FIXTURE(Fixture, "parse_attributes_on_function_type_declaration_in_table")
{
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
ParseOptions opts;
opts.allowDeclarationSyntax = true;
std::string src = R"(
declare bit32: {
band: @checked (...number) -> number
})";
ParseResult pr = tryParse(src, opts);
CHECK_EQ(pr.errors.size(), 0);
LUAU_ASSERT(pr.root->body.size == 1);
AstStat* root = *(pr.root->body.data);
AstStatDeclareGlobal* glob = root->as<AstStatDeclareGlobal>();
LUAU_ASSERT(glob);
auto tbl = glob->type->as<AstTypeTable>();
LUAU_ASSERT(tbl);
LUAU_ASSERT(tbl->props.size == 1);
AstTableProp prop = tbl->props.data[0];
AstTypeFunction* func = prop.type->as<AstTypeFunction>();
LUAU_ASSERT(func);
AstArray<AstAttr*> attributes = func->attributes;
CHECK_EQ(attributes.size, 1);
checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(2, 10), Position(2, 18)));
}
TEST_CASE_FIXTURE(Fixture, "dont_parse_attributes_on_non_function_type_declarations")
{
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
ParseOptions opts;
opts.allowDeclarationSyntax = true;
ParseResult pr1 = tryParse(R"(
@checked declare foo: number
)",
opts);
checkFirstErrorForAttributes(
pr1.errors, 1, Location(Position(1, 17), Position(1, 20)), "Expected a function type declaration after attribute, but got 'foo' intead");
ParseResult pr2 = tryParse(R"(
@checked declare class Foo
prop: number
function method(self, foo: number): string
end)",
opts);
checkFirstErrorForAttributes(
pr2.errors, 1, Location(Position(1, 17), Position(1, 22)), "Expected a function type declaration after attribute, but got 'class' intead");
ParseResult pr3 = tryParse(R"(
declare bit32: {
band: @checked number
})",
opts);
checkFirstErrorForAttributes(
pr3.errors, 1, Location(Position(2, 19), Position(2, 25)), "Expected '(' when parsing function parameters, got 'number'");
}
TEST_CASE_FIXTURE(Fixture, "attributes_cannot_be_duplicated")
{
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
ParseResult result = tryParse(R"(
@checked
@checked
function hello(x, y)
return x + y
end)");
checkFirstErrorForAttributes(result.errors, 1, Location(Position(2, 4), Position(2, 12)), "Cannot duplicate attribute '@checked'");
}
TEST_CASE_FIXTURE(Fixture, "unsupported_attributes_are_not_allowed")
{
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
ParseResult result = tryParse(R"(
@checked
@cool_attribute
function hello(x, y)
return x + y
end)");
checkFirstErrorForAttributes(result.errors, 1, Location(Position(2, 4), Position(2, 19)), "Invalid attribute '@cool_attribute'");
}
TEST_CASE_FIXTURE(Fixture, "can_parse_leading_bar_unions_successfully")
{
ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true};
parse(R"(type A = | "Hello" | "World")");
}
TEST_CASE_FIXTURE(Fixture, "can_parse_leading_ampersand_intersections_successfully")
{
ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true};
parse(R"(type A = & { string } & { number })");
}
TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed")
{
ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true};
matchParseError("type A = & number | string | boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses.");
matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses.");
}
TEST_SUITE_END();

View file

@ -13,6 +13,7 @@ using namespace Luau;
LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAG(DebugLuauSharedSelf);
LUAU_FASTFLAG(LuauAttributeSyntax);
TEST_SUITE_BEGIN("ToString");
@ -1010,10 +1011,11 @@ TEST_CASE_FIXTURE(Fixture, "checked_fn_toString")
{
ScopedFastFlag flags[] = {
{FFlag::DebugLuauDeferredConstraintResolution, true},
{FFlag::LuauAttributeSyntax, true},
};
auto _result = loadDefinition(R"(
declare function @checked abs(n: number) : number
@checked declare function abs(n: number) : number
)");
auto result = check(Mode::Nonstrict, R"(

View file

@ -701,7 +701,7 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_basic")
REQUIRE(ut);
REQUIRE(ut->options.size() == 2);
CHECK_EQ(builtinTypes->nilType, ut->options[0]);
CHECK_EQ(builtinTypes->nilType, follow(ut->options[0]));
CHECK_EQ(*builtinTypes->numberType, *ut->options[1]);
}
else
@ -1179,4 +1179,14 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_preserves_error_suppression")
CHECK("any" == toString(requireTypeAtPosition({3, 25})));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "tryDispatchIterableFunction_under_constrained_loop_should_not_assert")
{
CheckResult result = check(R"(
local function foo(Instance)
for _, Child in next, Instance:GetChildren() do
end
end
)");
}
TEST_SUITE_END();

View file

@ -3153,7 +3153,7 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch")
LUAU_REQUIRE_ERROR_COUNT(1, result);
if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ("Value of type '{ x: number? }?' could be nil", toString(result.errors[0]));
CHECK_EQ("Type 'nil' does not have key 'x'", toString(result.errors[0]));
else
CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0]));
CHECK_EQ("boolean", toString(requireType("u")));
@ -4439,7 +4439,13 @@ TEST_CASE_FIXTURE(Fixture, "insert_a_and_f_of_a_into_table_res_in_a_loop")
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::DebugLuauDeferredConstraintResolution)
{
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK(get<FunctionExitsWithoutReturning>(result.errors[0]));
}
else
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_adds_an_unbounded_indexer")

View file

@ -0,0 +1,42 @@
-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
print('testing userdata')
function ecall(fn, ...)
local ok, err = pcall(fn, ...)
assert(not ok)
return err:sub((err:find(": ") or -1) + 2, #err)
end
local function realmad(a: vec2, b: vec2, c: vec2): vec2
return -c + a * b;
end
local function dm(s: vec2, t: vec2, u: vec2)
local x = s:Dot(t)
assert(x == 13)
local t = u:Min(s)
assert(t.X == 5)
assert(t.Y == 4)
end
local s: vec2 = vec2(5, 4)
local t: vec2 = vec2(1, 2)
local u: vec2 = vec2(10, 20)
local x: vec2 = realmad(s, t, u)
assert(x.X == -5)
assert(x.Y == -12)
dm(s, t, u)
local function mu(v: vec2)
assert(v.Magnitude == 2)
assert(v.Unit.X == 0)
assert(v.Unit.Y == 1)
end
mu(vec2(0, 2))
return 'OK'

View file

@ -4,7 +4,6 @@ AutocompleteTest.anonymous_autofilled_generic_type_pack_vararg
AutocompleteTest.autocomplete_string_singletons
AutocompleteTest.do_wrong_compatible_nonself_calls
AutocompleteTest.string_singleton_as_table_key
AutocompleteTest.string_singleton_in_if_statement2
AutocompleteTest.suggest_table_keys
AutocompleteTest.type_correct_suggestion_for_overloads
AutocompleteTest.type_correct_suggestion_in_table
@ -33,6 +32,15 @@ BuiltinTests.string_format_report_all_type_errors_at_correct_positions
BuiltinTests.string_format_use_correct_argument2
BuiltinTests.table_freeze_is_generic
BuiltinTests.tonumber_returns_optional_number_type
ControlFlowAnalysis.for_record_do_if_not_x_break
ControlFlowAnalysis.for_record_do_if_not_x_continue
ControlFlowAnalysis.if_not_x_break_elif_not_y_break
ControlFlowAnalysis.if_not_x_break_elif_not_y_continue
ControlFlowAnalysis.if_not_x_break_elif_rand_break_elif_not_y_break
ControlFlowAnalysis.if_not_x_continue_elif_not_y_continue
ControlFlowAnalysis.if_not_x_continue_elif_not_y_throw_elif_not_z_fallthrough
ControlFlowAnalysis.if_not_x_continue_elif_rand_continue_elif_not_y_continue
ControlFlowAnalysis.if_not_x_return_elif_not_y_break
DefinitionTests.class_definition_overload_metamethods
Differ.metatable_metamissing_left
Differ.metatable_metamissing_right
@ -46,7 +54,6 @@ FrontendTest.trace_requires_in_nonstrict_mode
GenericsTests.apply_type_function_nested_generics1
GenericsTests.better_mismatch_error_messages
GenericsTests.bound_tables_do_not_clone_original_fields
GenericsTests.correctly_instantiate_polymorphic_member_functions
GenericsTests.do_not_always_instantiate_generic_intersection_types
GenericsTests.do_not_infer_generic_functions
GenericsTests.dont_substitute_bound_types
@ -135,6 +142,7 @@ RefinementTest.discriminate_from_isa_of_x
RefinementTest.discriminate_from_truthiness_of_x
RefinementTest.globals_can_be_narrowed_too
RefinementTest.isa_type_refinement_must_be_known_ahead_of_time
RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true
RefinementTest.not_t_or_some_prop_of_t
RefinementTest.refine_a_param_that_got_resolved_during_constraint_solving_stage
RefinementTest.refine_a_property_of_some_global
@ -278,7 +286,9 @@ TypeInferAnyError.can_subscript_any
TypeInferAnyError.for_in_loop_iterator_is_any
TypeInferAnyError.for_in_loop_iterator_is_any2
TypeInferAnyError.for_in_loop_iterator_is_any_pack
TypeInferAnyError.for_in_loop_iterator_returns_any
TypeInferAnyError.for_in_loop_iterator_returns_any2
TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any
TypeInferClasses.callable_classes
TypeInferClasses.cannot_unify_class_instance_with_primitive
TypeInferClasses.class_type_mismatch_with_name_conflict
@ -337,6 +347,7 @@ TypeInferFunctions.too_many_arguments
TypeInferFunctions.too_many_arguments_error_location
TypeInferFunctions.too_many_return_values_in_parentheses
TypeInferFunctions.too_many_return_values_no_function
TypeInferFunctions.unifier_should_not_bind_free_types
TypeInferLoops.cli_68448_iterators_need_not_accept_nil
TypeInferLoops.dcr_iteration_on_never_gives_never
TypeInferLoops.dcr_xpath_candidates
@ -363,7 +374,6 @@ TypeInferModules.require
TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2
TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon
TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory
TypeInferOOP.methods_are_topologically_sorted
TypeInferOOP.promise_type_error_too_complex
TypeInferOperators.add_type_family_works
TypeInferOperators.cli_38355_recursive_union