Merge branch 'upstream' into merge

This commit is contained in:
Andy Friesen 2024-06-07 10:20:14 -07:00
commit 40e03164f7
68 changed files with 2712 additions and 687 deletions

View file

@ -57,7 +57,7 @@ struct GeneralizationConstraint
struct IterableConstraint struct IterableConstraint
{ {
TypePackId iterator; TypePackId iterator;
TypePackId variables; std::vector<TypeId> variables;
const AstNode* nextAstFragment; const AstNode* nextAstFragment;
DenseHashMap<const AstNode*, TypeId>* astForInNextTypes; DenseHashMap<const AstNode*, TypeId>* astForInNextTypes;
@ -192,13 +192,7 @@ struct HasIndexerConstraint
TypeId indexType; TypeId indexType;
}; };
struct AssignConstraint // assignProp lhsType propName rhsType
{
TypeId lhsType;
TypeId rhsType;
};
// assign lhsType propName rhsType
// //
// Assign a value of type rhsType into the named property of lhsType. // Assign a value of type rhsType into the named property of lhsType.
@ -212,6 +206,12 @@ struct AssignPropConstraint
/// populate astTypes during constraint resolution. Nothing should ever /// populate astTypes during constraint resolution. Nothing should ever
/// block on it. /// block on it.
TypeId propType; TypeId propType;
// When we generate constraints, we increment the remaining prop count on
// the table if we are able. This flag informs the solver as to whether or
// not it should in turn decrement the prop count when this constraint is
// dispatched.
bool decrementPropCount = false;
}; };
struct AssignIndexConstraint struct AssignIndexConstraint
@ -226,13 +226,13 @@ struct AssignIndexConstraint
TypeId propType; TypeId propType;
}; };
// resultType ~ unpack sourceTypePack // resultTypes ~ unpack sourceTypePack
// //
// Similar to PackSubtypeConstraint, but with one important difference: If the // Similar to PackSubtypeConstraint, but with one important difference: If the
// sourcePack is blocked, this constraint blocks. // sourcePack is blocked, this constraint blocks.
struct UnpackConstraint struct UnpackConstraint
{ {
TypePackId resultPack; std::vector<TypeId> resultPack;
TypePackId sourcePack; TypePackId sourcePack;
}; };
@ -254,7 +254,7 @@ struct ReducePackConstraint
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, IterableConstraint, NameConstraint, using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, IterableConstraint, NameConstraint,
TypeAliasExpansionConstraint, FunctionCallConstraint, FunctionCheckConstraint, PrimitiveTypeConstraint, HasPropConstraint, HasIndexerConstraint, TypeAliasExpansionConstraint, FunctionCallConstraint, FunctionCheckConstraint, PrimitiveTypeConstraint, HasPropConstraint, HasIndexerConstraint,
AssignConstraint, AssignPropConstraint, AssignIndexConstraint, UnpackConstraint, ReduceConstraint, ReducePackConstraint, EqualityConstraint>; AssignPropConstraint, AssignIndexConstraint, UnpackConstraint, ReduceConstraint, ReducePackConstraint, EqualityConstraint>;
struct Constraint struct Constraint
{ {

View file

@ -118,6 +118,8 @@ struct ConstraintGenerator
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope; std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope;
std::vector<RequireCycle> requireCycles; std::vector<RequireCycle> requireCycles;
DenseHashMap<TypeId, std::vector<TypeId>> localTypes{nullptr};
DcrLogger* logger; DcrLogger* logger;
ConstraintGenerator(ModulePtr module, NotNull<Normalizer> normalizer, NotNull<ModuleResolver> moduleResolver, NotNull<BuiltinTypes> builtinTypes, ConstraintGenerator(ModulePtr module, NotNull<Normalizer> normalizer, NotNull<ModuleResolver> moduleResolver, NotNull<BuiltinTypes> builtinTypes,
@ -354,6 +356,8 @@ private:
*/ */
void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program); void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program);
bool recordPropertyAssignment(TypeId ty);
// Record the fact that a particular local has a particular type in at least // Record the fact that a particular local has a particular type in at least
// one of its states. // one of its states.
void recordInferredBinding(AstLocal* local, TypeId ty); void recordInferredBinding(AstLocal* local, TypeId ty);

View file

@ -142,7 +142,6 @@ struct ConstraintSolver
std::pair<bool, std::optional<TypeId>> tryDispatchSetIndexer( std::pair<bool, std::optional<TypeId>> tryDispatchSetIndexer(
NotNull<const Constraint> constraint, TypeId subjectType, TypeId indexType, TypeId propType, bool expandFreeTypeBounds); NotNull<const Constraint> constraint, TypeId subjectType, TypeId indexType, TypeId propType, bool expandFreeTypeBounds);
bool tryDispatch(const AssignConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const AssignPropConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const AssignPropConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const AssignIndexConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const AssignIndexConstraint& c, NotNull<const Constraint> constraint);
@ -158,8 +157,7 @@ struct ConstraintSolver
bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
// for a, ... in next_function, t, ... do // for a, ... in next_function, t, ... do
bool tryDispatchIterableFunction( bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType, std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType,
const std::string& propName, ValueContext context, bool inConditional = false, bool suppressSimplification = false); const std::string& propName, ValueContext context, bool inConditional = false, bool suppressSimplification = false);
@ -168,14 +166,18 @@ struct ConstraintSolver
/** /**
* Generate constraints to unpack the types of srcTypes and assign each * Generate constraints to unpack the types of srcTypes and assign each
* value to the corresponding LocalType in destTypes. * value to the corresponding BlockedType in destTypes.
* *
* @param destTypes A finite TypePack comprised of LocalTypes. * This function also overwrites the owners of each BlockedType. This is
* okay because this function is only used to decompose IterableConstraint
* into an UnpackConstraint.
*
* @param destTypes A vector of types comprised of BlockedTypes.
* @param srcTypes A TypePack that represents rvalues to be assigned. * @param srcTypes A TypePack that represents rvalues to be assigned.
* @returns The underlying UnpackConstraint. There's a bit of code in * @returns The underlying UnpackConstraint. There's a bit of code in
* iteration that needs to pass blocks on to this constraint. * iteration that needs to pass blocks on to this constraint.
*/ */
NotNull<const Constraint> unpackAndAssign(TypePackId destTypes, TypePackId srcTypes, NotNull<const Constraint> constraint); NotNull<const Constraint> unpackAndAssign(const std::vector<TypeId> destTypes, TypePackId srcTypes, NotNull<const Constraint> constraint);
void block(NotNull<const Constraint> target, NotNull<const Constraint> constraint); void block(NotNull<const Constraint> target, NotNull<const Constraint> constraint);
/** /**

View file

@ -86,24 +86,6 @@ struct FreeType
TypeId upperBound = nullptr; TypeId upperBound = nullptr;
}; };
/** A type that tracks the domain of a local variable.
*
* We consider each local's domain to be the union of all types assigned to it.
* We accomplish this with LocalType. Each time we dispatch an assignment to a
* local, we accumulate this union and decrement blockCount.
*
* When blockCount reaches 0, we can consider the LocalType to be "fully baked"
* and replace it with the union we've built.
*/
struct LocalType
{
TypeId domain;
int blockCount = 0;
// Used for debugging
std::string name;
};
struct GenericType struct GenericType
{ {
// By default, generics are global, with a synthetic name // By default, generics are global, with a synthetic name
@ -148,6 +130,7 @@ struct BlockedType
Constraint* getOwner() const; Constraint* getOwner() const;
void setOwner(Constraint* newOwner); void setOwner(Constraint* newOwner);
void replaceOwner(Constraint* newOwner);
private: private:
// The constraint that is intended to unblock this type. Other constraints // The constraint that is intended to unblock this type. Other constraints
@ -471,6 +454,11 @@ struct TableType
// Methods of this table that have an untyped self will use the same shared self type. // Methods of this table that have an untyped self will use the same shared self type.
std::optional<TypeId> selfTy; std::optional<TypeId> selfTy;
// We track the number of as-yet-unadded properties to unsealed tables.
// Some constraints will use this information to decide whether or not they
// are able to dispatch.
size_t remainingProps = 0;
}; };
// Represents a metatable attached to a table type. Somewhat analogous to a bound type. // Represents a metatable attached to a table type. Somewhat analogous to a bound type.
@ -669,9 +657,9 @@ struct NegationType
using ErrorType = Unifiable::Error; using ErrorType = Unifiable::Error;
using TypeVariant = Unifiable::Variant<TypeId, FreeType, LocalType, GenericType, PrimitiveType, BlockedType, PendingExpansionType, SingletonType, using TypeVariant =
FunctionType, TableType, MetatableType, ClassType, AnyType, UnionType, IntersectionType, LazyType, UnknownType, NeverType, NegationType, Unifiable::Variant<TypeId, FreeType, GenericType, PrimitiveType, BlockedType, PendingExpansionType, SingletonType, FunctionType, TableType,
TypeFamilyInstanceType>; MetatableType, ClassType, AnyType, UnionType, IntersectionType, LazyType, UnknownType, NeverType, NegationType, TypeFamilyInstanceType>;
struct Type final struct Type final
{ {

View file

@ -69,7 +69,6 @@ struct Unifier2
*/ */
bool unify(TypeId subTy, TypeId superTy); bool unify(TypeId subTy, TypeId superTy);
bool unifyFreeWithType(TypeId subTy, TypeId superTy); bool unifyFreeWithType(TypeId subTy, TypeId superTy);
bool unify(const LocalType* subTy, TypeId superFn);
bool unify(TypeId subTy, const FunctionType* superFn); bool unify(TypeId subTy, const FunctionType* superFn);
bool unify(const UnionType* subUnion, TypeId superTy); bool unify(const UnionType* subUnion, TypeId superTy);
bool unify(TypeId subTy, const UnionType* superUnion); bool unify(TypeId subTy, const UnionType* superUnion);

View file

@ -100,10 +100,6 @@ struct GenericTypeVisitor
{ {
return visit(ty); return visit(ty);
} }
virtual bool visit(TypeId ty, const LocalType& ftv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const GenericType& gtv) virtual bool visit(TypeId ty, const GenericType& gtv)
{ {
return visit(ty); return visit(ty);
@ -248,11 +244,6 @@ struct GenericTypeVisitor
else else
visit(ty, *ftv); visit(ty, *ftv);
} }
else if (auto lt = get<LocalType>(ty))
{
if (visit(ty, *lt))
traverse(lt->domain);
}
else if (auto gtv = get<GenericType>(ty)) else if (auto gtv = get<GenericType>(ty))
visit(ty, *gtv); visit(ty, *gtv);
else if (auto etv = get<ErrorType>(ty)) else if (auto etv = get<ErrorType>(ty))

View file

@ -271,11 +271,6 @@ private:
t->upperBound = shallowClone(t->upperBound); t->upperBound = shallowClone(t->upperBound);
} }
void cloneChildren(LocalType* t)
{
t->domain = shallowClone(t->domain);
}
void cloneChildren(GenericType* t) void cloneChildren(GenericType* t)
{ {
// TOOD: clone upper bounds. // TOOD: clone upper bounds.

View file

@ -81,7 +81,8 @@ DenseHashSet<TypeId> Constraint::getMaybeMutatedFreeTypes() const
} }
else if (auto itc = get<IterableConstraint>(*this)) else if (auto itc = get<IterableConstraint>(*this))
{ {
rci.traverse(itc->variables); for (TypeId ty : itc->variables)
rci.traverse(ty);
// `IterableConstraints` should not mutate `iterator`. // `IterableConstraints` should not mutate `iterator`.
} }
else if (auto nc = get<NameConstraint>(*this)) else if (auto nc = get<NameConstraint>(*this))
@ -106,11 +107,6 @@ DenseHashSet<TypeId> Constraint::getMaybeMutatedFreeTypes() const
rci.traverse(hic->resultType); rci.traverse(hic->resultType);
// `HasIndexerConstraint` should not mutate `subjectType` or `indexType`. // `HasIndexerConstraint` should not mutate `subjectType` or `indexType`.
} }
else if (auto ac = get<AssignConstraint>(*this))
{
rci.traverse(ac->lhsType);
rci.traverse(ac->rhsType);
}
else if (auto apc = get<AssignPropConstraint>(*this)) else if (auto apc = get<AssignPropConstraint>(*this))
{ {
rci.traverse(apc->lhsType); rci.traverse(apc->lhsType);
@ -124,7 +120,8 @@ DenseHashSet<TypeId> Constraint::getMaybeMutatedFreeTypes() const
} }
else if (auto uc = get<UnpackConstraint>(*this)) else if (auto uc = get<UnpackConstraint>(*this))
{ {
rci.traverse(uc->resultPack); for (TypeId ty : uc->resultPack)
rci.traverse(ty);
// `UnpackConstraint` should not mutate `sourcePack`. // `UnpackConstraint` should not mutate `sourcePack`.
} }
else if (auto rpc = get<ReducePackConstraint>(*this)) else if (auto rpc = get<ReducePackConstraint>(*this))

View file

@ -28,6 +28,7 @@
LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTINT(LuauCheckRecursionLimit);
LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauLogSolverToJson);
LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(DebugLuauMagicTypes);
LUAU_FASTFLAG(LuauAttributeSyntax);
namespace Luau namespace Luau
{ {
@ -246,6 +247,17 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block)
if (logger) if (logger)
logger->captureGenerationModule(module); logger->captureGenerationModule(module);
for (const auto& [ty, domain] : localTypes)
{
// FIXME: This isn't the most efficient thing.
TypeId domainTy = builtinTypes->neverType;
for (TypeId d : domain)
domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result;
LUAU_ASSERT(get<BlockedType>(ty));
asMutable(ty)->ty.emplace<BoundType>(domainTy);
}
} }
TypeId ConstraintGenerator::freshType(const ScopePtr& scope) TypeId ConstraintGenerator::freshType(const ScopePtr& scope)
@ -310,7 +322,8 @@ std::optional<TypeId> ConstraintGenerator::lookup(const ScopePtr& scope, Locatio
std::optional<TypeId> ty = lookup(scope, location, operand, /*prototype*/ false); std::optional<TypeId> ty = lookup(scope, location, operand, /*prototype*/ false);
if (!ty) if (!ty)
{ {
ty = arena->addType(LocalType{builtinTypes->neverType}); ty = arena->addType(BlockedType{});
localTypes[*ty] = {};
rootScope->lvalueTypes[operand] = *ty; rootScope->lvalueTypes[operand] = *ty;
} }
@ -703,7 +716,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat
{ {
const Location location = local->location; const Location location = local->location;
TypeId assignee = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, local->name.value}); TypeId assignee = arena->addType(BlockedType{});
localTypes[assignee] = {};
assignees.push_back(assignee); assignees.push_back(assignee);
@ -740,7 +754,12 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat
if (hasAnnotation) if (hasAnnotation)
{ {
for (size_t i = 0; i < statLocal->vars.size; ++i) for (size_t i = 0; i < statLocal->vars.size; ++i)
addConstraint(scope, statLocal->location, AssignConstraint{assignees[i], annotatedTypes[i]}); {
LUAU_ASSERT(get<BlockedType>(assignees[i]));
std::vector<TypeId>* localDomain = localTypes.find(assignees[i]);
LUAU_ASSERT(localDomain);
localDomain->push_back(annotatedTypes[i]);
}
TypePackId annotatedPack = arena->addTypePack(std::move(annotatedTypes)); TypePackId annotatedPack = arena->addTypePack(std::move(annotatedTypes));
addConstraint(scope, statLocal->location, PackSubtypeConstraint{rvaluePack, annotatedPack}); addConstraint(scope, statLocal->location, PackSubtypeConstraint{rvaluePack, annotatedPack});
@ -750,15 +769,30 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat
std::vector<TypeId> valueTypes; std::vector<TypeId> valueTypes;
valueTypes.reserve(statLocal->vars.size); valueTypes.reserve(statLocal->vars.size);
for (size_t i = 0; i < statLocal->vars.size; ++i) auto [head, tail] = flatten(rvaluePack);
valueTypes.push_back(arena->addType(BlockedType{}));
auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{arena->addTypePack(valueTypes), rvaluePack}); if (head.size() >= statLocal->vars.size)
{
for (size_t i = 0; i < statLocal->vars.size; ++i)
valueTypes.push_back(head[i]);
}
else
{
for (size_t i = 0; i < statLocal->vars.size; ++i)
valueTypes.push_back(arena->addType(BlockedType{}));
auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{valueTypes, rvaluePack});
for (TypeId t: valueTypes)
getMutable<BlockedType>(t)->setOwner(uc);
}
for (size_t i = 0; i < statLocal->vars.size; ++i) for (size_t i = 0; i < statLocal->vars.size; ++i)
{ {
getMutable<BlockedType>(valueTypes[i])->setOwner(uc); LUAU_ASSERT(get<BlockedType>(assignees[i]));
addConstraint(scope, statLocal->location, AssignConstraint{assignees[i], valueTypes[i]}); std::vector<TypeId>* localDomain = localTypes.find(assignees[i]);
LUAU_ASSERT(localDomain);
localDomain->push_back(valueTypes[i]);
} }
} }
@ -860,25 +894,34 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forI
for (AstLocal* var : forIn->vars) for (AstLocal* var : forIn->vars)
{ {
TypeId assignee = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, var->name.value}); TypeId assignee = arena->addType(BlockedType{});
variableTypes.push_back(assignee); variableTypes.push_back(assignee);
TypeId loopVar = arena->addType(BlockedType{});
localTypes[loopVar].push_back(assignee);
if (var->annotation) if (var->annotation)
{ {
TypeId annotationTy = resolveType(loopScope, var->annotation, /*inTypeArguments*/ false); TypeId annotationTy = resolveType(loopScope, var->annotation, /*inTypeArguments*/ false);
loopScope->bindings[var] = Binding{annotationTy, var->location}; loopScope->bindings[var] = Binding{annotationTy, var->location};
addConstraint(scope, var->location, SubtypeConstraint{assignee, annotationTy}); addConstraint(scope, var->location, SubtypeConstraint{loopVar, annotationTy});
} }
else else
loopScope->bindings[var] = Binding{assignee, var->location}; loopScope->bindings[var] = Binding{loopVar, var->location};
DefId def = dfg->getDef(var); DefId def = dfg->getDef(var);
loopScope->lvalueTypes[def] = assignee; loopScope->lvalueTypes[def] = loopVar;
} }
TypePackId variablePack = arena->addTypePack(std::move(variableTypes));
auto iterable = addConstraint( auto iterable = addConstraint(
loopScope, getLocation(forIn->values), IterableConstraint{iterator, variablePack, forIn->values.data[0], &module->astForInNextTypes}); loopScope, getLocation(forIn->values), IterableConstraint{iterator, variableTypes, forIn->values.data[0], &module->astForInNextTypes});
for (TypeId var: variableTypes)
{
auto bt = getMutable<BlockedType>(var);
LUAU_ASSERT(bt);
bt->setOwner(iterable);
}
Checkpoint start = checkpoint(this); Checkpoint start = checkpoint(this);
visit(loopScope, forIn->body); visit(loopScope, forIn->body);
@ -1105,14 +1148,31 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatAssign* ass
std::vector<TypeId> valueTypes; std::vector<TypeId> valueTypes;
valueTypes.reserve(assign->vars.size); valueTypes.reserve(assign->vars.size);
for (size_t i = 0; i < assign->vars.size; ++i) auto [head, tail] = flatten(resultPack);
valueTypes.push_back(arena->addType(BlockedType{})); if (head.size() >= assign->vars.size)
{
// If the resultPack is definitely long enough for each variable, we can
// skip the UnpackConstraint and use the result types directly.
auto uc = addConstraint(scope, assign->location, UnpackConstraint{arena->addTypePack(valueTypes), resultPack}); for (size_t i = 0; i < assign->vars.size; ++i)
valueTypes.push_back(head[i]);
}
else
{
// We're not sure how many types are produced by the right-side
// expressions. We'll use an UnpackConstraint to defer this until
// later.
for (size_t i = 0; i < assign->vars.size; ++i)
valueTypes.push_back(arena->addType(BlockedType{}));
auto uc = addConstraint(scope, assign->location, UnpackConstraint{valueTypes, resultPack});
for (TypeId t: valueTypes)
getMutable<BlockedType>(t)->setOwner(uc);
}
for (size_t i = 0; i < assign->vars.size; ++i) for (size_t i = 0; i < assign->vars.size; ++i)
{ {
getMutable<BlockedType>(valueTypes[i])->setOwner(uc);
visitLValue(scope, assign->vars.data[i], valueTypes[i]); visitLValue(scope, assign->vars.data[i], valueTypes[i]);
} }
@ -1393,7 +1453,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareFunc
TypePackId retPack = resolveTypePack(funScope, global->retTypes, /* inTypeArguments */ false); TypePackId retPack = resolveTypePack(funScope, global->retTypes, /* inTypeArguments */ false);
TypeId fnType = arena->addType(FunctionType{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack}); TypeId fnType = arena->addType(FunctionType{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack});
FunctionType* ftv = getMutable<FunctionType>(fnType); FunctionType* ftv = getMutable<FunctionType>(fnType);
ftv->isCheckedFunction = global->checkedFunction; ftv->isCheckedFunction = FFlag::LuauAttributeSyntax ? global->isCheckedFunction() : false;
ftv->argNames.reserve(global->paramNames.size); ftv->argNames.reserve(global->paramNames.size);
for (const auto& el : global->paramNames) for (const auto& el : global->paramNames)
@ -1599,9 +1659,8 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall*
mt = arena->addType(BlockedType{}); mt = arena->addType(BlockedType{});
unpackedTypes.emplace_back(mt); unpackedTypes.emplace_back(mt);
TypePackId mtPack = arena->addTypePack(std::move(unpackedTypes));
auto c = addConstraint(scope, call->location, UnpackConstraint{mtPack, *argTail}); auto c = addConstraint(scope, call->location, UnpackConstraint{unpackedTypes, *argTail});
getMutable<BlockedType>(mt)->setOwner(c); getMutable<BlockedType>(mt)->setOwner(c);
if (auto b = getMutable<BlockedType>(target); b && b->getOwner() == nullptr) if (auto b = getMutable<BlockedType>(target); b && b->getOwner() == nullptr)
b->setOwner(c); b->setOwner(c);
@ -1842,7 +1901,37 @@ Inference ConstraintGenerator::checkIndexName(
const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation) const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation)
{ {
TypeId obj = check(scope, indexee).ty; TypeId obj = check(scope, indexee).ty;
TypeId result = arena->addType(BlockedType{}); TypeId result = nullptr;
// We optimize away the HasProp constraint in simple cases so that we can
// reason about updates to unsealed tables more accurately.
const TableType* tt = getTableType(obj);
// This is a little bit iffy but I *believe* it is okay because, if the
// local's domain is going to be extended at all, it will be someplace after
// the current lexical position within the script.
if (!tt)
{
if (auto localDomain = localTypes.find(obj); localDomain && 1 == localDomain->size())
tt = getTableType(localDomain->front());
}
if (tt)
{
auto it = tt->props.find(index);
if (it != tt->props.end() && it->second.readTy.has_value())
result = *it->second.readTy;
}
if (!result)
{
result = arena->addType(BlockedType{});
auto c = addConstraint(
scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)});
getMutable<BlockedType>(result)->setOwner(c);
}
if (key) if (key)
{ {
@ -1852,10 +1941,6 @@ Inference ConstraintGenerator::checkIndexName(
scope->rvalueRefinements[key->def] = result; scope->rvalueRefinements[key->def] = result;
} }
auto c =
addConstraint(scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)});
getMutable<BlockedType>(result)->setOwner(c);
if (key) if (key)
return Inference{result, refinementArena.proposition(key, builtinTypes->truthyType)}; return Inference{result, refinementArena.proposition(key, builtinTypes->truthyType)};
else else
@ -2242,18 +2327,14 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local
if (ty) if (ty)
{ {
if (auto lt = getMutable<LocalType>(*ty)) std::vector<TypeId>* localDomain = localTypes.find(*ty);
++lt->blockCount; if (localDomain)
else if (auto ut = getMutable<UnionType>(*ty)) localDomain->push_back(rhsType);
{
for (TypeId optTy : ut->options)
if (auto lt = getMutable<LocalType>(optTy))
++lt->blockCount;
}
} }
else else
{ {
ty = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, local->local->name.value}); ty = arena->addType(BlockedType{});
localTypes[*ty].push_back(rhsType);
if (annotatedTy) if (annotatedTy)
{ {
@ -2277,7 +2358,9 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local
if (annotatedTy) if (annotatedTy)
addConstraint(scope, local->location, SubtypeConstraint{rhsType, *annotatedTy}); addConstraint(scope, local->location, SubtypeConstraint{rhsType, *annotatedTy});
addConstraint(scope, local->location, AssignConstraint{*ty, rhsType});
if (auto localDomain = localTypes.find(*ty))
localDomain->push_back(rhsType);
} }
void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType) void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType)
@ -2289,7 +2372,6 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* glob
rootScope->lvalueTypes[def] = rhsType; rootScope->lvalueTypes[def] = rhsType;
addConstraint(scope, global->location, SubtypeConstraint{rhsType, *annotatedTy}); addConstraint(scope, global->location, SubtypeConstraint{rhsType, *annotatedTy});
addConstraint(scope, global->location, AssignConstraint{*annotatedTy, rhsType});
} }
} }
@ -2298,7 +2380,10 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexName* e
TypeId lhsTy = check(scope, expr->expr).ty; TypeId lhsTy = check(scope, expr->expr).ty;
TypeId propTy = arena->addType(BlockedType{}); TypeId propTy = arena->addType(BlockedType{});
module->astTypes[expr] = propTy; module->astTypes[expr] = propTy;
addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, propTy});
bool incremented = recordPropertyAssignment(lhsTy);
addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, propTy, incremented});
} }
void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* expr, TypeId rhsType) void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* expr, TypeId rhsType)
@ -2310,7 +2395,10 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* e
module->astTypes[expr] = propTy; module->astTypes[expr] = propTy;
module->astTypes[expr->index] = builtinTypes->stringType; // FIXME? Singleton strings exist. module->astTypes[expr->index] = builtinTypes->stringType; // FIXME? Singleton strings exist.
std::string propName{constantString->value.data, constantString->value.size}; std::string propName{constantString->value.data, constantString->value.size};
addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, propTy});
bool incremented = recordPropertyAssignment(lhsTy);
addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, propTy, incremented});
return; return;
} }
@ -2775,7 +2863,7 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool
// TODO: FunctionType needs a pointer to the scope so that we know // TODO: FunctionType needs a pointer to the scope so that we know
// how to quantify/instantiate it. // how to quantify/instantiate it.
FunctionType ftv{TypeLevel{}, scope.get(), {}, {}, argTypes, returnTypes}; FunctionType ftv{TypeLevel{}, scope.get(), {}, {}, argTypes, returnTypes};
ftv.isCheckedFunction = fn->checkedFunction; ftv.isCheckedFunction = FFlag::LuauAttributeSyntax ? fn->isCheckedFunction() : false;
// This replicates the behavior of the appropriate FunctionType // This replicates the behavior of the appropriate FunctionType
// constructors. // constructors.
@ -2977,8 +3065,7 @@ Inference ConstraintGenerator::flattenPack(const ScopePtr& scope, Location locat
return Inference{*f, refinement}; return Inference{*f, refinement};
TypeId typeResult = arena->addType(BlockedType{}); TypeId typeResult = arena->addType(BlockedType{});
TypePackId resultPack = arena->addTypePack({typeResult}, arena->freshTypePack(scope.get())); auto c = addConstraint(scope, location, UnpackConstraint{{typeResult}, tp});
auto c = addConstraint(scope, location, UnpackConstraint{resultPack, tp});
getMutable<BlockedType>(typeResult)->setOwner(c); getMutable<BlockedType>(typeResult)->setOwner(c);
return Inference{typeResult, refinement}; return Inference{typeResult, refinement};
@ -3075,6 +3162,46 @@ void ConstraintGenerator::prepopulateGlobalScope(const ScopePtr& globalScope, As
program->visit(&gp); program->visit(&gp);
} }
bool ConstraintGenerator::recordPropertyAssignment(TypeId ty)
{
DenseHashSet<TypeId> seen{nullptr};
VecDeque<TypeId> queue;
queue.push_back(ty);
bool incremented = false;
while (!queue.empty())
{
const TypeId t = follow(queue.front());
queue.pop_front();
if (seen.find(t))
continue;
seen.insert(t);
if (auto tt = getMutable<TableType>(t); tt && tt->state == TableState::Unsealed)
{
tt->remainingProps += 1;
incremented = true;
}
else if (auto mt = get<MetatableType>(t))
queue.push_back(mt->table);
else if (auto localDomain = localTypes.find(t))
{
for (TypeId domainTy : *localDomain)
queue.push_back(domainTy);
}
else if (auto ut = get<UnionType>(t))
{
for (TypeId part : ut)
queue.push_back(part);
}
}
return incremented;
}
void ConstraintGenerator::recordInferredBinding(AstLocal* local, TypeId ty) void ConstraintGenerator::recordInferredBinding(AstLocal* local, TypeId ty)
{ {
if (InferredBinding* ib = inferredBindings.find(local)) if (InferredBinding* ib = inferredBindings.find(local))

View file

@ -532,8 +532,6 @@ bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool fo
success = tryDispatch(*hpc, constraint); success = tryDispatch(*hpc, constraint);
else if (auto spc = get<HasIndexerConstraint>(*constraint)) else if (auto spc = get<HasIndexerConstraint>(*constraint))
success = tryDispatch(*spc, constraint); success = tryDispatch(*spc, constraint);
else if (auto uc = get<AssignConstraint>(*constraint))
success = tryDispatch(*uc, constraint);
else if (auto uc = get<AssignPropConstraint>(*constraint)) else if (auto uc = get<AssignPropConstraint>(*constraint))
success = tryDispatch(*uc, constraint); success = tryDispatch(*uc, constraint);
else if (auto uc = get<AssignIndexConstraint>(*constraint)) else if (auto uc = get<AssignIndexConstraint>(*constraint))
@ -686,7 +684,8 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull<const Co
if (0 == iterator.head.size()) if (0 == iterator.head.size())
{ {
unify(constraint, builtinTypes->anyTypePack, c.variables); for (TypeId ty : c.variables)
unify(constraint, builtinTypes->errorRecoveryType(), ty);
return true; return true;
} }
@ -696,21 +695,35 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull<const Co
{ {
TypeId keyTy = freshType(arena, builtinTypes, constraint->scope); TypeId keyTy = freshType(arena, builtinTypes, constraint->scope);
TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); TypeId valueTy = freshType(arena, builtinTypes, constraint->scope);
TypeId tableTy = arena->addType(TableType{TableState::Sealed, {}, constraint->scope}); TypeId tableTy = arena->addType(TableType{
getMutable<TableType>(tableTy)->indexer = TableIndexer{keyTy, valueTy}; TableType::Props{},
TableIndexer{keyTy, valueTy},
TypeLevel{},
constraint->scope,
TableState::Free
});
pushConstraint(constraint->scope, constraint->location, SubtypeConstraint{nextTy, tableTy}); unify(constraint, nextTy, tableTy);
auto it = begin(c.variables); auto it = begin(c.variables);
auto endIt = end(c.variables); auto endIt = end(c.variables);
if (it != endIt) if (it != endIt)
{ {
pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, keyTy}); bindBlockedType(*it, keyTy, keyTy, constraint);
++it; ++it;
} }
if (it != endIt) if (it != endIt)
pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, valueTy}); {
bindBlockedType(*it, valueTy, valueTy, constraint);
++it;
}
while (it != endIt)
{
bindBlockedType(*it, builtinTypes->nilType, builtinTypes->nilType, constraint);
++it;
}
return true; return true;
} }
@ -721,11 +734,7 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull<const Co
if (iterator.head.size() >= 2) if (iterator.head.size() >= 2)
tableTy = iterator.head[1]; tableTy = iterator.head[1];
TypeId firstIndexTy = builtinTypes->nilType; return tryDispatchIterableFunction(nextTy, tableTy, c, constraint, force);
if (iterator.head.size() >= 3)
firstIndexTy = iterator.head[2];
return tryDispatchIterableFunction(nextTy, tableTy, firstIndexTy, c, constraint, force);
} }
else else
@ -1310,6 +1319,14 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull<const Con
if (isBlocked(subjectType) || get<PendingExpansionType>(subjectType) || get<TypeFamilyInstanceType>(subjectType)) if (isBlocked(subjectType) || get<PendingExpansionType>(subjectType) || get<TypeFamilyInstanceType>(subjectType))
return block(subjectType, constraint); return block(subjectType, constraint);
if (const TableType* subjectTable = getTableType(subjectType))
{
if (subjectTable->state == TableState::Unsealed && subjectTable->remainingProps > 0 && subjectTable->props.count(c.prop) == 0)
{
return block(subjectType, constraint);
}
}
auto [blocked, result] = lookupTableProp(constraint, subjectType, c.prop, c.context, c.inConditional, c.suppressSimplification); auto [blocked, result] = lookupTableProp(constraint, subjectType, c.prop, c.context, c.inConditional, c.suppressSimplification);
if (!blocked.empty()) if (!blocked.empty())
{ {
@ -1517,7 +1534,10 @@ bool ConstraintSolver::tryDispatch(const HasIndexerConstraint& c, NotNull<const
Set<TypeId> seen{nullptr}; Set<TypeId> seen{nullptr};
return tryDispatchHasIndexer(recursionDepth, constraint, subjectType, indexType, c.resultType, seen); bool ok = tryDispatchHasIndexer(recursionDepth, constraint, subjectType, indexType, c.resultType, seen);
if (ok)
unblock(c.resultType, constraint->location);
return ok;
} }
std::pair<bool, std::optional<TypeId>> ConstraintSolver::tryDispatchSetIndexer( std::pair<bool, std::optional<TypeId>> ConstraintSolver::tryDispatchSetIndexer(
@ -1596,46 +1616,6 @@ std::pair<bool, std::optional<TypeId>> ConstraintSolver::tryDispatchSetIndexer(
return {true, std::nullopt}; return {true, std::nullopt};
} }
bool ConstraintSolver::tryDispatch(const AssignConstraint& c, NotNull<const Constraint> constraint)
{
const TypeId lhsTy = follow(c.lhsType);
const TypeId rhsTy = follow(c.rhsType);
if (!get<LocalType>(lhsTy) && isBlocked(lhsTy))
return block(lhsTy, constraint);
auto tryExpand = [&](TypeId ty) {
LocalType* lt = getMutable<LocalType>(ty);
if (!lt)
return;
lt->domain = simplifyUnion(builtinTypes, arena, lt->domain, rhsTy).result;
LUAU_ASSERT(lt->blockCount > 0);
--lt->blockCount;
if (0 == lt->blockCount)
{
shiftReferences(ty, lt->domain);
emplaceType<BoundType>(asMutable(ty), lt->domain);
}
};
if (auto ut = get<UnionType>(lhsTy))
{
// FIXME: I suspect there's a bug here where lhsTy is a union that contains no LocalTypes.
for (TypeId t : ut)
tryExpand(t);
}
else if (get<LocalType>(lhsTy))
tryExpand(lhsTy);
else
unify(constraint, rhsTy, lhsTy);
unblock(lhsTy, constraint->location);
return true;
}
bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull<const Constraint> constraint) bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull<const Constraint> constraint)
{ {
TypeId lhsType = follow(c.lhsType); TypeId lhsType = follow(c.lhsType);
@ -1753,6 +1733,14 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull<const
{ {
emplaceType<BoundType>(asMutable(c.propType), rhsType); emplaceType<BoundType>(asMutable(c.propType), rhsType);
lhsTable->props[propName] = Property::rw(rhsType); lhsTable->props[propName] = Property::rw(rhsType);
if (lhsTable->state == TableState::Unsealed && c.decrementPropCount)
{
LUAU_ASSERT(lhsTable->remainingProps > 0);
lhsTable->remainingProps -= 1;
unblock(lhsType, constraint->location);
}
return true; return true;
} }
} }
@ -1927,24 +1915,14 @@ bool ConstraintSolver::tryDispatchUnpack1(NotNull<const Constraint> constraint,
bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull<const Constraint> constraint) bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull<const Constraint> constraint)
{ {
TypePackId sourcePack = follow(c.sourcePack); TypePackId sourcePack = follow(c.sourcePack);
TypePackId resultPack = follow(c.resultPack);
if (isBlocked(sourcePack)) if (isBlocked(sourcePack))
return block(sourcePack, constraint); return block(sourcePack, constraint);
if (isBlocked(resultPack)) TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, c.resultPack.size());
{
LUAU_ASSERT(canMutate(resultPack, constraint));
LUAU_ASSERT(resultPack != sourcePack);
emplaceTypePack<BoundTypePack>(asMutable(resultPack), sourcePack);
unblock(resultPack, constraint->location);
return true;
}
TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, size(resultPack)); auto resultIter = begin(c.resultPack);
auto resultEnd = end(c.resultPack);
auto resultIter = begin(resultPack);
auto resultEnd = end(resultPack);
size_t i = 0; size_t i = 0;
while (resultIter != resultEnd) while (resultIter != resultEnd)
@ -2080,18 +2058,22 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl
auto endIt = end(c.variables); auto endIt = end(c.variables);
if (it != endIt) if (it != endIt)
{ {
pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, keyTy}); bindBlockedType(*it, keyTy, keyTy, constraint);
++it; ++it;
} }
if (it != endIt) if (it != endIt)
pushConstraint(constraint->scope, constraint->location, AssignConstraint{*it, valueTy}); bindBlockedType(*it, valueTy, valueTy, constraint);
return true; return true;
} }
auto unpack = [&](TypeId ty) { auto unpack = [&](TypeId ty) {
for (TypeId varTy : c.variables) for (TypeId varTy: c.variables)
pushConstraint(constraint->scope, constraint->location, AssignConstraint{varTy, ty}); {
LUAU_ASSERT(get<BlockedType>(varTy));
LUAU_ASSERT(varTy != ty);
bindBlockedType(varTy, ty, ty, constraint);
}
}; };
if (get<AnyType>(iteratorTy)) if (get<AnyType>(iteratorTy))
@ -2129,27 +2111,18 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl
if (iteratorTable->indexer) if (iteratorTable->indexer)
{ {
TypePackId expectedVariablePack = arena->addTypePack({iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType}); std::vector<TypeId> expectedVariables{iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType};
unify(constraint, c.variables, expectedVariablePack); while (c.variables.size() >= expectedVariables.size())
expectedVariables.push_back(builtinTypes->errorRecoveryType());
auto [variableTys, variablesTail] = flatten(c.variables); for (size_t i = 0; i < c.variables.size(); ++i)
// the local types for the indexer _should_ be all set after unification
for (TypeId ty : variableTys)
{ {
if (auto lt = getMutable<LocalType>(ty)) LUAU_ASSERT(c.variables[i] != expectedVariables[i]);
{
LUAU_ASSERT(lt->blockCount > 0);
--lt->blockCount;
LUAU_ASSERT(0 <= lt->blockCount); unify(constraint, c.variables[i], expectedVariables[i]);
if (0 == lt->blockCount) bindBlockedType(c.variables[i], expectedVariables[i], expectedVariables[i], constraint);
{ unblock(c.variables[i], constraint->location);
shiftReferences(ty, lt->domain);
emplaceType<BoundType>(asMutable(ty), lt->domain);
}
}
} }
} }
else else
@ -2213,26 +2186,16 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl
else if (auto primitiveTy = get<PrimitiveType>(iteratorTy); primitiveTy && primitiveTy->type == PrimitiveType::Type::Table) else if (auto primitiveTy = get<PrimitiveType>(iteratorTy); primitiveTy && primitiveTy->type == PrimitiveType::Type::Table)
unpack(builtinTypes->unknownType); unpack(builtinTypes->unknownType);
else else
{
unpack(builtinTypes->errorType); unpack(builtinTypes->errorType);
}
return true; return true;
} }
bool ConstraintSolver::tryDispatchIterableFunction( bool ConstraintSolver::tryDispatchIterableFunction(
TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force) TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force)
{ {
// We need to know whether or not this type is nil or not.
// If we don't know, block and reschedule ourselves.
firstIndexTy = follow(firstIndexTy);
if (get<FreeType>(firstIndexTy))
{
if (force)
LUAU_ASSERT(false);
else
block(firstIndexTy, constraint);
return false;
}
const FunctionType* nextFn = get<FunctionType>(nextTy); const FunctionType* nextFn = get<FunctionType>(nextTy);
// If this does not hold, we should've never called `tryDispatchIterableFunction` in the first place. // If this does not hold, we should've never called `tryDispatchIterableFunction` in the first place.
LUAU_ASSERT(nextFn); LUAU_ASSERT(nextFn);
@ -2267,27 +2230,18 @@ bool ConstraintSolver::tryDispatchIterableFunction(
return true; return true;
} }
NotNull<const Constraint> ConstraintSolver::unpackAndAssign(TypePackId destTypes, TypePackId srcTypes, NotNull<const Constraint> constraint) NotNull<const Constraint> ConstraintSolver::unpackAndAssign(const std::vector<TypeId> destTypes, TypePackId srcTypes, NotNull<const Constraint> constraint)
{ {
std::vector<TypeId> unpackedTys; auto c = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{destTypes, srcTypes});
for (TypeId _ty : destTypes)
for (TypeId t: destTypes)
{ {
(void) _ty; BlockedType* bt = getMutable<BlockedType>(t);
unpackedTys.push_back(arena->addType(BlockedType{})); LUAU_ASSERT(bt);
bt->replaceOwner(c);
} }
TypePackId unpackedTp = arena->addTypePack(TypePack{unpackedTys}); return c;
auto unpackConstraint = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{unpackedTp, srcTypes});
size_t i = 0;
for (TypeId varTy : destTypes)
{
pushConstraint(constraint->scope, constraint->location, AssignConstraint{varTy, unpackedTys[i]});
getMutable<BlockedType>(unpackedTys[i])->setOwner(unpackConstraint);
++i;
}
return unpackConstraint;
} }
std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType, std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType,
@ -2808,9 +2762,6 @@ bool ConstraintSolver::isBlocked(TypeId ty)
{ {
ty = follow(ty); ty = follow(ty);
if (auto lt = get<LocalType>(ty))
return lt->blockCount > 0;
if (auto tfit = get<TypeFamilyInstanceType>(ty)) if (auto tfit = get<TypeFamilyInstanceType>(ty))
return uninhabitedTypeFamilies.contains(ty) == false; return uninhabitedTypeFamilies.contains(ty) == false;

View file

@ -2,6 +2,7 @@
#include "Luau/BuiltinDefinitions.h" #include "Luau/BuiltinDefinitions.h"
LUAU_FASTFLAGVARIABLE(LuauCheckedEmbeddedDefinitions2, false); LUAU_FASTFLAGVARIABLE(LuauCheckedEmbeddedDefinitions2, false);
LUAU_FASTFLAG(LuauAttributeSyntax);
namespace Luau namespace Luau
{ {
@ -319,9 +320,9 @@ declare os: {
clock: () -> number, clock: () -> number,
} }
declare function @checked require(target: any): any @checked declare function require(target: any): any
declare function @checked getfenv(target: any): { [string]: any } @checked declare function getfenv(target: any): { [string]: any }
declare _G: any declare _G: any
declare _VERSION: string declare _VERSION: string
@ -363,7 +364,7 @@ declare function select<A...>(i: string | number, ...: A...): ...any
-- (nil, string). -- (nil, string).
declare function loadstring<A...>(src: string, chunkname: string?): (((A...) -> any)?, string?) declare function loadstring<A...>(src: string, chunkname: string?): (((A...) -> any)?, string?)
declare function @checked newproxy(mt: boolean?): any @checked declare function newproxy(mt: boolean?): any
declare coroutine: { declare coroutine: {
create: <A..., R...>(f: (A...) -> R...) -> thread, create: <A..., R...>(f: (A...) -> R...) -> thread,
@ -451,7 +452,7 @@ std::string getBuiltinDefinitionSource()
std::string result = kBuiltinDefinitionLuaSrc; std::string result = kBuiltinDefinitionLuaSrc;
// Annotates each non generic function as checked // Annotates each non generic function as checked
if (FFlag::LuauCheckedEmbeddedDefinitions2) if (FFlag::LuauCheckedEmbeddedDefinitions2 && FFlag::LuauAttributeSyntax)
result = kBuiltinDefinitionLuaSrcChecked; result = kBuiltinDefinitionLuaSrcChecked;
return result; return result;

View file

@ -1196,12 +1196,6 @@ struct InternalTypeFinder : TypeOnceVisitor
return false; return false;
} }
bool visit(TypeId, const LocalType&) override
{
LUAU_ASSERT(false);
return false;
}
bool visit(TypePackId, const BlockedTypePack&) override bool visit(TypePackId, const BlockedTypePack&) override
{ {
LUAU_ASSERT(false); LUAU_ASSERT(false);

View file

@ -1815,12 +1815,6 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
if (!isCacheable(there)) if (!isCacheable(there))
here.isCacheable = false; here.isCacheable = false;
} }
else if (auto lt = get<LocalType>(there))
{
// FIXME? This is somewhat questionable.
// Maybe we should assert because this should never happen?
unionNormalWithTy(here, lt->domain, seenSetTypes, ignoreSmallerTyvars);
}
else if (get<FunctionType>(there)) else if (get<FunctionType>(there))
unionFunctionsWithFunction(here.functions, there); unionFunctionsWithFunction(here.functions, there);
else if (get<TableType>(there) || get<MetatableType>(there)) else if (get<TableType>(there) || get<MetatableType>(there))
@ -3095,7 +3089,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
return NormalizationResult::True; return NormalizationResult::True;
} }
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) || else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) ||
get<TypeFamilyInstanceType>(there) || get<LocalType>(there)) get<TypeFamilyInstanceType>(there))
{ {
NormalizedType thereNorm{builtinTypes}; NormalizedType thereNorm{builtinTypes};
NormalizedType topNorm{builtinTypes}; NormalizedType topNorm{builtinTypes};
@ -3104,10 +3098,6 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
here.isCacheable = false; here.isCacheable = false;
return intersectNormals(here, thereNorm); return intersectNormals(here, thereNorm);
} }
else if (auto lt = get<LocalType>(there))
{
return intersectNormalWithTy(here, lt->domain, seenSetTypes);
}
NormalizedTyvars tyvars = std::move(here.tyvars); NormalizedTyvars tyvars = std::move(here.tyvars);

View file

@ -24,8 +24,6 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
// We decline to copy them. // We decline to copy them.
if constexpr (std::is_same_v<T, FreeType>) if constexpr (std::is_same_v<T, FreeType>)
return ty; return ty;
else if constexpr (std::is_same_v<T, LocalType>)
return ty;
else if constexpr (std::is_same_v<T, BoundType>) else if constexpr (std::is_same_v<T, BoundType>)
{ {
// This should never happen, but visit() cannot see it. // This should never happen, but visit() cannot see it.

View file

@ -262,14 +262,6 @@ void StateDot::visitChildren(TypeId ty, int index)
visitChild(t.upperBound, index, "[upperBound]"); visitChild(t.upperBound, index, "[upperBound]");
} }
} }
else if constexpr (std::is_same_v<T, LocalType>)
{
formatAppend(result, "LocalType");
finishNodeLabel(ty);
finishNode();
visitChild(t.domain, 1, "[domain]");
}
else if constexpr (std::is_same_v<T, AnyType>) else if constexpr (std::is_same_v<T, AnyType>)
{ {
formatAppend(result, "AnyType %d", index); formatAppend(result, "AnyType %d", index);

View file

@ -100,16 +100,6 @@ struct FindCyclicTypes final : TypeVisitor
return false; return false;
} }
bool visit(TypeId ty, const LocalType& lt) override
{
if (!visited.insert(ty))
return false;
traverse(lt.domain);
return false;
}
bool visit(TypeId ty, const TableType& ttv) override bool visit(TypeId ty, const TableType& ttv) override
{ {
if (!visited.insert(ty)) if (!visited.insert(ty))
@ -525,21 +515,6 @@ struct TypeStringifier
} }
} }
void operator()(TypeId ty, const LocalType& lt)
{
state.emit("l-");
state.emit(lt.name);
if (FInt::DebugLuauVerboseTypeNames >= 1)
{
state.emit("[");
state.emit(lt.blockCount);
state.emit("]");
}
state.emit("=[");
stringify(lt.domain);
state.emit("]");
}
void operator()(TypeId, const BoundType& btv) void operator()(TypeId, const BoundType& btv)
{ {
stringify(btv.boundTo); stringify(btv.boundTo);
@ -1724,6 +1699,18 @@ std::string generateName(size_t i)
return n; return n;
} }
std::string toStringVector(const std::vector<TypeId>& types, ToStringOptions& opts)
{
std::string s;
for (TypeId ty : types)
{
if (!s.empty())
s += ", ";
s += toString(ty, opts);
}
return s;
}
std::string toString(const Constraint& constraint, ToStringOptions& opts) std::string toString(const Constraint& constraint, ToStringOptions& opts)
{ {
auto go = [&opts](auto&& c) -> std::string { auto go = [&opts](auto&& c) -> std::string {
@ -1754,7 +1741,7 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
else if constexpr (std::is_same_v<T, IterableConstraint>) else if constexpr (std::is_same_v<T, IterableConstraint>)
{ {
std::string iteratorStr = tos(c.iterator); std::string iteratorStr = tos(c.iterator);
std::string variableStr = tos(c.variables); std::string variableStr = toStringVector(c.variables, opts);
return variableStr + " ~ iterate " + iteratorStr; return variableStr + " ~ iterate " + iteratorStr;
} }
@ -1791,14 +1778,12 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
{ {
return tos(c.resultType) + " ~ hasIndexer " + tos(c.subjectType) + " " + tos(c.indexType); return tos(c.resultType) + " ~ hasIndexer " + tos(c.subjectType) + " " + tos(c.indexType);
} }
else if constexpr (std::is_same_v<T, AssignConstraint>)
return "assign " + tos(c.lhsType) + " " + tos(c.rhsType);
else if constexpr (std::is_same_v<T, AssignPropConstraint>) else if constexpr (std::is_same_v<T, AssignPropConstraint>)
return "assignProp " + tos(c.lhsType) + " " + c.propName + " " + tos(c.rhsType); return "assignProp " + tos(c.lhsType) + " " + c.propName + " " + tos(c.rhsType);
else if constexpr (std::is_same_v<T, AssignIndexConstraint>) else if constexpr (std::is_same_v<T, AssignIndexConstraint>)
return "assignIndex " + tos(c.lhsType) + " " + tos(c.indexType) + " " + tos(c.rhsType); return "assignIndex " + tos(c.lhsType) + " " + tos(c.indexType) + " " + tos(c.rhsType);
else if constexpr (std::is_same_v<T, UnpackConstraint>) else if constexpr (std::is_same_v<T, UnpackConstraint>)
return tos(c.resultPack) + " ~ ...unpack " + tos(c.sourcePack); return toStringVector(c.resultPack, opts) + " ~ ...unpack " + tos(c.sourcePack);
else if constexpr (std::is_same_v<T, ReduceConstraint>) else if constexpr (std::is_same_v<T, ReduceConstraint>)
return "reduce " + tos(c.ty); return "reduce " + tos(c.ty);
else if constexpr (std::is_same_v<T, ReducePackConstraint>) else if constexpr (std::is_same_v<T, ReducePackConstraint>)

View file

@ -1182,11 +1182,11 @@ std::string toString(AstNode* node)
Printer printer(writer); Printer printer(writer);
printer.writeTypes = true; printer.writeTypes = true;
if (auto statNode = dynamic_cast<AstStat*>(node)) if (auto statNode = node->asStat())
printer.visualize(*statNode); printer.visualize(*statNode);
else if (auto exprNode = dynamic_cast<AstExpr*>(node)) else if (auto exprNode = node->asExpr())
printer.visualize(*exprNode); printer.visualize(*exprNode);
else if (auto typeNode = dynamic_cast<AstType*>(node)) else if (auto typeNode = node->asType())
printer.visualizeTypeAnnotation(*typeNode); printer.visualizeTypeAnnotation(*typeNode);
return writer.str(); return writer.str();

View file

@ -561,6 +561,11 @@ void BlockedType::setOwner(Constraint* newOwner)
owner = newOwner; owner = newOwner;
} }
void BlockedType::replaceOwner(Constraint* newOwner)
{
owner = newOwner;
}
PendingExpansionType::PendingExpansionType( PendingExpansionType::PendingExpansionType(
std::optional<AstName> prefix, AstName name, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments) std::optional<AstName> prefix, AstName name, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments)
: prefix(prefix) : prefix(prefix)

View file

@ -338,10 +338,6 @@ public:
{ {
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("free"), std::nullopt, Location()); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("free"), std::nullopt, Location());
} }
AstType* operator()(const LocalType& lt)
{
return Luau::visit(*this, lt.domain->ty);
}
AstType* operator()(const UnionType& uv) AstType* operator()(const UnionType& uv)
{ {
AstArray<AstType*> unionTypes; AstArray<AstType*> unionTypes;

View file

@ -447,7 +447,7 @@ FamilyGraphReductionResult reduceFamilies(TypePackId entrypoint, Location locati
bool isPending(TypeId ty, ConstraintSolver* solver) bool isPending(TypeId ty, ConstraintSolver* solver)
{ {
return is<BlockedType, PendingExpansionType, TypeFamilyInstanceType, LocalType>(ty) || (solver && solver->hasUnresolvedConstraints(ty)); return is<BlockedType, PendingExpansionType, TypeFamilyInstanceType>(ty) || (solver && solver->hasUnresolvedConstraints(ty));
} }
template<typename F, typename... Args> template<typename F, typename... Args>
@ -567,7 +567,7 @@ TypeFamilyReductionResult<TypeId> lenFamilyFn(TypeId instance, const std::vector
// check to see if the operand type is resolved enough, and wait to reduce if not // check to see if the operand type is resolved enough, and wait to reduce if not
// the use of `typeFromNormal` later necessitates blocking on local types. // the use of `typeFromNormal` later necessitates blocking on local types.
if (isPending(operandTy, ctx->solver) || get<LocalType>(operandTy)) if (isPending(operandTy, ctx->solver))
return {std::nullopt, false, {operandTy}, {}}; return {std::nullopt, false, {operandTy}, {}};
// if the type is free but has only one remaining reference, we can generalize it to its upper bound here. // if the type is free but has only one remaining reference, we can generalize it to its upper bound here.
@ -1427,12 +1427,6 @@ struct FindRefinementBlockers : TypeOnceVisitor
return false; return false;
} }
bool visit(TypeId ty, const LocalType&) override
{
found.insert(ty);
return false;
}
bool visit(TypeId ty, const ClassType&) override bool visit(TypeId ty, const ClassType&) override
{ {
return false; return false;

View file

@ -158,12 +158,6 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy)
if (subFree || superFree) if (subFree || superFree)
return true; return true;
if (auto subLocal = getMutable<LocalType>(subTy))
{
subLocal->domain = mkUnion(subLocal->domain, superTy);
expandedFreeTypes[subTy].push_back(superTy);
}
auto subFn = get<FunctionType>(subTy); auto subFn = get<FunctionType>(subTy);
auto superFn = get<FunctionType>(superTy); auto superFn = get<FunctionType>(superTy);
if (subFn && superFn) if (subFn && superFn)

View file

@ -60,6 +60,8 @@ class AstStat;
class AstStatBlock; class AstStatBlock;
class AstExpr; class AstExpr;
class AstTypePack; class AstTypePack;
class AstAttr;
class AstExprTable;
struct AstLocal struct AstLocal
{ {
@ -172,6 +174,10 @@ public:
{ {
return nullptr; return nullptr;
} }
virtual AstAttr* asAttr()
{
return nullptr;
}
template<typename T> template<typename T>
bool is() const bool is() const
@ -193,6 +199,28 @@ public:
Location location; Location location;
}; };
class AstAttr : public AstNode
{
public:
LUAU_RTTI(AstAttr)
enum Type
{
Checked,
};
AstAttr(const Location& location, Type type);
AstAttr* asAttr() override
{
return this;
}
void visit(AstVisitor* visitor) override;
Type type;
};
class AstExpr : public AstNode class AstExpr : public AstNode
{ {
public: public:
@ -384,13 +412,15 @@ class AstExprFunction : public AstExpr
public: public:
LUAU_RTTI(AstExprFunction) LUAU_RTTI(AstExprFunction)
AstExprFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, AstExprFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
AstLocal* self, const AstArray<AstLocal*>& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, const AstArray<AstGenericTypePack>& genericPacks, AstLocal* self, const AstArray<AstLocal*>& args, bool vararg,
const AstName& debugname, const std::optional<AstTypeList>& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, const AstName& debugname,
const std::optional<AstTypeList>& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr,
const std::optional<Location>& argLocation = std::nullopt); const std::optional<Location>& argLocation = std::nullopt);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
AstArray<AstAttr*> attributes;
AstArray<AstGenericType> generics; AstArray<AstGenericType> generics;
AstArray<AstGenericTypePack> genericPacks; AstArray<AstGenericTypePack> genericPacks;
AstLocal* self; AstLocal* self;
@ -810,20 +840,22 @@ public:
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames, const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames,
const AstTypeList& retTypes); const AstTypeList& retTypes);
AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray<AstGenericType>& generics, AstStatDeclareFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstName& name,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params,
const AstTypeList& retTypes, bool checkedFunction); const AstArray<AstArgumentName>& paramNames, const AstTypeList& retTypes);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
bool isCheckedFunction() const;
AstArray<AstAttr*> attributes;
AstName name; AstName name;
AstArray<AstGenericType> generics; AstArray<AstGenericType> generics;
AstArray<AstGenericTypePack> genericPacks; AstArray<AstGenericTypePack> genericPacks;
AstTypeList params; AstTypeList params;
AstArray<AstArgumentName> paramNames; AstArray<AstArgumentName> paramNames;
AstTypeList retTypes; AstTypeList retTypes;
bool checkedFunction;
}; };
struct AstDeclaredClassProp struct AstDeclaredClassProp
@ -936,17 +968,20 @@ public:
AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes); const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes);
AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, AstTypeFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes, bool checkedFunction); const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
bool isCheckedFunction() const;
AstArray<AstAttr*> attributes;
AstArray<AstGenericType> generics; AstArray<AstGenericType> generics;
AstArray<AstGenericTypePack> genericPacks; AstArray<AstGenericTypePack> genericPacks;
AstTypeList argTypes; AstTypeList argTypes;
AstArray<std::optional<AstArgumentName>> argNames; AstArray<std::optional<AstArgumentName>> argNames;
AstTypeList returnTypes; AstTypeList returnTypes;
bool checkedFunction;
}; };
class AstTypeTypeof : public AstType class AstTypeTypeof : public AstType
@ -1105,6 +1140,11 @@ public:
return true; return true;
} }
virtual bool visit(class AstAttr* node)
{
return visit(static_cast<AstNode*>(node));
}
virtual bool visit(class AstExpr* node) virtual bool visit(class AstExpr* node)
{ {
return visit(static_cast<AstNode*>(node)); return visit(static_cast<AstNode*>(node));

View file

@ -87,6 +87,8 @@ struct Lexeme
Comment, Comment,
BlockComment, BlockComment,
Attribute,
BrokenString, BrokenString,
BrokenComment, BrokenComment,
BrokenUnicode, BrokenUnicode,
@ -115,14 +117,20 @@ struct Lexeme
ReservedTrue, ReservedTrue,
ReservedUntil, ReservedUntil,
ReservedWhile, ReservedWhile,
ReservedChecked,
Reserved_END Reserved_END
}; };
Type type; Type type;
Location location; Location location;
// Field declared here, before the union, to ensure that Lexeme size is 32 bytes.
private:
// length is used to extract a slice from the input buffer.
// This field is only valid for certain lexeme types which don't duplicate portions of input
// but instead store a pointer to a location in the input buffer and the length of lexeme.
unsigned int length; unsigned int length;
public:
union union
{ {
const char* data; // String, Number, Comment const char* data; // String, Number, Comment
@ -135,9 +143,13 @@ struct Lexeme
Lexeme(const Location& location, Type type, const char* data, size_t size); Lexeme(const Location& location, Type type, const char* data, size_t size);
Lexeme(const Location& location, Type type, const char* name); Lexeme(const Location& location, Type type, const char* name);
unsigned int getLength() const;
std::string toString() const; std::string toString() const;
}; };
static_assert(sizeof(Lexeme) <= 32, "Size of `Lexeme` struct should be up to 32 bytes.");
class AstNameTable class AstNameTable
{ {
public: public:

View file

@ -82,8 +82,8 @@ private:
// if exp then block {elseif exp then block} [else block] end | // if exp then block {elseif exp then block} [else block] end |
// for Name `=' exp `,' exp [`,' exp] do block end | // for Name `=' exp `,' exp [`,' exp] do block end |
// for namelist in explist do block end | // for namelist in explist do block end |
// function funcname funcbody | // [attributes] function funcname funcbody |
// local function Name funcbody | // [attributes] local function Name funcbody |
// local namelist [`=' explist] // local namelist [`=' explist]
// laststat ::= return [explist] | break // laststat ::= return [explist] | break
AstStat* parseStat(); AstStat* parseStat();
@ -114,11 +114,25 @@ private:
AstExpr* parseFunctionName(Location start, bool& hasself, AstName& debugname); AstExpr* parseFunctionName(Location start, bool& hasself, AstName& debugname);
// function funcname funcbody // function funcname funcbody
AstStat* parseFunctionStat(); LUAU_FORCEINLINE AstStat* parseFunctionStat(const AstArray<AstAttr*>& attributes = {nullptr, 0});
std::pair<bool, AstAttr::Type> validateAttribute(const char* attributeName, const TempVector<AstAttr*>& attributes);
// attribute ::= '@' NAME
void parseAttribute(TempVector<AstAttr*>& attribute);
// attributes ::= {attribute}
AstArray<AstAttr*> parseAttributes();
// attributes local function Name funcbody
// attributes function funcname funcbody
// attributes `declare function' Name`(' [parlist] `)' [`:` Type]
// declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}'
AstStat* parseAttributeStat();
// local function Name funcbody | // local function Name funcbody |
// local namelist [`=' explist] // local namelist [`=' explist]
AstStat* parseLocal(); AstStat* parseLocal(const AstArray<AstAttr*>& attributes);
// return [explist] // return [explist]
AstStat* parseReturn(); AstStat* parseReturn();
@ -130,7 +144,7 @@ private:
// `declare global' Name: Type | // `declare global' Name: Type |
// `declare function' Name`(' [parlist] `)' [`:` Type] // `declare function' Name`(' [parlist] `)' [`:` Type]
AstStat* parseDeclaration(const Location& start); AstStat* parseDeclaration(const Location& start, const AstArray<AstAttr*>& attributes);
// varlist `=' explist // varlist `=' explist
AstStat* parseAssignment(AstExpr* initial); AstStat* parseAssignment(AstExpr* initial);
@ -143,7 +157,7 @@ private:
// funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` Type] // funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` Type]
// funcbody ::= funcbodyhead block end // funcbody ::= funcbodyhead block end
std::pair<AstExprFunction*, AstLocal*> parseFunctionBody( std::pair<AstExprFunction*, AstLocal*> parseFunctionBody(
bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName); bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray<AstAttr*>& attributes);
// explist ::= {exp `,'} exp // explist ::= {exp `,'} exp
void parseExprList(TempVector<AstExpr*>& result); void parseExprList(TempVector<AstExpr*>& result);
@ -176,10 +190,10 @@ private:
AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional<Location> accessLocation); AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional<Location> accessLocation);
AstTypeOrPack parseFunctionType(bool allowPack, bool isCheckedFunction = false); AstTypeOrPack parseFunctionType(bool allowPack, const AstArray<AstAttr*>& attributes);
AstType* parseFunctionTypeTail(const Lexeme& begin, AstArray<AstGenericType> generics, AstArray<AstGenericTypePack> genericPacks, AstType* parseFunctionTypeTail(const Lexeme& begin, const AstArray<AstAttr*>& attributes, AstArray<AstGenericType> generics,
AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames, AstTypePack* varargAnnotation, AstArray<AstGenericTypePack> genericPacks, AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames,
bool isCheckedFunction = false); AstTypePack* varargAnnotation);
AstType* parseTableType(bool inDeclarationContext = false); AstType* parseTableType(bool inDeclarationContext = false);
AstTypeOrPack parseSimpleType(bool allowPack, bool inDeclarationContext = false); AstTypeOrPack parseSimpleType(bool allowPack, bool inDeclarationContext = false);
@ -393,6 +407,7 @@ private:
std::vector<unsigned int> matchRecoveryStopOnToken; std::vector<unsigned int> matchRecoveryStopOnToken;
std::vector<AstAttr*> scratchAttr;
std::vector<AstStat*> scratchStat; std::vector<AstStat*> scratchStat;
std::vector<AstArray<char>> scratchString; std::vector<AstArray<char>> scratchString;
std::vector<AstExpr*> scratchExpr; std::vector<AstExpr*> scratchExpr;

View file

@ -3,6 +3,7 @@
#include "Luau/Common.h" #include "Luau/Common.h"
LUAU_FASTFLAG(LuauAttributeSyntax);
namespace Luau namespace Luau
{ {
@ -16,6 +17,17 @@ static void visitTypeList(AstVisitor* visitor, const AstTypeList& list)
list.tailType->visit(visitor); list.tailType->visit(visitor);
} }
AstAttr::AstAttr(const Location& location, Type type)
: AstNode(ClassIndex(), location)
, type(type)
{
}
void AstAttr::visit(AstVisitor* visitor)
{
visitor->visit(this);
}
int gAstRttiIndex = 0; int gAstRttiIndex = 0;
AstExprGroup::AstExprGroup(const Location& location, AstExpr* expr) AstExprGroup::AstExprGroup(const Location& location, AstExpr* expr)
@ -161,11 +173,12 @@ void AstExprIndexExpr::visit(AstVisitor* visitor)
} }
} }
AstExprFunction::AstExprFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, AstExprFunction::AstExprFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
AstLocal* self, const AstArray<AstLocal*>& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, const AstArray<AstGenericTypePack>& genericPacks, AstLocal* self, const AstArray<AstLocal*>& args, bool vararg, const Location& varargLocation,
const AstName& debugname, const std::optional<AstTypeList>& returnAnnotation, AstTypePack* varargAnnotation, AstStatBlock* body, size_t functionDepth, const AstName& debugname, const std::optional<AstTypeList>& returnAnnotation,
const std::optional<Location>& argLocation) AstTypePack* varargAnnotation, const std::optional<Location>& argLocation)
: AstExpr(ClassIndex(), location) : AstExpr(ClassIndex(), location)
, attributes(attributes)
, generics(generics) , generics(generics)
, genericPacks(genericPacks) , genericPacks(genericPacks)
, self(self) , self(self)
@ -696,27 +709,27 @@ AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const A
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames, const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames,
const AstTypeList& retTypes) const AstTypeList& retTypes)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, attributes()
, name(name) , name(name)
, generics(generics) , generics(generics)
, genericPacks(genericPacks) , genericPacks(genericPacks)
, params(params) , params(params)
, paramNames(paramNames) , paramNames(paramNames)
, retTypes(retTypes) , retTypes(retTypes)
, checkedFunction(false)
{ {
} }
AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray<AstGenericType>& generics, AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstName& name,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params,
const AstTypeList& retTypes, bool checkedFunction) const AstArray<AstArgumentName>& paramNames, const AstTypeList& retTypes)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, attributes(attributes)
, name(name) , name(name)
, generics(generics) , generics(generics)
, genericPacks(genericPacks) , genericPacks(genericPacks)
, params(params) , params(params)
, paramNames(paramNames) , paramNames(paramNames)
, retTypes(retTypes) , retTypes(retTypes)
, checkedFunction(checkedFunction)
{ {
} }
@ -729,6 +742,19 @@ void AstStatDeclareFunction::visit(AstVisitor* visitor)
} }
} }
bool AstStatDeclareFunction::isCheckedFunction() const
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
for (const AstAttr* attr : attributes)
{
if (attr->type == AstAttr::Type::Checked)
return true;
}
return false;
}
AstStatDeclareClass::AstStatDeclareClass(const Location& location, const AstName& name, std::optional<AstName> superName, AstStatDeclareClass::AstStatDeclareClass(const Location& location, const AstName& name, std::optional<AstName> superName,
const AstArray<AstDeclaredClassProp>& props, AstTableIndexer* indexer) const AstArray<AstDeclaredClassProp>& props, AstTableIndexer* indexer)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
@ -820,25 +846,26 @@ void AstTypeTable::visit(AstVisitor* visitor)
AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes) const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes)
: AstType(ClassIndex(), location) : AstType(ClassIndex(), location)
, attributes()
, generics(generics) , generics(generics)
, genericPacks(genericPacks) , genericPacks(genericPacks)
, argTypes(argTypes) , argTypes(argTypes)
, argNames(argNames) , argNames(argNames)
, returnTypes(returnTypes) , returnTypes(returnTypes)
, checkedFunction(false)
{ {
LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size); LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size);
} }
AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes, bool checkedFunction) const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes)
: AstType(ClassIndex(), location) : AstType(ClassIndex(), location)
, attributes(attributes)
, generics(generics) , generics(generics)
, genericPacks(genericPacks) , genericPacks(genericPacks)
, argTypes(argTypes) , argTypes(argTypes)
, argNames(argNames) , argNames(argNames)
, returnTypes(returnTypes) , returnTypes(returnTypes)
, checkedFunction(checkedFunction)
{ {
LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size); LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size);
} }
@ -852,6 +879,19 @@ void AstTypeFunction::visit(AstVisitor* visitor)
} }
} }
bool AstTypeFunction::isCheckedFunction() const
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
for (const AstAttr* attr : attributes)
{
if (attr->type == AstAttr::Type::Checked)
return true;
}
return false;
}
AstTypeTypeof::AstTypeTypeof(const Location& location, AstExpr* expr) AstTypeTypeof::AstTypeTypeof(const Location& location, AstExpr* expr)
: AstType(ClassIndex(), location) : AstType(ClassIndex(), location)
, expr(expr) , expr(expr)

View file

@ -8,6 +8,7 @@
#include <limits.h> #include <limits.h>
LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false) LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false)
LUAU_FASTFLAGVARIABLE(LuauAttributeSyntax, false)
namespace Luau namespace Luau
{ {
@ -102,11 +103,19 @@ Lexeme::Lexeme(const Location& location, Type type, const char* name)
, length(0) , length(0)
, name(name) , name(name)
{ {
LUAU_ASSERT(type == Name || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END)); LUAU_ASSERT(type == Name || type == Attribute || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END));
}
unsigned int Lexeme::getLength() const
{
LUAU_ASSERT(type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd ||
type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment);
return length;
} }
static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", "not", "or", static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", "not", "or",
"repeat", "return", "then", "true", "until", "while", "@checked"}; "repeat", "return", "then", "true", "until", "while"};
std::string Lexeme::toString() const std::string Lexeme::toString() const
{ {
@ -191,6 +200,10 @@ std::string Lexeme::toString() const
case Comment: case Comment:
return "comment"; return "comment";
case Attribute:
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
return name ? format("'%s'", name) : "attribute";
case BrokenString: case BrokenString:
return "malformed string"; return "malformed string";
@ -278,7 +291,7 @@ std::pair<AstName, Lexeme::Type> AstNameTable::getOrAddWithType(const char* name
nameData[length] = 0; nameData[length] = 0;
const_cast<Entry&>(entry).value = AstName(nameData); const_cast<Entry&>(entry).value = AstName(nameData);
const_cast<Entry&>(entry).type = Lexeme::Name; const_cast<Entry&>(entry).type = (name[0] == '@' ? Lexeme::Attribute : Lexeme::Name);
return std::make_pair(entry.value, entry.type); return std::make_pair(entry.value, entry.type);
} }
@ -994,14 +1007,11 @@ Lexeme Lexer::readNext()
} }
case '@': case '@':
{ {
// We're trying to lex the token @checked if (FFlag::LuauAttributeSyntax)
LUAU_ASSERT(peekch() == '@'); {
std::pair<AstName, Lexeme::Type> attribute = readName();
std::pair<AstName, Lexeme::Type> maybeChecked = readName(); return Lexeme(Location(start, position()), Lexeme::Attribute, attribute.first.value);
if (maybeChecked.second != Lexeme::ReservedChecked) }
return Lexeme(Location(start, position()), Lexeme::Error);
return Lexeme(Location(start, position()), maybeChecked.second, maybeChecked.first.value);
} }
default: default:
if (isDigit(peekch())) if (isDigit(peekch()))

View file

@ -17,11 +17,20 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
// flag so that we don't break production games by reverting syntax changes. // flag so that we don't break production games by reverting syntax changes.
// See docs/SyntaxChanges.md for an explanation. // See docs/SyntaxChanges.md for an explanation.
LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false)
LUAU_FASTFLAG(LuauAttributeSyntax)
LUAU_FASTFLAGVARIABLE(LuauLeadingBarAndAmpersand, false) LUAU_FASTFLAGVARIABLE(LuauLeadingBarAndAmpersand, false)
namespace Luau namespace Luau
{ {
struct AttributeEntry
{
const char* name;
AstAttr::Type type;
};
AttributeEntry kAttributeEntries[] = {{"@checked", AstAttr::Type::Checked}, {nullptr, AstAttr::Type::Checked}};
ParseError::ParseError(const Location& location, const std::string& message) ParseError::ParseError(const Location& location, const std::string& message)
: location(location) : location(location)
, message(message) , message(message)
@ -280,7 +289,9 @@ AstStatBlock* Parser::parseBlockNoScope()
// for binding `=' exp `,' exp [`,' exp] do block end | // for binding `=' exp `,' exp [`,' exp] do block end |
// for namelist in explist do block end | // for namelist in explist do block end |
// function funcname funcbody | // function funcname funcbody |
// attributes function funcname funcbody |
// local function Name funcbody | // local function Name funcbody |
// local attributes function Name funcbody |
// local namelist [`=' explist] // local namelist [`=' explist]
// laststat ::= return [explist] | break // laststat ::= return [explist] | break
AstStat* Parser::parseStat() AstStat* Parser::parseStat()
@ -299,13 +310,16 @@ AstStat* Parser::parseStat()
case Lexeme::ReservedRepeat: case Lexeme::ReservedRepeat:
return parseRepeat(); return parseRepeat();
case Lexeme::ReservedFunction: case Lexeme::ReservedFunction:
return parseFunctionStat(); return parseFunctionStat(AstArray<AstAttr*>({nullptr, 0}));
case Lexeme::ReservedLocal: case Lexeme::ReservedLocal:
return parseLocal(); return parseLocal(AstArray<AstAttr*>({nullptr, 0}));
case Lexeme::ReservedReturn: case Lexeme::ReservedReturn:
return parseReturn(); return parseReturn();
case Lexeme::ReservedBreak: case Lexeme::ReservedBreak:
return parseBreak(); return parseBreak();
case Lexeme::Attribute:
if (FFlag::LuauAttributeSyntax)
return parseAttributeStat();
default:; default:;
} }
@ -343,7 +357,7 @@ AstStat* Parser::parseStat()
if (options.allowDeclarationSyntax) if (options.allowDeclarationSyntax)
{ {
if (ident == "declare") if (ident == "declare")
return parseDeclaration(expr->location); return parseDeclaration(expr->location, AstArray<AstAttr*>({nullptr, 0}));
} }
// skip unexpected symbol if lexer couldn't advance at all (statements are parsed in a loop) // skip unexpected symbol if lexer couldn't advance at all (statements are parsed in a loop)
@ -652,7 +666,7 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug
} }
// function funcname funcbody // function funcname funcbody
AstStat* Parser::parseFunctionStat() AstStat* Parser::parseFunctionStat(const AstArray<AstAttr*>& attributes)
{ {
Location start = lexer.current().location; Location start = lexer.current().location;
@ -665,16 +679,125 @@ AstStat* Parser::parseFunctionStat()
matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; matchRecoveryStopOnToken[Lexeme::ReservedEnd]++;
AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr).first; AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr, attributes).first;
matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; matchRecoveryStopOnToken[Lexeme::ReservedEnd]--;
return allocator.alloc<AstStatFunction>(Location(start, body->location), expr, body); return allocator.alloc<AstStatFunction>(Location(start, body->location), expr, body);
} }
std::pair<bool, AstAttr::Type> Parser::validateAttribute(const char* attributeName, const TempVector<AstAttr*>& attributes)
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
AstAttr::Type type;
// check if the attribute name is valid
bool found = false;
for (int i = 0; kAttributeEntries[i].name; ++i)
{
found = !strcmp(attributeName, kAttributeEntries[i].name);
if (found)
{
type = kAttributeEntries[i].type;
break;
}
}
if (!found)
{
if (strlen(attributeName) == 1)
report(lexer.current().location, "Attribute name is missing");
else
report(lexer.current().location, "Invalid attribute '%s'", attributeName);
}
else
{
// check that attribute is not duplicated
for (const AstAttr* attr : attributes)
{
if (attr->type == type)
{
report(lexer.current().location, "Cannot duplicate attribute '%s'", attributeName);
}
}
}
return {found, type};
}
// attribute ::= '@' NAME
void Parser::parseAttribute(TempVector<AstAttr*>& attributes)
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
LUAU_ASSERT(lexer.current().type == Lexeme::Type::Attribute);
Location loc = lexer.current().location;
const char* name = lexer.current().name;
const auto [found, type] = validateAttribute(name, attributes);
nextLexeme();
if (found)
attributes.push_back(allocator.alloc<AstAttr>(loc, type));
}
// attributes ::= {attribute}
AstArray<AstAttr*> Parser::parseAttributes()
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
Lexeme::Type type = lexer.current().type;
LUAU_ASSERT(type == Lexeme::Attribute);
TempVector<AstAttr*> attributes(scratchAttr);
while (lexer.current().type == Lexeme::Attribute)
parseAttribute(attributes);
return copy(attributes);
}
// attributes local function Name funcbody
// attributes function funcname funcbody
// attributes `declare function' Name`(' [parlist] `)' [`:` Type]
// declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}'
AstStat* Parser::parseAttributeStat()
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
AstArray<AstAttr*> attributes = Parser::parseAttributes();
Lexeme::Type type = lexer.current().type;
switch (type)
{
case Lexeme::Type::ReservedFunction:
return parseFunctionStat(attributes);
case Lexeme::Type::ReservedLocal:
return parseLocal(attributes);
case Lexeme::Type::Name:
if (options.allowDeclarationSyntax && !strcmp("declare", lexer.current().data))
{
AstExpr* expr = parsePrimaryExpr(/* asStatement= */ true);
return parseDeclaration(expr->location, attributes);
}
default:
return reportStatError(lexer.current().location, {}, {},
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got %s intead",
lexer.current().toString().c_str());
}
}
// local function Name funcbody | // local function Name funcbody |
// local bindinglist [`=' explist] // local bindinglist [`=' explist]
AstStat* Parser::parseLocal() AstStat* Parser::parseLocal(const AstArray<AstAttr*>& attributes)
{ {
Location start = lexer.current().location; Location start = lexer.current().location;
@ -694,7 +817,7 @@ AstStat* Parser::parseLocal()
matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; matchRecoveryStopOnToken[Lexeme::ReservedEnd]++;
auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name); auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name, attributes);
matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; matchRecoveryStopOnToken[Lexeme::ReservedEnd]--;
@ -704,6 +827,12 @@ AstStat* Parser::parseLocal()
} }
else else
{ {
if (FFlag::LuauAttributeSyntax && attributes.size != 0)
{
return reportStatError(lexer.current().location, {}, {}, "Expected 'function' after local declaration with attribute, but got %s intead",
lexer.current().toString().c_str());
}
matchRecoveryStopOnToken['=']++; matchRecoveryStopOnToken['=']++;
TempVector<Binding> names(scratchBinding); TempVector<Binding> names(scratchBinding);
@ -831,18 +960,17 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod()
return AstDeclaredClassProp{fnName.name, fnType, true}; return AstDeclaredClassProp{fnName.name, fnType, true};
} }
AstStat* Parser::parseDeclaration(const Location& start) AstStat* Parser::parseDeclaration(const Location& start, const AstArray<AstAttr*>& attributes)
{ {
// `declare` token is already parsed at this point // `declare` token is already parsed at this point
if (FFlag::LuauAttributeSyntax && (attributes.size != 0) && (lexer.current().type != Lexeme::ReservedFunction))
return reportStatError(lexer.current().location, {}, {}, "Expected a function type declaration after attribute, but got %s intead",
lexer.current().toString().c_str());
if (lexer.current().type == Lexeme::ReservedFunction) if (lexer.current().type == Lexeme::ReservedFunction)
{ {
nextLexeme(); nextLexeme();
bool checkedFunction = false;
if (lexer.current().type == Lexeme::ReservedChecked)
{
checkedFunction = true;
nextLexeme();
}
Name globalName = parseName("global function name"); Name globalName = parseName("global function name");
auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false);
@ -880,8 +1008,8 @@ AstStat* Parser::parseDeclaration(const Location& start)
if (vararg && !varargAnnotation) if (vararg && !varargAnnotation)
return reportStatError(Location(start, end), {}, {}, "All declaration parameters must be annotated"); return reportStatError(Location(start, end), {}, {}, "All declaration parameters must be annotated");
return allocator.alloc<AstStatDeclareFunction>(Location(start, end), globalName.name, generics, genericPacks, return allocator.alloc<AstStatDeclareFunction>(Location(start, end), attributes, globalName.name, generics, genericPacks,
AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes, checkedFunction); AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes);
} }
else if (AstName(lexer.current().name) == "class") else if (AstName(lexer.current().name) == "class")
{ {
@ -1035,7 +1163,7 @@ std::pair<AstLocal*, AstArray<AstLocal*>> Parser::prepareFunctionArguments(const
// funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end // funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end
// parlist ::= bindinglist [`,' `...'] | `...' // parlist ::= bindinglist [`,' `...'] | `...'
std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody( std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody(
bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName) bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray<AstAttr*>& attributes)
{ {
Location start = matchFunction.location; Location start = matchFunction.location;
@ -1087,7 +1215,7 @@ std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody(
bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction); bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction);
body->hasEnd = hasEnd; body->hasEnd = hasEnd;
return {allocator.alloc<AstExprFunction>(Location(start, end), generics, genericPacks, self, vars, vararg, varargLocation, body, return {allocator.alloc<AstExprFunction>(Location(start, end), attributes, generics, genericPacks, self, vars, vararg, varargLocation, body,
functionStack.size(), debugname, typelist, varargAnnotation, argLocation), functionStack.size(), debugname, typelist, varargAnnotation, argLocation),
funLocal}; funLocal};
} }
@ -1296,7 +1424,7 @@ std::pair<Location, AstTypeList> Parser::parseReturnType()
return {location, AstTypeList{copy(result), varargAnnotation}}; return {location, AstTypeList{copy(result), varargAnnotation}};
} }
AstType* tail = parseFunctionTypeTail(begin, {}, {}, copy(result), copy(resultNames), varargAnnotation); AstType* tail = parseFunctionTypeTail(begin, {nullptr, 0}, {}, {}, copy(result), copy(resultNames), varargAnnotation);
return {Location{location, tail->location}, AstTypeList{copy(&tail, 1), varargAnnotation}}; return {Location{location, tail->location}, AstTypeList{copy(&tail, 1), varargAnnotation}};
} }
@ -1435,7 +1563,7 @@ AstType* Parser::parseTableType(bool inDeclarationContext)
// ReturnType ::= Type | `(' TypeList `)' // ReturnType ::= Type | `(' TypeList `)'
// FunctionType ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType // FunctionType ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType
AstTypeOrPack Parser::parseFunctionType(bool allowPack, bool isCheckedFunction) AstTypeOrPack Parser::parseFunctionType(bool allowPack, const AstArray<AstAttr*>& attributes)
{ {
incrementRecursionCounter("type annotation"); incrementRecursionCounter("type annotation");
@ -1483,11 +1611,12 @@ AstTypeOrPack Parser::parseFunctionType(bool allowPack, bool isCheckedFunction)
AstArray<std::optional<AstArgumentName>> paramNames = copy(names); AstArray<std::optional<AstArgumentName>> paramNames = copy(names);
return {parseFunctionTypeTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation, isCheckedFunction), {}}; return {parseFunctionTypeTail(begin, attributes, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}};
} }
AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray<AstGenericType> generics, AstArray<AstGenericTypePack> genericPacks, AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, const AstArray<AstAttr*>& attributes, AstArray<AstGenericType> generics,
AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames, AstTypePack* varargAnnotation, bool isCheckedFunction) AstArray<AstGenericTypePack> genericPacks, AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames,
AstTypePack* varargAnnotation)
{ {
incrementRecursionCounter("type annotation"); incrementRecursionCounter("type annotation");
@ -1512,7 +1641,7 @@ AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray<AstGenericT
AstTypeList paramTypes = AstTypeList{params, varargAnnotation}; AstTypeList paramTypes = AstTypeList{params, varargAnnotation};
return allocator.alloc<AstTypeFunction>( return allocator.alloc<AstTypeFunction>(
Location(begin.location, endLocation), generics, genericPacks, paramTypes, paramNames, returnTypeList, isCheckedFunction); Location(begin.location, endLocation), attributes, generics, genericPacks, paramTypes, paramNames, returnTypeList);
} }
// Type ::= // Type ::=
@ -1666,7 +1795,21 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext)
Location start = lexer.current().location; Location start = lexer.current().location;
if (lexer.current().type == Lexeme::ReservedNil) AstArray<AstAttr*> attributes{nullptr, 0};
if (lexer.current().type == Lexeme::Attribute)
{
if (!inDeclarationContext || !FFlag::LuauAttributeSyntax)
{
return {reportTypeError(start, {}, "attributes are not allowed in declaration context")};
}
else
{
attributes = Parser::parseAttributes();
return parseFunctionType(allowPack, attributes);
}
}
else if (lexer.current().type == Lexeme::ReservedNil)
{ {
nextLexeme(); nextLexeme();
return {allocator.alloc<AstTypeReference>(start, std::nullopt, nameNil, std::nullopt, start), {}}; return {allocator.alloc<AstTypeReference>(start, std::nullopt, nameNil, std::nullopt, start), {}};
@ -1754,14 +1897,9 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext)
{ {
return {parseTableType(/* inDeclarationContext */ inDeclarationContext), {}}; return {parseTableType(/* inDeclarationContext */ inDeclarationContext), {}};
} }
else if (inDeclarationContext && lexer.current().type == Lexeme::ReservedChecked)
{
nextLexeme();
return parseFunctionType(allowPack, /* isCheckedFunction */ true);
}
else if (lexer.current().type == '(' || lexer.current().type == '<') else if (lexer.current().type == '(' || lexer.current().type == '<')
{ {
return parseFunctionType(allowPack); return parseFunctionType(allowPack, AstArray<AstAttr*>({nullptr, 0}));
} }
else if (lexer.current().type == Lexeme::ReservedFunction) else if (lexer.current().type == Lexeme::ReservedFunction)
{ {
@ -2259,7 +2397,7 @@ AstExpr* Parser::parseSimpleExpr()
Lexeme matchFunction = lexer.current(); Lexeme matchFunction = lexer.current();
nextLexeme(); nextLexeme();
return parseFunctionBody(false, matchFunction, AstName(), nullptr).first; return parseFunctionBody(false, matchFunction, AstName(), nullptr, AstArray<AstAttr*>({nullptr, 0})).first;
} }
else if (lexer.current().type == Lexeme::Number) else if (lexer.current().type == Lexeme::Number)
{ {
@ -2689,7 +2827,7 @@ std::optional<AstArray<char>> Parser::parseCharArray()
LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString || LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString ||
lexer.current().type == Lexeme::InterpStringSimple); lexer.current().type == Lexeme::InterpStringSimple);
scratchData.assign(lexer.current().data, lexer.current().length); scratchData.assign(lexer.current().data, lexer.current().getLength());
if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple) if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple)
{ {
@ -2734,7 +2872,7 @@ AstExpr* Parser::parseInterpString()
endLocation = currentLexeme.location; endLocation = currentLexeme.location;
scratchData.assign(currentLexeme.data, currentLexeme.length); scratchData.assign(currentLexeme.data, currentLexeme.getLength());
if (!Lexer::fixupQuotedString(scratchData)) if (!Lexer::fixupQuotedString(scratchData))
{ {
@ -2807,7 +2945,7 @@ AstExpr* Parser::parseNumber()
{ {
Location start = lexer.current().location; Location start = lexer.current().location;
scratchData.assign(lexer.current().data, lexer.current().length); scratchData.assign(lexer.current().data, lexer.current().getLength());
// Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al // Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al
if (scratchData.find('_') != std::string::npos) if (scratchData.find('_') != std::string::npos)
@ -3162,11 +3300,11 @@ void Parser::nextLexeme()
return; return;
// Comments starting with ! are called "hot comments" and contain directives for type checking / linting / compiling // Comments starting with ! are called "hot comments" and contain directives for type checking / linting / compiling
if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!') if (lexeme.type == Lexeme::Comment && lexeme.getLength() && lexeme.data[0] == '!')
{ {
const char* text = lexeme.data; const char* text = lexeme.data;
unsigned int end = lexeme.length; unsigned int end = lexeme.getLength();
while (end > 0 && isSpace(text[end - 1])) while (end > 0 && isSpace(text[end - 1]))
--end; --end;

View file

@ -73,12 +73,39 @@ struct CompilationResult
}; };
struct IrBuilder; struct IrBuilder;
struct IrOp;
using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength); using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength);
using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos);
using HostVectorNamecallHandler = bool (*)( using HostVectorNamecallHandler = bool (*)(
IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos); IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos);
enum class HostMetamethod
{
Add,
Sub,
Mul,
Div,
Idiv,
Mod,
Pow,
Minus,
Equal,
LessThan,
LessEqual,
Length,
Concat,
};
using HostUserdataOperationBytecodeType = uint8_t (*)(uint8_t type, const char* member, size_t memberLength);
using HostUserdataMetamethodBytecodeType = uint8_t (*)(uint8_t lhsTy, uint8_t rhsTy, HostMetamethod method);
using HostUserdataAccessHandler = bool (*)(
IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos);
using HostUserdataMetamethodHandler = bool (*)(
IrBuilder& builder, uint8_t lhsTy, uint8_t rhsTy, int resultReg, IrOp lhs, IrOp rhs, HostMetamethod method, int pcpos);
using HostUserdataNamecallHandler = bool (*)(
IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos);
struct HostIrHooks struct HostIrHooks
{ {
// Suggest result type of a vector field access // Suggest result type of a vector field access
@ -97,6 +124,34 @@ struct HostIrHooks
// All other arguments can be of any type // All other arguments can be of any type
// Guards should take a VM exit to 'pcpos' // Guards should take a VM exit to 'pcpos'
HostVectorNamecallHandler vectorNamecall = nullptr; HostVectorNamecallHandler vectorNamecall = nullptr;
// Suggest result type of a userdata field access
HostUserdataOperationBytecodeType userdataAccessBytecodeType = nullptr;
// Suggest result type of a metamethod call
HostUserdataMetamethodBytecodeType userdataMetamethodBytecodeType = nullptr;
// Suggest result type of a userdata namecall
HostUserdataOperationBytecodeType userdataNamecallBytecodeType = nullptr;
// Handle userdata value field access
// 'sourceReg' is guaranteed to be a userdata, but tag has to be checked
// Write to 'resultReg' might invalidate 'sourceReg'
// Guards should take a VM exit to 'pcpos'
HostUserdataAccessHandler userdataAccess = nullptr;
// Handle metamethod operation on a userdata value
// 'lhs' and 'rhs' operands can be VM registers of constants
// Operand types have to be checked and userdata operand tags have to be checked
// Write to 'resultReg' might invalidate source operands
// Guards should take a VM exit to 'pcpos'
HostUserdataMetamethodHandler userdataMetamethod = nullptr;
// Handle namecall performed on a userdata value
// 'sourceReg' (self argument) is guaranteed to be a userdata, but tag has to be checked
// All other arguments can be of any type
// Guards should take a VM exit to 'pcpos'
HostUserdataNamecallHandler userdataNamecall = nullptr;
}; };
struct CompilationOptions struct CompilationOptions

View file

@ -290,6 +290,11 @@ enum class IrCmd : uint8_t
// C: block // C: block
TRY_CALL_FASTGETTM, TRY_CALL_FASTGETTM,
// Create new tagged userdata
// A: int (size)
// B: int (tag)
NEW_USERDATA,
// Convert integer into a double number // Convert integer into a double number
// A: int // A: int
INT_TO_NUM, INT_TO_NUM,
@ -460,6 +465,13 @@ enum class IrCmd : uint8_t
// When undef is specified instead of a block, execution is aborted on check failure // When undef is specified instead of a block, execution is aborted on check failure
CHECK_BUFFER_LEN, CHECK_BUFFER_LEN,
// Guard against userdata tag mismatch
// A: pointer (userdata)
// B: int (tag)
// C: block/vmexit/undef
// When undef is specified instead of a block, execution is aborted on check failure
CHECK_USERDATA_TAG,
// Special operations // Special operations
// Check interrupt handler // Check interrupt handler

View file

@ -11,6 +11,7 @@ namespace CodeGen
{ {
struct IrBuilder; struct IrBuilder;
enum class HostMetamethod;
inline bool isJumpD(LuauOpcode op) inline bool isJumpD(LuauOpcode op)
{ {
@ -129,6 +130,7 @@ inline bool isNonTerminatingJump(IrCmd cmd)
case IrCmd::CHECK_NODE_NO_NEXT: case IrCmd::CHECK_NODE_NO_NEXT:
case IrCmd::CHECK_NODE_VALUE: case IrCmd::CHECK_NODE_VALUE:
case IrCmd::CHECK_BUFFER_LEN: case IrCmd::CHECK_BUFFER_LEN:
case IrCmd::CHECK_USERDATA_TAG:
return true; return true;
default: default:
break; break;
@ -182,6 +184,7 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::DUP_TABLE: case IrCmd::DUP_TABLE:
case IrCmd::TRY_NUM_TO_INDEX: case IrCmd::TRY_NUM_TO_INDEX:
case IrCmd::TRY_CALL_FASTGETTM: case IrCmd::TRY_CALL_FASTGETTM:
case IrCmd::NEW_USERDATA:
case IrCmd::INT_TO_NUM: case IrCmd::INT_TO_NUM:
case IrCmd::UINT_TO_NUM: case IrCmd::UINT_TO_NUM:
case IrCmd::NUM_TO_INT: case IrCmd::NUM_TO_INT:
@ -245,6 +248,8 @@ bool isGCO(uint8_t tag);
bool isUserdataBytecodeType(uint8_t ty); bool isUserdataBytecodeType(uint8_t ty);
bool isCustomUserdataBytecodeType(uint8_t ty); bool isCustomUserdataBytecodeType(uint8_t ty);
HostMetamethod tmToHostMetamethod(int tm);
// Manually add or remove use of an operand // Manually add or remove use of an operand
void addUse(IrFunction& function, IrOp op); void addUse(IrFunction& function, IrOp op);
void removeUse(IrFunction& function, IrOp op); void removeUse(IrFunction& function, IrOp op);

View file

@ -16,6 +16,7 @@ LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo loa
LUAU_FASTFLAGVARIABLE(LuauCodegenTypeInfo, false) // New analysis is flagged separately LUAU_FASTFLAGVARIABLE(LuauCodegenTypeInfo, false) // New analysis is flagged separately
LUAU_FASTFLAGVARIABLE(LuauCodegenAnalyzeHostVectorOps, false) LUAU_FASTFLAGVARIABLE(LuauCodegenAnalyzeHostVectorOps, false)
LUAU_FASTFLAGVARIABLE(LuauCodegenLoadTypeUpvalCheck, false) LUAU_FASTFLAGVARIABLE(LuauCodegenLoadTypeUpvalCheck, false)
LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOps, false)
namespace Luau namespace Luau
{ {
@ -546,6 +547,49 @@ static void applyBuiltinCall(int bfid, BytecodeTypes& types)
} }
} }
static HostMetamethod opcodeToHostMetamethod(LuauOpcode op)
{
switch (op)
{
case LOP_ADD:
return HostMetamethod::Add;
case LOP_SUB:
return HostMetamethod::Sub;
case LOP_MUL:
return HostMetamethod::Mul;
case LOP_DIV:
return HostMetamethod::Div;
case LOP_IDIV:
return HostMetamethod::Idiv;
case LOP_MOD:
return HostMetamethod::Mod;
case LOP_POW:
return HostMetamethod::Pow;
case LOP_ADDK:
return HostMetamethod::Add;
case LOP_SUBK:
return HostMetamethod::Sub;
case LOP_MULK:
return HostMetamethod::Mul;
case LOP_DIVK:
return HostMetamethod::Div;
case LOP_IDIVK:
return HostMetamethod::Idiv;
case LOP_MODK:
return HostMetamethod::Mod;
case LOP_POWK:
return HostMetamethod::Pow;
case LOP_SUBRK:
return HostMetamethod::Sub;
case LOP_DIVRK:
return HostMetamethod::Div;
default:
CODEGEN_ASSERT(!"opcode is not assigned to a host metamethod");
}
return HostMetamethod::Add;
}
void buildBytecodeBlocks(IrFunction& function, const std::vector<uint8_t>& jumpTargets) void buildBytecodeBlocks(IrFunction& function, const std::vector<uint8_t>& jumpTargets)
{ {
Proto* proto = function.proto; Proto* proto = function.proto;
@ -760,22 +804,50 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
regTags[ra] = LBC_TYPE_ANY; regTags[ra] = LBC_TYPE_ANY;
if (bcType.a == LBC_TYPE_VECTOR) if (FFlag::LuauCodegenUserdataOps)
{ {
TString* str = gco2ts(function.proto->k[kc].value.gc); TString* str = gco2ts(function.proto->k[kc].value.gc);
const char* field = getstr(str); const char* field = getstr(str);
if (str->len == 1) if (bcType.a == LBC_TYPE_VECTOR)
{ {
// Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z" if (str->len == 1)
char ch = field[0] | ' '; {
// Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z"
char ch = field[0] | ' ';
if (ch == 'x' || ch == 'y' || ch == 'z') if (ch == 'x' || ch == 'y' || ch == 'z')
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
}
if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType)
regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len);
} }
else if (isCustomUserdataBytecodeType(bcType.a))
{
if (regTags[ra] == LBC_TYPE_ANY && hostHooks.userdataAccessBytecodeType)
regTags[ra] = hostHooks.userdataAccessBytecodeType(bcType.a, field, str->len);
}
}
else
{
if (bcType.a == LBC_TYPE_VECTOR)
{
TString* str = gco2ts(function.proto->k[kc].value.gc);
const char* field = getstr(str);
if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) if (str->len == 1)
regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len); {
// Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z"
char ch = field[0] | ' ';
if (ch == 'x' || ch == 'y' || ch == 'z')
regTags[ra] = LBC_TYPE_NUMBER;
}
if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType)
regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len);
}
} }
bcType.result = regTags[ra]; bcType.result = regTags[ra];
@ -812,6 +884,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -841,6 +916,11 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
} }
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
{
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
}
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -859,6 +939,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER) if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -879,6 +962,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -908,6 +994,11 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
} }
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
{
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
}
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -926,6 +1017,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER) if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -945,6 +1039,9 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -972,6 +1069,11 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
} }
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
{
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
}
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -1000,6 +1102,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR) else if (bcType.a == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && isCustomUserdataBytecodeType(bcType.a))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, LBC_TYPE_ANY, HostMetamethod::Minus);
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -1140,12 +1244,25 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
bcType.result = LBC_TYPE_FUNCTION; bcType.result = LBC_TYPE_FUNCTION;
if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) if (FFlag::LuauCodegenUserdataOps)
{ {
TString* str = gco2ts(function.proto->k[kc].value.gc); TString* str = gco2ts(function.proto->k[kc].value.gc);
const char* field = getstr(str); const char* field = getstr(str);
knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len)); if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType)
knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len));
else if (isCustomUserdataBytecodeType(bcType.a) && hostHooks.userdataNamecallBytecodeType)
knownNextCallResult = LuauBytecodeType(hostHooks.userdataNamecallBytecodeType(bcType.a, field, str->len));
}
else
{
if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType)
{
TString* str = gco2ts(function.proto->k[kc].value.gc);
const char* field = getstr(str);
knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len));
}
} }
} }
break; break;

View file

@ -258,39 +258,6 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde
return locations; return locations;
} }
bool initHeaderFunctions(NativeState& data)
{
AssemblyBuilderA64 build(/* logText= */ false);
UnwindBuilder& unwind = *data.unwindBuilder.get();
unwind.startInfo(UnwindBuilder::A64);
EntryLocations entryLocations = buildEntryFunction(build, unwind);
build.finalize();
unwind.finishInfo();
CODEGEN_ASSERT(build.data.empty());
uint8_t* codeStart = nullptr;
if (!data.codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast<const uint8_t*>(build.code.data()),
int(build.code.size() * sizeof(build.code[0])), data.gateData, data.gateDataSize, codeStart))
{
CODEGEN_ASSERT(!"Failed to create entry function");
return false;
}
// Set the offset at the begining so that functions in new blocks will not overlay the locations
// specified by the unwind information of the entry function
unwind.setBeginOffset(build.getLabelOffset(entryLocations.prologueEnd));
data.context.gateEntry = codeStart + build.getLabelOffset(entryLocations.start);
data.context.gateExit = codeStart + build.getLabelOffset(entryLocations.epilogueStart);
return true;
}
bool initHeaderFunctions(BaseCodeGenContext& codeGenContext) bool initHeaderFunctions(BaseCodeGenContext& codeGenContext)
{ {
AssemblyBuilderA64 build(/* logText= */ false); AssemblyBuilderA64 build(/* logText= */ false);

View file

@ -7,7 +7,6 @@ namespace CodeGen
{ {
class BaseCodeGenContext; class BaseCodeGenContext;
struct NativeState;
struct ModuleHelpers; struct ModuleHelpers;
namespace A64 namespace A64
@ -15,7 +14,6 @@ namespace A64
class AssemblyBuilderA64; class AssemblyBuilderA64;
bool initHeaderFunctions(NativeState& data);
bool initHeaderFunctions(BaseCodeGenContext& codeGenContext); bool initHeaderFunctions(BaseCodeGenContext& codeGenContext);
void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers); void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers);

View file

@ -14,8 +14,8 @@
LUAU_FASTFLAGVARIABLE(LuauCodegenCheckNullContext, false) LUAU_FASTFLAGVARIABLE(LuauCodegenCheckNullContext, false)
LUAU_FASTINT(LuauCodeGenBlockSize) LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024)
LUAU_FASTINT(LuauCodeGenMaxTotalSize) LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024)
namespace Luau namespace Luau
{ {

View file

@ -14,6 +14,7 @@
#include "lstate.h" #include "lstate.h"
#include "lstring.h" #include "lstring.h"
#include "ltable.h" #include "ltable.h"
#include "ludata.h"
#include <string.h> #include <string.h>
@ -219,6 +220,20 @@ void callEpilogC(lua_State* L, int nresults, int n)
L->top = (nresults == LUA_MULTRET) ? res : cip->top; L->top = (nresults == LUA_MULTRET) ? res : cip->top;
} }
Udata* newUserdata(lua_State* L, size_t s, int tag)
{
Udata* u = luaU_newudata(L, s, tag);
if (Table* h = L->global->udatamt[tag])
{
u->metatable = h;
luaC_objbarrier(L, u, h);
}
return u;
}
// Extracted as-is from lvmexecute.cpp with the exception of control flow (reentry) and removed interrupts/savedpc // Extracted as-is from lvmexecute.cpp with the exception of control flow (reentry) and removed interrupts/savedpc
Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults) Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults)
{ {

View file

@ -17,6 +17,8 @@ void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc);
Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults); Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults);
void callEpilogC(lua_State* L, int nresults, int n); void callEpilogC(lua_State* L, int nresults, int n);
Udata* newUserdata(lua_State* L, size_t s, int tag);
#define CALL_FALLBACK_YIELD 1 #define CALL_FALLBACK_YIELD 1
Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults); Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults);

View file

@ -186,39 +186,6 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde
return locations; return locations;
} }
bool initHeaderFunctions(NativeState& data)
{
AssemblyBuilderX64 build(/* logText= */ false);
UnwindBuilder& unwind = *data.unwindBuilder.get();
unwind.startInfo(UnwindBuilder::X64);
EntryLocations entryLocations = buildEntryFunction(build, unwind);
build.finalize();
unwind.finishInfo();
CODEGEN_ASSERT(build.data.empty());
uint8_t* codeStart = nullptr;
if (!data.codeAllocator.allocate(
build.data.data(), int(build.data.size()), build.code.data(), int(build.code.size()), data.gateData, data.gateDataSize, codeStart))
{
CODEGEN_ASSERT(!"Failed to create entry function");
return false;
}
// Set the offset at the begining so that functions in new blocks will not overlay the locations
// specified by the unwind information of the entry function
unwind.setBeginOffset(build.getLabelOffset(entryLocations.prologueEnd));
data.context.gateEntry = codeStart + build.getLabelOffset(entryLocations.start);
data.context.gateExit = codeStart + build.getLabelOffset(entryLocations.epilogueStart);
return true;
}
bool initHeaderFunctions(BaseCodeGenContext& codeGenContext) bool initHeaderFunctions(BaseCodeGenContext& codeGenContext)
{ {
AssemblyBuilderX64 build(/* logText= */ false); AssemblyBuilderX64 build(/* logText= */ false);

View file

@ -7,7 +7,6 @@ namespace CodeGen
{ {
class BaseCodeGenContext; class BaseCodeGenContext;
struct NativeState;
struct ModuleHelpers; struct ModuleHelpers;
namespace X64 namespace X64
@ -15,7 +14,6 @@ namespace X64
class AssemblyBuilderX64; class AssemblyBuilderX64;
bool initHeaderFunctions(NativeState& data);
bool initHeaderFunctions(BaseCodeGenContext& codeGenContext); bool initHeaderFunctions(BaseCodeGenContext& codeGenContext);
void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers); void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers);

View file

@ -22,8 +22,6 @@ namespace Luau
namespace CodeGen namespace CodeGen
{ {
struct NativeState;
namespace A64 namespace A64
{ {

View file

@ -26,7 +26,6 @@ namespace CodeGen
{ {
enum class IrCondition : uint8_t; enum class IrCondition : uint8_t;
struct NativeState;
struct IrOp; struct IrOp;
namespace X64 namespace X64

View file

@ -199,6 +199,8 @@ const char* getCmdName(IrCmd cmd)
return "TRY_NUM_TO_INDEX"; return "TRY_NUM_TO_INDEX";
case IrCmd::TRY_CALL_FASTGETTM: case IrCmd::TRY_CALL_FASTGETTM:
return "TRY_CALL_FASTGETTM"; return "TRY_CALL_FASTGETTM";
case IrCmd::NEW_USERDATA:
return "NEW_USERDATA";
case IrCmd::INT_TO_NUM: case IrCmd::INT_TO_NUM:
return "INT_TO_NUM"; return "INT_TO_NUM";
case IrCmd::UINT_TO_NUM: case IrCmd::UINT_TO_NUM:
@ -257,6 +259,8 @@ const char* getCmdName(IrCmd cmd)
return "CHECK_NODE_VALUE"; return "CHECK_NODE_VALUE";
case IrCmd::CHECK_BUFFER_LEN: case IrCmd::CHECK_BUFFER_LEN:
return "CHECK_BUFFER_LEN"; return "CHECK_BUFFER_LEN";
case IrCmd::CHECK_USERDATA_TAG:
return "CHECK_USERDATA_TAG";
case IrCmd::INTERRUPT: case IrCmd::INTERRUPT:
return "INTERRUPT"; return "INTERRUPT";
case IrCmd::CHECK_GC: case IrCmd::CHECK_GC:

View file

@ -13,6 +13,9 @@
LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5)
LUAU_FASTFLAG(LuauCodegenSplitDoarith) LUAU_FASTFLAG(LuauCodegenSplitDoarith)
LUAU_FASTFLAG(LuauCodegenUserdataOps)
LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataAlloc, false)
LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOpsFixA64, false)
namespace Luau namespace Luau
{ {
@ -1083,6 +1086,19 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
inst.regA64 = regs.takeReg(x0, index); inst.regA64 = regs.takeReg(x0, index);
break; break;
} }
case IrCmd::NEW_USERDATA:
{
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc);
regs.spill(build, index);
build.mov(x0, rState);
build.mov(x1, intOp(inst.a));
build.mov(x2, intOp(inst.b));
build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, newUserdata)));
build.blr(x3);
inst.regA64 = regs.takeReg(x0, index);
break;
}
case IrCmd::INT_TO_NUM: case IrCmd::INT_TO_NUM:
{ {
inst.regA64 = regs.allocReg(KindA64::d, index); inst.regA64 = regs.allocReg(KindA64::d, index);
@ -1677,6 +1693,24 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
finalizeTargetLabel(inst.d, fresh); finalizeTargetLabel(inst.d, fresh);
break; break;
} }
case IrCmd::CHECK_USERDATA_TAG:
{
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps);
Label fresh; // used when guard aborts execution or jumps to a VM exit
Label& fail = getTargetLabel(inst.c, fresh);
RegisterA64 temp = regs.allocTemp(KindA64::w);
build.ldrb(temp, mem(regOp(inst.a), offsetof(Udata, tag)));
if (FFlag::LuauCodegenUserdataOpsFixA64)
build.cmp(temp, intOp(inst.b));
else
build.cmp(temp, tagOp(inst.b));
build.b(ConditionA64::NotEqual, fail);
finalizeTargetLabel(inst.c, fresh);
break;
}
case IrCmd::INTERRUPT: case IrCmd::INTERRUPT:
{ {
regs.spill(build, index); regs.spill(build, index);
@ -2308,7 +2342,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READI8: case IrCmd::BUFFER_READI8:
{ {
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldrsb(inst.regA64, addr); build.ldrsb(inst.regA64, addr);
break; break;
@ -2317,7 +2351,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READU8: case IrCmd::BUFFER_READU8:
{ {
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldrb(inst.regA64, addr); build.ldrb(inst.regA64, addr);
break; break;
@ -2326,7 +2360,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_WRITEI8: case IrCmd::BUFFER_WRITEI8:
{ {
RegisterA64 temp = tempInt(inst.c); RegisterA64 temp = tempInt(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.strb(temp, addr); build.strb(temp, addr);
break; break;
@ -2335,7 +2369,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READI16: case IrCmd::BUFFER_READI16:
{ {
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldrsh(inst.regA64, addr); build.ldrsh(inst.regA64, addr);
break; break;
@ -2344,7 +2378,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READU16: case IrCmd::BUFFER_READU16:
{ {
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldrh(inst.regA64, addr); build.ldrh(inst.regA64, addr);
break; break;
@ -2353,7 +2387,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_WRITEI16: case IrCmd::BUFFER_WRITEI16:
{ {
RegisterA64 temp = tempInt(inst.c); RegisterA64 temp = tempInt(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.strh(temp, addr); build.strh(temp, addr);
break; break;
@ -2362,7 +2396,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READI32: case IrCmd::BUFFER_READI32:
{ {
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldr(inst.regA64, addr); build.ldr(inst.regA64, addr);
break; break;
@ -2371,7 +2405,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_WRITEI32: case IrCmd::BUFFER_WRITEI32:
{ {
RegisterA64 temp = tempInt(inst.c); RegisterA64 temp = tempInt(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.str(temp, addr); build.str(temp, addr);
break; break;
@ -2381,7 +2415,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{ {
inst.regA64 = regs.allocReg(KindA64::d, index); inst.regA64 = regs.allocReg(KindA64::d, index);
RegisterA64 temp = castReg(KindA64::s, inst.regA64); // safe to alias a fresh register RegisterA64 temp = castReg(KindA64::s, inst.regA64); // safe to alias a fresh register
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldr(temp, addr); build.ldr(temp, addr);
build.fcvt(inst.regA64, temp); build.fcvt(inst.regA64, temp);
@ -2392,7 +2426,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{ {
RegisterA64 temp1 = tempDouble(inst.c); RegisterA64 temp1 = tempDouble(inst.c);
RegisterA64 temp2 = regs.allocTemp(KindA64::s); RegisterA64 temp2 = regs.allocTemp(KindA64::s);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.fcvt(temp2, temp1); build.fcvt(temp2, temp1);
build.str(temp2, addr); build.str(temp2, addr);
@ -2402,7 +2436,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READF64: case IrCmd::BUFFER_READF64:
{ {
inst.regA64 = regs.allocReg(KindA64::d, index); inst.regA64 = regs.allocReg(KindA64::d, index);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldr(inst.regA64, addr); build.ldr(inst.regA64, addr);
break; break;
@ -2411,7 +2445,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_WRITEF64: case IrCmd::BUFFER_WRITEF64:
{ {
RegisterA64 temp = tempDouble(inst.c); RegisterA64 temp = tempDouble(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.str(temp, addr); build.str(temp, addr);
break; break;
@ -2639,32 +2673,68 @@ AddressA64 IrLoweringA64::tempAddr(IrOp op, int offset)
} }
} }
AddressA64 IrLoweringA64::tempAddrBuffer(IrOp bufferOp, IrOp indexOp) AddressA64 IrLoweringA64::tempAddrBuffer(IrOp bufferOp, IrOp indexOp, uint8_t tag)
{ {
if (indexOp.kind == IrOpKind::Inst) if (FFlag::LuauCodegenUserdataOps)
{ {
RegisterA64 temp = regs.allocTemp(KindA64::x); CODEGEN_ASSERT(tag == LUA_TUSERDATA || tag == LUA_TBUFFER);
build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw int dataOffset = tag == LUA_TBUFFER ? offsetof(Buffer, data) : offsetof(Udata, data);
return mem(temp, offsetof(Buffer, data));
}
else if (indexOp.kind == IrOpKind::Constant)
{
// Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled encoding
if (unsigned(intOp(indexOp)) + offsetof(Buffer, data) <= 255)
return mem(regOp(bufferOp), int(intOp(indexOp) + offsetof(Buffer, data)));
// indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset if (indexOp.kind == IrOpKind::Inst)
if (intOp(indexOp) < 0) {
return mem(regOp(bufferOp), offsetof(Buffer, data)); RegisterA64 temp = regs.allocTemp(KindA64::x);
build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw
return mem(temp, dataOffset);
}
else if (indexOp.kind == IrOpKind::Constant)
{
// Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled
// encoding
if (unsigned(intOp(indexOp)) + dataOffset <= 255)
return mem(regOp(bufferOp), int(intOp(indexOp) + dataOffset));
RegisterA64 temp = regs.allocTemp(KindA64::x); // indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset
emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp))); if (intOp(indexOp) < 0)
return mem(temp, offsetof(Buffer, data)); return mem(regOp(bufferOp), dataOffset);
RegisterA64 temp = regs.allocTemp(KindA64::x);
emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp)));
return mem(temp, dataOffset);
}
else
{
CODEGEN_ASSERT(!"Unsupported instruction form");
return noreg;
}
} }
else else
{ {
CODEGEN_ASSERT(!"Unsupported instruction form"); if (indexOp.kind == IrOpKind::Inst)
return noreg; {
RegisterA64 temp = regs.allocTemp(KindA64::x);
build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw
return mem(temp, offsetof(Buffer, data));
}
else if (indexOp.kind == IrOpKind::Constant)
{
// Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled
// encoding
if (unsigned(intOp(indexOp)) + offsetof(Buffer, data) <= 255)
return mem(regOp(bufferOp), int(intOp(indexOp) + offsetof(Buffer, data)));
// indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset
if (intOp(indexOp) < 0)
return mem(regOp(bufferOp), offsetof(Buffer, data));
RegisterA64 temp = regs.allocTemp(KindA64::x);
emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp)));
return mem(temp, offsetof(Buffer, data));
}
else
{
CODEGEN_ASSERT(!"Unsupported instruction form");
return noreg;
}
} }
} }

View file

@ -44,7 +44,7 @@ struct IrLoweringA64
RegisterA64 tempInt(IrOp op); RegisterA64 tempInt(IrOp op);
RegisterA64 tempUint(IrOp op); RegisterA64 tempUint(IrOp op);
AddressA64 tempAddr(IrOp op, int offset); AddressA64 tempAddr(IrOp op, int offset);
AddressA64 tempAddrBuffer(IrOp bufferOp, IrOp indexOp); AddressA64 tempAddrBuffer(IrOp bufferOp, IrOp indexOp, uint8_t tag);
// May emit restore instructions // May emit restore instructions
RegisterA64 regOp(IrOp op); RegisterA64 regOp(IrOp op);

View file

@ -15,6 +15,9 @@
#include "lstate.h" #include "lstate.h"
#include "lgc.h" #include "lgc.h"
LUAU_FASTFLAG(LuauCodegenUserdataOps)
LUAU_FASTFLAG(LuauCodegenUserdataAlloc)
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
@ -905,6 +908,18 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
inst.regX64 = regs.takeReg(rax, index); inst.regX64 = regs.takeReg(rax, index);
break; break;
} }
case IrCmd::NEW_USERDATA:
{
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc);
IrCallWrapperX64 callWrap(regs, build, index);
callWrap.addArgument(SizeX64::qword, rState);
callWrap.addArgument(SizeX64::qword, intOp(inst.a));
callWrap.addArgument(SizeX64::dword, intOp(inst.b));
callWrap.call(qword[rNativeContext + offsetof(NativeContext, newUserdata)]);
inst.regX64 = regs.takeReg(rax, index);
break;
}
case IrCmd::INT_TO_NUM: case IrCmd::INT_TO_NUM:
inst.regX64 = regs.allocReg(SizeX64::xmmword, index); inst.regX64 = regs.allocReg(SizeX64::xmmword, index);
@ -1350,6 +1365,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
} }
break; break;
} }
case IrCmd::CHECK_USERDATA_TAG:
{
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps);
build.cmp(byte[regOp(inst.a) + offsetof(Udata, tag)], intOp(inst.b));
jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.c, next);
break;
}
case IrCmd::INTERRUPT: case IrCmd::INTERRUPT:
{ {
unsigned pcpos = uintOp(inst.a); unsigned pcpos = uintOp(inst.a);
@ -1895,71 +1918,71 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READI8: case IrCmd::BUFFER_READI8:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movsx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b)]); build.movsx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break; break;
case IrCmd::BUFFER_READU8: case IrCmd::BUFFER_READU8:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movzx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b)]); build.movzx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break; break;
case IrCmd::BUFFER_WRITEI8: case IrCmd::BUFFER_WRITEI8:
{ {
OperandX64 value = inst.c.kind == IrOpKind::Inst ? byteReg(regOp(inst.c)) : OperandX64(int8_t(intOp(inst.c))); OperandX64 value = inst.c.kind == IrOpKind::Inst ? byteReg(regOp(inst.c)) : OperandX64(int8_t(intOp(inst.c)));
build.mov(byte[bufferAddrOp(inst.a, inst.b)], value); build.mov(byte[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value);
break; break;
} }
case IrCmd::BUFFER_READI16: case IrCmd::BUFFER_READI16:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movsx(inst.regX64, word[bufferAddrOp(inst.a, inst.b)]); build.movsx(inst.regX64, word[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break; break;
case IrCmd::BUFFER_READU16: case IrCmd::BUFFER_READU16:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movzx(inst.regX64, word[bufferAddrOp(inst.a, inst.b)]); build.movzx(inst.regX64, word[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break; break;
case IrCmd::BUFFER_WRITEI16: case IrCmd::BUFFER_WRITEI16:
{ {
OperandX64 value = inst.c.kind == IrOpKind::Inst ? wordReg(regOp(inst.c)) : OperandX64(int16_t(intOp(inst.c))); OperandX64 value = inst.c.kind == IrOpKind::Inst ? wordReg(regOp(inst.c)) : OperandX64(int16_t(intOp(inst.c)));
build.mov(word[bufferAddrOp(inst.a, inst.b)], value); build.mov(word[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value);
break; break;
} }
case IrCmd::BUFFER_READI32: case IrCmd::BUFFER_READI32:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.mov(inst.regX64, dword[bufferAddrOp(inst.a, inst.b)]); build.mov(inst.regX64, dword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break; break;
case IrCmd::BUFFER_WRITEI32: case IrCmd::BUFFER_WRITEI32:
{ {
OperandX64 value = inst.c.kind == IrOpKind::Inst ? regOp(inst.c) : OperandX64(intOp(inst.c)); OperandX64 value = inst.c.kind == IrOpKind::Inst ? regOp(inst.c) : OperandX64(intOp(inst.c));
build.mov(dword[bufferAddrOp(inst.a, inst.b)], value); build.mov(dword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value);
break; break;
} }
case IrCmd::BUFFER_READF32: case IrCmd::BUFFER_READF32:
inst.regX64 = regs.allocReg(SizeX64::xmmword, index); inst.regX64 = regs.allocReg(SizeX64::xmmword, index);
build.vcvtss2sd(inst.regX64, inst.regX64, dword[bufferAddrOp(inst.a, inst.b)]); build.vcvtss2sd(inst.regX64, inst.regX64, dword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break; break;
case IrCmd::BUFFER_WRITEF32: case IrCmd::BUFFER_WRITEF32:
storeDoubleAsFloat(dword[bufferAddrOp(inst.a, inst.b)], inst.c); storeDoubleAsFloat(dword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], inst.c);
break; break;
case IrCmd::BUFFER_READF64: case IrCmd::BUFFER_READF64:
inst.regX64 = regs.allocReg(SizeX64::xmmword, index); inst.regX64 = regs.allocReg(SizeX64::xmmword, index);
build.vmovsd(inst.regX64, qword[bufferAddrOp(inst.a, inst.b)]); build.vmovsd(inst.regX64, qword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break; break;
case IrCmd::BUFFER_WRITEF64: case IrCmd::BUFFER_WRITEF64:
@ -1967,11 +1990,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{ {
ScopedRegX64 tmp{regs, SizeX64::xmmword}; ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, build.f64(doubleOp(inst.c))); build.vmovsd(tmp.reg, build.f64(doubleOp(inst.c)));
build.vmovsd(qword[bufferAddrOp(inst.a, inst.b)], tmp.reg); build.vmovsd(qword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], tmp.reg);
} }
else if (inst.c.kind == IrOpKind::Inst) else if (inst.c.kind == IrOpKind::Inst)
{ {
build.vmovsd(qword[bufferAddrOp(inst.a, inst.b)], regOp(inst.c)); build.vmovsd(qword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], regOp(inst.c));
} }
else else
{ {
@ -2190,12 +2213,25 @@ RegisterX64 IrLoweringX64::regOp(IrOp op)
return inst.regX64; return inst.regX64;
} }
OperandX64 IrLoweringX64::bufferAddrOp(IrOp bufferOp, IrOp indexOp) OperandX64 IrLoweringX64::bufferAddrOp(IrOp bufferOp, IrOp indexOp, uint8_t tag)
{ {
if (indexOp.kind == IrOpKind::Inst) if (FFlag::LuauCodegenUserdataOps)
return regOp(bufferOp) + qwordReg(regOp(indexOp)) + offsetof(Buffer, data); {
else if (indexOp.kind == IrOpKind::Constant) CODEGEN_ASSERT(tag == LUA_TUSERDATA || tag == LUA_TBUFFER);
return regOp(bufferOp) + intOp(indexOp) + offsetof(Buffer, data); int dataOffset = tag == LUA_TBUFFER ? offsetof(Buffer, data) : offsetof(Udata, data);
if (indexOp.kind == IrOpKind::Inst)
return regOp(bufferOp) + qwordReg(regOp(indexOp)) + dataOffset;
else if (indexOp.kind == IrOpKind::Constant)
return regOp(bufferOp) + intOp(indexOp) + dataOffset;
}
else
{
if (indexOp.kind == IrOpKind::Inst)
return regOp(bufferOp) + qwordReg(regOp(indexOp)) + offsetof(Buffer, data);
else if (indexOp.kind == IrOpKind::Constant)
return regOp(bufferOp) + intOp(indexOp) + offsetof(Buffer, data);
}
CODEGEN_ASSERT(!"Unsupported instruction form"); CODEGEN_ASSERT(!"Unsupported instruction form");
return noreg; return noreg;

View file

@ -50,7 +50,7 @@ struct IrLoweringX64
OperandX64 memRegUintOp(IrOp op); OperandX64 memRegUintOp(IrOp op);
OperandX64 memRegTagOp(IrOp op); OperandX64 memRegTagOp(IrOp op);
RegisterX64 regOp(IrOp op); RegisterX64 regOp(IrOp op);
OperandX64 bufferAddrOp(IrOp bufferOp, IrOp indexOp); OperandX64 bufferAddrOp(IrOp bufferOp, IrOp indexOp, uint8_t tag);
RegisterX64 vecOp(IrOp op, ScopedRegX64& tmp); RegisterX64 vecOp(IrOp op, ScopedRegX64& tmp);
IrConst constOp(IrOp op) const; IrConst constOp(IrOp op) const;

View file

@ -15,6 +15,7 @@
LUAU_FASTFLAGVARIABLE(LuauCodegenDirectUserdataFlow, false) LUAU_FASTFLAGVARIABLE(LuauCodegenDirectUserdataFlow, false)
LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps)
LUAU_FASTFLAG(LuauCodegenUserdataOps)
namespace Luau namespace Luau
{ {
@ -444,6 +445,17 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc,
return; return;
} }
if (FFlag::LuauCodegenUserdataOps && (isUserdataBytecodeType(bcTypes.a) || isUserdataBytecodeType(bcTypes.b)))
{
if (build.hostHooks.userdataMetamethod &&
build.hostHooks.userdataMetamethod(build, bcTypes.a, bcTypes.b, ra, opb, opc, tmToHostMetamethod(tm), pcpos))
return;
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::DO_ARITH, build.vmReg(ra), opb, opc, build.constInt(tm));
return;
}
IrOp fallback; IrOp fallback;
// fast-path: number // fast-path: number
@ -585,6 +597,17 @@ void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos)
return; return;
} }
if (FFlag::LuauCodegenUserdataOps && isUserdataBytecodeType(bcTypes.a))
{
if (build.hostHooks.userdataMetamethod &&
build.hostHooks.userdataMetamethod(build, bcTypes.a, bcTypes.b, ra, build.vmReg(rb), {}, tmToHostMetamethod(TM_UNM), pcpos))
return;
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::DO_ARITH, build.vmReg(ra), build.vmReg(rb), build.vmReg(rb), build.constInt(TM_UNM));
return;
}
IrOp fallback; IrOp fallback;
IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb));
@ -606,8 +629,17 @@ void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos)
FallbackStreamScope scope(build, fallback, next); FallbackStreamScope scope(build, fallback, next);
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(
IrCmd::DO_ARITH, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.constInt(TM_UNM)); if (FFlag::LuauCodegenUserdataOps)
{
build.inst(IrCmd::DO_ARITH, build.vmReg(ra), build.vmReg(rb), build.vmReg(rb), build.constInt(TM_UNM));
}
else
{
build.inst(
IrCmd::DO_ARITH, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.constInt(TM_UNM));
}
build.inst(IrCmd::JUMP, next); build.inst(IrCmd::JUMP, next);
} }
} }
@ -619,6 +651,17 @@ void translateInstLength(IrBuilder& build, const Instruction* pc, int pcpos)
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc); int rb = LUAU_INSN_B(*pc);
if (FFlag::LuauCodegenUserdataOps && isUserdataBytecodeType(bcTypes.a))
{
if (build.hostHooks.userdataMetamethod &&
build.hostHooks.userdataMetamethod(build, bcTypes.a, bcTypes.b, ra, build.vmReg(rb), {}, tmToHostMetamethod(TM_LEN), pcpos))
return;
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::DO_LEN, build.vmReg(ra), build.vmReg(rb));
return;
}
IrOp fallback = build.block(IrBlockKind::Fallback); IrOp fallback = build.block(IrBlockKind::Fallback);
IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb));
@ -638,7 +681,12 @@ void translateInstLength(IrBuilder& build, const Instruction* pc, int pcpos)
FallbackStreamScope scope(build, fallback, next); FallbackStreamScope scope(build, fallback, next);
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::DO_LEN, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)));
if (FFlag::LuauCodegenUserdataOps)
build.inst(IrCmd::DO_LEN, build.vmReg(ra), build.vmReg(rb));
else
build.inst(IrCmd::DO_LEN, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)));
build.inst(IrCmd::JUMP, next); build.inst(IrCmd::JUMP, next);
} }
@ -1229,10 +1277,19 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos)
return; return;
} }
if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA) if (FFlag::LuauCodegenDirectUserdataFlow && (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA))
{ {
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TUSERDATA), build.vmExit(pcpos)); build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TUSERDATA), build.vmExit(pcpos));
if (FFlag::LuauCodegenUserdataOps && build.hostHooks.userdataAccess)
{
TString* str = gco2ts(build.function.proto->k[aux].value.gc);
const char* field = getstr(str);
if (build.hostHooks.userdataAccess(build, bcTypes.a, field, str->len, ra, rb, pcpos))
return;
}
build.inst(IrCmd::FALLBACK_GETTABLEKS, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); build.inst(IrCmd::FALLBACK_GETTABLEKS, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux));
return; return;
} }
@ -1267,7 +1324,7 @@ void translateInstSetTableKS(IrBuilder& build, const Instruction* pc, int pcpos)
IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb));
if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA) if (FFlag::LuauCodegenDirectUserdataFlow && (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA))
{ {
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TUSERDATA), build.vmExit(pcpos)); build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TUSERDATA), build.vmExit(pcpos));
@ -1413,10 +1470,26 @@ bool translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos)
return false; return false;
} }
if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA) if (FFlag::LuauCodegenDirectUserdataFlow && (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA))
{ {
build.loadAndCheckTag(build.vmReg(rb), LUA_TUSERDATA, build.vmExit(pcpos)); build.loadAndCheckTag(build.vmReg(rb), LUA_TUSERDATA, build.vmExit(pcpos));
if (FFlag::LuauCodegenUserdataOps && build.hostHooks.userdataNamecall)
{
Instruction call = pc[2];
CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
int callra = LUAU_INSN_A(call);
int nparams = LUAU_INSN_B(call) - 1;
int nresults = LUAU_INSN_C(call) - 1;
TString* str = gco2ts(build.function.proto->k[aux].value.gc);
const char* field = getstr(str);
if (build.hostHooks.userdataNamecall(build, bcTypes.a, field, str->len, callra, rb, nparams, nresults, pcpos))
return true;
}
build.inst(IrCmd::FALLBACK_NAMECALL, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); build.inst(IrCmd::FALLBACK_NAMECALL, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux));
return false; return false;
} }

View file

@ -99,6 +99,7 @@ IrValueKind getCmdValueKind(IrCmd cmd)
case IrCmd::TRY_NUM_TO_INDEX: case IrCmd::TRY_NUM_TO_INDEX:
return IrValueKind::Int; return IrValueKind::Int;
case IrCmd::TRY_CALL_FASTGETTM: case IrCmd::TRY_CALL_FASTGETTM:
case IrCmd::NEW_USERDATA:
return IrValueKind::Pointer; return IrValueKind::Pointer;
case IrCmd::INT_TO_NUM: case IrCmd::INT_TO_NUM:
case IrCmd::UINT_TO_NUM: case IrCmd::UINT_TO_NUM:
@ -135,6 +136,7 @@ IrValueKind getCmdValueKind(IrCmd cmd)
case IrCmd::CHECK_NODE_NO_NEXT: case IrCmd::CHECK_NODE_NO_NEXT:
case IrCmd::CHECK_NODE_VALUE: case IrCmd::CHECK_NODE_VALUE:
case IrCmd::CHECK_BUFFER_LEN: case IrCmd::CHECK_BUFFER_LEN:
case IrCmd::CHECK_USERDATA_TAG:
case IrCmd::INTERRUPT: case IrCmd::INTERRUPT:
case IrCmd::CHECK_GC: case IrCmd::CHECK_GC:
case IrCmd::BARRIER_OBJ: case IrCmd::BARRIER_OBJ:
@ -262,6 +264,44 @@ bool isCustomUserdataBytecodeType(uint8_t ty)
return ty >= LBC_TYPE_TAGGED_USERDATA_BASE && ty < LBC_TYPE_TAGGED_USERDATA_END; return ty >= LBC_TYPE_TAGGED_USERDATA_BASE && ty < LBC_TYPE_TAGGED_USERDATA_END;
} }
HostMetamethod tmToHostMetamethod(int tm)
{
switch (TMS(tm))
{
case TM_ADD:
return HostMetamethod::Add;
case TM_SUB:
return HostMetamethod::Sub;
case TM_MUL:
return HostMetamethod::Mul;
case TM_DIV:
return HostMetamethod::Div;
case TM_IDIV:
return HostMetamethod::Idiv;
case TM_MOD:
return HostMetamethod::Mod;
case TM_POW:
return HostMetamethod::Pow;
case TM_UNM:
return HostMetamethod::Minus;
case TM_EQ:
return HostMetamethod::Equal;
case TM_LT:
return HostMetamethod::LessThan;
case TM_LE:
return HostMetamethod::LessEqual;
case TM_LEN:
return HostMetamethod::Length;
case TM_CONCAT:
return HostMetamethod::Concat;
default:
CODEGEN_ASSERT(!"invalid tag method for host");
break;
}
return HostMetamethod::Add;
}
void kill(IrFunction& function, IrInst& inst) void kill(IrFunction& function, IrInst& inst)
{ {
CODEGEN_ASSERT(inst.useCount == 0); CODEGEN_ASSERT(inst.useCount == 0);

View file

@ -14,114 +14,13 @@
#include <math.h> #include <math.h>
#include <string.h> #include <string.h>
LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024) LUAU_FASTFLAG(LuauCodegenUserdataAlloc)
LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024)
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
{ {
NativeState::NativeState()
: NativeState(nullptr, nullptr)
{
}
NativeState::NativeState(AllocationCallback* allocationCallback, void* allocationCallbackContext)
: codeAllocator{size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext}
{
}
NativeState::~NativeState() = default;
void initFunctions(NativeState& data)
{
static_assert(sizeof(data.context.luauF_table) == sizeof(luauF_table), "fastcall tables are not of the same length");
memcpy(data.context.luauF_table, luauF_table, sizeof(luauF_table));
data.context.luaV_lessthan = luaV_lessthan;
data.context.luaV_lessequal = luaV_lessequal;
data.context.luaV_equalval = luaV_equalval;
data.context.luaV_doarith = luaV_doarith;
data.context.luaV_doarithadd = luaV_doarithimpl<TM_ADD>;
data.context.luaV_doarithsub = luaV_doarithimpl<TM_SUB>;
data.context.luaV_doarithmul = luaV_doarithimpl<TM_MUL>;
data.context.luaV_doarithdiv = luaV_doarithimpl<TM_DIV>;
data.context.luaV_doarithidiv = luaV_doarithimpl<TM_IDIV>;
data.context.luaV_doarithmod = luaV_doarithimpl<TM_MOD>;
data.context.luaV_doarithpow = luaV_doarithimpl<TM_POW>;
data.context.luaV_doarithunm = luaV_doarithimpl<TM_UNM>;
data.context.luaV_dolen = luaV_dolen;
data.context.luaV_gettable = luaV_gettable;
data.context.luaV_settable = luaV_settable;
data.context.luaV_getimport = luaV_getimport;
data.context.luaV_concat = luaV_concat;
data.context.luaH_getn = luaH_getn;
data.context.luaH_new = luaH_new;
data.context.luaH_clone = luaH_clone;
data.context.luaH_resizearray = luaH_resizearray;
data.context.luaH_setnum = luaH_setnum;
data.context.luaC_barriertable = luaC_barriertable;
data.context.luaC_barrierf = luaC_barrierf;
data.context.luaC_barrierback = luaC_barrierback;
data.context.luaC_step = luaC_step;
data.context.luaF_close = luaF_close;
data.context.luaF_findupval = luaF_findupval;
data.context.luaF_newLclosure = luaF_newLclosure;
data.context.luaT_gettm = luaT_gettm;
data.context.luaT_objtypenamestr = luaT_objtypenamestr;
data.context.libm_exp = exp;
data.context.libm_pow = pow;
data.context.libm_fmod = fmod;
data.context.libm_log = log;
data.context.libm_log2 = log2;
data.context.libm_log10 = log10;
data.context.libm_ldexp = ldexp;
data.context.libm_round = round;
data.context.libm_frexp = frexp;
data.context.libm_modf = modf;
data.context.libm_asin = asin;
data.context.libm_sin = sin;
data.context.libm_sinh = sinh;
data.context.libm_acos = acos;
data.context.libm_cos = cos;
data.context.libm_cosh = cosh;
data.context.libm_atan = atan;
data.context.libm_atan2 = atan2;
data.context.libm_tan = tan;
data.context.libm_tanh = tanh;
data.context.forgLoopTableIter = forgLoopTableIter;
data.context.forgLoopNodeIter = forgLoopNodeIter;
data.context.forgLoopNonTableFallback = forgLoopNonTableFallback;
data.context.forgPrepXnextFallback = forgPrepXnextFallback;
data.context.callProlog = callProlog;
data.context.callEpilogC = callEpilogC;
data.context.callFallback = callFallback;
data.context.executeGETGLOBAL = executeGETGLOBAL;
data.context.executeSETGLOBAL = executeSETGLOBAL;
data.context.executeGETTABLEKS = executeGETTABLEKS;
data.context.executeSETTABLEKS = executeSETTABLEKS;
data.context.executeNAMECALL = executeNAMECALL;
data.context.executeFORGPREP = executeFORGPREP;
data.context.executeGETVARARGSMultRet = executeGETVARARGSMultRet;
data.context.executeGETVARARGSConst = executeGETVARARGSConst;
data.context.executeDUPCLOSURE = executeDUPCLOSURE;
data.context.executePREPVARARGS = executePREPVARARGS;
data.context.executeSETLIST = executeSETLIST;
}
void initFunctions(NativeContext& context) void initFunctions(NativeContext& context)
{ {
static_assert(sizeof(context.luauF_table) == sizeof(luauF_table), "fastcall tables are not of the same length"); static_assert(sizeof(context.luauF_table) == sizeof(luauF_table), "fastcall tables are not of the same length");
@ -194,6 +93,9 @@ void initFunctions(NativeContext& context)
context.callProlog = callProlog; context.callProlog = callProlog;
context.callEpilogC = callEpilogC; context.callEpilogC = callEpilogC;
if (FFlag::LuauCodegenUserdataAlloc)
context.newUserdata = newUserdata;
context.callFallback = callFallback; context.callFallback = callFallback;
context.executeGETGLOBAL = executeGETGLOBAL; context.executeGETGLOBAL = executeGETGLOBAL;

View file

@ -94,6 +94,7 @@ struct NativeContext
void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr; void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr;
Closure* (*callProlog)(lua_State* L, TValue* ra, StkId argtop, int nresults) = nullptr; Closure* (*callProlog)(lua_State* L, TValue* ra, StkId argtop, int nresults) = nullptr;
void (*callEpilogC)(lua_State* L, int nresults, int n) = nullptr; void (*callEpilogC)(lua_State* L, int nresults, int n) = nullptr;
Udata* (*newUserdata)(lua_State* L, size_t s, int tag) = nullptr;
Closure* (*callFallback)(lua_State* L, StkId ra, StkId argtop, int nresults) = nullptr; Closure* (*callFallback)(lua_State* L, StkId ra, StkId argtop, int nresults) = nullptr;
@ -116,22 +117,6 @@ struct NativeContext
using GateFn = int (*)(lua_State*, Proto*, uintptr_t, NativeContext*); using GateFn = int (*)(lua_State*, Proto*, uintptr_t, NativeContext*);
struct NativeState
{
NativeState();
NativeState(AllocationCallback* allocationCallback, void* allocationCallbackContext);
~NativeState();
CodeAllocator codeAllocator;
std::unique_ptr<UnwindBuilder> unwindBuilder;
uint8_t* gateData = nullptr;
size_t gateDataSize = 0;
NativeContext context;
};
void initFunctions(NativeState& data);
void initFunctions(NativeContext& context); void initFunctions(NativeContext& context);
} // namespace CodeGen } // namespace CodeGen

View file

@ -16,9 +16,12 @@
LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3)
LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64)
LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64)
LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false)
LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5)
LUAU_FASTFLAGVARIABLE(LuauCodegenFixSplitStoreConstMismatch, false) LUAU_FASTFLAGVARIABLE(LuauCodegenFixSplitStoreConstMismatch, false)
LUAU_FASTFLAG(LuauCodegenUserdataOps)
LUAU_FASTFLAG(LuauCodegenUserdataAlloc)
namespace Luau namespace Luau
{ {
@ -200,6 +203,11 @@ struct ConstPropState
checkBufferLenCache.clear(); checkBufferLenCache.clear();
} }
void invalidateUserdataData()
{
useradataTagCache.clear();
}
void invalidateHeap() void invalidateHeap()
{ {
for (int i = 0; i <= maxReg; ++i) for (int i = 0; i <= maxReg; ++i)
@ -417,6 +425,9 @@ struct ConstPropState
invalidateValuePropagation(); invalidateValuePropagation();
invalidateHeapTableData(); invalidateHeapTableData();
invalidateHeapBufferData(); invalidateHeapBufferData();
if (FFlag::LuauCodegenUserdataOps)
invalidateUserdataData();
} }
IrFunction& function; IrFunction& function;
@ -446,6 +457,9 @@ struct ConstPropState
std::vector<uint32_t> checkArraySizeCache; // Additionally, fallback block argument might be different std::vector<uint32_t> checkArraySizeCache; // Additionally, fallback block argument might be different
std::vector<uint32_t> checkBufferLenCache; // Additionally, fallback block argument might be different std::vector<uint32_t> checkBufferLenCache; // Additionally, fallback block argument might be different
// Userdata tag cache can point to both NEW_USERDATA and CHECK_USERDATA_TAG instructions
std::vector<uint32_t> useradataTagCache; // Additionally, fallback block argument might be different
}; };
static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid, uint32_t firstReturnReg, int nresults) static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid, uint32_t firstReturnReg, int nresults)
@ -1061,6 +1075,37 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
state.checkBufferLenCache.push_back(index); state.checkBufferLenCache.push_back(index);
break; break;
} }
case IrCmd::CHECK_USERDATA_TAG:
{
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps);
for (uint32_t prevIdx : state.useradataTagCache)
{
IrInst& prev = function.instructions[prevIdx];
if (prev.cmd == IrCmd::CHECK_USERDATA_TAG)
{
if (prev.a != inst.a || prev.b != inst.b)
continue;
}
else if (FFlag::LuauCodegenUserdataAlloc && prev.cmd == IrCmd::NEW_USERDATA)
{
if (inst.a.kind != IrOpKind::Inst || prevIdx != inst.a.index || prev.b != inst.b)
continue;
}
if (FFlag::DebugLuauAbortingChecks)
replace(function, inst.c, build.undef());
else
kill(function, inst);
return; // Break out from both the loop and the switch
}
if (int(state.useradataTagCache.size()) < FInt::LuauCodeGenReuseUdataTagLimit)
state.useradataTagCache.push_back(index);
break;
}
case IrCmd::BUFFER_READI8: case IrCmd::BUFFER_READI8:
case IrCmd::BUFFER_READU8: case IrCmd::BUFFER_READU8:
case IrCmd::BUFFER_WRITEI8: case IrCmd::BUFFER_WRITEI8:
@ -1228,6 +1273,12 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
break; break;
case IrCmd::TRY_CALL_FASTGETTM: case IrCmd::TRY_CALL_FASTGETTM:
break; break;
case IrCmd::NEW_USERDATA:
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc);
if (int(state.useradataTagCache.size()) < FInt::LuauCodeGenReuseUdataTagLimit)
state.useradataTagCache.push_back(index);
break;
case IrCmd::INT_TO_NUM: case IrCmd::INT_TO_NUM:
case IrCmd::UINT_TO_NUM: case IrCmd::UINT_TO_NUM:
state.substituteOrRecord(inst, index); state.substituteOrRecord(inst, index);
@ -1512,6 +1563,9 @@ static void constPropInBlockChain(IrBuilder& build, std::vector<uint8_t>& visite
state.invalidateHeapTableData(); state.invalidateHeapTableData();
state.invalidateHeapBufferData(); state.invalidateHeapBufferData();
if (FFlag::LuauCodegenUserdataOps)
state.invalidateUserdataData();
// Blocks in a chain are guaranteed to follow each other // Blocks in a chain are guaranteed to follow each other
// We force that by giving all blocks the same sorting key, but consecutive chain keys // We force that by giving all blocks the same sorting key, but consecutive chain keys
block->sortkey = startSortkey; block->sortkey = startSortkey;

View file

@ -10,6 +10,7 @@
#include "lobject.h" #include "lobject.h"
LUAU_FASTFLAGVARIABLE(LuauCodegenRemoveDeadStores5, false) LUAU_FASTFLAGVARIABLE(LuauCodegenRemoveDeadStores5, false)
LUAU_FASTFLAG(LuauCodegenUserdataOps)
// TODO: optimization can be improved by knowing which registers are live in at each VM exit // TODO: optimization can be improved by knowing which registers are live in at each VM exit
@ -595,6 +596,11 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build,
case IrCmd::CHECK_BUFFER_LEN: case IrCmd::CHECK_BUFFER_LEN:
state.checkLiveIns(inst.d); state.checkLiveIns(inst.d);
break; break;
case IrCmd::CHECK_USERDATA_TAG:
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps);
state.checkLiveIns(inst.c);
break;
case IrCmd::JUMP: case IrCmd::JUMP:
// Ideally, we would be able to remove stores to registers that are not live out from a block // Ideally, we would be able to remove stores to registers that are not live out from a block

View file

@ -4219,7 +4219,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c
for (AstExprFunction* expr : functions) for (AstExprFunction* expr : functions)
compiler.compileFunction(expr, 0); compiler.compileFunction(expr, 0);
AstExprFunction main(root->location, /*generics= */ AstArray<AstGenericType>(), /*genericPacks= */ AstArray<AstGenericTypePack>(), AstExprFunction main(root->location, /*attributes=*/AstArray<AstAttr*>({nullptr, 0}), /*generics= */ AstArray<AstGenericType>(),
/*genericPacks= */ AstArray<AstGenericTypePack>(),
/* self= */ nullptr, AstArray<AstLocal*>(), /* vararg= */ true, /* varargLocation= */ Luau::Location(), root, /* functionDepth= */ 0, /* self= */ nullptr, AstArray<AstLocal*>(), /* vararg= */ true, /* varargLocation= */ Luau::Location(), root, /* functionDepth= */ 0,
/* debugname= */ AstName()); /* debugname= */ AstName());
uint32_t mainid = compiler.compileFunction(&main, mainFlags); uint32_t mainid = compiler.compileFunction(&main, mainFlags);

View file

@ -195,7 +195,7 @@ static Error parseJson(const std::string& contents, Action action)
} }
else if (lexer.current().type == Lexeme::QuotedString) else if (lexer.current().type == Lexeme::QuotedString)
{ {
std::string value(lexer.current().data, lexer.current().length); std::string value(lexer.current().data, lexer.current().getLength());
next(lexer); next(lexer);
if (Error err = action(keys, value)) if (Error err = action(keys, value))
@ -232,7 +232,7 @@ static Error parseJson(const std::string& contents, Action action)
} }
else if (lexer.current().type == Lexeme::QuotedString) else if (lexer.current().type == Lexeme::QuotedString)
{ {
std::string key(lexer.current().data, lexer.current().length); std::string key(lexer.current().data, lexer.current().getLength());
next(lexer); next(lexer);
keys.push_back(key); keys.push_back(key);
@ -250,7 +250,7 @@ static Error parseJson(const std::string& contents, Action action)
lexer.current().type == Lexeme::ReservedFalse) lexer.current().type == Lexeme::ReservedFalse)
{ {
std::string value = lexer.current().type == Lexeme::QuotedString std::string value = lexer.current().type == Lexeme::QuotedString
? std::string(lexer.current().data, lexer.current().length) ? std::string(lexer.current().data, lexer.current().getLength())
: (lexer.current().type == Lexeme::ReservedTrue ? "true" : "false"); : (lexer.current().type == Lexeme::ReservedTrue ? "true" : "false");
next(lexer); next(lexer);

View file

@ -324,6 +324,10 @@ typedef void (*lua_Destructor)(lua_State* L, void* userdata);
LUA_API void lua_setuserdatadtor(lua_State* L, int tag, lua_Destructor dtor); LUA_API void lua_setuserdatadtor(lua_State* L, int tag, lua_Destructor dtor);
LUA_API lua_Destructor lua_getuserdatadtor(lua_State* L, int tag); LUA_API lua_Destructor lua_getuserdatadtor(lua_State* L, int tag);
// alternative access for metatables already registered with luaL_newmetatable
LUA_API void lua_setuserdatametatable(lua_State* L, int tag, int idx);
LUA_API void lua_getuserdatametatable(lua_State* L, int tag);
LUA_API void lua_setlightuserdataname(lua_State* L, int tag, const char* name); LUA_API void lua_setlightuserdataname(lua_State* L, int tag, const char* name);
LUA_API const char* lua_getlightuserdataname(lua_State* L, int tag); LUA_API const char* lua_getlightuserdataname(lua_State* L, int tag);

View file

@ -1427,6 +1427,33 @@ lua_Destructor lua_getuserdatadtor(lua_State* L, int tag)
return L->global->udatagc[tag]; return L->global->udatagc[tag];
} }
void lua_setuserdatametatable(lua_State* L, int tag, int idx)
{
api_check(L, unsigned(tag) < LUA_UTAG_LIMIT);
api_check(L, !L->global->udatamt[tag]); // reassignment not supported
StkId o = index2addr(L, idx);
api_check(L, ttistable(o));
L->global->udatamt[tag] = hvalue(o);
L->top--;
}
void lua_getuserdatametatable(lua_State* L, int tag)
{
api_check(L, unsigned(tag) < LUA_UTAG_LIMIT);
luaC_threadbarrier(L);
if (Table* h = L->global->udatamt[tag])
{
sethvalue(L, L->top, h);
}
else
{
setnilvalue(L->top);
}
api_incr_top(L);
}
void lua_setlightuserdataname(lua_State* L, int tag, const char* name) void lua_setlightuserdataname(lua_State* L, int tag, const char* name)
{ {
api_check(L, unsigned(tag) < LUA_LUTAG_LIMIT); api_check(L, unsigned(tag) < LUA_LUTAG_LIMIT);

View file

@ -210,7 +210,10 @@ lua_State* lua_newstate(lua_Alloc f, void* ud)
for (i = 0; i < LUA_T_COUNT; i++) for (i = 0; i < LUA_T_COUNT; i++)
g->mt[i] = NULL; g->mt[i] = NULL;
for (i = 0; i < LUA_UTAG_LIMIT; i++) for (i = 0; i < LUA_UTAG_LIMIT; i++)
{
g->udatagc[i] = NULL; g->udatagc[i] = NULL;
g->udatamt[i] = NULL;
}
for (i = 0; i < LUA_LUTAG_LIMIT; i++) for (i = 0; i < LUA_LUTAG_LIMIT; i++)
g->lightuserdataname[i] = NULL; g->lightuserdataname[i] = NULL;
for (i = 0; i < LUA_MEMORY_CATEGORIES; i++) for (i = 0; i < LUA_MEMORY_CATEGORIES; i++)

View file

@ -217,6 +217,7 @@ typedef struct global_State
lua_ExecutionCallbacks ecb; lua_ExecutionCallbacks ecb;
void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory
Table* udatamt[LUA_LUTAG_LIMIT]; // metatables for tagged userdata
TString* lightuserdataname[LUA_LUTAG_LIMIT]; // names for tagged lightuserdata TString* lightuserdataname[LUA_LUTAG_LIMIT]; // names for tagged lightuserdata

View file

@ -342,6 +342,209 @@ void setupVectorHelpers(lua_State* L)
lua_pop(L, 1); lua_pop(L, 1);
} }
Vec2* lua_vec2_push(lua_State* L)
{
Vec2* data = (Vec2*)lua_newuserdatatagged(L, sizeof(Vec2), kTagVec2);
lua_getuserdatametatable(L, kTagVec2);
lua_setmetatable(L, -2);
return data;
}
Vec2* lua_vec2_get(lua_State* L, int idx)
{
Vec2* a = (Vec2*)lua_touserdatatagged(L, idx, kTagVec2);
if (a)
return a;
luaL_typeerror(L, idx, "vec2");
}
static int lua_vec2(lua_State* L)
{
double x = luaL_checknumber(L, 1);
double y = luaL_checknumber(L, 2);
Vec2* data = lua_vec2_push(L);
data->x = float(x);
data->y = float(y);
return 1;
}
static int lua_vec2_dot(lua_State* L)
{
Vec2* a = lua_vec2_get(L, 1);
Vec2* b = lua_vec2_get(L, 2);
lua_pushnumber(L, a->x * b->x + a->y * b->y);
return 1;
}
static int lua_vec2_min(lua_State* L)
{
Vec2* a = lua_vec2_get(L, 1);
Vec2* b = lua_vec2_get(L, 2);
Vec2* data = lua_vec2_push(L);
data->x = a->x < b->x ? a->x : b->x;
data->y = a->y < b->y ? a->y : b->y;
return 1;
}
static int lua_vec2_index(lua_State* L)
{
Vec2* v = lua_vec2_get(L, 1);
const char* name = luaL_checkstring(L, 2);
if (strcmp(name, "X") == 0)
{
lua_pushnumber(L, v->x);
return 1;
}
if (strcmp(name, "Y") == 0)
{
lua_pushnumber(L, v->y);
return 1;
}
if (strcmp(name, "Magnitude") == 0)
{
lua_pushnumber(L, sqrtf(v->x * v->x + v->y * v->y));
return 1;
}
if (strcmp(name, "Unit") == 0)
{
float invSqrt = 1.0f / sqrtf(v->x * v->x + v->y * v->y);
Vec2* data = lua_vec2_push(L);
data->x = v->x * invSqrt;
data->y = v->y * invSqrt;
return 1;
}
luaL_error(L, "%s is not a valid member of vector", name);
}
static int lua_vec2_namecall(lua_State* L)
{
if (const char* str = lua_namecallatom(L, nullptr))
{
if (strcmp(str, "Dot") == 0)
return lua_vec2_dot(L);
if (strcmp(str, "Min") == 0)
return lua_vec2_min(L);
}
luaL_error(L, "%s is not a valid method of vector", luaL_checkstring(L, 1));
}
void setupUserdataHelpers(lua_State* L)
{
// create metatable with all the metamethods
luaL_newmetatable(L, "vec2");
luaL_getmetatable(L, "vec2");
lua_pushvalue(L, -1);
lua_setuserdatametatable(L, kTagVec2, -1);
lua_pushcfunction(L, lua_vec2_index, nullptr);
lua_setfield(L, -2, "__index");
lua_pushcfunction(L, lua_vec2_namecall, nullptr);
lua_setfield(L, -2, "__namecall");
lua_pushcclosurek(
L,
[](lua_State* L) {
Vec2* a = lua_vec2_get(L, 1);
Vec2* b = lua_vec2_get(L, 2);
Vec2* data = lua_vec2_push(L);
data->x = a->x + b->x;
data->y = a->y + b->y;
return 1;
},
nullptr, 0, nullptr);
lua_setfield(L, -2, "__add");
lua_pushcclosurek(
L,
[](lua_State* L) {
Vec2* a = lua_vec2_get(L, 1);
Vec2* b = lua_vec2_get(L, 2);
Vec2* data = lua_vec2_push(L);
data->x = a->x - b->x;
data->y = a->y - b->y;
return 1;
},
nullptr, 0, nullptr);
lua_setfield(L, -2, "__sub");
lua_pushcclosurek(
L,
[](lua_State* L) {
Vec2* a = lua_vec2_get(L, 1);
Vec2* b = lua_vec2_get(L, 2);
Vec2* data = lua_vec2_push(L);
data->x = a->x * b->x;
data->y = a->y * b->y;
return 1;
},
nullptr, 0, nullptr);
lua_setfield(L, -2, "__mul");
lua_pushcclosurek(
L,
[](lua_State* L) {
Vec2* a = lua_vec2_get(L, 1);
Vec2* b = lua_vec2_get(L, 2);
Vec2* data = lua_vec2_push(L);
data->x = a->x / b->x;
data->y = a->y / b->y;
return 1;
},
nullptr, 0, nullptr);
lua_setfield(L, -2, "__div");
lua_pushcclosurek(
L,
[](lua_State* L) {
Vec2* a = lua_vec2_get(L, 1);
Vec2* data = lua_vec2_push(L);
data->x = -a->x;
data->y = -a->y;
return 1;
},
nullptr, 0, nullptr);
lua_setfield(L, -2, "__unm");
lua_setreadonly(L, -1, true);
// ctor
lua_pushcfunction(L, lua_vec2, "vec2");
lua_setglobal(L, "vec2");
lua_pop(L, 1);
}
static void setupNativeHelpers(lua_State* L) static void setupNativeHelpers(lua_State* L)
{ {
lua_pushcclosurek( lua_pushcclosurek(
@ -1828,16 +2031,36 @@ TEST_CASE("UserdataApi")
luaL_newmetatable(L, "udata2"); luaL_newmetatable(L, "udata2");
void* ud5 = lua_newuserdata(L, 0); void* ud5 = lua_newuserdata(L, 0);
lua_getfield(L, LUA_REGISTRYINDEX, "udata1"); luaL_getmetatable(L, "udata1");
lua_setmetatable(L, -2); lua_setmetatable(L, -2);
void* ud6 = lua_newuserdata(L, 0); void* ud6 = lua_newuserdata(L, 0);
lua_getfield(L, LUA_REGISTRYINDEX, "udata2"); luaL_getmetatable(L, "udata2");
lua_setmetatable(L, -2); lua_setmetatable(L, -2);
CHECK(luaL_checkudata(L, -2, "udata1") == ud5); CHECK(luaL_checkudata(L, -2, "udata1") == ud5);
CHECK(luaL_checkudata(L, -1, "udata2") == ud6); CHECK(luaL_checkudata(L, -1, "udata2") == ud6);
// tagged user data with fast metatable access
luaL_newmetatable(L, "udata3");
luaL_getmetatable(L, "udata3");
lua_setuserdatametatable(L, 50, -1);
luaL_newmetatable(L, "udata4");
luaL_getmetatable(L, "udata4");
lua_setuserdatametatable(L, 51, -1);
void* ud7 = lua_newuserdatatagged(L, 16, 50);
lua_getuserdatametatable(L, 50);
lua_setmetatable(L, -2);
void* ud8 = lua_newuserdatatagged(L, 16, 51);
lua_getuserdatametatable(L, 51);
lua_setmetatable(L, -2);
CHECK(luaL_checkudata(L, -2, "udata3") == ud7);
CHECK(luaL_checkudata(L, -1, "udata4") == ud8);
globalState.reset(); globalState.reset();
CHECK(dtorhits == 42); CHECK(dtorhits == 42);
@ -1911,7 +2134,6 @@ TEST_CASE("Iter")
} }
const int kInt64Tag = 1; const int kInt64Tag = 1;
static int gInt64MT = -1;
static int64_t getInt64(lua_State* L, int idx) static int64_t getInt64(lua_State* L, int idx)
{ {
@ -1928,7 +2150,7 @@ static void pushInt64(lua_State* L, int64_t value)
{ {
void* p = lua_newuserdatatagged(L, sizeof(int64_t), kInt64Tag); void* p = lua_newuserdatatagged(L, sizeof(int64_t), kInt64Tag);
lua_getref(L, gInt64MT); luaL_getmetatable(L, "int64");
lua_setmetatable(L, -2); lua_setmetatable(L, -2);
*static_cast<int64_t*>(p) = value; *static_cast<int64_t*>(p) = value;
@ -1938,8 +2160,7 @@ TEST_CASE("Userdata")
{ {
runConformance("userdata.lua", [](lua_State* L) { runConformance("userdata.lua", [](lua_State* L) {
// create metatable with all the metamethods // create metatable with all the metamethods
lua_newtable(L); luaL_newmetatable(L, "int64");
gInt64MT = lua_ref(L, -1);
// __index // __index
lua_pushcfunction( lua_pushcfunction(
@ -2164,6 +2385,86 @@ TEST_CASE("NativeTypeAnnotations")
}); });
} }
TEST_CASE("NativeUserdata")
{
lua_CompileOptions copts = defaultOptions();
Luau::CodeGen::CompilationOptions nativeOpts = defaultCodegenOptions();
static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr};
copts.userdataTypes = kUserdataCompileTypes;
SUBCASE("NoIrHooks")
{
SUBCASE("O0")
{
copts.optimizationLevel = 0;
}
SUBCASE("O1")
{
copts.optimizationLevel = 1;
}
SUBCASE("O2")
{
copts.optimizationLevel = 2;
}
}
SUBCASE("IrHooks")
{
nativeOpts.hooks.vectorAccessBytecodeType = vectorAccessBytecodeType;
nativeOpts.hooks.vectorNamecallBytecodeType = vectorNamecallBytecodeType;
nativeOpts.hooks.vectorAccess = vectorAccess;
nativeOpts.hooks.vectorNamecall = vectorNamecall;
nativeOpts.hooks.userdataAccessBytecodeType = userdataAccessBytecodeType;
nativeOpts.hooks.userdataMetamethodBytecodeType = userdataMetamethodBytecodeType;
nativeOpts.hooks.userdataNamecallBytecodeType = userdataNamecallBytecodeType;
nativeOpts.hooks.userdataAccess = userdataAccess;
nativeOpts.hooks.userdataMetamethod = userdataMetamethod;
nativeOpts.hooks.userdataNamecall = userdataNamecall;
nativeOpts.userdataTypes = kUserdataRunTypes;
SUBCASE("O0")
{
copts.optimizationLevel = 0;
}
SUBCASE("O1")
{
copts.optimizationLevel = 1;
}
SUBCASE("O2")
{
copts.optimizationLevel = 2;
}
}
runConformance(
"native_userdata.lua",
[](lua_State* L) {
Luau::CodeGen::setUserdataRemapper(L, kUserdataRunTypes, [](void* context, const char* str, size_t len) -> uint8_t {
const char** types = (const char**)context;
uint8_t index = 0;
std::string_view sv{str, len};
for (; *types; ++types)
{
if (sv == *types)
return index;
index++;
}
return 0xff;
});
setupVectorHelpers(L);
setupUserdataHelpers(L);
},
nullptr, nullptr, &copts, false, &nativeOpts);
}
[[nodiscard]] static std::string makeHugeFunctionSource() [[nodiscard]] static std::string makeHugeFunctionSource()
{ {
std::string source; std::string source;

View file

@ -5,14 +5,44 @@
static const char* kUserdataRunTypes[] = {"extra", "color", "vec2", "mat3", nullptr}; static const char* kUserdataRunTypes[] = {"extra", "color", "vec2", "mat3", nullptr};
constexpr uint8_t kUserdataExtra = 0;
constexpr uint8_t kUserdataColor = 1;
constexpr uint8_t kUserdataVec2 = 2;
constexpr uint8_t kUserdataMat3 = 3;
// Userdata tags can be different from userdata bytecode type indices
constexpr uint8_t kTagVec2 = 12;
struct Vec2
{
float x;
float y;
};
inline bool compareMemberName(const char* member, size_t memberLength, const char* str)
{
return memberLength == strlen(str) && strcmp(member, str) == 0;
}
inline uint8_t typeToUserdataIndex(uint8_t type)
{
// Underflow will push the type into a value that is not comparable to any kUserdata* constants
return type - LBC_TYPE_TAGGED_USERDATA_BASE;
}
inline uint8_t userdataIndexToType(uint8_t userdataIndex)
{
return LBC_TYPE_TAGGED_USERDATA_BASE + userdataIndex;
}
inline uint8_t vectorAccessBytecodeType(const char* member, size_t memberLength) inline uint8_t vectorAccessBytecodeType(const char* member, size_t memberLength)
{ {
using namespace Luau::CodeGen; using namespace Luau::CodeGen;
if (memberLength == strlen("Magnitude") && strcmp(member, "Magnitude") == 0) if (compareMemberName(member, memberLength, "Magnitude"))
return LBC_TYPE_NUMBER; return LBC_TYPE_NUMBER;
if (memberLength == strlen("Unit") && strcmp(member, "Unit") == 0) if (compareMemberName(member, memberLength, "Unit"))
return LBC_TYPE_VECTOR; return LBC_TYPE_VECTOR;
return LBC_TYPE_ANY; return LBC_TYPE_ANY;
@ -22,7 +52,7 @@ inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, si
{ {
using namespace Luau::CodeGen; using namespace Luau::CodeGen;
if (memberLength == strlen("Magnitude") && strcmp(member, "Magnitude") == 0) if (compareMemberName(member, memberLength, "Magnitude"))
{ {
IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4));
@ -42,7 +72,7 @@ inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, si
return true; return true;
} }
if (memberLength == strlen("Unit") && strcmp(member, "Unit") == 0) if (compareMemberName(member, memberLength, "Unit"))
{ {
IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4));
@ -72,10 +102,10 @@ inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, si
inline uint8_t vectorNamecallBytecodeType(const char* member, size_t memberLength) inline uint8_t vectorNamecallBytecodeType(const char* member, size_t memberLength)
{ {
if (memberLength == strlen("Dot") && strcmp(member, "Dot") == 0) if (compareMemberName(member, memberLength, "Dot"))
return LBC_TYPE_NUMBER; return LBC_TYPE_NUMBER;
if (memberLength == strlen("Cross") && strcmp(member, "Cross") == 0) if (compareMemberName(member, memberLength, "Cross"))
return LBC_TYPE_VECTOR; return LBC_TYPE_VECTOR;
return LBC_TYPE_ANY; return LBC_TYPE_ANY;
@ -86,7 +116,7 @@ inline bool vectorNamecall(
{ {
using namespace Luau::CodeGen; using namespace Luau::CodeGen;
if (memberLength == strlen("Dot") && strcmp(member, "Dot") == 0 && params == 2 && results <= 1) if (compareMemberName(member, memberLength, "Dot") && params == 2 && results <= 1)
{ {
build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos)); build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos));
@ -114,7 +144,7 @@ inline bool vectorNamecall(
return true; return true;
} }
if (memberLength == strlen("Cross") && strcmp(member, "Cross") == 0 && params == 2 && results <= 1) if (compareMemberName(member, memberLength, "Cross") && params == 2 && results <= 1)
{ {
build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos)); build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos));
@ -151,3 +181,362 @@ inline bool vectorNamecall(
return false; return false;
} }
inline uint8_t userdataAccessBytecodeType(uint8_t type, const char* member, size_t memberLength)
{
switch (typeToUserdataIndex(type))
{
case kUserdataColor:
if (compareMemberName(member, memberLength, "R"))
return LBC_TYPE_NUMBER;
if (compareMemberName(member, memberLength, "G"))
return LBC_TYPE_NUMBER;
if (compareMemberName(member, memberLength, "B"))
return LBC_TYPE_NUMBER;
break;
case kUserdataVec2:
if (compareMemberName(member, memberLength, "X"))
return LBC_TYPE_NUMBER;
if (compareMemberName(member, memberLength, "Y"))
return LBC_TYPE_NUMBER;
if (compareMemberName(member, memberLength, "Magnitude"))
return LBC_TYPE_NUMBER;
if (compareMemberName(member, memberLength, "Unit"))
return userdataIndexToType(kUserdataVec2);
break;
case kUserdataMat3:
if (compareMemberName(member, memberLength, "Row1"))
return LBC_TYPE_VECTOR;
if (compareMemberName(member, memberLength, "Row2"))
return LBC_TYPE_VECTOR;
if (compareMemberName(member, memberLength, "Row3"))
return LBC_TYPE_VECTOR;
break;
}
return LBC_TYPE_ANY;
}
inline bool userdataAccess(
Luau::CodeGen::IrBuilder& build, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos)
{
using namespace Luau::CodeGen;
switch (typeToUserdataIndex(type))
{
case kUserdataColor:
break;
case kUserdataVec2:
if (compareMemberName(member, memberLength, "X"))
{
IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg));
build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp value = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), value);
build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER));
return true;
}
if (compareMemberName(member, memberLength, "Y"))
{
IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg));
build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp value = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), value);
build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER));
return true;
}
if (compareMemberName(member, memberLength, "Magnitude"))
{
IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg));
build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp x = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp y = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x);
IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y);
IrOp sum = build.inst(IrCmd::ADD_NUM, x2, y2);
IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), mag);
build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER));
return true;
}
if (compareMemberName(member, memberLength, "Unit"))
{
IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg));
build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp x = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp y = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x);
IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y);
IrOp sum = build.inst(IrCmd::ADD_NUM, x2, y2);
IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);
IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag);
IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv);
IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv);
build.inst(IrCmd::CHECK_GC);
IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), xr, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), yr, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar);
build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA));
return true;
}
break;
case kUserdataMat3:
break;
}
return false;
}
inline uint8_t userdataMetamethodBytecodeType(uint8_t lhsTy, uint8_t rhsTy, Luau::CodeGen::HostMetamethod method)
{
switch (method)
{
case Luau::CodeGen::HostMetamethod::Add:
case Luau::CodeGen::HostMetamethod::Sub:
case Luau::CodeGen::HostMetamethod::Mul:
case Luau::CodeGen::HostMetamethod::Div:
if (typeToUserdataIndex(lhsTy) == kUserdataVec2 || typeToUserdataIndex(rhsTy) == kUserdataVec2)
return userdataIndexToType(kUserdataVec2);
break;
case Luau::CodeGen::HostMetamethod::Minus:
if (typeToUserdataIndex(lhsTy) == kUserdataVec2)
return userdataIndexToType(kUserdataVec2);
break;
default:
break;
}
return LBC_TYPE_ANY;
}
inline bool userdataMetamethod(Luau::CodeGen::IrBuilder& build, uint8_t lhsTy, uint8_t rhsTy, int resultReg, Luau::CodeGen::IrOp lhs,
Luau::CodeGen::IrOp rhs, Luau::CodeGen::HostMetamethod method, int pcpos)
{
using namespace Luau::CodeGen;
switch (method)
{
case Luau::CodeGen::HostMetamethod::Add:
if (typeToUserdataIndex(lhsTy) == kUserdataVec2 && typeToUserdataIndex(rhsTy) == kUserdataVec2)
{
build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos));
build.loadAndCheckTag(rhs, LUA_TUSERDATA, build.vmExit(pcpos));
IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs);
build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, rhs);
build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp mx = build.inst(IrCmd::ADD_NUM, x1, x2);
IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp my = build.inst(IrCmd::ADD_NUM, y1, y2);
build.inst(IrCmd::CHECK_GC);
IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar);
build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA));
return true;
}
break;
case Luau::CodeGen::HostMetamethod::Mul:
if (typeToUserdataIndex(lhsTy) == kUserdataVec2 && typeToUserdataIndex(rhsTy) == kUserdataVec2)
{
build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos));
build.loadAndCheckTag(rhs, LUA_TUSERDATA, build.vmExit(pcpos));
IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs);
build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, rhs);
build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp mx = build.inst(IrCmd::MUL_NUM, x1, x2);
IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp my = build.inst(IrCmd::MUL_NUM, y1, y2);
build.inst(IrCmd::CHECK_GC);
IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar);
build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA));
return true;
}
break;
case Luau::CodeGen::HostMetamethod::Minus:
if (typeToUserdataIndex(lhsTy) == kUserdataVec2)
{
build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos));
IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs);
build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp x = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp y = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp mx = build.inst(IrCmd::UNM_NUM, x);
IrOp my = build.inst(IrCmd::UNM_NUM, y);
build.inst(IrCmd::CHECK_GC);
IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar);
build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA));
return true;
}
break;
default:
break;
}
return false;
}
inline uint8_t userdataNamecallBytecodeType(uint8_t type, const char* member, size_t memberLength)
{
switch (typeToUserdataIndex(type))
{
case kUserdataColor:
break;
case kUserdataVec2:
if (compareMemberName(member, memberLength, "Dot"))
return LBC_TYPE_NUMBER;
if (compareMemberName(member, memberLength, "Min"))
return userdataIndexToType(kUserdataVec2);
break;
case kUserdataMat3:
break;
}
return LBC_TYPE_ANY;
}
inline bool userdataNamecall(Luau::CodeGen::IrBuilder& build, uint8_t type, const char* member, size_t memberLength, int argResReg, int sourceReg,
int params, int results, int pcpos)
{
using namespace Luau::CodeGen;
switch (typeToUserdataIndex(type))
{
case kUserdataColor:
break;
case kUserdataVec2:
if (compareMemberName(member, memberLength, "Dot"))
{
IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg));
build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos));
build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TUSERDATA, build.vmExit(pcpos));
IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(argResReg + 2));
build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2);
IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2);
IrOp sum = build.inst(IrCmd::ADD_NUM, xx, yy);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(argResReg), sum);
build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TNUMBER));
// If the function is called in multi-return context, stack has to be adjusted
if (results == LUA_MULTRET)
build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1));
return true;
}
if (compareMemberName(member, memberLength, "Min"))
{
IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg));
build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos));
build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TUSERDATA, build.vmExit(pcpos));
IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(argResReg + 2));
build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos));
IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA));
IrOp mx = build.inst(IrCmd::MIN_NUM, x1, x2);
IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA));
IrOp my = build.inst(IrCmd::MIN_NUM, y1, y2);
build.inst(IrCmd::CHECK_GC);
IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA));
build.inst(IrCmd::STORE_POINTER, build.vmReg(argResReg), udatar);
build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TUSERDATA));
// If the function is called in multi-return context, stack has to be adjusted
if (results == LUA_MULTRET)
build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1));
return true;
}
break;
case kUserdataMat3:
break;
}
return false;
}

View file

@ -24,6 +24,8 @@ LUAU_FASTFLAG(LuauCompileTempTypeInfo)
LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps) LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps)
LUAU_FASTFLAG(LuauCompileUserdataInfo) LUAU_FASTFLAG(LuauCompileUserdataInfo)
LUAU_FASTFLAG(LuauLoadUserdataInfo) LUAU_FASTFLAG(LuauLoadUserdataInfo)
LUAU_FASTFLAG(LuauCodegenUserdataOps)
LUAU_FASTFLAG(LuauCodegenUserdataAlloc)
static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1) static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1)
{ {
@ -34,6 +36,13 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes =
options.compilationOptions.hooks.vectorAccess = vectorAccess; options.compilationOptions.hooks.vectorAccess = vectorAccess;
options.compilationOptions.hooks.vectorNamecall = vectorNamecall; options.compilationOptions.hooks.vectorNamecall = vectorNamecall;
options.compilationOptions.hooks.userdataAccessBytecodeType = userdataAccessBytecodeType;
options.compilationOptions.hooks.userdataMetamethodBytecodeType = userdataMetamethodBytecodeType;
options.compilationOptions.hooks.userdataNamecallBytecodeType = userdataNamecallBytecodeType;
options.compilationOptions.hooks.userdataAccess = userdataAccess;
options.compilationOptions.hooks.userdataMetamethod = userdataMetamethod;
options.compilationOptions.hooks.userdataNamecall = userdataNamecall;
// For IR, we don't care about assembly, but we want a stable target // For IR, we don't care about assembly, but we want a stable target
options.target = Luau::CodeGen::AssemblyOptions::Target::X64_SystemV; options.target = Luau::CodeGen::AssemblyOptions::Target::X64_SystemV;
@ -1690,4 +1699,352 @@ end
)"); )");
} }
TEST_CASE("CustomUserdataPropertyAccess")
{
// This test requires runtime component to be present
if (!Luau::CodeGen::isSupported())
return;
ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true},
{FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true},
{FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(v: vec2)
return v.X + v.Y
end
)",
/* includeIrTypes */ true),
R"(
; function foo($arg0) line 2
; R0: vec2 [argument]
bb_0:
CHECK_TAG R0, tuserdata, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
%6 = LOAD_POINTER R0
CHECK_USERDATA_TAG %6, 12i, exit(0)
%8 = BUFFER_READF32 %6, 0i, tuserdata
%15 = BUFFER_READF32 %6, 4i, tuserdata
%24 = ADD_NUM %8, %15
STORE_DOUBLE R1, %24
STORE_TAG R1, tnumber
INTERRUPT 5u
RETURN R1, 1i
)");
}
TEST_CASE("CustomUserdataPropertyAccess2")
{
// This test requires runtime component to be present
if (!Luau::CodeGen::isSupported())
return;
ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true},
{FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true},
{FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(a: mat3)
return a.Row1 * a.Row2
end
)",
/* includeIrTypes */ true),
R"(
; function foo($arg0) line 2
; R0: mat3 [argument]
bb_0:
CHECK_TAG R0, tuserdata, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
FALLBACK_GETTABLEKS 0u, R2, R0, K0
FALLBACK_GETTABLEKS 2u, R3, R0, K1
CHECK_TAG R2, tvector, exit(4)
CHECK_TAG R3, tvector, exit(4)
%14 = LOAD_TVALUE R2
%15 = LOAD_TVALUE R3
%16 = MUL_VEC %14, %15
%17 = TAG_VECTOR %16
STORE_TVALUE R1, %17
INTERRUPT 5u
RETURN R1, 1i
)");
}
TEST_CASE("CustomUserdataNamecall1")
{
// This test requires runtime component to be present
if (!Luau::CodeGen::isSupported())
return;
ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true},
{FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true},
{FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true},
{FFlag::LuauCodegenUserdataOps, true}};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(a: vec2, b: vec2)
return a:Dot(b)
end
)",
/* includeIrTypes */ true),
R"(
; function foo($arg0, $arg1) line 2
; R0: vec2 [argument]
; R1: vec2 [argument]
bb_0:
CHECK_TAG R0, tuserdata, exit(entry)
CHECK_TAG R1, tuserdata, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
%6 = LOAD_TVALUE R1
STORE_TVALUE R4, %6
%10 = LOAD_POINTER R0
CHECK_USERDATA_TAG %10, 12i, exit(1)
%14 = LOAD_POINTER R4
CHECK_USERDATA_TAG %14, 12i, exit(1)
%16 = BUFFER_READF32 %10, 0i, tuserdata
%17 = BUFFER_READF32 %14, 0i, tuserdata
%18 = MUL_NUM %16, %17
%19 = BUFFER_READF32 %10, 4i, tuserdata
%20 = BUFFER_READF32 %14, 4i, tuserdata
%21 = MUL_NUM %19, %20
%22 = ADD_NUM %18, %21
STORE_DOUBLE R2, %22
STORE_TAG R2, tnumber
ADJUST_STACK_TO_REG R2, 1i
INTERRUPT 4u
RETURN R2, -1i
)");
}
TEST_CASE("CustomUserdataNamecall2")
{
// This test requires runtime component to be present
if (!Luau::CodeGen::isSupported())
return;
ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true},
{FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true},
{FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true},
{FFlag::LuauCodegenUserdataOps, true}, {FFlag::LuauCodegenUserdataAlloc, true}};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(a: vec2, b: vec2)
return a:Min(b)
end
)",
/* includeIrTypes */ true),
R"(
; function foo($arg0, $arg1) line 2
; R0: vec2 [argument]
; R1: vec2 [argument]
bb_0:
CHECK_TAG R0, tuserdata, exit(entry)
CHECK_TAG R1, tuserdata, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
%6 = LOAD_TVALUE R1
STORE_TVALUE R4, %6
%10 = LOAD_POINTER R0
CHECK_USERDATA_TAG %10, 12i, exit(1)
%14 = LOAD_POINTER R4
CHECK_USERDATA_TAG %14, 12i, exit(1)
%16 = BUFFER_READF32 %10, 0i, tuserdata
%17 = BUFFER_READF32 %14, 0i, tuserdata
%18 = MIN_NUM %16, %17
%19 = BUFFER_READF32 %10, 4i, tuserdata
%20 = BUFFER_READF32 %14, 4i, tuserdata
%21 = MIN_NUM %19, %20
CHECK_GC
%23 = NEW_USERDATA 8i, 12i
BUFFER_WRITEF32 %23, 0i, %18, tuserdata
BUFFER_WRITEF32 %23, 4i, %21, tuserdata
STORE_POINTER R2, %23
STORE_TAG R2, tuserdata
ADJUST_STACK_TO_REG R2, 1i
INTERRUPT 4u
RETURN R2, -1i
)");
}
TEST_CASE("CustomUserdataMetamethodDirectFlow")
{
// This test requires runtime component to be present
if (!Luau::CodeGen::isSupported())
return;
ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true},
{FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true},
{FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(a: mat3, b: mat3)
return a * b
end
)",
/* includeIrTypes */ true),
R"(
; function foo($arg0, $arg1) line 2
; R0: mat3 [argument]
; R1: mat3 [argument]
bb_0:
CHECK_TAG R0, tuserdata, exit(entry)
CHECK_TAG R1, tuserdata, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
SET_SAVEDPC 1u
DO_ARITH R2, R0, R1, 10i
INTERRUPT 1u
RETURN R2, 1i
)");
}
TEST_CASE("CustomUserdataMetamethodDirectFlow2")
{
// This test requires runtime component to be present
if (!Luau::CodeGen::isSupported())
return;
ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true},
{FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true},
{FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(a: mat3)
return -a
end
)",
/* includeIrTypes */ true),
R"(
; function foo($arg0) line 2
; R0: mat3 [argument]
bb_0:
CHECK_TAG R0, tuserdata, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
SET_SAVEDPC 1u
DO_ARITH R1, R0, R0, 15i
INTERRUPT 1u
RETURN R1, 1i
)");
}
TEST_CASE("CustomUserdataMetamethodDirectFlow3")
{
// This test requires runtime component to be present
if (!Luau::CodeGen::isSupported())
return;
ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true},
{FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true},
{FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true}};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(a: sequence)
return #a
end
)",
/* includeIrTypes */ true),
R"(
; function foo($arg0) line 2
; R0: userdata [argument]
bb_0:
CHECK_TAG R0, tuserdata, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
SET_SAVEDPC 1u
DO_LEN R1, R0
INTERRUPT 1u
RETURN R1, 1i
)");
}
TEST_CASE("CustomUserdataMetamethod")
{
// This test requires runtime component to be present
if (!Luau::CodeGen::isSupported())
return;
ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true},
{FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCompileUserdataInfo, true},
{FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenUserdataOps, true},
{FFlag::LuauCodegenUserdataAlloc, true}};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(a: vec2, b: vec2, c: vec2)
return -c + a * b
end
)",
/* includeIrTypes */ true),
R"(
; function foo($arg0, $arg1, $arg2) line 2
; R0: vec2 [argument]
; R1: vec2 [argument]
; R2: vec2 [argument]
bb_0:
CHECK_TAG R0, tuserdata, exit(entry)
CHECK_TAG R1, tuserdata, exit(entry)
CHECK_TAG R2, tuserdata, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
%10 = LOAD_POINTER R2
CHECK_USERDATA_TAG %10, 12i, exit(0)
%12 = BUFFER_READF32 %10, 0i, tuserdata
%13 = BUFFER_READF32 %10, 4i, tuserdata
%14 = UNM_NUM %12
%15 = UNM_NUM %13
CHECK_GC
%17 = NEW_USERDATA 8i, 12i
BUFFER_WRITEF32 %17, 0i, %14, tuserdata
BUFFER_WRITEF32 %17, 4i, %15, tuserdata
STORE_POINTER R4, %17
STORE_TAG R4, tuserdata
%26 = LOAD_POINTER R0
CHECK_USERDATA_TAG %26, 12i, exit(1)
%28 = LOAD_POINTER R1
CHECK_USERDATA_TAG %28, 12i, exit(1)
%30 = BUFFER_READF32 %26, 0i, tuserdata
%31 = BUFFER_READF32 %28, 0i, tuserdata
%32 = MUL_NUM %30, %31
%33 = BUFFER_READF32 %26, 4i, tuserdata
%34 = BUFFER_READF32 %28, 4i, tuserdata
%35 = MUL_NUM %33, %34
%37 = NEW_USERDATA 8i, 12i
BUFFER_WRITEF32 %37, 0i, %32, tuserdata
BUFFER_WRITEF32 %37, 4i, %35, tuserdata
STORE_POINTER R5, %37
STORE_TAG R5, tuserdata
%50 = BUFFER_READF32 %17, 0i, tuserdata
%51 = BUFFER_READF32 %37, 0i, tuserdata
%52 = ADD_NUM %50, %51
%53 = BUFFER_READF32 %17, 4i, tuserdata
%54 = BUFFER_READF32 %37, 4i, tuserdata
%55 = ADD_NUM %53, %54
%57 = NEW_USERDATA 8i, 12i
BUFFER_WRITEF32 %57, 0i, %52, tuserdata
BUFFER_WRITEF32 %57, 4i, %55, tuserdata
STORE_POINTER R3, %57
STORE_TAG R3, tuserdata
INTERRUPT 3u
RETURN R3, 1i
)");
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -192,13 +192,13 @@ TEST_CASE("string_interpolation_double_brace")
auto brokenInterpBegin = lexer.next(); auto brokenInterpBegin = lexer.next();
CHECK_EQ(brokenInterpBegin.type, Lexeme::BrokenInterpDoubleBrace); CHECK_EQ(brokenInterpBegin.type, Lexeme::BrokenInterpDoubleBrace);
CHECK_EQ(std::string(brokenInterpBegin.data, brokenInterpBegin.length), std::string("foo")); CHECK_EQ(std::string(brokenInterpBegin.data, brokenInterpBegin.getLength()), std::string("foo"));
CHECK_EQ(lexer.next().type, Lexeme::Name); CHECK_EQ(lexer.next().type, Lexeme::Name);
auto interpEnd = lexer.next(); auto interpEnd = lexer.next();
CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd); CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd);
CHECK_EQ(std::string(interpEnd.data, interpEnd.length), std::string("}bar")); CHECK_EQ(std::string(interpEnd.data, interpEnd.getLength()), std::string("}bar"));
} }
TEST_CASE("string_interpolation_double_but_unmatched_brace") TEST_CASE("string_interpolation_double_but_unmatched_brace")

View file

@ -15,6 +15,8 @@
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauAttributeSyntax);
#define NONSTRICT_REQUIRE_ERR_AT_POS(pos, result, idx) \ #define NONSTRICT_REQUIRE_ERR_AT_POS(pos, result, idx) \
do \ do \
{ \ { \
@ -68,6 +70,7 @@ struct NonStrictTypeCheckerFixture : Fixture
{ {
ScopedFastFlag flags[] = { ScopedFastFlag flags[] = {
{FFlag::DebugLuauDeferredConstraintResolution, true}, {FFlag::DebugLuauDeferredConstraintResolution, true},
{FFlag::LuauAttributeSyntax, true},
}; };
LoadDefinitionFileResult res = loadDefinition(definitions); LoadDefinitionFileResult res = loadDefinition(definitions);
LUAU_ASSERT(res.success); LUAU_ASSERT(res.success);
@ -78,6 +81,7 @@ struct NonStrictTypeCheckerFixture : Fixture
{ {
ScopedFastFlag flags[] = { ScopedFastFlag flags[] = {
{FFlag::DebugLuauDeferredConstraintResolution, true}, {FFlag::DebugLuauDeferredConstraintResolution, true},
{FFlag::LuauAttributeSyntax, true},
}; };
LoadDefinitionFileResult res = loadDefinition(definitions); LoadDefinitionFileResult res = loadDefinition(definitions);
LUAU_ASSERT(res.success); LUAU_ASSERT(res.success);
@ -85,21 +89,21 @@ struct NonStrictTypeCheckerFixture : Fixture
} }
std::string definitions = R"BUILTIN_SRC( std::string definitions = R"BUILTIN_SRC(
declare function @checked abs(n: number): number @checked declare function abs(n: number): number
declare function @checked lower(s: string): string @checked declare function lower(s: string): string
declare function cond() : boolean declare function cond() : boolean
declare function @checked contrived(n : Not<number>) : number @checked declare function contrived(n : Not<number>) : number
-- interesting types of things that we would like to mark as checked -- interesting types of things that we would like to mark as checked
declare function @checked onlyNums(...: number) : number @checked declare function onlyNums(...: number) : number
declare function @checked mixedArgs(x: string, ...: number) : number @checked declare function mixedArgs(x: string, ...: number) : number
declare function @checked optionalArg(x: string?) : number @checked declare function optionalArg(x: string?) : number
declare foo: { declare foo: {
bar: @checked (number) -> number, bar: @checked (number) -> number,
} }
declare function @checked optionalArgsAtTheEnd1(x: string, y: number?, z: number?) : number @checked declare function optionalArgsAtTheEnd1(x: string, y: number?, z: number?) : number
declare function @checked optionalArgsAtTheEnd2(x: string, y: number?, z: string) : number @checked declare function optionalArgsAtTheEnd2(x: string, y: number?, z: string) : number
type DateTypeArg = { type DateTypeArg = {
year: number, year: number,
@ -115,7 +119,7 @@ declare os : {
time: @checked (time: DateTypeArg?) -> number time: @checked (time: DateTypeArg?) -> number
} }
declare function @checked require(target : any) : any @checked declare function require(target : any) : any
)BUILTIN_SRC"; )BUILTIN_SRC";
}; };
@ -558,6 +562,10 @@ local E = require(script.Parent.A)
TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "nonstrict_shouldnt_warn_on_valid_buffer_use") TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "nonstrict_shouldnt_warn_on_valid_buffer_use")
{ {
ScopedFastFlag flags[] = {
{FFlag::LuauAttributeSyntax, true},
};
loadDefinition(R"( loadDefinition(R"(
declare buffer: { declare buffer: {
create: @checked (size: number) -> buffer, create: @checked (size: number) -> buffer,

View file

@ -16,6 +16,7 @@ LUAU_FASTINT(LuauRecursionLimit);
LUAU_FASTINT(LuauTypeLengthLimit); LUAU_FASTINT(LuauTypeLengthLimit);
LUAU_FASTINT(LuauParseErrorLimit); LUAU_FASTINT(LuauParseErrorLimit);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAG(LuauAttributeSyntax);
LUAU_FASTFLAG(LuauLeadingBarAndAmpersand); LUAU_FASTFLAG(LuauLeadingBarAndAmpersand);
namespace namespace
@ -3051,9 +3052,10 @@ TEST_CASE_FIXTURE(Fixture, "parse_top_level_checked_fn")
{ {
ParseOptions opts; ParseOptions opts;
opts.allowDeclarationSyntax = true; opts.allowDeclarationSyntax = true;
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
std::string src = R"BUILTIN_SRC( std::string src = R"BUILTIN_SRC(
declare function @checked abs(n: number): number @checked declare function abs(n: number): number
)BUILTIN_SRC"; )BUILTIN_SRC";
ParseResult pr = tryParse(src, opts); ParseResult pr = tryParse(src, opts);
@ -3063,13 +3065,14 @@ declare function @checked abs(n: number): number
AstStat* root = *(pr.root->body.data); AstStat* root = *(pr.root->body.data);
auto func = root->as<AstStatDeclareFunction>(); auto func = root->as<AstStatDeclareFunction>();
LUAU_ASSERT(func); LUAU_ASSERT(func);
LUAU_ASSERT(func->checkedFunction); LUAU_ASSERT(func->isCheckedFunction());
} }
TEST_CASE_FIXTURE(Fixture, "parse_declared_table_checked_member") TEST_CASE_FIXTURE(Fixture, "parse_declared_table_checked_member")
{ {
ParseOptions opts; ParseOptions opts;
opts.allowDeclarationSyntax = true; opts.allowDeclarationSyntax = true;
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
const std::string src = R"BUILTIN_SRC( const std::string src = R"BUILTIN_SRC(
declare math : { declare math : {
@ -3090,13 +3093,14 @@ TEST_CASE_FIXTURE(Fixture, "parse_declared_table_checked_member")
auto prop = *tbl->props.data; auto prop = *tbl->props.data;
auto func = prop.type->as<AstTypeFunction>(); auto func = prop.type->as<AstTypeFunction>();
LUAU_ASSERT(func); LUAU_ASSERT(func);
LUAU_ASSERT(func->checkedFunction); LUAU_ASSERT(func->isCheckedFunction());
} }
TEST_CASE_FIXTURE(Fixture, "parse_checked_outside_decl_fails") TEST_CASE_FIXTURE(Fixture, "parse_checked_outside_decl_fails")
{ {
ParseOptions opts; ParseOptions opts;
opts.allowDeclarationSyntax = true; opts.allowDeclarationSyntax = true;
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
ParseResult pr = tryParse(R"( ParseResult pr = tryParse(R"(
local @checked = 3 local @checked = 3
@ -3110,10 +3114,11 @@ TEST_CASE_FIXTURE(Fixture, "parse_checked_in_and_out_of_decl_fails")
{ {
ParseOptions opts; ParseOptions opts;
opts.allowDeclarationSyntax = true; opts.allowDeclarationSyntax = true;
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
auto pr = tryParse(R"( auto pr = tryParse(R"(
local @checked = 3 local @checked = 3
declare function @checked abs(n: number): number @checked declare function abs(n: number): number
)", )",
opts); opts);
LUAU_ASSERT(pr.errors.size() == 2); LUAU_ASSERT(pr.errors.size() == 2);
@ -3125,9 +3130,10 @@ TEST_CASE_FIXTURE(Fixture, "parse_checked_as_function_name_fails")
{ {
ParseOptions opts; ParseOptions opts;
opts.allowDeclarationSyntax = true; opts.allowDeclarationSyntax = true;
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
auto pr = tryParse(R"( auto pr = tryParse(R"(
function @checked(x: number) : number @checked function(x: number) : number
end end
)", )",
opts); opts);
@ -3138,6 +3144,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_use_@_as_variable_name")
{ {
ParseOptions opts; ParseOptions opts;
opts.allowDeclarationSyntax = true; opts.allowDeclarationSyntax = true;
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
auto pr = tryParse(R"( auto pr = tryParse(R"(
local @blah = 3 local @blah = 3
@ -3190,4 +3197,300 @@ TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed")
matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses.");
} }
void checkAttribute(const AstAttr* attr, const AstAttr::Type type, const Location& location)
{
CHECK_EQ(attr->type, type);
CHECK_EQ(attr->location, location);
}
void checkFirstErrorForAttributes(const std::vector<ParseError>& errors, const size_t minSize, const Location& location, const std::string& message)
{
LUAU_ASSERT(minSize >= 1);
CHECK_GE(errors.size(), minSize);
CHECK_EQ(errors[0].getLocation(), location);
CHECK_EQ(errors[0].getMessage(), message);
}
TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_function_stat")
{
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
AstStatBlock* stat = parse(R"(
@checked
function hello(x, y)
return x + y
end)");
LUAU_ASSERT(stat != nullptr);
AstStatFunction* statFun = stat->body.data[0]->as<AstStatFunction>();
LUAU_ASSERT(statFun != nullptr);
AstArray<AstAttr*> attributes = statFun->func->attributes;
CHECK_EQ(attributes.size, 1);
checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 0), Position(1, 8)));
}
TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_local_function_stat")
{
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
AstStatBlock* stat = parse(R"(
@checked
local function hello(x, y)
return x + y
end)");
LUAU_ASSERT(stat != nullptr);
AstStatLocalFunction* statFun = stat->body.data[0]->as<AstStatLocalFunction>();
LUAU_ASSERT(statFun != nullptr);
AstArray<AstAttr*> attributes = statFun->func->attributes;
CHECK_EQ(attributes.size, 1);
checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 4), Position(1, 12)));
}
TEST_CASE_FIXTURE(Fixture, "empty_attribute_name_is_not_allowed")
{
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
ParseResult result = tryParse(R"(
@
function hello(x, y)
return x + y
end)");
checkFirstErrorForAttributes(result.errors, 1, Location(Position(1, 0), Position(1, 1)), "Attribute name is missing");
}
TEST_CASE_FIXTURE(Fixture, "dont_parse_attributes_on_non_function_stat")
{
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
ParseResult pr1 = tryParse(R"(
@checked
if a<0 then a = 0 end)");
checkFirstErrorForAttributes(pr1.errors, 1, Location(Position(2, 0), Position(2, 2)),
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'if' intead");
ParseResult pr2 = tryParse(R"(
local i = 1
@checked
while a[i] do
print(a[i])
i = i + 1
end)");
checkFirstErrorForAttributes(pr2.errors, 1, Location(Position(3, 0), Position(3, 5)),
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'while' intead");
ParseResult pr3 = tryParse(R"(
@checked
do
local a2 = 2*a
local d = sqrt(b^2 - 4*a*c)
x1 = (-b + d)/a2
x2 = (-b - d)/a2
end)");
checkFirstErrorForAttributes(pr3.errors, 1, Location(Position(2, 0), Position(2, 2)),
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'do' intead");
ParseResult pr4 = tryParse(R"(
@checked
for i=1,10 do print(i) end
)");
checkFirstErrorForAttributes(pr4.errors, 1, Location(Position(2, 0), Position(2, 3)),
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'for' intead");
ParseResult pr5 = tryParse(R"(
@checked
repeat
line = io.read()
until line ~= ""
)");
checkFirstErrorForAttributes(pr5.errors, 1, Location(Position(2, 0), Position(2, 6)),
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'repeat' intead");
ParseResult pr6 = tryParse(R"(
@checked
local x = 10
)");
checkFirstErrorForAttributes(
pr6.errors, 1, Location(Position(2, 6), Position(2, 7)), "Expected 'function' after local declaration with attribute, but got 'x' intead");
ParseResult pr7 = tryParse(R"(
local i = 1
while a[i] do
if a[i] == v then @checked break end
i = i + 1
end
)");
checkFirstErrorForAttributes(pr7.errors, 1, Location(Position(3, 31), Position(3, 36)),
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'break' intead");
ParseResult pr8 = tryParse(R"(
function foo1 () @checked return 'a' end
)");
checkFirstErrorForAttributes(pr8.errors, 1, Location(Position(1, 26), Position(1, 32)),
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'return' intead");
}
TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_function_type_declaration")
{
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
ParseOptions opts;
opts.allowDeclarationSyntax = true;
std::string src = R"(
@checked declare function abs(n: number): number
)";
ParseResult pr = tryParse(src, opts);
CHECK_EQ(pr.errors.size(), 0);
LUAU_ASSERT(pr.root->body.size == 1);
AstStat* root = *(pr.root->body.data);
auto func = root->as<AstStatDeclareFunction>();
LUAU_ASSERT(func != nullptr);
CHECK(func->isCheckedFunction());
AstArray<AstAttr*> attributes = func->attributes;
checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 0), Position(1, 8)));
}
TEST_CASE_FIXTURE(Fixture, "parse_attributes_on_function_type_declaration_in_table")
{
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
ParseOptions opts;
opts.allowDeclarationSyntax = true;
std::string src = R"(
declare bit32: {
band: @checked (...number) -> number
})";
ParseResult pr = tryParse(src, opts);
CHECK_EQ(pr.errors.size(), 0);
LUAU_ASSERT(pr.root->body.size == 1);
AstStat* root = *(pr.root->body.data);
AstStatDeclareGlobal* glob = root->as<AstStatDeclareGlobal>();
LUAU_ASSERT(glob);
auto tbl = glob->type->as<AstTypeTable>();
LUAU_ASSERT(tbl);
LUAU_ASSERT(tbl->props.size == 1);
AstTableProp prop = tbl->props.data[0];
AstTypeFunction* func = prop.type->as<AstTypeFunction>();
LUAU_ASSERT(func);
AstArray<AstAttr*> attributes = func->attributes;
CHECK_EQ(attributes.size, 1);
checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(2, 10), Position(2, 18)));
}
TEST_CASE_FIXTURE(Fixture, "dont_parse_attributes_on_non_function_type_declarations")
{
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
ParseOptions opts;
opts.allowDeclarationSyntax = true;
ParseResult pr1 = tryParse(R"(
@checked declare foo: number
)",
opts);
checkFirstErrorForAttributes(
pr1.errors, 1, Location(Position(1, 17), Position(1, 20)), "Expected a function type declaration after attribute, but got 'foo' intead");
ParseResult pr2 = tryParse(R"(
@checked declare class Foo
prop: number
function method(self, foo: number): string
end)",
opts);
checkFirstErrorForAttributes(
pr2.errors, 1, Location(Position(1, 17), Position(1, 22)), "Expected a function type declaration after attribute, but got 'class' intead");
ParseResult pr3 = tryParse(R"(
declare bit32: {
band: @checked number
})",
opts);
checkFirstErrorForAttributes(
pr3.errors, 1, Location(Position(2, 19), Position(2, 25)), "Expected '(' when parsing function parameters, got 'number'");
}
TEST_CASE_FIXTURE(Fixture, "attributes_cannot_be_duplicated")
{
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
ParseResult result = tryParse(R"(
@checked
@checked
function hello(x, y)
return x + y
end)");
checkFirstErrorForAttributes(result.errors, 1, Location(Position(2, 4), Position(2, 12)), "Cannot duplicate attribute '@checked'");
}
TEST_CASE_FIXTURE(Fixture, "unsupported_attributes_are_not_allowed")
{
ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true};
ParseResult result = tryParse(R"(
@checked
@cool_attribute
function hello(x, y)
return x + y
end)");
checkFirstErrorForAttributes(result.errors, 1, Location(Position(2, 4), Position(2, 19)), "Invalid attribute '@cool_attribute'");
}
TEST_CASE_FIXTURE(Fixture, "can_parse_leading_bar_unions_successfully")
{
ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true};
parse(R"(type A = | "Hello" | "World")");
}
TEST_CASE_FIXTURE(Fixture, "can_parse_leading_ampersand_intersections_successfully")
{
ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true};
parse(R"(type A = & { string } & { number })");
}
TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed")
{
ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand, true};
matchParseError("type A = & number | string | boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses.");
matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses.");
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -13,6 +13,7 @@ using namespace Luau;
LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAG(DebugLuauSharedSelf); LUAU_FASTFLAG(DebugLuauSharedSelf);
LUAU_FASTFLAG(LuauAttributeSyntax);
TEST_SUITE_BEGIN("ToString"); TEST_SUITE_BEGIN("ToString");
@ -1010,10 +1011,11 @@ TEST_CASE_FIXTURE(Fixture, "checked_fn_toString")
{ {
ScopedFastFlag flags[] = { ScopedFastFlag flags[] = {
{FFlag::DebugLuauDeferredConstraintResolution, true}, {FFlag::DebugLuauDeferredConstraintResolution, true},
{FFlag::LuauAttributeSyntax, true},
}; };
auto _result = loadDefinition(R"( auto _result = loadDefinition(R"(
declare function @checked abs(n: number) : number @checked declare function abs(n: number) : number
)"); )");
auto result = check(Mode::Nonstrict, R"( auto result = check(Mode::Nonstrict, R"(

View file

@ -701,7 +701,7 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_basic")
REQUIRE(ut); REQUIRE(ut);
REQUIRE(ut->options.size() == 2); REQUIRE(ut->options.size() == 2);
CHECK_EQ(builtinTypes->nilType, ut->options[0]); CHECK_EQ(builtinTypes->nilType, follow(ut->options[0]));
CHECK_EQ(*builtinTypes->numberType, *ut->options[1]); CHECK_EQ(*builtinTypes->numberType, *ut->options[1]);
} }
else else
@ -1179,4 +1179,14 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_preserves_error_suppression")
CHECK("any" == toString(requireTypeAtPosition({3, 25}))); CHECK("any" == toString(requireTypeAtPosition({3, 25})));
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "tryDispatchIterableFunction_under_constrained_loop_should_not_assert")
{
CheckResult result = check(R"(
local function foo(Instance)
for _, Child in next, Instance:GetChildren() do
end
end
)");
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -3153,7 +3153,7 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch")
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
if (FFlag::DebugLuauDeferredConstraintResolution) if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ("Value of type '{ x: number? }?' could be nil", toString(result.errors[0])); CHECK_EQ("Type 'nil' does not have key 'x'", toString(result.errors[0]));
else else
CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0]));
CHECK_EQ("boolean", toString(requireType("u"))); CHECK_EQ("boolean", toString(requireType("u")));
@ -4439,7 +4439,13 @@ TEST_CASE_FIXTURE(Fixture, "insert_a_and_f_of_a_into_table_res_in_a_loop")
end end
)"); )");
LUAU_REQUIRE_NO_ERRORS(result); if (FFlag::DebugLuauDeferredConstraintResolution)
{
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK(get<FunctionExitsWithoutReturning>(result.errors[0]));
}
else
LUAU_REQUIRE_NO_ERRORS(result);
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_adds_an_unbounded_indexer") TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_adds_an_unbounded_indexer")

View file

@ -0,0 +1,42 @@
-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
print('testing userdata')
function ecall(fn, ...)
local ok, err = pcall(fn, ...)
assert(not ok)
return err:sub((err:find(": ") or -1) + 2, #err)
end
local function realmad(a: vec2, b: vec2, c: vec2): vec2
return -c + a * b;
end
local function dm(s: vec2, t: vec2, u: vec2)
local x = s:Dot(t)
assert(x == 13)
local t = u:Min(s)
assert(t.X == 5)
assert(t.Y == 4)
end
local s: vec2 = vec2(5, 4)
local t: vec2 = vec2(1, 2)
local u: vec2 = vec2(10, 20)
local x: vec2 = realmad(s, t, u)
assert(x.X == -5)
assert(x.Y == -12)
dm(s, t, u)
local function mu(v: vec2)
assert(v.Magnitude == 2)
assert(v.Unit.X == 0)
assert(v.Unit.Y == 1)
end
mu(vec2(0, 2))
return 'OK'

View file

@ -4,7 +4,6 @@ AutocompleteTest.anonymous_autofilled_generic_type_pack_vararg
AutocompleteTest.autocomplete_string_singletons AutocompleteTest.autocomplete_string_singletons
AutocompleteTest.do_wrong_compatible_nonself_calls AutocompleteTest.do_wrong_compatible_nonself_calls
AutocompleteTest.string_singleton_as_table_key AutocompleteTest.string_singleton_as_table_key
AutocompleteTest.string_singleton_in_if_statement2
AutocompleteTest.suggest_table_keys AutocompleteTest.suggest_table_keys
AutocompleteTest.type_correct_suggestion_for_overloads AutocompleteTest.type_correct_suggestion_for_overloads
AutocompleteTest.type_correct_suggestion_in_table AutocompleteTest.type_correct_suggestion_in_table
@ -33,6 +32,15 @@ BuiltinTests.string_format_report_all_type_errors_at_correct_positions
BuiltinTests.string_format_use_correct_argument2 BuiltinTests.string_format_use_correct_argument2
BuiltinTests.table_freeze_is_generic BuiltinTests.table_freeze_is_generic
BuiltinTests.tonumber_returns_optional_number_type BuiltinTests.tonumber_returns_optional_number_type
ControlFlowAnalysis.for_record_do_if_not_x_break
ControlFlowAnalysis.for_record_do_if_not_x_continue
ControlFlowAnalysis.if_not_x_break_elif_not_y_break
ControlFlowAnalysis.if_not_x_break_elif_not_y_continue
ControlFlowAnalysis.if_not_x_break_elif_rand_break_elif_not_y_break
ControlFlowAnalysis.if_not_x_continue_elif_not_y_continue
ControlFlowAnalysis.if_not_x_continue_elif_not_y_throw_elif_not_z_fallthrough
ControlFlowAnalysis.if_not_x_continue_elif_rand_continue_elif_not_y_continue
ControlFlowAnalysis.if_not_x_return_elif_not_y_break
DefinitionTests.class_definition_overload_metamethods DefinitionTests.class_definition_overload_metamethods
Differ.metatable_metamissing_left Differ.metatable_metamissing_left
Differ.metatable_metamissing_right Differ.metatable_metamissing_right
@ -46,7 +54,6 @@ FrontendTest.trace_requires_in_nonstrict_mode
GenericsTests.apply_type_function_nested_generics1 GenericsTests.apply_type_function_nested_generics1
GenericsTests.better_mismatch_error_messages GenericsTests.better_mismatch_error_messages
GenericsTests.bound_tables_do_not_clone_original_fields GenericsTests.bound_tables_do_not_clone_original_fields
GenericsTests.correctly_instantiate_polymorphic_member_functions
GenericsTests.do_not_always_instantiate_generic_intersection_types GenericsTests.do_not_always_instantiate_generic_intersection_types
GenericsTests.do_not_infer_generic_functions GenericsTests.do_not_infer_generic_functions
GenericsTests.dont_substitute_bound_types GenericsTests.dont_substitute_bound_types
@ -135,6 +142,7 @@ RefinementTest.discriminate_from_isa_of_x
RefinementTest.discriminate_from_truthiness_of_x RefinementTest.discriminate_from_truthiness_of_x
RefinementTest.globals_can_be_narrowed_too RefinementTest.globals_can_be_narrowed_too
RefinementTest.isa_type_refinement_must_be_known_ahead_of_time RefinementTest.isa_type_refinement_must_be_known_ahead_of_time
RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true
RefinementTest.not_t_or_some_prop_of_t RefinementTest.not_t_or_some_prop_of_t
RefinementTest.refine_a_param_that_got_resolved_during_constraint_solving_stage RefinementTest.refine_a_param_that_got_resolved_during_constraint_solving_stage
RefinementTest.refine_a_property_of_some_global RefinementTest.refine_a_property_of_some_global
@ -278,7 +286,9 @@ TypeInferAnyError.can_subscript_any
TypeInferAnyError.for_in_loop_iterator_is_any TypeInferAnyError.for_in_loop_iterator_is_any
TypeInferAnyError.for_in_loop_iterator_is_any2 TypeInferAnyError.for_in_loop_iterator_is_any2
TypeInferAnyError.for_in_loop_iterator_is_any_pack TypeInferAnyError.for_in_loop_iterator_is_any_pack
TypeInferAnyError.for_in_loop_iterator_returns_any
TypeInferAnyError.for_in_loop_iterator_returns_any2 TypeInferAnyError.for_in_loop_iterator_returns_any2
TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any
TypeInferClasses.callable_classes TypeInferClasses.callable_classes
TypeInferClasses.cannot_unify_class_instance_with_primitive TypeInferClasses.cannot_unify_class_instance_with_primitive
TypeInferClasses.class_type_mismatch_with_name_conflict TypeInferClasses.class_type_mismatch_with_name_conflict
@ -337,6 +347,7 @@ TypeInferFunctions.too_many_arguments
TypeInferFunctions.too_many_arguments_error_location TypeInferFunctions.too_many_arguments_error_location
TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_in_parentheses
TypeInferFunctions.too_many_return_values_no_function TypeInferFunctions.too_many_return_values_no_function
TypeInferFunctions.unifier_should_not_bind_free_types
TypeInferLoops.cli_68448_iterators_need_not_accept_nil TypeInferLoops.cli_68448_iterators_need_not_accept_nil
TypeInferLoops.dcr_iteration_on_never_gives_never TypeInferLoops.dcr_iteration_on_never_gives_never
TypeInferLoops.dcr_xpath_candidates TypeInferLoops.dcr_xpath_candidates
@ -363,7 +374,6 @@ TypeInferModules.require
TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2
TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon
TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory
TypeInferOOP.methods_are_topologically_sorted
TypeInferOOP.promise_type_error_too_complex TypeInferOOP.promise_type_error_too_complex
TypeInferOperators.add_type_family_works TypeInferOperators.add_type_family_works
TypeInferOperators.cli_38355_recursive_union TypeInferOperators.cli_38355_recursive_union