Merge branch 'master' of https://github.com/Roblox/luau into store-class-type-location

This commit is contained in:
JohnnyMorganz 2024-07-07 13:37:33 +02:00
commit 43e31b1f5c
198 changed files with 12446 additions and 4857 deletions

View file

@ -77,10 +77,12 @@ jobs:
valgrind --tool=callgrind ./luau-compile --null -O1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O1 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --null -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --codegennull -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-codegen | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --codegennull -O2 -t1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-t1-codegen | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --null -O0 bench/other/regex.lua 2>&1 | filter regex-O0 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --null -O1 bench/other/regex.lua 2>&1 | filter regex-O1 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --null -O2 bench/other/regex.lua 2>&1 | filter regex-O2 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --codegennull -O2 bench/other/regex.lua 2>&1 | filter regex-O2-codegen | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --codegennull -O2 -t1 bench/other/regex.lua 2>&1 | filter regex-O2-t1-codegen | tee -a compile-output.txt
- name: Checkout benchmark results
uses: actions/checkout@v3

2
.gitignore vendored
View file

@ -1,5 +1,7 @@
/build/
/build[.-]*/
/cmake/
/cmake[.-]*/
/coverage/
/.vs/
/.vscode/

View file

@ -57,7 +57,7 @@ struct GeneralizationConstraint
struct IterableConstraint
{
TypePackId iterator;
TypePackId variables;
std::vector<TypeId> variables;
const AstNode* nextAstFragment;
DenseHashMap<const AstNode*, TypeId>* astForInNextTypes;
@ -179,23 +179,6 @@ struct HasPropConstraint
bool suppressSimplification = false;
};
// result ~ setProp subjectType ["prop", "prop2", ...] propType
//
// If the subject is a table or table-like thing that already has the named
// property chain, we unify propType with that existing property type.
//
// If the subject is a free table, we augment it in place.
//
// If the subject is an unsealed table, result is an augmented table that
// includes that new prop.
struct SetPropConstraint
{
TypeId resultType;
TypeId subjectType;
std::vector<std::string> path;
TypeId propType;
};
// resultType ~ hasIndexer subjectType indexType
//
// If the subject type is a table or table-like thing that supports indexing,
@ -209,46 +192,48 @@ struct HasIndexerConstraint
TypeId indexType;
};
// result ~ setIndexer subjectType indexType propType
// assignProp lhsType propName rhsType
//
// If the subject is a table or table-like thing that already has an indexer,
// unify its indexType and propType with those from this constraint.
//
// If the table is a free or unsealed table, we augment it with a new indexer.
struct SetIndexerConstraint
// Assign a value of type rhsType into the named property of lhsType.
struct AssignPropConstraint
{
TypeId subjectType;
TypeId lhsType;
std::string propName;
TypeId rhsType;
/// The canonical write type of the property. It is _solely_ used to
/// populate astTypes during constraint resolution. Nothing should ever
/// block on it.
TypeId propType;
// When we generate constraints, we increment the remaining prop count on
// the table if we are able. This flag informs the solver as to whether or
// not it should in turn decrement the prop count when this constraint is
// dispatched.
bool decrementPropCount = false;
};
struct AssignIndexConstraint
{
TypeId lhsType;
TypeId indexType;
TypeId rhsType;
/// The canonical write type of the property. It is _solely_ used to
/// populate astTypes during constraint resolution. Nothing should ever
/// block on it.
TypeId propType;
};
// resultType ~ unpack sourceTypePack
// resultTypes ~ unpack sourceTypePack
//
// Similar to PackSubtypeConstraint, but with one important difference: If the
// sourcePack is blocked, this constraint blocks.
struct UnpackConstraint
{
TypePackId resultPack;
std::vector<TypeId> resultPack;
TypePackId sourcePack;
// UnpackConstraint is sometimes used to resolve the types of assignments.
// When this is the case, any LocalTypes in resultPack can have their
// domains extended by the corresponding type from sourcePack.
bool resultIsLValue = false;
};
// resultType ~ unpack sourceType
//
// The same as UnpackConstraint, but specialized for a pair of types as opposed to packs.
struct Unpack1Constraint
{
TypeId resultType;
TypeId sourceType;
// UnpackConstraint is sometimes used to resolve the types of assignments.
// When this is the case, any LocalTypes in resultPack can have their
// domains extended by the corresponding type from sourcePack.
bool resultIsLValue = false;
};
// ty ~ reduce ty
@ -268,8 +253,8 @@ struct ReducePackConstraint
};
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, IterableConstraint, NameConstraint,
TypeAliasExpansionConstraint, FunctionCallConstraint, FunctionCheckConstraint, PrimitiveTypeConstraint, HasPropConstraint, SetPropConstraint,
HasIndexerConstraint, SetIndexerConstraint, UnpackConstraint, Unpack1Constraint, ReduceConstraint, ReducePackConstraint, EqualityConstraint>;
TypeAliasExpansionConstraint, FunctionCallConstraint, FunctionCheckConstraint, PrimitiveTypeConstraint, HasPropConstraint, HasIndexerConstraint,
AssignPropConstraint, AssignIndexConstraint, UnpackConstraint, ReduceConstraint, ReducePackConstraint, EqualityConstraint>;
struct Constraint
{
@ -284,11 +269,13 @@ struct Constraint
std::vector<NotNull<Constraint>> dependencies;
DenseHashSet<TypeId> getFreeTypes() const;
DenseHashSet<TypeId> getMaybeMutatedFreeTypes() const;
};
using ConstraintPtr = std::unique_ptr<Constraint>;
bool isReferenceCountedType(const TypeId typ);
inline Constraint& asMutable(const Constraint& c)
{
return const_cast<Constraint&>(c);

View file

@ -118,6 +118,8 @@ struct ConstraintGenerator
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope;
std::vector<RequireCycle> requireCycles;
DenseHashMap<TypeId, TypeIds> localTypes{nullptr};
DcrLogger* logger;
ConstraintGenerator(ModulePtr module, NotNull<Normalizer> normalizer, NotNull<ModuleResolver> moduleResolver, NotNull<BuiltinTypes> builtinTypes,
@ -254,18 +256,11 @@ private:
Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
std::tuple<TypeId, TypeId, RefinementId> checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
struct LValueBounds
{
std::optional<TypeId> annotationTy;
std::optional<TypeId> assignedTy;
};
LValueBounds checkLValue(const ScopePtr& scope, AstExpr* expr);
LValueBounds checkLValue(const ScopePtr& scope, AstExprLocal* local);
LValueBounds checkLValue(const ScopePtr& scope, AstExprGlobal* global);
LValueBounds checkLValue(const ScopePtr& scope, AstExprIndexName* indexName);
LValueBounds checkLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr);
LValueBounds updateProperty(const ScopePtr& scope, AstExpr* expr);
void visitLValue(const ScopePtr& scope, AstExpr* expr, TypeId rhsType);
void visitLValue(const ScopePtr& scope, AstExprLocal* local, TypeId rhsType);
void visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType);
void visitLValue(const ScopePtr& scope, AstExprIndexName* indexName, TypeId rhsType);
void visitLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr, TypeId rhsType);
struct FunctionSignature
{
@ -361,6 +356,8 @@ private:
*/
void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program);
bool recordPropertyAssignment(TypeId ty);
// Record the fact that a particular local has a particular type in at least
// one of its states.
void recordInferredBinding(AstLocal* local, TypeId ty);
@ -373,7 +370,8 @@ private:
*/
std::vector<std::optional<TypeId>> getExpectedCallTypesForFunctionOverloads(const TypeId fnType);
TypeId createFamilyInstance(TypeFamilyInstanceType instance, const ScopePtr& scope, Location location);
TypeId createTypeFamilyInstance(
const TypeFamily& family, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments, const ScopePtr& scope, Location location);
};
/** Borrow a vector of pointers from a vector of owning pointers to constraints.

View file

@ -94,6 +94,10 @@ struct ConstraintSolver
// Irreducible/uninhabited type families or type pack families.
DenseHashSet<const void*> uninhabitedTypeFamilies{{}};
// The set of types that will definitely be unchanged by generalization.
DenseHashSet<TypeId> generalizedTypes_{nullptr};
const NotNull<DenseHashSet<TypeId>> generalizedTypes{&generalizedTypes_};
// Recorded errors that take place within the solver.
ErrorVec errors;
@ -103,6 +107,8 @@ struct ConstraintSolver
DcrLogger* logger;
TypeCheckLimits limits;
DenseHashMap<TypeId, const Constraint*> typeFamiliesToFinalize{nullptr};
explicit ConstraintSolver(NotNull<Normalizer> normalizer, NotNull<Scope> rootScope, std::vector<NotNull<Constraint>> constraints,
ModuleName moduleName, NotNull<ModuleResolver> moduleResolver, std::vector<RequireCycle> requireCycles, DcrLogger* logger,
TypeCheckLimits limits);
@ -116,8 +122,35 @@ struct ConstraintSolver
**/
void run();
/**
* Attempts to perform one final reduction on type families after every constraint has been completed
*
**/
void finalizeTypeFamilies();
bool isDone();
private:
/**
* Bind a type variable to another type.
*
* A constraint is required and will validate that blockedTy is owned by this
* constraint. This prevents one constraint from interfering with another's
* blocked types.
*
* Bind will also unblock the type variable for you.
*/
void bind(NotNull<const Constraint> constraint, TypeId ty, TypeId boundTo);
void bind(NotNull<const Constraint> constraint, TypePackId tp, TypePackId boundTo);
template<typename T, typename... Args>
void emplace(NotNull<const Constraint> constraint, TypeId ty, Args&&... args);
template<typename T, typename... Args>
void emplace(NotNull<const Constraint> constraint, TypePackId tp, Args&&... args);
public:
/** Attempt to dispatch a constraint. Returns true if it was successful. If
* tryDispatch() returns false, the constraint remains in the unsolved set
* and will be retried later.
@ -134,20 +167,15 @@ struct ConstraintSolver
bool tryDispatch(const FunctionCheckConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const SetPropConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatchHasIndexer(
int& recursionDepth, NotNull<const Constraint> constraint, TypeId subjectType, TypeId indexType, TypeId resultType, Set<TypeId>& seen);
bool tryDispatch(const HasIndexerConstraint& c, NotNull<const Constraint> constraint);
std::pair<bool, std::optional<TypeId>> tryDispatchSetIndexer(
NotNull<const Constraint> constraint, TypeId subjectType, TypeId indexType, TypeId propType, bool expandFreeTypeBounds);
bool tryDispatch(const SetIndexerConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatchUnpack1(NotNull<const Constraint> constraint, TypeId resultType, TypeId sourceType, bool resultIsLValue);
bool tryDispatch(const AssignPropConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const AssignIndexConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const UnpackConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const Unpack1Constraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const ReduceConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const ReducePackConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const EqualityConstraint& c, NotNull<const Constraint> constraint, bool force);
@ -157,14 +185,28 @@ struct ConstraintSolver
bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
// for a, ... in next_function, t, ... do
bool tryDispatchIterableFunction(
TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType,
const std::string& propName, ValueContext context, bool inConditional = false, bool suppressSimplification = false);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType,
const std::string& propName, ValueContext context, bool inConditional, bool suppressSimplification, DenseHashSet<TypeId>& seen);
/**
* Generate constraints to unpack the types of srcTypes and assign each
* value to the corresponding BlockedType in destTypes.
*
* This function also overwrites the owners of each BlockedType. This is
* okay because this function is only used to decompose IterableConstraint
* into an UnpackConstraint.
*
* @param destTypes A vector of types comprised of BlockedTypes.
* @param srcTypes A TypePack that represents rvalues to be assigned.
* @returns The underlying UnpackConstraint. There's a bit of code in
* iteration that needs to pass blocks on to this constraint.
*/
NotNull<const Constraint> unpackAndAssign(const std::vector<TypeId> destTypes, TypePackId srcTypes, NotNull<const Constraint> constraint);
void block(NotNull<const Constraint> target, NotNull<const Constraint> constraint);
/**
* Block a constraint on the resolution of a Type.
@ -242,6 +284,24 @@ struct ConstraintSolver
void reportError(TypeErrorData&& data, const Location& location);
void reportError(TypeError e);
/**
* Shifts the count of references from `source` to `target`. This should be paired
* with any instance of binding a free type in order to maintain accurate refcounts.
* If `target` is not a free type, this is a noop.
* @param source the free type which is being bound
* @param target the type which the free type is being bound to
*/
void shiftReferences(TypeId source, TypeId target);
/**
* Generalizes the given free type if the reference counting allows it.
* @param the scope to generalize in
* @param type the free type we want to generalize
* @returns a non-free type that generalizes the argument, or `std::nullopt` if one
* does not exist
*/
std::optional<TypeId> generalizeFreeType(NotNull<Scope> scope, TypeId type, bool avoidSealingTables = false);
/**
* Checks the existing set of constraints to see if there exist any that contain
* the provided free type, indicating that it is not yet ready to be replaced by
@ -266,22 +326,6 @@ struct ConstraintSolver
template<typename TID>
bool unify(NotNull<const Constraint> constraint, TID subTy, TID superTy);
private:
/**
* Bind a BlockedType to another type while taking care not to bind it to
* itself in the case that resultTy == blockedTy. This can happen if we
* have a tautological constraint. When it does, we must instead bind
* blockedTy to a fresh type belonging to an appropriate scope.
*
* To determine which scope is appropriate, we also accept rootTy, which is
* to be the type that contains blockedTy.
*
* A constraint is required and will validate that blockedTy is owned by this
* constraint. This prevents one constraint from interfering with another's
* blocked types.
*/
void bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId rootTy, NotNull<const Constraint> constraint);
/**
* Marks a constraint as being blocked on a type or type pack. The constraint
* solver will not attempt to dispatch blocked constraints until their

View file

@ -191,7 +191,7 @@ struct Frontend
void queueModuleCheck(const std::vector<ModuleName>& names);
void queueModuleCheck(const ModuleName& name);
std::vector<ModuleName> checkQueuedModules(std::optional<FrontendOptions> optionOverride = {},
std::function<void(std::function<void()> task)> executeTask = {}, std::function<void(size_t done, size_t total)> progress = {});
std::function<void(std::function<void()> task)> executeTask = {}, std::function<bool(size_t done, size_t total)> progress = {});
std::optional<CheckResult> getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false);

View file

@ -0,0 +1,13 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Scope.h"
#include "Luau/NotNull.h"
#include "Luau/TypeFwd.h"
namespace Luau
{
std::optional<TypeId> generalize(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> bakedTypes, TypeId ty, /* avoid sealing tables*/ bool avoidSealingTables = false);
}

View file

@ -27,12 +27,16 @@ struct ReplaceGenerics : Substitution
{
}
void resetState(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope,
const std::vector<TypeId>& generics, const std::vector<TypePackId>& genericPacks);
NotNull<BuiltinTypes> builtinTypes;
TypeLevel level;
Scope* scope;
std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks;
bool ignoreChildren(TypeId ty) override;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
@ -48,13 +52,19 @@ struct Instantiation : Substitution
, builtinTypes(builtinTypes)
, level(level)
, scope(scope)
, reusableReplaceGenerics(log, arena, builtinTypes, level, scope, {}, {})
{
}
void resetState(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope);
NotNull<BuiltinTypes> builtinTypes;
TypeLevel level;
Scope* scope;
ReplaceGenerics reusableReplaceGenerics;
bool ignoreChildren(TypeId ty) override;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;

View file

@ -102,6 +102,12 @@ struct Module
DenseHashMap<const AstType*, TypeId> astResolvedTypes{nullptr};
DenseHashMap<const AstTypePack*, TypePackId> astResolvedTypePacks{nullptr};
// The computed result type of a compound assignment. (eg foo += 1)
//
// Type checking uses this to check that the result of such an operation is
// actually compatible with the left-side operand.
DenseHashMap<const AstStat*, TypeId> astCompoundAssignResultTypes{nullptr};
DenseHashMap<TypeId, std::vector<std::pair<Location, TypeId>>> upperBoundContributors{nullptr};
// Map AST nodes to the scope they create. Cannot be NotNull<Scope> because

View file

@ -307,6 +307,9 @@ struct NormalizedType
/// Returns true if the type is a subtype of string(it could be a singleton). Behaves like Type::isString()
bool isSubtypeOfString() const;
/// Returns true if the type is a subtype of boolean(it could be a singleton). Behaves like Type::isBoolean()
bool isSubtypeOfBooleans() const;
/// Returns true if this type should result in error suppressing behavior.
bool shouldSuppressErrors() const;
@ -360,7 +363,6 @@ public:
Normalizer& operator=(Normalizer&) = delete;
// If this returns null, the typechecker should emit a "too complex" error
const NormalizedType* DEPRECATED_normalize(TypeId ty);
std::shared_ptr<const NormalizedType> normalize(TypeId ty);
void clearNormal(NormalizedType& norm);
@ -395,6 +397,7 @@ public:
TypeId negate(TypeId there);
void subtractPrimitive(NormalizedType& here, TypeId ty);
void subtractSingleton(NormalizedType& here, TypeId ty);
NormalizationResult intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect);
// ------- Normalizing intersections
TypeId intersectionOfTops(TypeId here, TypeId there);
@ -403,8 +406,8 @@ public:
void intersectClassesWithClass(NormalizedClassType& heres, TypeId there);
void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there);
std::optional<TypePackId> intersectionOfTypePacks(TypePackId here, TypePackId there);
std::optional<TypeId> intersectionOfTables(TypeId here, TypeId there);
void intersectTablesWithTable(TypeIds& heres, TypeId there);
std::optional<TypeId> intersectionOfTables(TypeId here, TypeId there, Set<TypeId>& seenSet);
void intersectTablesWithTable(TypeIds& heres, TypeId there, Set<TypeId>& seenSetTypes);
void intersectTables(TypeIds& heres, const TypeIds& theres);
std::optional<TypeId> intersectionOfFunctions(TypeId here, TypeId there);
void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there);
@ -412,7 +415,7 @@ public:
NormalizationResult intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set<TypeId>& seenSetTypes);
NormalizationResult intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
NormalizationResult intersectNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes);
NormalizationResult normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType);
NormalizationResult normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType, Set<TypeId>& seenSet);
// Check for inhabitance
NormalizationResult isInhabited(TypeId ty);
@ -422,6 +425,7 @@ public:
// Check for intersections being inhabited
NormalizationResult isIntersectionInhabited(TypeId left, TypeId right);
NormalizationResult isIntersectionInhabited(TypeId left, TypeId right, Set<TypeId>& seenSet);
// -------- Convert back from a normalized type to a type
TypeId typeFromNormal(const NormalizedType& norm);

View file

@ -102,4 +102,12 @@ bool subsumesStrict(Scope* left, Scope* right);
// outermost-possible scope.
bool subsumes(Scope* left, Scope* right);
inline Scope* max(Scope* left, Scope* right)
{
if (subsumes(left, right))
return right;
else
return left;
}
} // namespace Luau

View file

@ -4,7 +4,6 @@
#include "Luau/Common.h"
#include "Luau/DenseHash.h"
LUAU_FASTFLAG(LuauFixSetIter)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
namespace Luau
@ -143,11 +142,8 @@ public:
: impl(impl_)
, end(end_)
{
if (FFlag::LuauFixSetIter || FFlag::DebugLuauDeferredConstraintResolution)
{
while (impl != end && impl->second == false)
++impl;
}
while (impl != end && impl->second == false)
++impl;
}
const T& operator*() const

View file

@ -5,6 +5,7 @@
#include "Luau/DenseHash.h"
#include "Luau/NotNull.h"
#include "Luau/TypeFwd.h"
#include <set>
namespace Luau
{
@ -19,6 +20,8 @@ struct SimplifyResult
};
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty, TypeId discriminant);
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, std::set<TypeId> parts);
SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty, TypeId discriminant);
enum class Relation

View file

@ -134,7 +134,8 @@ struct Tarjan
TarjanResult visitRoot(TypeId ty);
TarjanResult visitRoot(TypePackId ty);
void clearTarjan();
// Used to reuse the object for a new operation
void clearTarjan(const TxnLog* log);
// Get/set the dirty bit for an index (grows the vector if needed)
bool getDirty(int index);
@ -212,6 +213,8 @@ public:
std::optional<TypeId> substitute(TypeId ty);
std::optional<TypePackId> substitute(TypePackId tp);
void resetState(const TxnLog* log, TypeArena* arena);
TypeId replace(TypeId ty);
TypePackId replace(TypePackId tp);

View file

@ -86,24 +86,6 @@ struct FreeType
TypeId upperBound = nullptr;
};
/** A type that tracks the domain of a local variable.
*
* We consider each local's domain to be the union of all types assigned to it.
* We accomplish this with LocalType. Each time we dispatch an assignment to a
* local, we accumulate this union and decrement blockCount.
*
* When blockCount reaches 0, we can consider the LocalType to be "fully baked"
* and replace it with the union we've built.
*/
struct LocalType
{
TypeId domain;
int blockCount = 0;
// Used for debugging
std::string name;
};
struct GenericType
{
// By default, generics are global, with a synthetic name
@ -148,6 +130,7 @@ struct BlockedType
Constraint* getOwner() const;
void setOwner(Constraint* newOwner);
void replaceOwner(Constraint* newOwner);
private:
// The constraint that is intended to unblock this type. Other constraints
@ -471,6 +454,11 @@ struct TableType
// Methods of this table that have an untyped self will use the same shared self type.
std::optional<TypeId> selfTy;
// We track the number of as-yet-unadded properties to unsealed tables.
// Some constraints will use this information to decide whether or not they
// are able to dispatch.
size_t remainingProps = 0;
};
// Represents a metatable attached to a table type. Somewhat analogous to a bound type.
@ -672,9 +660,9 @@ struct NegationType
using ErrorType = Unifiable::Error;
using TypeVariant = Unifiable::Variant<TypeId, FreeType, LocalType, GenericType, PrimitiveType, BlockedType, PendingExpansionType, SingletonType,
FunctionType, TableType, MetatableType, ClassType, AnyType, UnionType, IntersectionType, LazyType, UnknownType, NeverType, NegationType,
TypeFamilyInstanceType>;
using TypeVariant =
Unifiable::Variant<TypeId, FreeType, GenericType, PrimitiveType, SingletonType, BlockedType, PendingExpansionType, FunctionType, TableType,
MetatableType, ClassType, AnyType, UnionType, IntersectionType, LazyType, UnknownType, NeverType, NegationType, TypeFamilyInstanceType>;
struct Type final
{

View file

@ -6,7 +6,6 @@
#include "Luau/NotNull.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFwd.h"
#include "Luau/Variant.h"
#include <functional>
#include <string>
@ -19,22 +18,6 @@ struct TypeArena;
struct TxnLog;
class Normalizer;
struct TypeFamilyQueue
{
NotNull<VecDeque<TypeId>> queuedTys;
NotNull<VecDeque<TypePackId>> queuedTps;
void add(TypeId instanceTy);
void add(TypePackId instanceTp);
template<typename T>
void add(const std::vector<T>& ts)
{
for (const T& t : ts)
enqueue(t);
}
};
struct TypeFamilyContext
{
NotNull<TypeArena> arena;
@ -99,8 +82,8 @@ struct TypeFamilyReductionResult
};
template<typename T>
using ReducerFunction = std::function<TypeFamilyReductionResult<T>(
T, NotNull<TypeFamilyQueue>, const std::vector<TypeId>&, const std::vector<TypePackId>&, NotNull<TypeFamilyContext>)>;
using ReducerFunction =
std::function<TypeFamilyReductionResult<T>(T, const std::vector<TypeId>&, const std::vector<TypePackId>&, NotNull<TypeFamilyContext>)>;
/// Represents a type function that may be applied to map a series of types and
/// type packs to a single output type.
@ -196,11 +179,12 @@ struct BuiltinTypeFamilies
TypeFamily keyofFamily;
TypeFamily rawkeyofFamily;
TypeFamily indexFamily;
TypeFamily rawgetFamily;
void addToScope(NotNull<TypeArena> arena, NotNull<Scope> scope) const;
};
const BuiltinTypeFamilies kBuiltinTypeFamilies{};
const BuiltinTypeFamilies& builtinTypeFunctions();
} // namespace Luau

View file

@ -4,6 +4,7 @@
#include "Luau/Anyification.h"
#include "Luau/ControlFlow.h"
#include "Luau/Error.h"
#include "Luau/Instantiation.h"
#include "Luau/Module.h"
#include "Luau/Predicate.h"
#include "Luau/Substitution.h"
@ -362,6 +363,8 @@ public:
UnifierSharedState unifierState;
Normalizer normalizer;
Instantiation reusableInstantiation;
std::vector<RequireCycle> requireCycles;
// Type inference limits

View file

@ -55,6 +55,9 @@ struct InConditionalContext
using ScopePtr = std::shared_ptr<struct Scope>;
std::optional<Property> findTableProperty(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location);
std::optional<TypeId> findMetatableEntry(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location);
std::optional<TypeId> findTablePropertyRespectingMeta(

View file

@ -69,7 +69,6 @@ struct Unifier2
*/
bool unify(TypeId subTy, TypeId superTy);
bool unifyFreeWithType(TypeId subTy, TypeId superTy);
bool unify(const LocalType* subTy, TypeId superFn);
bool unify(TypeId subTy, const FunctionType* superFn);
bool unify(const UnionType* subUnion, TypeId superTy);
bool unify(TypeId subTy, const UnionType* superUnion);
@ -78,6 +77,11 @@ struct Unifier2
bool unify(TableType* subTable, const TableType* superTable);
bool unify(const MetatableType* subMetatable, const MetatableType* superMetatable);
bool unify(const AnyType* subAny, const FunctionType* superFn);
bool unify(const FunctionType* subFn, const AnyType* superAny);
bool unify(const AnyType* subAny, const TableType* superTable);
bool unify(const TableType* subTable, const AnyType* superAny);
// TODO think about this one carefully. We don't do unions or intersections of type packs
bool unify(TypePackId subTp, TypePackId superTp);

View file

@ -100,10 +100,6 @@ struct GenericTypeVisitor
{
return visit(ty);
}
virtual bool visit(TypeId ty, const LocalType& ftv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const GenericType& gtv)
{
return visit(ty);
@ -248,11 +244,6 @@ struct GenericTypeVisitor
else
visit(ty, *ftv);
}
else if (auto lt = get<LocalType>(ty))
{
if (visit(ty, *lt))
traverse(lt->domain);
}
else if (auto gtv = get<GenericType>(ty))
visit(ty, *gtv);
else if (auto etv = get<ErrorType>(ty))
@ -357,16 +348,38 @@ struct GenericTypeVisitor
{
if (visit(ty, *utv))
{
bool unionChanged = false;
for (TypeId optTy : utv->options)
{
traverse(optTy);
if (!get<UnionType>(follow(ty)))
{
unionChanged = true;
break;
}
}
if (unionChanged)
traverse(ty);
}
}
else if (auto itv = get<IntersectionType>(ty))
{
if (visit(ty, *itv))
{
bool intersectionChanged = false;
for (TypeId partTy : itv->parts)
{
traverse(partTy);
if (!get<IntersectionType>(follow(ty)))
{
intersectionChanged = true;
break;
}
}
if (intersectionChanged)
traverse(ty);
}
}
else if (auto ltv = get<LazyType>(ty))

View file

@ -8,6 +8,8 @@
#include <math.h>
LUAU_FASTFLAG(LuauDeclarationExtraPropData)
namespace Luau
{
@ -735,8 +737,21 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstStatDeclareFunction* node)
{
writeNode(node, "AstStatDeclareFunction", [&]() {
// TODO: attributes
PROP(name);
if (FFlag::LuauDeclarationExtraPropData)
PROP(nameLocation);
PROP(params);
if (FFlag::LuauDeclarationExtraPropData)
{
PROP(paramNames);
PROP(vararg);
PROP(varargLocation);
}
PROP(retTypes);
PROP(generics);
PROP(genericPacks);
@ -747,6 +762,10 @@ struct AstJsonEncoder : public AstVisitor
{
writeNode(node, "AstStatDeclareGlobal", [&]() {
PROP(name);
if (FFlag::LuauDeclarationExtraPropData)
PROP(nameLocation);
PROP(type);
});
}
@ -756,8 +775,16 @@ struct AstJsonEncoder : public AstVisitor
writeRaw("{");
bool c = pushComma();
write("name", prop.name);
if (FFlag::LuauDeclarationExtraPropData)
write("nameLocation", prop.nameLocation);
writeType("AstDeclaredClassProp");
write("luauType", prop.ty);
if (FFlag::LuauDeclarationExtraPropData)
write("location", prop.location);
popComma(c);
writeRaw("}");
}

View file

@ -12,6 +12,7 @@
#include <algorithm>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAGVARIABLE(LuauFixBindingForGlobalPos, false);
namespace Luau
{
@ -332,6 +333,11 @@ std::optional<TypeId> findExpectedTypeAtPosition(const Module& module, const Sou
static std::optional<AstStatLocal*> findBindingLocalStatement(const SourceModule& source, const Binding& binding)
{
// Bindings coming from global sources (e.g., definition files) have a zero position.
// They cannot be defined from a local statement
if (FFlag::LuauFixBindingForGlobalPos && binding.location == Location{{0, 0}, {0, 0}})
return std::nullopt;
std::vector<AstNode*> nodes = findAstAncestryOfPosition(source, binding.location.begin);
auto iter = std::find_if(nodes.rbegin(), nodes.rend(), [](AstNode* node) {
return node->is<AstStatLocal>();

View file

@ -1830,12 +1830,21 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName
if (!sourceModule)
return {};
ModulePtr module = frontend.moduleResolverForAutocomplete.getModule(moduleName);
ModulePtr module;
if (FFlag::DebugLuauDeferredConstraintResolution)
module = frontend.moduleResolver.getModule(moduleName);
else
module = frontend.moduleResolverForAutocomplete.getModule(moduleName);
if (!module)
return {};
NotNull<BuiltinTypes> builtinTypes = frontend.builtinTypes;
Scope* globalScope = frontend.globalsForAutocomplete.globalScope.get();
Scope* globalScope;
if (FFlag::DebugLuauDeferredConstraintResolution)
globalScope = frontend.globals.globalScope.get();
else
globalScope = frontend.globalsForAutocomplete.globalScope.get();
TypeArena typeArena;
return autocomplete(*sourceModule, module, builtinTypes, &typeArena, globalScope, position, callback);

View file

@ -24,7 +24,6 @@
*/
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAGVARIABLE(LuauMakeStringMethodsChecked, false);
namespace Luau
{
@ -217,7 +216,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
NotNull<BuiltinTypes> builtinTypes = globals.builtinTypes;
if (FFlag::DebugLuauDeferredConstraintResolution)
kBuiltinTypeFamilies.addToScope(NotNull{&arena}, NotNull{globals.globalScope.get()});
builtinTypeFunctions().addToScope(NotNull{&arena}, NotNull{globals.globalScope.get()});
LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile(
globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false, typeCheckForAutocomplete);
@ -257,21 +256,44 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
TypeId tableMetaMT = arena.addType(MetatableType{tabTy, genericMT});
// getmetatable : <MT>({ @metatable MT, {+ +} }) -> MT
addGlobalBinding(globals, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau");
// clang-format off
// setmetatable<T: {}, MT>(T, MT) -> { @metatable MT, T }
addGlobalBinding(globals, "setmetatable",
arena.addType(
FunctionType{
{genericMT},
{},
arena.addTypePack(TypePack{{tabTy, genericMT}}),
arena.addTypePack(TypePack{{tableMetaMT}})
}
), "@luau"
);
// clang-format on
if (FFlag::DebugLuauDeferredConstraintResolution)
{
TypeId genericT = arena.addType(GenericType{"T"});
TypeId tMetaMT = arena.addType(MetatableType{genericT, genericMT});
// clang-format off
// setmetatable<T: {}, MT>(T, MT) -> { @metatable MT, T }
addGlobalBinding(globals, "setmetatable",
arena.addType(
FunctionType{
{genericT, genericMT},
{},
arena.addTypePack(TypePack{{genericT, genericMT}}),
arena.addTypePack(TypePack{{tMetaMT}})
}
), "@luau"
);
// clang-format on
}
else
{
// clang-format off
// setmetatable<T: {}, MT>(T, MT) -> { @metatable MT, T }
addGlobalBinding(globals, "setmetatable",
arena.addType(
FunctionType{
{genericMT},
{},
arena.addTypePack(TypePack{{tabTy, genericMT}}),
arena.addTypePack(TypePack{{tableMetaMT}})
}
), "@luau"
);
// clang-format on
}
for (const auto& pair : globals.globalScope->bindings)
{
@ -291,7 +313,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
// declare function assert<T>(value: T, errorMessage: string?): intersect<T, ~(false?)>
TypeId genericT = arena.addType(GenericType{"T"});
TypeId refinedTy = arena.addType(TypeFamilyInstanceType{
NotNull{&kBuiltinTypeFamilies.intersectFamily}, {genericT, arena.addType(NegationType{builtinTypes->falsyType})}, {}});
NotNull{&builtinTypeFunctions().intersectFamily}, {genericT, arena.addType(NegationType{builtinTypes->falsyType})}, {}});
TypeId assertTy = arena.addType(FunctionType{
{genericT}, {}, arena.addTypePack(TypePack{{genericT, builtinTypes->optionalStringType}}), arena.addTypePack(TypePack{{refinedTy}})});
@ -773,153 +795,87 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}});
if (FFlag::LuauMakeStringMethodsChecked)
{
FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack};
formatFTV.magicFunction = &magicFunctionFormat;
formatFTV.isCheckedFunction = true;
const TypeId formatFn = arena->addType(formatFTV);
attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat);
FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack};
formatFTV.magicFunction = &magicFunctionFormat;
formatFTV.isCheckedFunction = true;
const TypeId formatFn = arena->addType(formatFTV);
attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat);
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true);
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true);
const TypeId replArgType = arena->addType(
UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)),
makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ false)}});
const TypeId gsubFunc =
makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false);
const TypeId gmatchFunc = makeFunction(
*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true);
attachMagicFunction(gmatchFunc, magicFunctionGmatch);
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
const TypeId replArgType =
arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)),
makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ false)}});
const TypeId gsubFunc =
makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false);
const TypeId gmatchFunc =
makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true);
attachMagicFunction(gmatchFunc, magicFunctionGmatch);
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
FunctionType matchFuncTy{
arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})};
matchFuncTy.isCheckedFunction = true;
const TypeId matchFunc = arena->addType(matchFuncTy);
attachMagicFunction(matchFunc, magicFunctionMatch);
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
FunctionType matchFuncTy{
arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})};
matchFuncTy.isCheckedFunction = true;
const TypeId matchFunc = arena->addType(matchFuncTy);
attachMagicFunction(matchFunc, magicFunctionMatch);
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
FunctionType findFuncTy{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})};
findFuncTy.isCheckedFunction = true;
const TypeId findFunc = arena->addType(findFuncTy);
attachMagicFunction(findFunc, magicFunctionFind);
attachDcrMagicFunction(findFunc, dcrMagicFunctionFind);
FunctionType findFuncTy{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})};
findFuncTy.isCheckedFunction = true;
const TypeId findFunc = arena->addType(findFuncTy);
attachMagicFunction(findFunc, magicFunctionFind);
attachDcrMagicFunction(findFunc, dcrMagicFunctionFind);
// string.byte : string -> number? -> number? -> ...number
FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList};
stringDotByte.isCheckedFunction = true;
// string.byte : string -> number? -> number? -> ...number
FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList};
stringDotByte.isCheckedFunction = true;
// string.char : .... number -> string
FunctionType stringDotChar{numberVariadicList, arena->addTypePack({stringType})};
stringDotChar.isCheckedFunction = true;
// string.char : .... number -> string
FunctionType stringDotChar{numberVariadicList, arena->addTypePack({stringType})};
stringDotChar.isCheckedFunction = true;
// string.unpack : string -> string -> number? -> ...any
FunctionType stringDotUnpack{
arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}),
variadicTailPack,
};
stringDotUnpack.isCheckedFunction = true;
// string.unpack : string -> string -> number? -> ...any
FunctionType stringDotUnpack{
arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}),
variadicTailPack,
};
stringDotUnpack.isCheckedFunction = true;
TableType::Props stringLib = {
{"byte", {arena->addType(stringDotByte)}},
{"char", {arena->addType(stringDotChar)}},
{"find", {findFunc}},
{"format", {formatFn}}, // FIXME
{"gmatch", {gmatchFunc}},
{"gsub", {gsubFunc}},
{"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}},
{"lower", {stringToStringType}},
{"match", {matchFunc}},
{"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType}, /* checked */ true)}},
{"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType}, /* checked */ true)}},
{"upper", {stringToStringType}},
{"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {},
{arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})},
/* checked */ true)}},
{"pack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType}, variadicTailPack}),
oneStringPack,
})}},
{"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}},
{"unpack", {arena->addType(stringDotUnpack)}},
};
assignPropDocumentationSymbols(stringLib, "@luau/global/string");
TableType::Props stringLib = {
{"byte", {arena->addType(stringDotByte)}},
{"char", {arena->addType(stringDotChar)}},
{"find", {findFunc}},
{"format", {formatFn}}, // FIXME
{"gmatch", {gmatchFunc}},
{"gsub", {gsubFunc}},
{"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}},
{"lower", {stringToStringType}},
{"match", {matchFunc}},
{"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType}, /* checked */ true)}},
{"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType}, /* checked */ true)}},
{"upper", {stringToStringType}},
{"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {},
{arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})},
/* checked */ true)}},
{"pack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType}, variadicTailPack}),
oneStringPack,
})}},
{"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}},
{"unpack", {arena->addType(stringDotUnpack)}},
};
TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed});
assignPropDocumentationSymbols(stringLib, "@luau/global/string");
if (TableType* ttv = getMutable<TableType>(tableType))
ttv->name = "typeof(string)";
TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed});
return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
}
else
{
FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack};
formatFTV.magicFunction = &magicFunctionFormat;
const TypeId formatFn = arena->addType(formatFTV);
attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat);
if (TableType* ttv = getMutable<TableType>(tableType))
ttv->name = "typeof(string)";
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType});
const TypeId replArgType = arena->addType(
UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)),
makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}});
const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType});
const TypeId gmatchFunc =
makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})});
attachMagicFunction(gmatchFunc, magicFunctionGmatch);
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
const TypeId matchFunc = arena->addType(FunctionType{
arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})});
attachMagicFunction(matchFunc, magicFunctionMatch);
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})});
attachMagicFunction(findFunc, magicFunctionFind);
attachDcrMagicFunction(findFunc, dcrMagicFunctionFind);
TableType::Props stringLib = {
{"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}},
{"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}},
{"find", {findFunc}},
{"format", {formatFn}}, // FIXME
{"gmatch", {gmatchFunc}},
{"gsub", {gsubFunc}},
{"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"lower", {stringToStringType}},
{"match", {matchFunc}},
{"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}},
{"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}},
{"upper", {stringToStringType}},
{"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {},
{arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}},
{"pack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType}, variadicTailPack}),
oneStringPack,
})}},
{"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"unpack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}),
variadicTailPack,
})}},
};
assignPropDocumentationSymbols(stringLib, "@luau/global/string");
TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed});
if (TableType* ttv = getMutable<TableType>(tableType))
ttv->name = "typeof(string)";
return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
}
return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
}
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(

View file

@ -271,11 +271,6 @@ private:
t->upperBound = shallowClone(t->upperBound);
}
void cloneChildren(LocalType* t)
{
t->domain = shallowClone(t->domain);
}
void cloneChildren(GenericType* t)
{
// TOOD: clone upper bounds.

View file

@ -13,12 +13,12 @@ Constraint::Constraint(NotNull<Scope> scope, const Location& location, Constrain
{
}
struct FreeTypeCollector : TypeOnceVisitor
struct ReferenceCountInitializer : TypeOnceVisitor
{
DenseHashSet<TypeId>* result;
FreeTypeCollector(DenseHashSet<TypeId>* result)
ReferenceCountInitializer(DenseHashSet<TypeId>* result)
: result(result)
{
}
@ -29,6 +29,18 @@ struct FreeTypeCollector : TypeOnceVisitor
return false;
}
bool visit(TypeId ty, const BlockedType&) override
{
result->insert(ty);
return false;
}
bool visit(TypeId ty, const PendingExpansionType&) override
{
result->insert(ty);
return false;
}
bool visit(TypeId ty, const ClassType&) override
{
// ClassTypes never contain free types.
@ -36,26 +48,89 @@ struct FreeTypeCollector : TypeOnceVisitor
}
};
DenseHashSet<TypeId> Constraint::getFreeTypes() const
bool isReferenceCountedType(const TypeId typ)
{
DenseHashSet<TypeId> types{{}};
FreeTypeCollector ftc{&types};
// n.b. this should match whatever `ReferenceCountInitializer` includes.
return get<FreeType>(typ) || get<BlockedType>(typ) || get<PendingExpansionType>(typ);
}
if (auto sc = get<SubtypeConstraint>(*this))
DenseHashSet<TypeId> Constraint::getMaybeMutatedFreeTypes() const
{
// For the purpose of this function and reference counting in general, we are only considering
// mutations that affect the _bounds_ of the free type, and not something that may bind the free
// type itself to a new type. As such, `ReduceConstraint` and `GeneralizationConstraint` have no
// contribution to the output set here.
DenseHashSet<TypeId> types{{}};
ReferenceCountInitializer rci{&types};
if (auto ec = get<EqualityConstraint>(*this))
{
ftc.traverse(sc->subType);
ftc.traverse(sc->superType);
rci.traverse(ec->resultType);
// `EqualityConstraints` should not mutate `assignmentType`.
}
else if (auto sc = get<SubtypeConstraint>(*this))
{
rci.traverse(sc->subType);
rci.traverse(sc->superType);
}
else if (auto psc = get<PackSubtypeConstraint>(*this))
{
ftc.traverse(psc->subPack);
ftc.traverse(psc->superPack);
rci.traverse(psc->subPack);
rci.traverse(psc->superPack);
}
else if (auto itc = get<IterableConstraint>(*this))
{
for (TypeId ty : itc->variables)
rci.traverse(ty);
// `IterableConstraints` should not mutate `iterator`.
}
else if (auto nc = get<NameConstraint>(*this))
{
rci.traverse(nc->namedType);
}
else if (auto taec = get<TypeAliasExpansionConstraint>(*this))
{
rci.traverse(taec->target);
}
else if (auto fchc = get<FunctionCheckConstraint>(*this))
{
rci.traverse(fchc->argsPack);
}
else if (auto ptc = get<PrimitiveTypeConstraint>(*this))
{
// we need to take into account primitive type constraints to prevent type families from reducing on
// primitive whose types we have not yet selected to be singleton or not.
ftc.traverse(ptc->freeType);
rci.traverse(ptc->freeType);
}
else if (auto hpc = get<HasPropConstraint>(*this))
{
rci.traverse(hpc->resultType);
// `HasPropConstraints` should not mutate `subjectType`.
}
else if (auto hic = get<HasIndexerConstraint>(*this))
{
rci.traverse(hic->resultType);
// `HasIndexerConstraint` should not mutate `subjectType` or `indexType`.
}
else if (auto apc = get<AssignPropConstraint>(*this))
{
rci.traverse(apc->lhsType);
rci.traverse(apc->rhsType);
}
else if (auto aic = get<AssignIndexConstraint>(*this))
{
rci.traverse(aic->lhsType);
rci.traverse(aic->indexType);
rci.traverse(aic->rhsType);
}
else if (auto uc = get<UnpackConstraint>(*this))
{
for (TypeId ty : uc->resultPack)
rci.traverse(ty);
// `UnpackConstraint` should not mutate `sourcePack`.
}
else if (auto rpc = get<ReducePackConstraint>(*this))
{
rci.traverse(rpc->tp);
}
return types;

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -763,7 +763,8 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c)
for (AstExpr* arg : c->args)
visitExpr(scope, arg);
return {defArena->freshCell(), nullptr};
// calls should be treated as subscripted.
return {defArena->freshCell(/* subscripted */ true), nullptr};
}
DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i)

View file

@ -2,7 +2,7 @@
#include "Luau/BuiltinDefinitions.h"
LUAU_FASTFLAGVARIABLE(LuauCheckedEmbeddedDefinitions2, false);
LUAU_FASTFLAG(LuauCheckedFunctionSyntax);
LUAU_FASTFLAG(LuauAttributeSyntax);
namespace Luau
{
@ -320,9 +320,9 @@ declare os: {
clock: () -> number,
}
declare function @checked require(target: any): any
@checked declare function require(target: any): any
declare function @checked getfenv(target: any): { [string]: any }
@checked declare function getfenv(target: any): { [string]: any }
declare _G: any
declare _VERSION: string
@ -364,7 +364,7 @@ declare function select<A...>(i: string | number, ...: A...): ...any
-- (nil, string).
declare function loadstring<A...>(src: string, chunkname: string?): (((A...) -> any)?, string?)
declare function @checked newproxy(mt: boolean?): any
@checked declare function newproxy(mt: boolean?): any
declare coroutine: {
create: <A..., R...>(f: (A...) -> R...) -> thread,
@ -452,7 +452,7 @@ std::string getBuiltinDefinitionSource()
std::string result = kBuiltinDefinitionLuaSrc;
// Annotates each non generic function as checked
if (FFlag::LuauCheckedEmbeddedDefinitions2 && FFlag::LuauCheckedFunctionSyntax)
if (FFlag::LuauCheckedEmbeddedDefinitions2 && FFlag::LuauAttributeSyntax)
result = kBuiltinDefinitionLuaSrcChecked;
return result;

View file

@ -7,11 +7,14 @@
#include "Luau/NotNull.h"
#include "Luau/StringUtils.h"
#include "Luau/ToString.h"
#include "Luau/Type.h"
#include "Luau/TypeFamily.h"
#include <optional>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <unordered_set>
LUAU_FASTINTVARIABLE(LuauIndentTypeMismatchMaxTypeLength, 10)
@ -61,6 +64,17 @@ static std::string wrongNumberOfArgsString(
namespace Luau
{
// this list of binary operator type families is used for better stringification of type families errors
static const std::unordered_map<std::string, const char*> kBinaryOps{{"add", "+"}, {"sub", "-"}, {"mul", "*"}, {"div", "/"}, {"idiv", "//"},
{"pow", "^"}, {"mod", "%"}, {"concat", ".."}, {"and", "and"}, {"or", "or"}, {"lt", "< or >="}, {"le", "<= or >"}, {"eq", "== or ~="}};
// this list of unary operator type families is used for better stringification of type families errors
static const std::unordered_map<std::string, const char*> kUnaryOps{{"unm", "-"}, {"len", "#"}, {"not", "not"}};
// this list of type families will receive a special error indicating that the user should file a bug on the GitHub repository
// putting a type family in this list indicates that it is expected to _always_ reduce
static const std::unordered_set<std::string> kUnreachableTypeFamilies{"refine", "singleton", "union", "intersect"};
struct ErrorConverter
{
FileResolver* fileResolver = nullptr;
@ -565,6 +579,108 @@ struct ErrorConverter
std::string operator()(const UninhabitedTypeFamily& e) const
{
auto tfit = get<TypeFamilyInstanceType>(e.ty);
LUAU_ASSERT(tfit); // Luau analysis has actually done something wrong if this type is not a type family.
if (!tfit)
return "Unexpected type " + Luau::toString(e.ty) + " flagged as an uninhabited type family.";
// unary operators
if (auto unaryString = kUnaryOps.find(tfit->family->name); unaryString != kUnaryOps.end())
{
std::string result = "Operator '" + std::string(unaryString->second) + "' could not be applied to ";
if (tfit->typeArguments.size() == 1 && tfit->packArguments.empty())
{
result += "operand of type " + Luau::toString(tfit->typeArguments[0]);
if (tfit->family->name != "not")
result += "; there is no corresponding overload for __" + tfit->family->name;
}
else
{
// if it's not the expected case, we ought to add a specialization later, but this is a sane default.
result += "operands of types ";
bool isFirst = true;
for (auto arg : tfit->typeArguments)
{
if (!isFirst)
result += ", ";
result += Luau::toString(arg);
isFirst = false;
}
for (auto packArg : tfit->packArguments)
result += ", " + Luau::toString(packArg);
}
return result;
}
// binary operators
if (auto binaryString = kBinaryOps.find(tfit->family->name); binaryString != kBinaryOps.end())
{
std::string result = "Operator '" + std::string(binaryString->second) + "' could not be applied to operands of types ";
if (tfit->typeArguments.size() == 2 && tfit->packArguments.empty())
{
// this is the expected case.
result += Luau::toString(tfit->typeArguments[0]) + " and " + Luau::toString(tfit->typeArguments[1]);
}
else
{
// if it's not the expected case, we ought to add a specialization later, but this is a sane default.
bool isFirst = true;
for (auto arg : tfit->typeArguments)
{
if (!isFirst)
result += ", ";
result += Luau::toString(arg);
isFirst = false;
}
for (auto packArg : tfit->packArguments)
result += ", " + Luau::toString(packArg);
}
result += "; there is no corresponding overload for __" + tfit->family->name;
return result;
}
// miscellaneous
if ("keyof" == tfit->family->name || "rawkeyof" == tfit->family->name)
{
if (tfit->typeArguments.size() == 1 && tfit->packArguments.empty())
return "Type '" + toString(tfit->typeArguments[0]) + "' does not have keys, so '" + Luau::toString(e.ty) + "' is invalid";
else
return "Type family instance " + Luau::toString(e.ty) + " is ill-formed, and thus invalid";
}
if ("index" == tfit->family->name || "rawget" == tfit->family->name)
{
if (tfit->typeArguments.size() != 2)
return "Type family instance " + Luau::toString(e.ty) + " is ill-formed, and thus invalid";
if (auto errType = get<ErrorType>(tfit->typeArguments[1])) // Second argument to (index | rawget)<_,_> is not a type
return "Second argument to " + tfit->family->name + "<" + Luau::toString(tfit->typeArguments[0]) + ", _> is not a valid index type";
else // Property `indexer` does not exist on type `indexee`
return "Property '" + Luau::toString(tfit->typeArguments[1]) + "' does not exist on type '" + Luau::toString(tfit->typeArguments[0]) +
"'";
}
if (kUnreachableTypeFamilies.count(tfit->family->name))
{
return "Type family instance " + Luau::toString(e.ty) + " is uninhabited\n" +
"This is likely to be a bug, please report it at https://github.com/luau-lang/luau/issues";
}
// Everything should be specialized above to report a more descriptive error that hopefully does not mention "type families" explicitly.
// If we produce this message, it's an indication that we've missed a specialization and it should be fixed!
return "Type family instance " + Luau::toString(e.ty) + " is uninhabited";
}
@ -1205,7 +1321,7 @@ void copyError(T& e, TypeArena& destArena, CloneState& cloneState)
else if constexpr (std::is_same_v<T, ExplicitFunctionAnnotationRecommended>)
{
e.recommendedReturn = clone(e.recommendedReturn);
for (auto [_, t] : e.recommendedArgs)
for (auto& [_, t] : e.recommendedArgs)
t = clone(t);
}
else if constexpr (std::is_same_v<T, UninhabitedTypePackFamily>)

View file

@ -34,6 +34,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauInferInNoCheckMode)
LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false)
LUAU_FASTFLAGVARIABLE(LuauCancelFromProgress, false)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJsonFile, false)
@ -440,6 +441,8 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
FrontendOptions frontendOptions = optionOverride.value_or(options);
if (FFlag::DebugLuauDeferredConstraintResolution)
frontendOptions.forAutocomplete = false;
if (std::optional<CheckResult> result = getCheckResult(name, true, frontendOptions.forAutocomplete))
return std::move(*result);
@ -492,9 +495,11 @@ void Frontend::queueModuleCheck(const ModuleName& name)
}
std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptions> optionOverride,
std::function<void(std::function<void()> task)> executeTask, std::function<void(size_t done, size_t total)> progress)
std::function<void(std::function<void()> task)> executeTask, std::function<bool(size_t done, size_t total)> progress)
{
FrontendOptions frontendOptions = optionOverride.value_or(options);
if (FFlag::DebugLuauDeferredConstraintResolution)
frontendOptions.forAutocomplete = false;
// By taking data into locals, we make sure queue is cleared at the end, even if an ICE or a different exception is thrown
std::vector<ModuleName> currModuleQueue;
@ -673,7 +678,17 @@ std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptio
}
if (progress)
progress(buildQueueItems.size() - remaining, buildQueueItems.size());
{
if (FFlag::LuauCancelFromProgress)
{
if (!progress(buildQueueItems.size() - remaining, buildQueueItems.size()))
cancelled = true;
}
else
{
progress(buildQueueItems.size() - remaining, buildQueueItems.size());
}
}
// Items cannot be submitted while holding the lock
for (size_t i : nextItems)
@ -707,6 +722,9 @@ std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptio
std::optional<CheckResult> Frontend::getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete)
{
if (FFlag::DebugLuauDeferredConstraintResolution)
forAutocomplete = false;
auto it = sourceNodes.find(name);
if (it == sourceNodes.end() || it->second->hasDirtyModule(forAutocomplete))
@ -1003,11 +1021,10 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item)
module->astForInNextTypes.clear();
module->astResolvedTypes.clear();
module->astResolvedTypePacks.clear();
module->astCompoundAssignResultTypes.clear();
module->astScopes.clear();
module->upperBoundContributors.clear();
if (!FFlag::DebugLuauDeferredConstraintResolution)
module->scopes.clear();
module->scopes.clear();
}
if (mode != Mode::NoCheck)
@ -1196,12 +1213,6 @@ struct InternalTypeFinder : TypeOnceVisitor
return false;
}
bool visit(TypeId, const LocalType&) override
{
LUAU_ASSERT(false);
return false;
}
bool visit(TypePackId, const BlockedTypePack&) override
{
LUAU_ASSERT(false);
@ -1297,6 +1308,30 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<R
result->type = sourceModule.type;
result->upperBoundContributors = std::move(cs.upperBoundContributors);
if (result->timeout || result->cancelled)
{
// If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending
// types
ScopePtr moduleScope = result->getModuleScope();
moduleScope->returnType = builtinTypes->errorRecoveryTypePack();
for (auto& [name, ty] : result->declaredGlobals)
ty = builtinTypes->errorRecoveryType();
for (auto& [name, tf] : result->exportedTypeBindings)
tf.type = builtinTypes->errorRecoveryType();
}
else
{
if (mode == Mode::Nonstrict)
Luau::checkNonStrict(builtinTypes, iceHandler, NotNull{&unifierState}, NotNull{&dfg}, NotNull{&limits}, sourceModule, result.get());
else
Luau::check(builtinTypes, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get());
}
unfreeze(result->interfaceTypes);
result->clonePublicInterface(builtinTypes, *iceHandler);
if (FFlag::DebugLuauForbidInternalTypes)
{
InternalTypeFinder finder;
@ -1325,30 +1360,6 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<R
finder.traverse(tp);
}
if (result->timeout || result->cancelled)
{
// If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending
// types
ScopePtr moduleScope = result->getModuleScope();
moduleScope->returnType = builtinTypes->errorRecoveryTypePack();
for (auto& [name, ty] : result->declaredGlobals)
ty = builtinTypes->errorRecoveryType();
for (auto& [name, tf] : result->exportedTypeBindings)
tf.type = builtinTypes->errorRecoveryType();
}
else
{
if (mode == Mode::Nonstrict)
Luau::checkNonStrict(builtinTypes, iceHandler, NotNull{&unifierState}, NotNull{&dfg}, NotNull{&limits}, sourceModule, result.get());
else
Luau::check(builtinTypes, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get());
}
unfreeze(result->interfaceTypes);
result->clonePublicInterface(builtinTypes, *iceHandler);
// It would be nice if we could freeze the arenas before doing type
// checking, but we'll have to do some work to get there.
//

View file

@ -0,0 +1,910 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Generalization.h"
#include "Luau/Scope.h"
#include "Luau/Type.h"
#include "Luau/ToString.h"
#include "Luau/TypeArena.h"
#include "Luau/TypePack.h"
#include "Luau/VisitType.h"
namespace Luau
{
struct MutatingGeneralizer : TypeOnceVisitor
{
NotNull<BuiltinTypes> builtinTypes;
NotNull<Scope> scope;
NotNull<DenseHashSet<TypeId>> cachedTypes;
DenseHashMap<const void*, size_t> positiveTypes;
DenseHashMap<const void*, size_t> negativeTypes;
std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks;
bool isWithinFunction = false;
bool avoidSealingTables = false;
MutatingGeneralizer(NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope, NotNull<DenseHashSet<TypeId>> cachedTypes,
DenseHashMap<const void*, size_t> positiveTypes, DenseHashMap<const void*, size_t> negativeTypes, bool avoidSealingTables)
: TypeOnceVisitor(/* skipBoundTypes */ true)
, builtinTypes(builtinTypes)
, scope(scope)
, cachedTypes(cachedTypes)
, positiveTypes(std::move(positiveTypes))
, negativeTypes(std::move(negativeTypes))
, avoidSealingTables(avoidSealingTables)
{
}
static void replace(DenseHashSet<TypeId>& seen, TypeId haystack, TypeId needle, TypeId replacement)
{
haystack = follow(haystack);
if (seen.find(haystack))
return;
seen.insert(haystack);
if (UnionType* ut = getMutable<UnionType>(haystack))
{
for (auto iter = ut->options.begin(); iter != ut->options.end();)
{
// FIXME: I bet this function has reentrancy problems
TypeId option = follow(*iter);
if (option == needle && get<NeverType>(replacement))
{
iter = ut->options.erase(iter);
continue;
}
if (option == needle)
{
*iter = replacement;
iter++;
continue;
}
// advance the iterator, nothing after this can use it.
iter++;
if (seen.find(option))
continue;
seen.insert(option);
if (get<UnionType>(option))
replace(seen, option, needle, haystack);
else if (get<IntersectionType>(option))
replace(seen, option, needle, haystack);
}
if (ut->options.size() == 1)
{
TypeId onlyType = ut->options[0];
LUAU_ASSERT(onlyType != haystack);
emplaceType<BoundType>(asMutable(haystack), onlyType);
}
return;
}
if (IntersectionType* it = getMutable<IntersectionType>(needle))
{
for (auto iter = it->parts.begin(); iter != it->parts.end();)
{
// FIXME: I bet this function has reentrancy problems
TypeId part = follow(*iter);
if (part == needle && get<UnknownType>(replacement))
{
iter = it->parts.erase(iter);
continue;
}
if (part == needle)
{
*iter = replacement;
iter++;
continue;
}
// advance the iterator, nothing after this can use it.
iter++;
if (seen.find(part))
continue;
seen.insert(part);
if (get<UnionType>(part))
replace(seen, part, needle, haystack);
else if (get<IntersectionType>(part))
replace(seen, part, needle, haystack);
}
if (it->parts.size() == 1)
{
TypeId onlyType = it->parts[0];
LUAU_ASSERT(onlyType != needle);
emplaceType<BoundType>(asMutable(needle), onlyType);
}
return;
}
}
bool visit(TypeId ty, const FunctionType& ft) override
{
if (cachedTypes->contains(ty))
return false;
const bool oldValue = isWithinFunction;
isWithinFunction = true;
traverse(ft.argTypes);
traverse(ft.retTypes);
isWithinFunction = oldValue;
return false;
}
bool visit(TypeId ty, const FreeType&) override
{
LUAU_ASSERT(!cachedTypes->contains(ty));
const FreeType* ft = get<FreeType>(ty);
LUAU_ASSERT(ft);
traverse(ft->lowerBound);
traverse(ft->upperBound);
// It is possible for the above traverse() calls to cause ty to be
// transmuted. We must reacquire ft if this happens.
ty = follow(ty);
ft = get<FreeType>(ty);
if (!ft)
return false;
const size_t positiveCount = getCount(positiveTypes, ty);
const size_t negativeCount = getCount(negativeTypes, ty);
if (!positiveCount && !negativeCount)
return false;
const bool hasLowerBound = !get<NeverType>(follow(ft->lowerBound));
const bool hasUpperBound = !get<UnknownType>(follow(ft->upperBound));
DenseHashSet<TypeId> seen{nullptr};
seen.insert(ty);
if (!hasLowerBound && !hasUpperBound)
{
if (!isWithinFunction || (positiveCount + negativeCount == 1))
emplaceType<BoundType>(asMutable(ty), builtinTypes->unknownType);
else
{
emplaceType<GenericType>(asMutable(ty), scope);
generics.push_back(ty);
}
}
// It is possible that this free type has other free types in its upper
// or lower bounds. If this is the case, we must replace those
// references with never (for the lower bound) or unknown (for the upper
// bound).
//
// If we do not do this, we get tautological bounds like a <: a <: unknown.
else if (positiveCount && !hasUpperBound)
{
TypeId lb = follow(ft->lowerBound);
if (FreeType* lowerFree = getMutable<FreeType>(lb); lowerFree && lowerFree->upperBound == ty)
lowerFree->upperBound = builtinTypes->unknownType;
else
{
DenseHashSet<TypeId> replaceSeen{nullptr};
replace(replaceSeen, lb, ty, builtinTypes->unknownType);
}
if (lb != ty)
emplaceType<BoundType>(asMutable(ty), lb);
else if (!isWithinFunction || (positiveCount + negativeCount == 1))
emplaceType<BoundType>(asMutable(ty), builtinTypes->unknownType);
else
{
// if the lower bound is the type in question, we don't actually have a lower bound.
emplaceType<GenericType>(asMutable(ty), scope);
generics.push_back(ty);
}
}
else
{
TypeId ub = follow(ft->upperBound);
if (FreeType* upperFree = getMutable<FreeType>(ub); upperFree && upperFree->lowerBound == ty)
upperFree->lowerBound = builtinTypes->neverType;
else
{
DenseHashSet<TypeId> replaceSeen{nullptr};
replace(replaceSeen, ub, ty, builtinTypes->neverType);
}
if (ub != ty)
emplaceType<BoundType>(asMutable(ty), ub);
else if (!isWithinFunction || (positiveCount + negativeCount == 1))
emplaceType<BoundType>(asMutable(ty), builtinTypes->unknownType);
else
{
// if the upper bound is the type in question, we don't actually have an upper bound.
emplaceType<GenericType>(asMutable(ty), scope);
generics.push_back(ty);
}
}
return false;
}
size_t getCount(const DenseHashMap<const void*, size_t>& map, const void* ty)
{
if (const size_t* count = map.find(ty))
return *count;
else
return 0;
}
bool visit(TypeId ty, const TableType&) override
{
if (cachedTypes->contains(ty))
return false;
const size_t positiveCount = getCount(positiveTypes, ty);
const size_t negativeCount = getCount(negativeTypes, ty);
// FIXME: Free tables should probably just be replaced by upper bounds on free types.
//
// eg never <: 'a <: {x: number} & {z: boolean}
if (!positiveCount && !negativeCount)
return true;
TableType* tt = getMutable<TableType>(ty);
LUAU_ASSERT(tt);
if (!avoidSealingTables)
tt->state = TableState::Sealed;
return true;
}
bool visit(TypePackId tp, const FreeTypePack& ftp) override
{
if (!subsumes(scope, ftp.scope))
return true;
tp = follow(tp);
const size_t positiveCount = getCount(positiveTypes, tp);
const size_t negativeCount = getCount(negativeTypes, tp);
if (1 == positiveCount + negativeCount)
emplaceTypePack<BoundTypePack>(asMutable(tp), builtinTypes->unknownTypePack);
else
{
emplaceTypePack<GenericTypePack>(asMutable(tp), scope);
genericPacks.push_back(tp);
}
return true;
}
};
struct FreeTypeSearcher : TypeVisitor
{
NotNull<Scope> scope;
NotNull<DenseHashSet<TypeId>> cachedTypes;
explicit FreeTypeSearcher(NotNull<Scope> scope, NotNull<DenseHashSet<TypeId>> cachedTypes)
: TypeVisitor(/*skipBoundTypes*/ true)
, scope(scope)
, cachedTypes(cachedTypes)
{
}
enum Polarity
{
Positive,
Negative,
Both,
};
Polarity polarity = Positive;
void flip()
{
switch (polarity)
{
case Positive:
polarity = Negative;
break;
case Negative:
polarity = Positive;
break;
case Both:
break;
}
}
DenseHashSet<const void*> seenPositive{nullptr};
DenseHashSet<const void*> seenNegative{nullptr};
bool seenWithPolarity(const void* ty)
{
switch (polarity)
{
case Positive:
{
if (seenPositive.contains(ty))
return true;
seenPositive.insert(ty);
return false;
}
case Negative:
{
if (seenNegative.contains(ty))
return true;
seenNegative.insert(ty);
return false;
}
case Both:
{
if (seenPositive.contains(ty) && seenNegative.contains(ty))
return true;
seenPositive.insert(ty);
seenNegative.insert(ty);
return false;
}
}
return false;
}
// The keys in these maps are either TypeIds or TypePackIds. It's safe to
// mix them because we only use these pointers as unique keys. We never
// indirect them.
DenseHashMap<const void*, size_t> negativeTypes{0};
DenseHashMap<const void*, size_t> positiveTypes{0};
bool visit(TypeId ty) override
{
if (cachedTypes->contains(ty) || seenWithPolarity(ty))
return false;
LUAU_ASSERT(ty);
return true;
}
bool visit(TypeId ty, const FreeType& ft) override
{
if (cachedTypes->contains(ty) || seenWithPolarity(ty))
return false;
if (!subsumes(scope, ft.scope))
return true;
switch (polarity)
{
case Positive:
positiveTypes[ty]++;
break;
case Negative:
negativeTypes[ty]++;
break;
case Both:
positiveTypes[ty]++;
negativeTypes[ty]++;
break;
}
return true;
}
bool visit(TypeId ty, const TableType& tt) override
{
if (cachedTypes->contains(ty) || seenWithPolarity(ty))
return false;
if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope))
{
switch (polarity)
{
case Positive:
positiveTypes[ty]++;
break;
case Negative:
negativeTypes[ty]++;
break;
case Both:
positiveTypes[ty]++;
negativeTypes[ty]++;
break;
}
}
for (const auto& [_name, prop] : tt.props)
{
if (prop.isReadOnly())
traverse(*prop.readTy);
else
{
LUAU_ASSERT(prop.isShared());
Polarity p = polarity;
polarity = Both;
traverse(prop.type());
polarity = p;
}
}
if (tt.indexer)
{
traverse(tt.indexer->indexType);
traverse(tt.indexer->indexResultType);
}
return false;
}
bool visit(TypeId ty, const FunctionType& ft) override
{
if (cachedTypes->contains(ty) || seenWithPolarity(ty))
return false;
flip();
traverse(ft.argTypes);
flip();
traverse(ft.retTypes);
return false;
}
bool visit(TypeId, const ClassType&) override
{
return false;
}
bool visit(TypePackId tp, const FreeTypePack& ftp) override
{
if (seenWithPolarity(tp))
return false;
if (!subsumes(scope, ftp.scope))
return true;
switch (polarity)
{
case Positive:
positiveTypes[tp]++;
break;
case Negative:
negativeTypes[tp]++;
break;
case Both:
positiveTypes[tp]++;
negativeTypes[tp]++;
break;
}
return true;
}
};
// We keep a running set of types that will not change under generalization and
// only have outgoing references to types that are the same. We use this to
// short circuit generalization. It improves performance quite a lot.
//
// We do this by tracing through the type and searching for types that are
// uncacheable. If a type has a reference to an uncacheable type, it is itself
// uncacheable.
//
// If a type has no outbound references to uncacheable types, we add it to the
// cache.
struct TypeCacher : TypeOnceVisitor
{
NotNull<DenseHashSet<TypeId>> cachedTypes;
DenseHashSet<TypeId> uncacheable{nullptr};
DenseHashSet<TypePackId> uncacheablePacks{nullptr};
explicit TypeCacher(NotNull<DenseHashSet<TypeId>> cachedTypes)
: TypeOnceVisitor(/* skipBoundTypes */ true)
, cachedTypes(cachedTypes)
{
}
void cache(TypeId ty)
{
cachedTypes->insert(ty);
}
bool isCached(TypeId ty) const
{
return cachedTypes->contains(ty);
}
void markUncacheable(TypeId ty)
{
uncacheable.insert(ty);
}
void markUncacheable(TypePackId tp)
{
uncacheablePacks.insert(tp);
}
bool isUncacheable(TypeId ty) const
{
return uncacheable.contains(ty);
}
bool isUncacheable(TypePackId tp) const
{
return uncacheablePacks.contains(tp);
}
bool visit(TypeId ty) override
{
if (isUncacheable(ty) || isCached(ty))
return false;
return true;
}
bool visit(TypeId ty, const FreeType& ft) override
{
// Free types are never cacheable.
LUAU_ASSERT(!isCached(ty));
if (!isUncacheable(ty))
{
traverse(ft.lowerBound);
traverse(ft.upperBound);
markUncacheable(ty);
}
return false;
}
bool visit(TypeId ty, const GenericType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const PrimitiveType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const SingletonType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const BlockedType&) override
{
markUncacheable(ty);
return false;
}
bool visit(TypeId ty, const PendingExpansionType&) override
{
markUncacheable(ty);
return false;
}
bool visit(TypeId ty, const FunctionType& ft) override
{
if (isCached(ty) || isUncacheable(ty))
return false;
traverse(ft.argTypes);
traverse(ft.retTypes);
for (TypeId gen : ft.generics)
traverse(gen);
bool uncacheable = false;
if (isUncacheable(ft.argTypes))
uncacheable = true;
else if (isUncacheable(ft.retTypes))
uncacheable = true;
for (TypeId argTy : ft.argTypes)
{
if (isUncacheable(argTy))
{
uncacheable = true;
break;
}
}
for (TypeId retTy : ft.retTypes)
{
if (isUncacheable(retTy))
{
uncacheable = true;
break;
}
}
for (TypeId g : ft.generics)
{
if (isUncacheable(g))
{
uncacheable = true;
break;
}
}
if (uncacheable)
markUncacheable(ty);
else
cache(ty);
return false;
}
bool visit(TypeId ty, const TableType& tt) override
{
if (isCached(ty) || isUncacheable(ty))
return false;
if (tt.boundTo)
{
traverse(*tt.boundTo);
if (isUncacheable(*tt.boundTo))
{
markUncacheable(ty);
return false;
}
}
bool uncacheable = false;
// This logic runs immediately after generalization, so any remaining
// unsealed tables are assuredly not cacheable. They may yet have
// properties added to them.
if (tt.state == TableState::Free || tt.state == TableState::Unsealed)
uncacheable = true;
for (const auto& [_name, prop] : tt.props)
{
if (prop.readTy)
{
traverse(*prop.readTy);
if (isUncacheable(*prop.readTy))
uncacheable = true;
}
if (prop.writeTy && prop.writeTy != prop.readTy)
{
traverse(*prop.writeTy);
if (isUncacheable(*prop.writeTy))
uncacheable = true;
}
}
if (tt.indexer)
{
traverse(tt.indexer->indexType);
if (isUncacheable(tt.indexer->indexType))
uncacheable = true;
traverse(tt.indexer->indexResultType);
if (isUncacheable(tt.indexer->indexResultType))
uncacheable = true;
}
if (uncacheable)
markUncacheable(ty);
else
cache(ty);
return false;
}
bool visit(TypeId ty, const ClassType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const AnyType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const UnionType& ut) override
{
if (isUncacheable(ty) || isCached(ty))
return false;
bool uncacheable = false;
for (TypeId partTy : ut.options)
{
traverse(partTy);
uncacheable |= isUncacheable(partTy);
}
if (uncacheable)
markUncacheable(ty);
else
cache(ty);
return false;
}
bool visit(TypeId ty, const IntersectionType& it) override
{
if (isUncacheable(ty) || isCached(ty))
return false;
bool uncacheable = false;
for (TypeId partTy : it.parts)
{
traverse(partTy);
uncacheable |= isUncacheable(partTy);
}
if (uncacheable)
markUncacheable(ty);
else
cache(ty);
return false;
}
bool visit(TypeId ty, const UnknownType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const NeverType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const NegationType& nt) override
{
if (!isCached(ty) && !isUncacheable(ty))
{
traverse(nt.ty);
if (isUncacheable(nt.ty))
markUncacheable(ty);
else
cache(ty);
}
return false;
}
bool visit(TypeId ty, const TypeFamilyInstanceType& tfit) override
{
if (isCached(ty) || isUncacheable(ty))
return false;
bool uncacheable = false;
for (TypeId argTy : tfit.typeArguments)
{
traverse(argTy);
if (isUncacheable(argTy))
uncacheable = true;
}
for (TypePackId argPack : tfit.packArguments)
{
traverse(argPack);
if (isUncacheable(argPack))
uncacheable = true;
}
if (uncacheable)
markUncacheable(ty);
else
cache(ty);
return false;
}
bool visit(TypePackId tp, const FreeTypePack&) override
{
markUncacheable(tp);
return false;
}
bool visit(TypePackId tp, const VariadicTypePack& vtp) override
{
if (isUncacheable(tp))
return false;
traverse(vtp.ty);
if (isUncacheable(vtp.ty))
markUncacheable(tp);
return false;
}
bool visit(TypePackId tp, const BlockedTypePack&) override
{
markUncacheable(tp);
return false;
}
bool visit(TypePackId tp, const TypeFamilyInstanceTypePack&) override
{
markUncacheable(tp);
return false;
}
};
std::optional<TypeId> generalize(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> cachedTypes, TypeId ty, bool avoidSealingTables)
{
ty = follow(ty);
if (ty->owningArena != arena || ty->persistent)
return ty;
if (const FunctionType* ft = get<FunctionType>(ty); ft && (!ft->generics.empty() || !ft->genericPacks.empty()))
return ty;
FreeTypeSearcher fts{scope, cachedTypes};
fts.traverse(ty);
MutatingGeneralizer gen{builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables};
gen.traverse(ty);
/* MutatingGeneralizer mutates types in place, so it is possible that ty has
* been transmuted to a BoundType. We must follow it again and verify that
* we are allowed to mutate it before we attach generics to it.
*/
ty = follow(ty);
if (ty->owningArena != arena || ty->persistent)
return ty;
TypeCacher cacher{cachedTypes};
cacher.traverse(ty);
FunctionType* ftv = getMutable<FunctionType>(ty);
if (ftv)
{
ftv->generics = std::move(gen.generics);
ftv->genericPacks = std::move(gen.genericPacks);
}
return ty;
}
} // namespace Luau

View file

@ -11,10 +11,23 @@
#include <algorithm>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauReusableSubstitutions)
namespace Luau
{
void Instantiation::resetState(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope)
{
LUAU_ASSERT(FFlag::LuauReusableSubstitutions);
Substitution::resetState(log, arena);
this->builtinTypes = builtinTypes;
this->level = level;
this->scope = scope;
}
bool Instantiation::isDirty(TypeId ty)
{
if (const FunctionType* ftv = log->getMutable<FunctionType>(ty))
@ -58,13 +71,26 @@ TypeId Instantiation::clean(TypeId ty)
clone.argNames = ftv->argNames;
TypeId result = addType(std::move(clone));
// Annoyingly, we have to do this even if there are no generics,
// to replace any generic tables.
ReplaceGenerics replaceGenerics{log, arena, builtinTypes, level, scope, ftv->generics, ftv->genericPacks};
if (FFlag::LuauReusableSubstitutions)
{
// Annoyingly, we have to do this even if there are no generics,
// to replace any generic tables.
reusableReplaceGenerics.resetState(log, arena, builtinTypes, level, scope, ftv->generics, ftv->genericPacks);
// TODO: What to do if this returns nullopt?
// We don't have access to the error-reporting machinery
result = replaceGenerics.substitute(result).value_or(result);
// TODO: What to do if this returns nullopt?
// We don't have access to the error-reporting machinery
result = reusableReplaceGenerics.substitute(result).value_or(result);
}
else
{
// Annoyingly, we have to do this even if there are no generics,
// to replace any generic tables.
ReplaceGenerics replaceGenerics{log, arena, builtinTypes, level, scope, ftv->generics, ftv->genericPacks};
// TODO: What to do if this returns nullopt?
// We don't have access to the error-reporting machinery
result = replaceGenerics.substitute(result).value_or(result);
}
asMutable(result)->documentationSymbol = ty->documentationSymbol;
return result;
@ -76,6 +102,22 @@ TypePackId Instantiation::clean(TypePackId tp)
return tp;
}
void ReplaceGenerics::resetState(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope,
const std::vector<TypeId>& generics, const std::vector<TypePackId>& genericPacks)
{
LUAU_ASSERT(FFlag::LuauReusableSubstitutions);
Substitution::resetState(log, arena);
this->builtinTypes = builtinTypes;
this->level = level;
this->scope = scope;
this->generics = generics;
this->genericPacks = genericPacks;
}
bool ReplaceGenerics::ignoreChildren(TypeId ty)
{
if (const FunctionType* ftv = log->getMutable<FunctionType>(ty))

View file

@ -16,6 +16,11 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauAttributeSyntax)
LUAU_FASTFLAG(LuauAttribute)
LUAU_FASTFLAG(LuauNativeAttribute)
LUAU_FASTFLAGVARIABLE(LintRedundantNativeAttribute, false)
namespace Luau
{
@ -2922,6 +2927,64 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
}
}
static bool hasNativeCommentDirective(const std::vector<HotComment>& hotcomments)
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
LUAU_ASSERT(FFlag::LuauNativeAttribute);
LUAU_ASSERT(FFlag::LintRedundantNativeAttribute);
for (const HotComment& hc : hotcomments)
{
if (hc.content.empty() || hc.content[0] == ' ' || hc.content[0] == '\t')
continue;
if (hc.header)
{
size_t space = hc.content.find_first_of(" \t");
std::string_view first = std::string_view(hc.content).substr(0, space);
if (first == "native")
return true;
}
}
return false;
}
struct LintRedundantNativeAttribute : AstVisitor
{
public:
LUAU_NOINLINE static void process(LintContext& context)
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
LUAU_ASSERT(FFlag::LuauNativeAttribute);
LUAU_ASSERT(FFlag::LintRedundantNativeAttribute);
LintRedundantNativeAttribute pass;
pass.context = &context;
context.root->visit(&pass);
}
private:
LintContext* context;
bool visit(AstExprFunction* node) override
{
node->body->visit(this);
for (const auto attribute : node->attributes)
{
if (attribute->type == AstAttr::Type::Native)
{
emitWarning(*context, LintWarning::Code_RedundantNativeAttribute, attribute->location,
"native attribute on a function is redundant in a native module; consider removing it");
}
}
return false;
}
};
std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module,
const std::vector<HotComment>& hotcomments, const LintOptions& options)
{
@ -3008,6 +3071,13 @@ std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const Sc
if (context.warningEnabled(LintWarning::Code_ComparisonPrecedence))
LintComparisonPrecedence::process(context);
if (FFlag::LuauAttributeSyntax && FFlag::LuauNativeAttribute && FFlag::LintRedundantNativeAttribute &&
context.warningEnabled(LintWarning::Code_RedundantNativeAttribute))
{
if (hasNativeCommentDirective(hotcomments))
LintRedundantNativeAttribute::process(context);
}
std::sort(context.result.begin(), context.result.end(), WarningComparator());
return context.result;

View file

@ -17,23 +17,23 @@
LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false)
LUAU_FASTFLAGVARIABLE(LuauNormalizeAwayUninhabitableTables, false)
LUAU_FASTFLAGVARIABLE(LuauFixNormalizeCaching, false);
LUAU_FASTFLAGVARIABLE(LuauNormalizeNotUnknownIntersection, false);
LUAU_FASTFLAGVARIABLE(LuauFixCyclicUnionsOfIntersections, false);
LUAU_FASTFLAGVARIABLE(LuauFixReduceStackPressure, false);
LUAU_FASTFLAGVARIABLE(LuauFixCyclicTablesBlowingStack, false);
// This could theoretically be 2000 on amd64, but x86 requires this.
LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200);
LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
static bool fixNormalizeCaching()
static bool fixReduceStackPressure()
{
return FFlag::LuauFixNormalizeCaching || FFlag::DebugLuauDeferredConstraintResolution;
return FFlag::LuauFixReduceStackPressure || FFlag::DebugLuauDeferredConstraintResolution;
}
static bool fixCyclicUnionsOfIntersections()
static bool fixCyclicTablesBlowingStack()
{
return FFlag::LuauFixCyclicUnionsOfIntersections || FFlag::DebugLuauDeferredConstraintResolution;
return FFlag::LuauFixCyclicTablesBlowingStack || FFlag::DebugLuauDeferredConstraintResolution;
}
namespace Luau
@ -45,6 +45,14 @@ static bool normalizeAwayUninhabitableTables()
return FFlag::LuauNormalizeAwayUninhabitableTables || FFlag::DebugLuauDeferredConstraintResolution;
}
static bool shouldEarlyExit(NormalizationResult res)
{
// if res is hit limits, return control flow
if (res == NormalizationResult::HitLimits || res == NormalizationResult::False)
return true;
return false;
}
TypeIds::TypeIds(std::initializer_list<TypeId> tys)
{
for (TypeId ty : tys)
@ -339,6 +347,12 @@ bool NormalizedType::isSubtypeOfString() const
!hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars();
}
bool NormalizedType::isSubtypeOfBooleans() const
{
return hasBooleans() && !hasTops() && !hasClasses() && !hasErrors() && !hasNils() && !hasNumbers() && !hasStrings() && !hasThreads() &&
!hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars();
}
bool NormalizedType::shouldSuppressErrors() const
{
return hasErrors() || get<AnyType>(tops);
@ -547,22 +561,21 @@ NormalizationResult Normalizer::isInhabited(TypeId ty, Set<TypeId>& seen)
return isInhabited(mtv->metatable, seen);
}
if (fixNormalizeCaching())
{
std::shared_ptr<const NormalizedType> norm = normalize(ty);
return isInhabited(norm.get(), seen);
}
else
{
const NormalizedType* norm = DEPRECATED_normalize(ty);
return isInhabited(norm, seen);
}
std::shared_ptr<const NormalizedType> norm = normalize(ty);
return isInhabited(norm.get(), seen);
}
NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right)
{
Set<TypeId> seen{nullptr};
return isIntersectionInhabited(left, right, seen);
}
NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right, Set<TypeId>& seenSet)
{
left = follow(left);
right = follow(right);
// We're asking if intersection is inahbited between left and right but we've already seen them ....
if (cacheInhabitance)
{
@ -570,12 +583,8 @@ NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId righ
return *result ? NormalizationResult::True : NormalizationResult::False;
}
Set<TypeId> seen{nullptr};
seen.insert(left);
seen.insert(right);
NormalizedType norm{builtinTypes};
NormalizationResult res = normalizeIntersections({left, right}, norm);
NormalizationResult res = normalizeIntersections({left, right}, norm, seenSet);
if (res != NormalizationResult::True)
{
if (cacheInhabitance && res == NormalizationResult::False)
@ -584,7 +593,7 @@ NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId righ
return res;
}
NormalizationResult result = isInhabited(&norm, seen);
NormalizationResult result = isInhabited(&norm, seenSet);
if (cacheInhabitance && result == NormalizationResult::True)
cachedIsInhabitedIntersection[{left, right}] = true;
@ -856,31 +865,6 @@ Normalizer::Normalizer(TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, Not
{
}
const NormalizedType* Normalizer::DEPRECATED_normalize(TypeId ty)
{
if (!arena)
sharedState->iceHandler->ice("Normalizing types outside a module");
auto found = cachedNormals.find(ty);
if (found != cachedNormals.end())
return found->second.get();
NormalizedType norm{builtinTypes};
Set<TypeId> seenSetTypes{nullptr};
NormalizationResult res = unionNormalWithTy(norm, ty, seenSetTypes);
if (res != NormalizationResult::True)
return nullptr;
if (norm.isUnknown())
{
clearNormal(norm);
norm.tops = builtinTypes->unknownType;
}
std::shared_ptr<NormalizedType> shared = std::make_shared<NormalizedType>(std::move(norm));
const NormalizedType* result = shared.get();
cachedNormals[ty] = std::move(shared);
return result;
}
static bool isCacheable(TypeId ty, Set<TypeId>& seen);
static bool isCacheable(TypePackId tp, Set<TypeId>& seen)
@ -935,9 +919,6 @@ static bool isCacheable(TypeId ty, Set<TypeId>& seen)
static bool isCacheable(TypeId ty)
{
if (!fixNormalizeCaching())
return true;
Set<TypeId> seen{nullptr};
return isCacheable(ty, seen);
}
@ -971,7 +952,7 @@ std::shared_ptr<const NormalizedType> Normalizer::normalize(TypeId ty)
return shared;
}
NormalizationResult Normalizer::normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType)
NormalizationResult Normalizer::normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType, Set<TypeId>& seenSet)
{
if (!arena)
sharedState->iceHandler->ice("Normalizing types outside a module");
@ -981,7 +962,7 @@ NormalizationResult Normalizer::normalizeIntersections(const std::vector<TypeId>
Set<TypeId> seenSetTypes{nullptr};
for (auto ty : intersections)
{
NormalizationResult res = intersectNormalWithTy(norm, ty, seenSetTypes);
NormalizationResult res = intersectNormalWithTy(norm, ty, seenSet);
if (res != NormalizationResult::True)
return res;
}
@ -1729,6 +1710,20 @@ bool Normalizer::withinResourceLimits()
return true;
}
NormalizationResult Normalizer::intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect)
{
std::optional<NormalizedType> negated;
std::shared_ptr<const NormalizedType> normal = normalize(toNegate);
negated = negateNormal(*normal);
if (!negated)
return NormalizationResult::False;
intersectNormals(intersect, *negated);
return NormalizationResult::True;
}
// See above for an explaination of `ignoreSmallerTyvars`.
NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes, int ignoreSmallerTyvars)
{
@ -1775,12 +1770,9 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
}
else if (const IntersectionType* itv = get<IntersectionType>(there))
{
if (fixCyclicUnionsOfIntersections())
{
if (seenSetTypes.count(there))
return NormalizationResult::True;
seenSetTypes.insert(there);
}
if (seenSetTypes.count(there))
return NormalizationResult::True;
seenSetTypes.insert(there);
NormalizedType norm{builtinTypes};
norm.tops = builtinTypes->anyType;
@ -1789,14 +1781,12 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
NormalizationResult res = intersectNormalWithTy(norm, *it, seenSetTypes);
if (res != NormalizationResult::True)
{
if (fixCyclicUnionsOfIntersections())
seenSetTypes.erase(there);
seenSetTypes.erase(there);
return res;
}
}
if (fixCyclicUnionsOfIntersections())
seenSetTypes.erase(there);
seenSetTypes.erase(there);
return unionNormals(here, norm);
}
@ -1814,12 +1804,6 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
if (!isCacheable(there))
here.isCacheable = false;
}
else if (auto lt = get<LocalType>(there))
{
// FIXME? This is somewhat questionable.
// Maybe we should assert because this should never happen?
unionNormalWithTy(here, lt->domain, seenSetTypes, ignoreSmallerTyvars);
}
else if (get<FunctionType>(there))
unionFunctionsWithFunction(here.functions, there);
else if (get<TableType>(there) || get<MetatableType>(there))
@ -1876,16 +1860,8 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
{
std::optional<NormalizedType> tn;
if (fixNormalizeCaching())
{
std::shared_ptr<const NormalizedType> thereNormal = normalize(ntv->ty);
tn = negateNormal(*thereNormal);
}
else
{
const NormalizedType* thereNormal = DEPRECATED_normalize(ntv->ty);
tn = negateNormal(*thereNormal);
}
std::shared_ptr<const NormalizedType> thereNormal = normalize(ntv->ty);
tn = negateNormal(*thereNormal);
if (!tn)
return NormalizationResult::False;
@ -2484,7 +2460,7 @@ std::optional<TypePackId> Normalizer::intersectionOfTypePacks(TypePackId here, T
return arena->addTypePack({});
}
std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there)
std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there, Set<TypeId>& seenSet)
{
if (here == there)
return here;
@ -2541,8 +2517,9 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
state = tttv->state;
TypeLevel level = max(httv->level, tttv->level);
TableType result{state, level};
Scope* scope = max(httv->scope, tttv->scope);
std::unique_ptr<TableType> result = nullptr;
bool hereSubThere = true;
bool thereSubHere = true;
@ -2563,8 +2540,43 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
if (tprop.readTy.has_value())
{
// if the intersection of the read types of a property is uninhabited, the whole table is `never`.
if (normalizeAwayUninhabitableTables() && NormalizationResult::False == isIntersectionInhabited(*hprop.readTy, *tprop.readTy))
return {builtinTypes->neverType};
if (fixReduceStackPressure())
{
// We've seen these table prop elements before and we're about to ask if their intersection
// is inhabited
if (fixCyclicTablesBlowingStack())
{
if (seenSet.contains(*hprop.readTy) && seenSet.contains(*tprop.readTy))
{
seenSet.erase(*hprop.readTy);
seenSet.erase(*tprop.readTy);
return {builtinTypes->neverType};
}
else
{
seenSet.insert(*hprop.readTy);
seenSet.insert(*tprop.readTy);
}
}
NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy, seenSet);
// Cleanup
if (fixCyclicTablesBlowingStack())
{
seenSet.erase(*hprop.readTy);
seenSet.erase(*tprop.readTy);
}
if (normalizeAwayUninhabitableTables() && NormalizationResult::True != res)
return {builtinTypes->neverType};
}
else
{
if (normalizeAwayUninhabitableTables() &&
NormalizationResult::False == isIntersectionInhabited(*hprop.readTy, *tprop.readTy))
return {builtinTypes->neverType};
}
TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result;
prop.readTy = ty;
@ -2614,14 +2626,21 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
// TODO: string indexers
if (prop.readTy || prop.writeTy)
result.props[name] = prop;
{
if (!result.get())
result = std::make_unique<TableType>(TableType{state, level, scope});
result->props[name] = prop;
}
}
for (const auto& [name, tprop] : tttv->props)
{
if (httv->props.count(name) == 0)
{
result.props[name] = tprop;
if (!result.get())
result = std::make_unique<TableType>(TableType{state, level, scope});
result->props[name] = tprop;
hereSubThere = false;
}
}
@ -2631,18 +2650,24 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
// TODO: What should intersection of indexes be?
TypeId index = unionType(httv->indexer->indexType, tttv->indexer->indexType);
TypeId indexResult = intersectionType(httv->indexer->indexResultType, tttv->indexer->indexResultType);
result.indexer = {index, indexResult};
if (!result.get())
result = std::make_unique<TableType>(TableType{state, level, scope});
result->indexer = {index, indexResult};
hereSubThere &= (httv->indexer->indexType == index) && (httv->indexer->indexResultType == indexResult);
thereSubHere &= (tttv->indexer->indexType == index) && (tttv->indexer->indexResultType == indexResult);
}
else if (httv->indexer)
{
result.indexer = httv->indexer;
if (!result.get())
result = std::make_unique<TableType>(TableType{state, level, scope});
result->indexer = httv->indexer;
thereSubHere = false;
}
else if (tttv->indexer)
{
result.indexer = tttv->indexer;
if (!result.get())
result = std::make_unique<TableType>(TableType{state, level, scope});
result->indexer = tttv->indexer;
hereSubThere = false;
}
@ -2652,12 +2677,17 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
else if (thereSubHere)
table = ttable;
else
table = arena->addType(std::move(result));
{
if (result.get())
table = arena->addType(std::move(*result));
else
table = arena->addType(TableType{state, level, scope});
}
if (tmtable && hmtable)
{
// NOTE: this assumes metatables are ivariant
if (std::optional<TypeId> mtable = intersectionOfTables(hmtable, tmtable))
if (std::optional<TypeId> mtable = intersectionOfTables(hmtable, tmtable, seenSet))
{
if (table == htable && *mtable == hmtable)
return here;
@ -2687,12 +2717,12 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
return table;
}
void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there)
void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there, Set<TypeId>& seenSetTypes)
{
TypeIds tmp;
for (TypeId here : heres)
{
if (std::optional<TypeId> inter = intersectionOfTables(here, there))
if (std::optional<TypeId> inter = intersectionOfTables(here, there, seenSetTypes))
tmp.insert(*inter);
}
heres.retain(tmp);
@ -2706,7 +2736,8 @@ void Normalizer::intersectTables(TypeIds& heres, const TypeIds& theres)
{
for (TypeId there : theres)
{
if (std::optional<TypeId> inter = intersectionOfTables(here, there))
Set<TypeId> seenSetTypes{nullptr};
if (std::optional<TypeId> inter = intersectionOfTables(here, there, seenSetTypes))
tmp.insert(*inter);
}
}
@ -3047,7 +3078,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
return NormalizationResult::True;
}
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) ||
get<TypeFamilyInstanceType>(there) || get<LocalType>(there))
get<TypeFamilyInstanceType>(there))
{
NormalizedType thereNorm{builtinTypes};
NormalizedType topNorm{builtinTypes};
@ -3056,10 +3087,6 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
here.isCacheable = false;
return intersectNormals(here, thereNorm);
}
else if (auto lt = get<LocalType>(there))
{
return intersectNormalWithTy(here, lt->domain, seenSetTypes);
}
NormalizedTyvars tyvars = std::move(here.tyvars);
@ -3074,7 +3101,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
{
TypeIds tables = std::move(here.tables);
clearNormal(here);
intersectTablesWithTable(tables, there);
intersectTablesWithTable(tables, there, seenSetTypes);
here.tables = std::move(tables);
}
else if (get<ClassType>(there))
@ -3148,60 +3175,17 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
subtractSingleton(here, follow(ntv->ty));
else if (get<ClassType>(t))
{
if (fixNormalizeCaching())
{
std::shared_ptr<const NormalizedType> normal = normalize(t);
std::optional<NormalizedType> negated = negateNormal(*normal);
if (!negated)
return NormalizationResult::False;
intersectNormals(here, *negated);
}
else
{
const NormalizedType* normal = DEPRECATED_normalize(t);
std::optional<NormalizedType> negated = negateNormal(*normal);
if (!negated)
return NormalizationResult::False;
intersectNormals(here, *negated);
}
NormalizationResult res = intersectNormalWithNegationTy(t, here);
if (shouldEarlyExit(res))
return res;
}
else if (const UnionType* itv = get<UnionType>(t))
{
if (fixNormalizeCaching())
for (TypeId part : itv->options)
{
for (TypeId part : itv->options)
{
std::shared_ptr<const NormalizedType> normalPart = normalize(part);
std::optional<NormalizedType> negated = negateNormal(*normalPart);
if (!negated)
return NormalizationResult::False;
intersectNormals(here, *negated);
}
}
else
{
if (fixNormalizeCaching())
{
for (TypeId part : itv->options)
{
std::shared_ptr<const NormalizedType> normalPart = normalize(part);
std::optional<NormalizedType> negated = negateNormal(*normalPart);
if (!negated)
return NormalizationResult::False;
intersectNormals(here, *negated);
}
}
else
{
for (TypeId part : itv->options)
{
const NormalizedType* normalPart = DEPRECATED_normalize(part);
std::optional<NormalizedType> negated = negateNormal(*normalPart);
if (!negated)
return NormalizationResult::False;
intersectNormals(here, *negated);
}
}
NormalizationResult res = intersectNormalWithNegationTy(part, here);
if (shouldEarlyExit(res))
return res;
}
}
else if (get<AnyType>(t))

View file

@ -1,5 +0,0 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Common.h"
LUAU_FASTFLAGVARIABLE(LuauFixSetIter, false)

View file

@ -1255,6 +1255,10 @@ TypeId TypeSimplifier::union_(TypeId left, TypeId right)
case Relation::Coincident:
case Relation::Superset:
return left;
case Relation::Subset:
newParts.insert(right);
changed = true;
break;
default:
newParts.insert(part);
newParts.insert(right);
@ -1364,6 +1368,17 @@ SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<
return SimplifyResult{res, std::move(s.blockedTypes)};
}
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, std::set<TypeId> parts)
{
LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution);
TypeSimplifier s{builtinTypes, arena};
TypeId res = s.intersectFromParts(std::move(parts));
return SimplifyResult{res, std::move(s.blockedTypes)};
}
SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right)
{
LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution);

View file

@ -11,6 +11,7 @@
LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTINTVARIABLE(LuauTarjanPreallocationSize, 256);
LUAU_FASTFLAG(LuauReusableSubstitutions)
namespace Luau
{
@ -24,8 +25,6 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
// We decline to copy them.
if constexpr (std::is_same_v<T, FreeType>)
return ty;
else if constexpr (std::is_same_v<T, LocalType>)
return ty;
else if constexpr (std::is_same_v<T, BoundType>)
{
// This should never happen, but visit() cannot see it.
@ -148,6 +147,8 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
}
Tarjan::Tarjan()
: typeToIndex(nullptr, FFlag::LuauReusableSubstitutions ? FInt::LuauTarjanPreallocationSize : 0)
, packToIndex(nullptr, FFlag::LuauReusableSubstitutions ? FInt::LuauTarjanPreallocationSize : 0)
{
nodes.reserve(FInt::LuauTarjanPreallocationSize);
stack.reserve(FInt::LuauTarjanPreallocationSize);
@ -448,14 +449,31 @@ TarjanResult Tarjan::visitRoot(TypePackId tp)
return loop();
}
void Tarjan::clearTarjan()
void Tarjan::clearTarjan(const TxnLog* log)
{
typeToIndex.clear();
packToIndex.clear();
if (FFlag::LuauReusableSubstitutions)
{
typeToIndex.clear(~0u);
packToIndex.clear(~0u);
}
else
{
typeToIndex.clear();
packToIndex.clear();
}
nodes.clear();
stack.clear();
if (FFlag::LuauReusableSubstitutions)
{
childCount = 0;
// childLimit setting stays the same
this->log = log;
}
edgesTy.clear();
edgesTp.clear();
worklist.clear();
@ -530,7 +548,6 @@ Substitution::Substitution(const TxnLog* log_, TypeArena* arena)
{
log = log_;
LUAU_ASSERT(log);
LUAU_ASSERT(arena);
}
void Substitution::dontTraverseInto(TypeId ty)
@ -548,7 +565,7 @@ std::optional<TypeId> Substitution::substitute(TypeId ty)
ty = log->follow(ty);
// clear algorithm state for reentrancy
clearTarjan();
clearTarjan(log);
auto result = findDirty(ty);
if (result != TarjanResult::Ok)
@ -581,7 +598,7 @@ std::optional<TypePackId> Substitution::substitute(TypePackId tp)
tp = log->follow(tp);
// clear algorithm state for reentrancy
clearTarjan();
clearTarjan(log);
auto result = findDirty(tp);
if (result != TarjanResult::Ok)
@ -609,6 +626,23 @@ std::optional<TypePackId> Substitution::substitute(TypePackId tp)
return newTp;
}
void Substitution::resetState(const TxnLog* log, TypeArena* arena)
{
LUAU_ASSERT(FFlag::LuauReusableSubstitutions);
clearTarjan(log);
this->arena = arena;
newTypes.clear();
newPacks.clear();
replacedTypes.clear();
replacedTypePacks.clear();
noTraverseTypes.clear();
noTraverseTypePacks.clear();
}
TypeId Substitution::clone(TypeId ty)
{
return shallowClone(ty, *arena, log, /* alwaysClone */ true);

View file

@ -1438,6 +1438,7 @@ SubtypingResult Subtyping::isCovariantWith(
result.andAlso(isCovariantWith(env, subNorm->strings, superNorm->strings));
result.andAlso(isCovariantWith(env, subNorm->strings, superNorm->tables));
result.andAlso(isCovariantWith(env, subNorm->threads, superNorm->threads));
result.andAlso(isCovariantWith(env, subNorm->buffers, superNorm->buffers));
result.andAlso(isCovariantWith(env, subNorm->tables, superNorm->tables));
result.andAlso(isCovariantWith(env, subNorm->functions, superNorm->functions));
// isCovariantWith(subNorm->tyvars, superNorm->tyvars);

View file

@ -337,7 +337,9 @@ TypeId matchLiteralType(NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes,
TypeId matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier,
expectedTableTy->indexer->indexResultType, *propTy, item.value, toBlock);
tableTy->indexer->indexResultType = matchedType;
// if the index result type is the prop type, we can replace it with the matched type here.
if (tableTy->indexer->indexResultType == *propTy)
tableTy->indexer->indexResultType = matchedType;
}
}
else if (item.kind == AstExprTable::Item::General)

View file

@ -262,14 +262,6 @@ void StateDot::visitChildren(TypeId ty, int index)
visitChild(t.upperBound, index, "[upperBound]");
}
}
else if constexpr (std::is_same_v<T, LocalType>)
{
formatAppend(result, "LocalType");
finishNodeLabel(ty);
finishNode();
visitChild(t.domain, 1, "[domain]");
}
else if constexpr (std::is_same_v<T, AnyType>)
{
formatAppend(result, "AnyType %d", index);

View file

@ -20,7 +20,6 @@
#include <string>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAGVARIABLE(LuauToStringiteTypesSingleLine, false)
/*
* Enables increasing levels of verbosity for Luau type names when stringifying.
@ -101,16 +100,6 @@ struct FindCyclicTypes final : TypeVisitor
return false;
}
bool visit(TypeId ty, const LocalType& lt) override
{
if (!visited.insert(ty))
return false;
traverse(lt.domain);
return false;
}
bool visit(TypeId ty, const TableType& ttv) override
{
if (!visited.insert(ty))
@ -526,21 +515,6 @@ struct TypeStringifier
}
}
void operator()(TypeId ty, const LocalType& lt)
{
state.emit("l-");
state.emit(lt.name);
if (FInt::DebugLuauVerboseTypeNames >= 1)
{
state.emit("[");
state.emit(lt.blockCount);
state.emit("]");
}
state.emit("=[");
stringify(lt.domain);
state.emit("]");
}
void operator()(TypeId, const BoundType& btv)
{
stringify(btv.boundTo);
@ -1725,6 +1699,18 @@ std::string generateName(size_t i)
return n;
}
std::string toStringVector(const std::vector<TypeId>& types, ToStringOptions& opts)
{
std::string s;
for (TypeId ty : types)
{
if (!s.empty())
s += ", ";
s += toString(ty, opts);
}
return s;
}
std::string toString(const Constraint& constraint, ToStringOptions& opts)
{
auto go = [&opts](auto&& c) -> std::string {
@ -1755,7 +1741,7 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
else if constexpr (std::is_same_v<T, IterableConstraint>)
{
std::string iteratorStr = tos(c.iterator);
std::string variableStr = tos(c.variables);
std::string variableStr = toStringVector(c.variables, opts);
return variableStr + " ~ iterate " + iteratorStr;
}
@ -1788,23 +1774,16 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
{
return tos(c.resultType) + " ~ hasProp " + tos(c.subjectType) + ", \"" + c.prop + "\" ctx=" + std::to_string(int(c.context));
}
else if constexpr (std::is_same_v<T, SetPropConstraint>)
{
const std::string pathStr = c.path.size() == 1 ? "\"" + c.path[0] + "\"" : "[\"" + join(c.path, "\", \"") + "\"]";
return tos(c.resultType) + " ~ setProp " + tos(c.subjectType) + ", " + pathStr + " " + tos(c.propType);
}
else if constexpr (std::is_same_v<T, HasIndexerConstraint>)
{
return tos(c.resultType) + " ~ hasIndexer " + tos(c.subjectType) + " " + tos(c.indexType);
}
else if constexpr (std::is_same_v<T, SetIndexerConstraint>)
{
return "setIndexer " + tos(c.subjectType) + " [ " + tos(c.indexType) + " ] " + tos(c.propType);
}
else if constexpr (std::is_same_v<T, AssignPropConstraint>)
return "assignProp " + tos(c.lhsType) + " " + c.propName + " " + tos(c.rhsType);
else if constexpr (std::is_same_v<T, AssignIndexConstraint>)
return "assignIndex " + tos(c.lhsType) + " " + tos(c.indexType) + " " + tos(c.rhsType);
else if constexpr (std::is_same_v<T, UnpackConstraint>)
return tos(c.resultPack) + " ~ ...unpack " + tos(c.sourcePack);
else if constexpr (std::is_same_v<T, Unpack1Constraint>)
return tos(c.resultType) + " ~ unpack " + tos(c.sourceType);
return toStringVector(c.resultPack, opts) + " ~ ...unpack " + tos(c.sourcePack);
else if constexpr (std::is_same_v<T, ReduceConstraint>)
return "reduce " + tos(c.ty);
else if constexpr (std::is_same_v<T, ReducePackConstraint>)

View file

@ -1182,11 +1182,11 @@ std::string toString(AstNode* node)
Printer printer(writer);
printer.writeTypes = true;
if (auto statNode = dynamic_cast<AstStat*>(node))
if (auto statNode = node->asStat())
printer.visualize(*statNode);
else if (auto exprNode = dynamic_cast<AstExpr*>(node))
else if (auto exprNode = node->asExpr())
printer.visualize(*exprNode);
else if (auto typeNode = dynamic_cast<AstType*>(node))
else if (auto typeNode = node->asType())
printer.visualizeTypeAnnotation(*typeNode);
return writer.str();

View file

@ -561,6 +561,11 @@ void BlockedType::setOwner(Constraint* newOwner)
owner = newOwner;
}
void BlockedType::replaceOwner(Constraint* newOwner)
{
owner = newOwner;
}
PendingExpansionType::PendingExpansionType(
std::optional<AstName> prefix, AstName name, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments)
: prefix(prefix)

View file

@ -338,10 +338,6 @@ public:
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("free"), std::nullopt, Location());
}
AstType* operator()(const LocalType& lt)
{
return Luau::visit(*this, lt.domain->ty);
}
AstType* operator()(const UnionType& uv)
{
AstArray<AstType*> unionTypes;

View file

@ -446,7 +446,6 @@ struct TypeChecker2
.errors;
if (!isErrorSuppressing(location, instance))
reportErrors(std::move(errors));
return instance;
}
@ -1108,10 +1107,13 @@ struct TypeChecker2
void visit(AstStatCompoundAssign* stat)
{
AstExprBinary fake{stat->location, stat->op, stat->var, stat->value};
TypeId resultTy = visit(&fake, stat);
visit(&fake, stat);
TypeId* resultTy = module->astCompoundAssignResultTypes.find(stat);
LUAU_ASSERT(resultTy);
TypeId varTy = lookupType(stat->var);
testIsSubtype(resultTy, varTy, stat->location);
testIsSubtype(*resultTy, varTy, stat->location);
}
void visit(AstStatFunction* stat)
@ -1242,13 +1244,14 @@ struct TypeChecker2
void visit(AstExprConstantBool* expr)
{
#if defined(LUAU_ENABLE_ASSERT)
// booleans use specialized inference logic for singleton types, which can lead to real type errors here.
const TypeId bestType = expr->value ? builtinTypes->trueType : builtinTypes->falseType;
const TypeId inferredType = lookupType(expr);
const SubtypingResult r = subtyping->isSubtype(bestType, inferredType);
LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, inferredType));
#endif
if (!r.isSubtype && !isErrorSuppressing(expr->location, inferredType))
reportError(TypeMismatch{inferredType, bestType}, expr->location);
}
void visit(AstExprConstantNumber* expr)
@ -1264,13 +1267,14 @@ struct TypeChecker2
void visit(AstExprConstantString* expr)
{
#if defined(LUAU_ENABLE_ASSERT)
// strings use specialized inference logic for singleton types, which can lead to real type errors here.
const TypeId bestType = module->internalTypes.addType(SingletonType{StringSingleton{std::string{expr->value.data, expr->value.size}}});
const TypeId inferredType = lookupType(expr);
const SubtypingResult r = subtyping->isSubtype(bestType, inferredType);
LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, inferredType));
#endif
if (!r.isSubtype && !isErrorSuppressing(expr->location, inferredType))
reportError(TypeMismatch{inferredType, bestType}, expr->location);
}
void visit(AstExprLocal* expr)
@ -1280,7 +1284,9 @@ struct TypeChecker2
void visit(AstExprGlobal* expr)
{
// TODO!
NotNull<Scope> scope = stack.back();
if (!scope->lookup(expr->name))
reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location);
}
void visit(AstExprVarargs* expr)
@ -1534,6 +1540,24 @@ struct TypeChecker2
visitExprName(indexName->expr, indexName->location, indexName->index.value, context, builtinTypes->stringType);
}
void indexExprMetatableHelper(AstExprIndexExpr* indexExpr, const MetatableType* metaTable, TypeId exprType, TypeId indexType)
{
if (auto tt = get<TableType>(follow(metaTable->table)); tt && tt->indexer)
testIsSubtype(indexType, tt->indexer->indexType, indexExpr->index->location);
else if (auto mt = get<MetatableType>(follow(metaTable->table)))
indexExprMetatableHelper(indexExpr, mt, exprType, indexType);
else if (auto tmt = get<TableType>(follow(metaTable->metatable)); tmt && tmt->indexer)
testIsSubtype(indexType, tmt->indexer->indexType, indexExpr->index->location);
else if (auto mtmt = get<MetatableType>(follow(metaTable->metatable)))
indexExprMetatableHelper(indexExpr, mtmt, exprType, indexType);
else
{
LUAU_ASSERT(tt || get<PrimitiveType>(follow(metaTable->table)));
reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location);
}
}
void visit(AstExprIndexExpr* indexExpr, ValueContext context)
{
if (auto str = indexExpr->index->as<AstExprConstantString>())
@ -1557,6 +1581,10 @@ struct TypeChecker2
else
reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location);
}
else if (auto mt = get<MetatableType>(exprType))
{
return indexExprMetatableHelper(indexExpr, mt, exprType, indexType);
}
else if (auto cls = get<ClassType>(exprType))
{
if (cls->indexer)
@ -1577,6 +1605,19 @@ struct TypeChecker2
reportError(OptionalValueAccess{exprType}, indexExpr->location);
}
}
else if (auto exprIntersection = get<IntersectionType>(exprType))
{
for (TypeId part : exprIntersection)
{
(void)part;
}
}
else if (get<NeverType>(exprType) || isErrorSuppressing(indexExpr->location, exprType))
{
// Nothing
}
else
reportError(NotATable{exprType}, indexExpr->location);
}
void visit(AstExprFunction* fn)
@ -1589,7 +1630,6 @@ struct TypeChecker2
functionDeclStack.push_back(inferredFnTy);
std::shared_ptr<const NormalizedType> normalizedFnTy = normalizer.normalize(inferredFnTy);
const FunctionType* inferredFtv = get<FunctionType>(normalizedFnTy->functions.parts.front());
if (!normalizedFnTy)
{
reportError(CodeTooComplex{}, fn->location);
@ -1684,16 +1724,23 @@ struct TypeChecker2
if (fn->returnAnnotation)
visit(*fn->returnAnnotation);
// If the function type has a family annotation, we need to see if we can suggest an annotation
TypeFamilyReductionGuesser guesser{NotNull{&module->internalTypes}, builtinTypes, NotNull{&normalizer}};
for (TypeId retTy : inferredFtv->retTypes)
if (normalizedFnTy)
{
if (get<TypeFamilyInstanceType>(follow(retTy)))
const FunctionType* inferredFtv = get<FunctionType>(normalizedFnTy->functions.parts.front());
LUAU_ASSERT(inferredFtv);
TypeFamilyReductionGuesser guesser{NotNull{&module->internalTypes}, builtinTypes, NotNull{&normalizer}};
for (TypeId retTy : inferredFtv->retTypes)
{
TypeFamilyReductionGuessResult result = guesser.guessTypeFamilyReductionForFunction(*fn, inferredFtv, retTy);
if (result.shouldRecommendAnnotation)
reportError(
ExplicitFunctionAnnotationRecommended{std::move(result.guessedFunctionAnnotations), result.guessedReturnType}, fn->location);
if (get<TypeFamilyInstanceType>(follow(retTy)))
{
TypeFamilyReductionGuessResult result = guesser.guessTypeFamilyReductionForFunction(*fn, inferredFtv, retTy);
if (result.shouldRecommendAnnotation)
reportError(ExplicitFunctionAnnotationRecommended{std::move(result.guessedFunctionAnnotations), result.guessedReturnType},
fn->location);
}
}
}
@ -1822,7 +1869,7 @@ struct TypeChecker2
bool isStringOperation =
(normLeft ? normLeft->isSubtypeOfString() : isString(leftType)) && (normRight ? normRight->isSubtypeOfString() : isString(rightType));
leftType = follow(leftType);
if (get<AnyType>(leftType) || get<ErrorType>(leftType) || get<NeverType>(leftType))
return leftType;
else if (get<AnyType>(rightType) || get<ErrorType>(rightType) || get<NeverType>(rightType))
@ -2091,24 +2138,39 @@ struct TypeChecker2
TypeId annotationType = lookupAnnotation(expr->annotation);
TypeId computedType = lookupType(expr->expr);
// Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case.
if (subtyping->isSubtype(annotationType, computedType).isSubtype)
return;
if (subtyping->isSubtype(computedType, annotationType).isSubtype)
return;
switch (shouldSuppressErrors(NotNull{&normalizer}, computedType).orElse(shouldSuppressErrors(NotNull{&normalizer}, annotationType)))
{
case ErrorSuppression::Suppress:
return;
case ErrorSuppression::NormalizationFailed:
reportError(NormalizationTooComplex{}, expr->location);
return;
case ErrorSuppression::DoNotSuppress:
break;
}
reportError(TypesAreUnrelated{computedType, annotationType}, expr->location);
switch (normalizer.isInhabited(computedType))
{
case NormalizationResult::True:
break;
case NormalizationResult::False:
return;
case NormalizationResult::HitLimits:
reportError(NormalizationTooComplex{}, expr->location);
return;
}
switch (normalizer.isIntersectionInhabited(computedType, annotationType))
{
case NormalizationResult::True:
return;
case NormalizationResult::False:
reportError(TypesAreUnrelated{computedType, annotationType}, expr->location);
break;
case NormalizationResult::HitLimits:
reportError(NormalizationTooComplex{}, expr->location);
break;
}
}
void visit(AstExprIfElse* expr)
@ -2710,6 +2772,8 @@ struct TypeChecker2
fetch(builtinTypes->stringType);
if (normValid)
fetch(norm->threads);
if (normValid)
fetch(norm->buffers);
if (normValid)
{

File diff suppressed because it is too large Load diff

View file

@ -33,13 +33,12 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false)
LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false)
LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauMetatableInstantiationCloneCheck, false)
LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false)
LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false)
LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false)
LUAU_FASTFLAGVARIABLE(LuauForbidAliasNamedTypeof, false)
LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false)
LUAU_FASTFLAG(LuauFixNormalizeCaching)
LUAU_FASTFLAGVARIABLE(LuauReusableSubstitutions, false)
LUAU_FASTFLAG(LuauDeclarationExtraPropData)
namespace Luau
{
@ -216,6 +215,7 @@ TypeChecker::TypeChecker(const ScopePtr& globalScope, ModuleResolver* resolver,
, iceHandler(iceHandler)
, unifierState(iceHandler)
, normalizer(nullptr, builtinTypes, NotNull{&unifierState})
, reusableInstantiation(TxnLog::empty(), nullptr, builtinTypes, {}, nullptr)
, nilType(builtinTypes->nilType)
, numberType(builtinTypes->numberType)
, stringType(builtinTypes->stringType)
@ -668,7 +668,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std
{
if (const auto& typealias = stat->as<AstStatTypeAlias>())
{
if (typealias->name == kParseNameError || (FFlag::LuauForbidAliasNamedTypeof && typealias->name == "typeof"))
if (typealias->name == kParseNameError || typealias->name == "typeof")
continue;
auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings;
@ -1536,7 +1536,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& ty
if (name == kParseNameError)
return ControlFlow::None;
if (FFlag::LuauForbidAliasNamedTypeof && name == "typeof")
if (name == "typeof")
{
reportError(typealias.location, GenericError{"Type aliases cannot be named typeof"});
return ControlFlow::None;
@ -1657,7 +1657,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea
// If the alias is missing a name, we can't do anything with it. Ignore it.
// Also, typeof is not a valid type alias name. We will report an error for
// this in check()
if (name == kParseNameError || (FFlag::LuauForbidAliasNamedTypeof && name == "typeof"))
if (name == kParseNameError || name == "typeof")
return;
std::optional<TypeFun> binding;
@ -1784,12 +1784,55 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass&
ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}});
ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes});
ftv->hasSelf = true;
if (FFlag::LuauDeclarationExtraPropData)
{
FunctionDefinition defn;
defn.definitionModuleName = currentModule->name;
defn.definitionLocation = prop.location;
// No data is preserved for varargLocation
defn.originalNameLocation = prop.nameLocation;
ftv->definition = defn;
}
}
}
if (assignTo.count(propName) == 0)
{
assignTo[propName] = {propTy};
if (FFlag::LuauDeclarationExtraPropData)
assignTo[propName] = {propTy, /*deprecated*/ false, /*deprecatedSuggestion*/ "", prop.location};
else
assignTo[propName] = {propTy};
}
else if (FFlag::LuauDeclarationExtraPropData)
{
Luau::Property& prop = assignTo[propName];
TypeId currentTy = prop.type();
// We special-case this logic to keep the intersection flat; otherwise we
// would create a ton of nested intersection types.
if (const IntersectionType* itv = get<IntersectionType>(currentTy))
{
std::vector<TypeId> options = itv->parts;
options.push_back(propTy);
TypeId newItv = addType(IntersectionType{std::move(options)});
prop.readTy = newItv;
prop.writeTy = newItv;
}
else if (get<FunctionType>(currentTy))
{
TypeId intersection = addType(IntersectionType{{currentTy, propTy}});
prop.readTy = intersection;
prop.writeTy = intersection;
}
else
{
reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())});
}
}
else
{
@ -1841,7 +1884,18 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFuncti
TypePackId argPack = resolveTypePack(funScope, global.params);
TypePackId retPack = resolveTypePack(funScope, global.retTypes);
TypeId fnType = addType(FunctionType{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack});
FunctionDefinition defn;
if (FFlag::LuauDeclarationExtraPropData)
{
defn.definitionModuleName = currentModule->name;
defn.definitionLocation = global.location;
defn.varargLocation = global.vararg ? std::make_optional(global.varargLocation) : std::nullopt;
defn.originalNameLocation = global.nameLocation;
}
TypeId fnType = addType(FunctionType{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack, defn});
FunctionType* ftv = getMutable<FunctionType>(fnType);
ftv->argNames.reserve(global.paramNames.size);
@ -2649,24 +2703,12 @@ static std::optional<bool> areEqComparable(NotNull<TypeArena> arena, NotNull<Nor
NormalizationResult nr;
if (FFlag::LuauFixNormalizeCaching)
{
TypeId c = arena->addType(IntersectionType{{a, b}});
std::shared_ptr<const NormalizedType> n = normalizer->normalize(c);
if (!n)
return std::nullopt;
TypeId c = arena->addType(IntersectionType{{a, b}});
std::shared_ptr<const NormalizedType> n = normalizer->normalize(c);
if (!n)
return std::nullopt;
nr = normalizer->isInhabited(n.get());
}
else
{
TypeId c = arena->addType(IntersectionType{{a, b}});
const NormalizedType* n = normalizer->DEPRECATED_normalize(c);
if (!n)
return std::nullopt;
nr = normalizer->isInhabited(n);
}
nr = normalizer->isInhabited(n.get());
switch (nr)
{
@ -4879,12 +4921,27 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat
if (ftv && ftv->hasNoFreeOrGenericTypes)
return ty;
Instantiation instantiation{log, &currentModule->internalTypes, builtinTypes, scope->level, /*scope*/ nullptr};
std::optional<TypeId> instantiated;
if (instantiationChildLimit)
instantiation.childLimit = *instantiationChildLimit;
if (FFlag::LuauReusableSubstitutions)
{
reusableInstantiation.resetState(log, &currentModule->internalTypes, builtinTypes, scope->level, /*scope*/ nullptr);
if (instantiationChildLimit)
reusableInstantiation.childLimit = *instantiationChildLimit;
instantiated = reusableInstantiation.substitute(ty);
}
else
{
Instantiation instantiation{log, &currentModule->internalTypes, builtinTypes, scope->level, /*scope*/ nullptr};
if (instantiationChildLimit)
instantiation.childLimit = *instantiationChildLimit;
instantiated = instantiation.substitute(ty);
}
std::optional<TypeId> instantiated = instantiation.substitute(ty);
if (instantiated.has_value())
return *instantiated;
else
@ -5633,8 +5690,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf,
TypeId instantiated = *maybeInstantiated;
TypeId target = follow(instantiated);
const TableType* tfTable = FFlag::LuauMetatableInstantiationCloneCheck ? getTableType(tf.type) : nullptr;
bool needsClone = follow(tf.type) == target || (FFlag::LuauMetatableInstantiationCloneCheck && tfTable != nullptr && tfTable == getTableType(target));
const TableType* tfTable = getTableType(tf.type);
bool needsClone = follow(tf.type) == target || (tfTable != nullptr && tfTable == getTableType(target));
bool shouldMutate = getTableType(tf.type);
TableType* ttv = getMutableTableType(target);

View file

@ -38,6 +38,59 @@ bool occursCheck(TypeId needle, TypeId haystack)
return false;
}
// FIXME: Property is quite large.
//
// Returning it on the stack like this isn't great. We'd like to just return a
// const Property*, but we mint a property of type any if the subject type is
// any.
std::optional<Property> findTableProperty(NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location)
{
if (get<AnyType>(ty))
return Property::rw(ty);
if (const TableType* tableType = getTableType(ty))
{
const auto& it = tableType->props.find(name);
if (it != tableType->props.end())
return it->second;
}
std::optional<TypeId> mtIndex = findMetatableEntry(builtinTypes, errors, ty, "__index", location);
int count = 0;
while (mtIndex)
{
TypeId index = follow(*mtIndex);
if (count >= 100)
return std::nullopt;
++count;
if (const auto& itt = getTableType(index))
{
const auto& fit = itt->props.find(name);
if (fit != itt->props.end())
return fit->second.type();
}
else if (const auto& itf = get<FunctionType>(index))
{
std::optional<TypeId> r = first(follow(itf->retTypes));
if (!r)
return builtinTypes->nilType;
else
return *r;
}
else if (get<AnyType>(index))
return builtinTypes->anyType;
else
errors.push_back(TypeError{location, GenericError{"__index should either be a function or table. Got " + toString(index)}});
mtIndex = findMetatableEntry(builtinTypes, errors, *mtIndex, "__index", location);
}
return std::nullopt;
}
std::optional<TypeId> findMetatableEntry(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location)
{

View file

@ -23,7 +23,7 @@ LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering, false)
LUAU_FASTFLAGVARIABLE(LuauUnifierShouldNotCopyError, false)
LUAU_FASTFLAG(LuauFixNormalizeCaching)
LUAU_FASTFLAGVARIABLE(LuauUnifierRecursionOnRestart, false)
namespace Luau
{
@ -580,28 +580,14 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
{
if (normalize)
{
if (FFlag::LuauFixNormalizeCaching)
{
// TODO: there are probably cheaper ways to check if any <: T.
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
// TODO: there are probably cheaper ways to check if any <: T.
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (!superNorm)
return reportError(location, NormalizationTooComplex{});
if (!superNorm)
return reportError(location, NormalizationTooComplex{});
if (!log.get<AnyType>(superNorm->tops))
failure = true;
}
else
{
// TODO: there are probably cheaper ways to check if any <: T.
const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy);
if (!superNorm)
return reportError(location, NormalizationTooComplex{});
if (!log.get<AnyType>(superNorm->tops))
failure = true;
}
if (!log.get<AnyType>(superNorm->tops))
failure = true;
}
else
failure = true;
@ -962,30 +948,15 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
// We deal with this by type normalization.
Unifier innerState = makeChildUnifier();
if (FFlag::LuauFixNormalizeCaching)
{
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm)
return reportError(location, NormalizationTooComplex{});
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
innerState.tryUnifyNormalizedTypes(
subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
else
innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
}
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm)
return reportError(location, NormalizationTooComplex{});
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
innerState.tryUnifyNormalizedTypes(
subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
else
{
const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy);
const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy);
if (!subNorm || !superNorm)
return reportError(location, NormalizationTooComplex{});
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
innerState.tryUnifyNormalizedTypes(
subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
else
innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
}
innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
if (!innerState.failure)
log.concat(std::move(innerState.log));
@ -999,30 +970,14 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
// It is possible that T <: A | B even though T </: A and T </:B
// for example boolean <: true | false.
// We deal with this by type normalization.
if (FFlag::LuauFixNormalizeCaching)
{
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm)
reportError(location, NormalizationTooComplex{});
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
tryUnifyNormalizedTypes(
subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
else
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
}
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm)
reportError(location, NormalizationTooComplex{});
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
else
{
const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy);
const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy);
if (!subNorm || !superNorm)
reportError(location, NormalizationTooComplex{});
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
tryUnifyNormalizedTypes(
subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
else
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
}
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
}
else if (!found)
{
@ -1125,24 +1080,12 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType*
// It is possible that A & B <: T even though A </: T and B </: T
// for example (string?) & ~nil <: string.
// We deal with this by type normalization.
if (FFlag::LuauFixNormalizeCaching)
{
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (subNorm && superNorm)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
else
reportError(location, NormalizationTooComplex{});
}
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (subNorm && superNorm)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
else
{
const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy);
const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy);
if (subNorm && superNorm)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
else
reportError(location, NormalizationTooComplex{});
}
reportError(location, NormalizationTooComplex{});
return;
}
@ -1192,24 +1135,12 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType*
// for example string? & number? <: nil.
// We deal with this by type normalization.
if (FFlag::LuauFixNormalizeCaching)
{
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (subNorm && superNorm)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
else
reportError(location, NormalizationTooComplex{});
}
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (subNorm && superNorm)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
else
{
const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy);
const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy);
if (subNorm && superNorm)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
else
reportError(location, NormalizationTooComplex{});
}
reportError(location, NormalizationTooComplex{});
}
else if (!found)
{
@ -2249,7 +2180,18 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection,
// If one of the types stopped being a table altogether, we need to restart from the top
if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty())
return tryUnify(subTy, superTy, false, isIntersection);
{
if (FFlag::LuauUnifierRecursionOnRestart)
{
RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit);
tryUnify(subTy, superTy, false, isIntersection);
return;
}
else
{
return tryUnify(subTy, superTy, false, isIntersection);
}
}
// Otherwise, restart only the table unification
TableType* newSuperTable = log.getMutable<TableType>(superTyNew);
@ -2328,7 +2270,18 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection,
// If one of the types stopped being a table altogether, we need to restart from the top
if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty())
return tryUnify(subTy, superTy, false, isIntersection);
{
if (FFlag::LuauUnifierRecursionOnRestart)
{
RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit);
tryUnify(subTy, superTy, false, isIntersection);
return;
}
else
{
return tryUnify(subTy, superTy, false, isIntersection);
}
}
// Recursive unification can change the txn log, and invalidate the old
// table. If we detect that this has happened, we start over, with the updated
@ -2712,32 +2665,16 @@ void Unifier::tryUnifyNegations(TypeId subTy, TypeId superTy)
if (!log.get<NegationType>(subTy) && !log.get<NegationType>(superTy))
ice("tryUnifyNegations superTy or subTy must be a negation type");
if (FFlag::LuauFixNormalizeCaching)
{
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm)
return reportError(location, NormalizationTooComplex{});
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm)
return reportError(location, NormalizationTooComplex{});
// T </: ~U iff T <: U
Unifier state = makeChildUnifier();
state.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "");
if (state.errors.empty())
reportError(location, TypeMismatch{superTy, subTy, mismatchContext()});
}
else
{
const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy);
const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy);
if (!subNorm || !superNorm)
return reportError(location, NormalizationTooComplex{});
// T </: ~U iff T <: U
Unifier state = makeChildUnifier();
state.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "");
if (state.errors.empty())
reportError(location, TypeMismatch{superTy, subTy, mismatchContext()});
}
// T </: ~U iff T <: U
Unifier state = makeChildUnifier();
state.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "");
if (state.errors.empty())
reportError(location, TypeMismatch{superTy, subTy, mismatchContext()});
}
static void queueTypePack(std::vector<TypeId>& queue, DenseHashSet<TypePackId>& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack)

View file

@ -158,12 +158,6 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy)
if (subFree || superFree)
return true;
if (auto subLocal = getMutable<LocalType>(subTy))
{
subLocal->domain = mkUnion(subLocal->domain, superTy);
expandedFreeTypes[subTy].push_back(superTy);
}
auto subFn = get<FunctionType>(subTy);
auto superFn = get<FunctionType>(superTy);
if (subFn && superFn)
@ -204,25 +198,21 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy)
auto subAny = get<AnyType>(subTy);
auto superAny = get<AnyType>(superTy);
if (subAny && superAny)
return true;
else if (subAny && superFn)
{
// If `any` is the subtype, then we can propagate that inward.
bool argResult = unify(superFn->argTypes, builtinTypes->anyTypePack);
bool retResult = unify(builtinTypes->anyTypePack, superFn->retTypes);
return argResult && retResult;
}
else if (subFn && superAny)
{
// If `any` is the supertype, then we can propagate that inward.
bool argResult = unify(builtinTypes->anyTypePack, subFn->argTypes);
bool retResult = unify(subFn->retTypes, builtinTypes->anyTypePack);
return argResult && retResult;
}
auto subTable = getMutable<TableType>(subTy);
auto superTable = get<TableType>(superTy);
if (subAny && superAny)
return true;
else if (subAny && superFn)
return unify(subAny, superFn);
else if (subFn && superAny)
return unify(subFn, superAny);
else if (subAny && superTable)
return unify(subAny, superTable);
else if (subTable && superAny)
return unify(subTable, superAny);
if (subTable && superTable)
{
// `boundTo` works like a bound type, and therefore we'd replace it
@ -451,7 +441,16 @@ bool Unifier2::unify(TableType* subTable, const TableType* superTable)
* an indexer, we therefore conclude that the unsealed table has the
* same indexer.
*/
subTable->indexer = *superTable->indexer;
TypeId indexType = superTable->indexer->indexType;
if (TypeId* subst = genericSubstitutions.find(indexType))
indexType = *subst;
TypeId indexResultType = superTable->indexer->indexResultType;
if (TypeId* subst = genericSubstitutions.find(indexResultType))
indexResultType = *subst;
subTable->indexer = TableIndexer{indexType, indexResultType};
}
return result;
@ -462,6 +461,62 @@ bool Unifier2::unify(const MetatableType* subMetatable, const MetatableType* sup
return unify(subMetatable->metatable, superMetatable->metatable) && unify(subMetatable->table, superMetatable->table);
}
bool Unifier2::unify(const AnyType* subAny, const FunctionType* superFn)
{
// If `any` is the subtype, then we can propagate that inward.
bool argResult = unify(superFn->argTypes, builtinTypes->anyTypePack);
bool retResult = unify(builtinTypes->anyTypePack, superFn->retTypes);
return argResult && retResult;
}
bool Unifier2::unify(const FunctionType* subFn, const AnyType* superAny)
{
// If `any` is the supertype, then we can propagate that inward.
bool argResult = unify(builtinTypes->anyTypePack, subFn->argTypes);
bool retResult = unify(subFn->retTypes, builtinTypes->anyTypePack);
return argResult && retResult;
}
bool Unifier2::unify(const AnyType* subAny, const TableType* superTable)
{
for (const auto& [propName, prop] : superTable->props)
{
if (prop.readTy)
unify(builtinTypes->anyType, *prop.readTy);
if (prop.writeTy)
unify(*prop.writeTy, builtinTypes->anyType);
}
if (superTable->indexer)
{
unify(builtinTypes->anyType, superTable->indexer->indexType);
unify(builtinTypes->anyType, superTable->indexer->indexResultType);
}
return true;
}
bool Unifier2::unify(const TableType* subTable, const AnyType* superAny)
{
for (const auto& [propName, prop] : subTable->props)
{
if (prop.readTy)
unify(*prop.readTy, builtinTypes->anyType);
if (prop.writeTy)
unify(builtinTypes->anyType, *prop.writeTy);
}
if (subTable->indexer)
{
unify(subTable->indexer->indexType, builtinTypes->anyType);
unify(subTable->indexer->indexResultType, builtinTypes->anyType);
}
return true;
}
// FIXME? This should probably return an ErrorVec or an optional<TypeError>
// rather than a boolean to signal an occurs check failure.
bool Unifier2::unify(TypePackId subTp, TypePackId superTp)
@ -596,6 +651,43 @@ struct FreeTypeSearcher : TypeVisitor
}
}
DenseHashSet<const void*> seenPositive{nullptr};
DenseHashSet<const void*> seenNegative{nullptr};
bool seenWithPolarity(const void* ty)
{
switch (polarity)
{
case Positive:
{
if (seenPositive.contains(ty))
return true;
seenPositive.insert(ty);
return false;
}
case Negative:
{
if (seenNegative.contains(ty))
return true;
seenNegative.insert(ty);
return false;
}
case Both:
{
if (seenPositive.contains(ty) && seenNegative.contains(ty))
return true;
seenPositive.insert(ty);
seenNegative.insert(ty);
return false;
}
}
return false;
}
// The keys in these maps are either TypeIds or TypePackIds. It's safe to
// mix them because we only use these pointers as unique keys. We never
// indirect them.
@ -604,12 +696,18 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypeId ty) override
{
if (seenWithPolarity(ty))
return false;
LUAU_ASSERT(ty);
return true;
}
bool visit(TypeId ty, const FreeType& ft) override
{
if (seenWithPolarity(ty))
return false;
if (!subsumes(scope, ft.scope))
return true;
@ -632,6 +730,9 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypeId ty, const TableType& tt) override
{
if (seenWithPolarity(ty))
return false;
if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope))
{
switch (polarity)
@ -675,6 +776,9 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypeId ty, const FunctionType& ft) override
{
if (seenWithPolarity(ty))
return false;
flip();
traverse(ft.argTypes);
flip();
@ -691,6 +795,9 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypePackId tp, const FreeTypePack& ftp) override
{
if (seenWithPolarity(tp))
return false;
if (!subsumes(scope, ftp.scope))
return true;
@ -712,315 +819,6 @@ struct FreeTypeSearcher : TypeVisitor
}
};
struct MutatingGeneralizer : TypeOnceVisitor
{
NotNull<BuiltinTypes> builtinTypes;
NotNull<Scope> scope;
DenseHashMap<const void*, size_t> positiveTypes;
DenseHashMap<const void*, size_t> negativeTypes;
std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks;
bool isWithinFunction = false;
MutatingGeneralizer(NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope, DenseHashMap<const void*, size_t> positiveTypes,
DenseHashMap<const void*, size_t> negativeTypes)
: TypeOnceVisitor(/* skipBoundTypes */ true)
, builtinTypes(builtinTypes)
, scope(scope)
, positiveTypes(std::move(positiveTypes))
, negativeTypes(std::move(negativeTypes))
{
}
static void replace(DenseHashSet<TypeId>& seen, TypeId haystack, TypeId needle, TypeId replacement)
{
haystack = follow(haystack);
if (seen.find(haystack))
return;
seen.insert(haystack);
if (UnionType* ut = getMutable<UnionType>(haystack))
{
for (auto iter = ut->options.begin(); iter != ut->options.end();)
{
// FIXME: I bet this function has reentrancy problems
TypeId option = follow(*iter);
if (option == needle && get<NeverType>(replacement))
{
iter = ut->options.erase(iter);
continue;
}
if (option == needle)
{
*iter = replacement;
iter++;
continue;
}
// advance the iterator, nothing after this can use it.
iter++;
if (seen.find(option))
continue;
seen.insert(option);
if (get<UnionType>(option))
replace(seen, option, needle, haystack);
else if (get<IntersectionType>(option))
replace(seen, option, needle, haystack);
}
if (ut->options.size() == 1)
{
TypeId onlyType = ut->options[0];
LUAU_ASSERT(onlyType != haystack);
emplaceType<BoundType>(asMutable(haystack), onlyType);
}
return;
}
if (IntersectionType* it = getMutable<IntersectionType>(needle))
{
for (auto iter = it->parts.begin(); iter != it->parts.end();)
{
// FIXME: I bet this function has reentrancy problems
TypeId part = follow(*iter);
if (part == needle && get<UnknownType>(replacement))
{
iter = it->parts.erase(iter);
continue;
}
if (part == needle)
{
*iter = replacement;
iter++;
continue;
}
// advance the iterator, nothing after this can use it.
iter++;
if (seen.find(part))
continue;
seen.insert(part);
if (get<UnionType>(part))
replace(seen, part, needle, haystack);
else if (get<IntersectionType>(part))
replace(seen, part, needle, haystack);
}
if (it->parts.size() == 1)
{
TypeId onlyType = it->parts[0];
LUAU_ASSERT(onlyType != needle);
emplaceType<BoundType>(asMutable(needle), onlyType);
}
return;
}
}
bool visit(TypeId ty, const FunctionType& ft) override
{
const bool oldValue = isWithinFunction;
isWithinFunction = true;
traverse(ft.argTypes);
traverse(ft.retTypes);
isWithinFunction = oldValue;
return false;
}
bool visit(TypeId ty, const FreeType&) override
{
const FreeType* ft = get<FreeType>(ty);
LUAU_ASSERT(ft);
traverse(ft->lowerBound);
traverse(ft->upperBound);
// It is possible for the above traverse() calls to cause ty to be
// transmuted. We must reacquire ft if this happens.
ty = follow(ty);
ft = get<FreeType>(ty);
if (!ft)
return false;
const size_t positiveCount = getCount(positiveTypes, ty);
const size_t negativeCount = getCount(negativeTypes, ty);
if (!positiveCount && !negativeCount)
return false;
const bool hasLowerBound = !get<NeverType>(follow(ft->lowerBound));
const bool hasUpperBound = !get<UnknownType>(follow(ft->upperBound));
DenseHashSet<TypeId> seen{nullptr};
seen.insert(ty);
if (!hasLowerBound && !hasUpperBound)
{
if (!isWithinFunction || (positiveCount + negativeCount == 1))
emplaceType<BoundType>(asMutable(ty), builtinTypes->unknownType);
else
{
emplaceType<GenericType>(asMutable(ty), scope);
generics.push_back(ty);
}
}
// It is possible that this free type has other free types in its upper
// or lower bounds. If this is the case, we must replace those
// references with never (for the lower bound) or unknown (for the upper
// bound).
//
// If we do not do this, we get tautological bounds like a <: a <: unknown.
else if (positiveCount && !hasUpperBound)
{
TypeId lb = follow(ft->lowerBound);
if (FreeType* lowerFree = getMutable<FreeType>(lb); lowerFree && lowerFree->upperBound == ty)
lowerFree->upperBound = builtinTypes->unknownType;
else
{
DenseHashSet<TypeId> replaceSeen{nullptr};
replace(replaceSeen, lb, ty, builtinTypes->unknownType);
}
if (lb != ty)
emplaceType<BoundType>(asMutable(ty), lb);
else if (!isWithinFunction || (positiveCount + negativeCount == 1))
emplaceType<BoundType>(asMutable(ty), builtinTypes->unknownType);
else
{
// if the lower bound is the type in question, we don't actually have a lower bound.
emplaceType<GenericType>(asMutable(ty), scope);
generics.push_back(ty);
}
}
else
{
TypeId ub = follow(ft->upperBound);
if (FreeType* upperFree = getMutable<FreeType>(ub); upperFree && upperFree->lowerBound == ty)
upperFree->lowerBound = builtinTypes->neverType;
else
{
DenseHashSet<TypeId> replaceSeen{nullptr};
replace(replaceSeen, ub, ty, builtinTypes->neverType);
}
if (ub != ty)
emplaceType<BoundType>(asMutable(ty), ub);
else if (!isWithinFunction || (positiveCount + negativeCount == 1))
emplaceType<BoundType>(asMutable(ty), builtinTypes->unknownType);
else
{
// if the upper bound is the type in question, we don't actually have an upper bound.
emplaceType<GenericType>(asMutable(ty), scope);
generics.push_back(ty);
}
}
return false;
}
size_t getCount(const DenseHashMap<const void*, size_t>& map, const void* ty)
{
if (const size_t* count = map.find(ty))
return *count;
else
return 0;
}
bool visit(TypeId ty, const TableType&) override
{
const size_t positiveCount = getCount(positiveTypes, ty);
const size_t negativeCount = getCount(negativeTypes, ty);
// FIXME: Free tables should probably just be replaced by upper bounds on free types.
//
// eg never <: 'a <: {x: number} & {z: boolean}
if (!positiveCount && !negativeCount)
return true;
TableType* tt = getMutable<TableType>(ty);
LUAU_ASSERT(tt);
tt->state = TableState::Sealed;
return true;
}
bool visit(TypePackId tp, const FreeTypePack& ftp) override
{
if (!subsumes(scope, ftp.scope))
return true;
tp = follow(tp);
const size_t positiveCount = getCount(positiveTypes, tp);
const size_t negativeCount = getCount(negativeTypes, tp);
if (1 == positiveCount + negativeCount)
emplaceTypePack<BoundTypePack>(asMutable(tp), builtinTypes->unknownTypePack);
else
{
emplaceTypePack<GenericTypePack>(asMutable(tp), scope);
genericPacks.push_back(tp);
}
return true;
}
};
std::optional<TypeId> Unifier2::generalize(TypeId ty)
{
ty = follow(ty);
if (ty->owningArena != arena || ty->persistent)
return ty;
if (const FunctionType* ft = get<FunctionType>(ty); ft && (!ft->generics.empty() || !ft->genericPacks.empty()))
return ty;
FreeTypeSearcher fts{scope};
fts.traverse(ty);
MutatingGeneralizer gen{builtinTypes, scope, std::move(fts.positiveTypes), std::move(fts.negativeTypes)};
gen.traverse(ty);
/* MutatingGeneralizer mutates types in place, so it is possible that ty has
* been transmuted to a BoundType. We must follow it again and verify that
* we are allowed to mutate it before we attach generics to it.
*/
ty = follow(ty);
if (ty->owningArena != arena || ty->persistent)
return ty;
FunctionType* ftv = getMutable<FunctionType>(ty);
if (ftv)
{
ftv->generics = std::move(gen.generics);
ftv->genericPacks = std::move(gen.genericPacks);
}
return ty;
}
TypeId Unifier2::mkUnion(TypeId left, TypeId right)
{
left = follow(left);

View file

@ -60,6 +60,8 @@ class AstStat;
class AstStatBlock;
class AstExpr;
class AstTypePack;
class AstAttr;
class AstExprTable;
struct AstLocal
{
@ -172,6 +174,10 @@ public:
{
return nullptr;
}
virtual AstAttr* asAttr()
{
return nullptr;
}
template<typename T>
bool is() const
@ -193,6 +199,29 @@ public:
Location location;
};
class AstAttr : public AstNode
{
public:
LUAU_RTTI(AstAttr)
enum Type
{
Checked,
Native,
};
AstAttr(const Location& location, Type type);
AstAttr* asAttr() override
{
return this;
}
void visit(AstVisitor* visitor) override;
Type type;
};
class AstExpr : public AstNode
{
public:
@ -384,13 +413,17 @@ class AstExprFunction : public AstExpr
public:
LUAU_RTTI(AstExprFunction)
AstExprFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
AstLocal* self, const AstArray<AstLocal*>& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth,
const AstName& debugname, const std::optional<AstTypeList>& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr,
AstExprFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, AstLocal* self, const AstArray<AstLocal*>& args, bool vararg,
const Location& varargLocation, AstStatBlock* body, size_t functionDepth, const AstName& debugname,
const std::optional<AstTypeList>& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr,
const std::optional<Location>& argLocation = std::nullopt);
void visit(AstVisitor* visitor) override;
bool hasNativeAttribute() const;
AstArray<AstAttr*> attributes;
AstArray<AstGenericType> generics;
AstArray<AstGenericTypePack> genericPacks;
AstLocal* self;
@ -793,11 +826,12 @@ class AstStatDeclareGlobal : public AstStat
public:
LUAU_RTTI(AstStatDeclareGlobal)
AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type);
AstStatDeclareGlobal(const Location& location, const AstName& name, const Location& nameLocation, AstType* type);
void visit(AstVisitor* visitor) override;
AstName name;
Location nameLocation;
AstType* type;
};
@ -806,31 +840,38 @@ class AstStatDeclareFunction : public AstStat
public:
LUAU_RTTI(AstStatDeclareFunction)
AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames,
const AstTypeList& retTypes);
AstStatDeclareFunction(const Location& location, const AstName& name, const Location& nameLocation, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames, bool vararg,
const Location& varargLocation, const AstTypeList& retTypes);
AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames,
const AstTypeList& retTypes, bool checkedFunction);
AstStatDeclareFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstName& name, const Location& nameLocation,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames, bool vararg, const Location& varargLocation, const AstTypeList& retTypes);
void visit(AstVisitor* visitor) override;
bool isCheckedFunction() const;
AstArray<AstAttr*> attributes;
AstName name;
Location nameLocation;
AstArray<AstGenericType> generics;
AstArray<AstGenericTypePack> genericPacks;
AstTypeList params;
AstArray<AstArgumentName> paramNames;
bool vararg = false;
Location varargLocation;
AstTypeList retTypes;
bool checkedFunction;
};
struct AstDeclaredClassProp
{
AstName name;
Location nameLocation;
AstType* ty = nullptr;
bool isMethod = false;
Location location;
};
enum class AstTableAccess
@ -936,17 +977,20 @@ public:
AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes);
AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes, bool checkedFunction);
AstTypeFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes);
void visit(AstVisitor* visitor) override;
bool isCheckedFunction() const;
AstArray<AstAttr*> attributes;
AstArray<AstGenericType> generics;
AstArray<AstGenericTypePack> genericPacks;
AstTypeList argTypes;
AstArray<std::optional<AstArgumentName>> argNames;
AstTypeList returnTypes;
bool checkedFunction;
};
class AstTypeTypeof : public AstType
@ -1105,6 +1149,11 @@ public:
return true;
}
virtual bool visit(class AstAttr* node)
{
return visit(static_cast<AstNode*>(node));
}
virtual bool visit(class AstExpr* node)
{
return visit(static_cast<AstNode*>(node));

View file

@ -87,6 +87,8 @@ struct Lexeme
Comment,
BlockComment,
Attribute,
BrokenString,
BrokenComment,
BrokenUnicode,
@ -115,14 +117,20 @@ struct Lexeme
ReservedTrue,
ReservedUntil,
ReservedWhile,
ReservedChecked,
Reserved_END
};
Type type;
Location location;
// Field declared here, before the union, to ensure that Lexeme size is 32 bytes.
private:
// length is used to extract a slice from the input buffer.
// This field is only valid for certain lexeme types which don't duplicate portions of input
// but instead store a pointer to a location in the input buffer and the length of lexeme.
unsigned int length;
public:
union
{
const char* data; // String, Number, Comment
@ -135,9 +143,13 @@ struct Lexeme
Lexeme(const Location& location, Type type, const char* data, size_t size);
Lexeme(const Location& location, Type type, const char* name);
unsigned int getLength() const;
std::string toString() const;
};
static_assert(sizeof(Lexeme) <= 32, "Size of `Lexeme` struct should be up to 32 bytes.");
class AstNameTable
{
public:

View file

@ -82,8 +82,8 @@ private:
// if exp then block {elseif exp then block} [else block] end |
// for Name `=' exp `,' exp [`,' exp] do block end |
// for namelist in explist do block end |
// function funcname funcbody |
// local function Name funcbody |
// [attributes] function funcname funcbody |
// [attributes] local function Name funcbody |
// local namelist [`=' explist]
// laststat ::= return [explist] | break
AstStat* parseStat();
@ -114,11 +114,25 @@ private:
AstExpr* parseFunctionName(Location start, bool& hasself, AstName& debugname);
// function funcname funcbody
AstStat* parseFunctionStat();
LUAU_FORCEINLINE AstStat* parseFunctionStat(const AstArray<AstAttr*>& attributes = {nullptr, 0});
std::pair<bool, AstAttr::Type> validateAttribute(const char* attributeName, const TempVector<AstAttr*>& attributes);
// attribute ::= '@' NAME
void parseAttribute(TempVector<AstAttr*>& attribute);
// attributes ::= {attribute}
AstArray<AstAttr*> parseAttributes();
// attributes local function Name funcbody
// attributes function funcname funcbody
// attributes `declare function' Name`(' [parlist] `)' [`:` Type]
// declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}'
AstStat* parseAttributeStat();
// local function Name funcbody |
// local namelist [`=' explist]
AstStat* parseLocal();
AstStat* parseLocal(const AstArray<AstAttr*>& attributes);
// return [explist]
AstStat* parseReturn();
@ -130,7 +144,7 @@ private:
// `declare global' Name: Type |
// `declare function' Name`(' [parlist] `)' [`:` Type]
AstStat* parseDeclaration(const Location& start);
AstStat* parseDeclaration(const Location& start, const AstArray<AstAttr*>& attributes);
// varlist `=' explist
AstStat* parseAssignment(AstExpr* initial);
@ -143,7 +157,7 @@ private:
// funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` Type]
// funcbody ::= funcbodyhead block end
std::pair<AstExprFunction*, AstLocal*> parseFunctionBody(
bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName);
bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray<AstAttr*>& attributes);
// explist ::= {exp `,'} exp
void parseExprList(TempVector<AstExpr*>& result);
@ -176,10 +190,10 @@ private:
AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional<Location> accessLocation);
AstTypeOrPack parseFunctionType(bool allowPack, bool isCheckedFunction = false);
AstType* parseFunctionTypeTail(const Lexeme& begin, AstArray<AstGenericType> generics, AstArray<AstGenericTypePack> genericPacks,
AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames, AstTypePack* varargAnnotation,
bool isCheckedFunction = false);
AstTypeOrPack parseFunctionType(bool allowPack, const AstArray<AstAttr*>& attributes);
AstType* parseFunctionTypeTail(const Lexeme& begin, const AstArray<AstAttr*>& attributes, AstArray<AstGenericType> generics,
AstArray<AstGenericTypePack> genericPacks, AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames,
AstTypePack* varargAnnotation);
AstType* parseTableType(bool inDeclarationContext = false);
AstTypeOrPack parseSimpleType(bool allowPack, bool inDeclarationContext = false);
@ -220,7 +234,7 @@ private:
// asexp -> simpleexp [`::' Type]
AstExpr* parseAssertionExpr();
// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp
// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | [attributes] FUNCTION body | primaryexp
AstExpr* parseSimpleExpr();
// args ::= `(' [explist] `)' | tableconstructor | String
@ -393,6 +407,7 @@ private:
std::vector<unsigned int> matchRecoveryStopOnToken;
std::vector<AstAttr*> scratchAttr;
std::vector<AstStat*> scratchStat;
std::vector<AstArray<char>> scratchString;
std::vector<AstExpr*> scratchExpr;

View file

@ -134,6 +134,14 @@ struct ThreadContext
static constexpr size_t kEventFlushLimit = 8192;
};
using ThreadContextProvider = ThreadContext& (*)();
inline ThreadContextProvider& threadContextProvider()
{
static ThreadContextProvider handler = nullptr;
return handler;
}
ThreadContext& getThreadContext();
struct Scope

View file

@ -3,6 +3,8 @@
#include "Luau/Common.h"
LUAU_FASTFLAG(LuauAttributeSyntax);
LUAU_FASTFLAG(LuauNativeAttribute);
namespace Luau
{
@ -16,6 +18,17 @@ static void visitTypeList(AstVisitor* visitor, const AstTypeList& list)
list.tailType->visit(visitor);
}
AstAttr::AstAttr(const Location& location, Type type)
: AstNode(ClassIndex(), location)
, type(type)
{
}
void AstAttr::visit(AstVisitor* visitor)
{
visitor->visit(this);
}
int gAstRttiIndex = 0;
AstExprGroup::AstExprGroup(const Location& location, AstExpr* expr)
@ -161,11 +174,12 @@ void AstExprIndexExpr::visit(AstVisitor* visitor)
}
}
AstExprFunction::AstExprFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
AstLocal* self, const AstArray<AstLocal*>& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth,
const AstName& debugname, const std::optional<AstTypeList>& returnAnnotation, AstTypePack* varargAnnotation,
const std::optional<Location>& argLocation)
AstExprFunction::AstExprFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, AstLocal* self, const AstArray<AstLocal*>& args, bool vararg, const Location& varargLocation,
AstStatBlock* body, size_t functionDepth, const AstName& debugname, const std::optional<AstTypeList>& returnAnnotation,
AstTypePack* varargAnnotation, const std::optional<Location>& argLocation)
: AstExpr(ClassIndex(), location)
, attributes(attributes)
, generics(generics)
, genericPacks(genericPacks)
, self(self)
@ -201,6 +215,18 @@ void AstExprFunction::visit(AstVisitor* visitor)
}
}
bool AstExprFunction::hasNativeAttribute() const
{
LUAU_ASSERT(FFlag::LuauNativeAttribute);
for (const auto attribute : attributes)
{
if (attribute->type == AstAttr::Type::Native)
return true;
}
return false;
}
AstExprTable::AstExprTable(const Location& location, const AstArray<Item>& items)
: AstExpr(ClassIndex(), location)
, items(items)
@ -679,9 +705,10 @@ void AstStatTypeAlias::visit(AstVisitor* visitor)
}
}
AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type)
AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, const Location& nameLocation, AstType* type)
: AstStat(ClassIndex(), location)
, name(name)
, nameLocation(nameLocation)
, type(type)
{
}
@ -692,31 +719,37 @@ void AstStatDeclareGlobal::visit(AstVisitor* visitor)
type->visit(visitor);
}
AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames,
const AstTypeList& retTypes)
AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const Location& nameLocation,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames, bool vararg, const Location& varargLocation, const AstTypeList& retTypes)
: AstStat(ClassIndex(), location)
, attributes()
, name(name)
, nameLocation(nameLocation)
, generics(generics)
, genericPacks(genericPacks)
, params(params)
, paramNames(paramNames)
, vararg(vararg)
, varargLocation(varargLocation)
, retTypes(retTypes)
, checkedFunction(false)
{
}
AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames,
const AstTypeList& retTypes, bool checkedFunction)
AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstName& name,
const Location& nameLocation, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& params, const AstArray<AstArgumentName>& paramNames, bool vararg, const Location& varargLocation, const AstTypeList& retTypes)
: AstStat(ClassIndex(), location)
, attributes(attributes)
, name(name)
, nameLocation(nameLocation)
, generics(generics)
, genericPacks(genericPacks)
, params(params)
, paramNames(paramNames)
, vararg(vararg)
, varargLocation(varargLocation)
, retTypes(retTypes)
, checkedFunction(checkedFunction)
{
}
@ -729,6 +762,19 @@ void AstStatDeclareFunction::visit(AstVisitor* visitor)
}
}
bool AstStatDeclareFunction::isCheckedFunction() const
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
for (const AstAttr* attr : attributes)
{
if (attr->type == AstAttr::Type::Checked)
return true;
}
return false;
}
AstStatDeclareClass::AstStatDeclareClass(const Location& location, const AstName& name, std::optional<AstName> superName,
const AstArray<AstDeclaredClassProp>& props, AstTableIndexer* indexer)
: AstStat(ClassIndex(), location)
@ -820,25 +866,26 @@ void AstTypeTable::visit(AstVisitor* visitor)
AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes)
: AstType(ClassIndex(), location)
, attributes()
, generics(generics)
, genericPacks(genericPacks)
, argTypes(argTypes)
, argNames(argNames)
, returnTypes(returnTypes)
, checkedFunction(false)
{
LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size);
}
AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes, bool checkedFunction)
AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes)
: AstType(ClassIndex(), location)
, attributes(attributes)
, generics(generics)
, genericPacks(genericPacks)
, argTypes(argTypes)
, argNames(argNames)
, returnTypes(returnTypes)
, checkedFunction(checkedFunction)
{
LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size);
}
@ -852,6 +899,19 @@ void AstTypeFunction::visit(AstVisitor* visitor)
}
}
bool AstTypeFunction::isCheckedFunction() const
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
for (const AstAttr* attr : attributes)
{
if (attr->type == AstAttr::Type::Checked)
return true;
}
return false;
}
AstTypeTypeof::AstTypeTypeof(const Location& location, AstExpr* expr)
: AstType(ClassIndex(), location)
, expr(expr)

View file

@ -8,7 +8,7 @@
#include <limits.h>
LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false)
LUAU_FASTFLAGVARIABLE(LuauCheckedFunctionSyntax, false)
LUAU_FASTFLAGVARIABLE(LuauAttributeSyntax, false)
namespace Luau
{
@ -103,11 +103,19 @@ Lexeme::Lexeme(const Location& location, Type type, const char* name)
, length(0)
, name(name)
{
LUAU_ASSERT(type == Name || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END));
LUAU_ASSERT(type == Name || type == Attribute || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END));
}
unsigned int Lexeme::getLength() const
{
LUAU_ASSERT(type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd ||
type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment);
return length;
}
static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", "not", "or",
"repeat", "return", "then", "true", "until", "while", "@checked"};
"repeat", "return", "then", "true", "until", "while"};
std::string Lexeme::toString() const
{
@ -192,6 +200,10 @@ std::string Lexeme::toString() const
case Comment:
return "comment";
case Attribute:
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
return name ? format("'%s'", name) : "attribute";
case BrokenString:
return "malformed string";
@ -279,7 +291,7 @@ std::pair<AstName, Lexeme::Type> AstNameTable::getOrAddWithType(const char* name
nameData[length] = 0;
const_cast<Entry&>(entry).value = AstName(nameData);
const_cast<Entry&>(entry).type = Lexeme::Name;
const_cast<Entry&>(entry).type = (name[0] == '@' ? Lexeme::Attribute : Lexeme::Name);
return std::make_pair(entry.value, entry.type);
}
@ -995,16 +1007,10 @@ Lexeme Lexer::readNext()
}
case '@':
{
if (FFlag::LuauCheckedFunctionSyntax)
if (FFlag::LuauAttributeSyntax)
{
// We're trying to lex the token @checked
LUAU_ASSERT(peekch() == '@');
std::pair<AstName, Lexeme::Type> maybeChecked = readName();
if (maybeChecked.second != Lexeme::ReservedChecked)
return Lexeme(Location(start, position()), Lexeme::Error);
return Lexeme(Location(start, position()), maybeChecked.second, maybeChecked.first.value);
std::pair<AstName, Lexeme::Type> attribute = readName();
return Lexeme(Location(start, position()), Lexeme::Attribute, attribute.first.value);
}
}
default:

View file

@ -16,13 +16,24 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
// Warning: If you are introducing new syntax, ensure that it is behind a separate
// flag so that we don't break production games by reverting syntax changes.
// See docs/SyntaxChanges.md for an explanation.
LUAU_FASTFLAG(LuauCheckedFunctionSyntax)
LUAU_FASTFLAGVARIABLE(LuauReadWritePropertySyntax, false)
LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false)
LUAU_FASTFLAG(LuauAttributeSyntax)
LUAU_FASTFLAGVARIABLE(LuauLeadingBarAndAmpersand2, false)
LUAU_FASTFLAGVARIABLE(LuauNativeAttribute, false)
LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr, false)
LUAU_FASTFLAGVARIABLE(LuauDeclarationExtraPropData, false)
namespace Luau
{
struct AttributeEntry
{
const char* name;
AstAttr::Type type;
};
AttributeEntry kAttributeEntries[] = {{"@checked", AstAttr::Type::Checked}, {"@native", AstAttr::Type::Native}, {nullptr, AstAttr::Type::Checked}};
ParseError::ParseError(const Location& location, const std::string& message)
: location(location)
, message(message)
@ -281,7 +292,9 @@ AstStatBlock* Parser::parseBlockNoScope()
// for binding `=' exp `,' exp [`,' exp] do block end |
// for namelist in explist do block end |
// function funcname funcbody |
// attributes function funcname funcbody |
// local function Name funcbody |
// local attributes function Name funcbody |
// local namelist [`=' explist]
// laststat ::= return [explist] | break
AstStat* Parser::parseStat()
@ -300,13 +313,16 @@ AstStat* Parser::parseStat()
case Lexeme::ReservedRepeat:
return parseRepeat();
case Lexeme::ReservedFunction:
return parseFunctionStat();
return parseFunctionStat(AstArray<AstAttr*>({nullptr, 0}));
case Lexeme::ReservedLocal:
return parseLocal();
return parseLocal(AstArray<AstAttr*>({nullptr, 0}));
case Lexeme::ReservedReturn:
return parseReturn();
case Lexeme::ReservedBreak:
return parseBreak();
case Lexeme::Attribute:
if (FFlag::LuauAttributeSyntax)
return parseAttributeStat();
default:;
}
@ -344,7 +360,7 @@ AstStat* Parser::parseStat()
if (options.allowDeclarationSyntax)
{
if (ident == "declare")
return parseDeclaration(expr->location);
return parseDeclaration(expr->location, AstArray<AstAttr*>({nullptr, 0}));
}
// skip unexpected symbol if lexer couldn't advance at all (statements are parsed in a loop)
@ -653,7 +669,7 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug
}
// function funcname funcbody
AstStat* Parser::parseFunctionStat()
AstStat* Parser::parseFunctionStat(const AstArray<AstAttr*>& attributes)
{
Location start = lexer.current().location;
@ -666,16 +682,129 @@ AstStat* Parser::parseFunctionStat()
matchRecoveryStopOnToken[Lexeme::ReservedEnd]++;
AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr).first;
AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr, attributes).first;
matchRecoveryStopOnToken[Lexeme::ReservedEnd]--;
return allocator.alloc<AstStatFunction>(Location(start, body->location), expr, body);
}
std::pair<bool, AstAttr::Type> Parser::validateAttribute(const char* attributeName, const TempVector<AstAttr*>& attributes)
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
AstAttr::Type type;
// check if the attribute name is valid
bool found = false;
for (int i = 0; kAttributeEntries[i].name; ++i)
{
found = !strcmp(attributeName, kAttributeEntries[i].name);
if (found)
{
type = kAttributeEntries[i].type;
if (!FFlag::LuauNativeAttribute && type == AstAttr::Type::Native)
found = false;
break;
}
}
if (!found)
{
if (strlen(attributeName) == 1)
report(lexer.current().location, "Attribute name is missing");
else
report(lexer.current().location, "Invalid attribute '%s'", attributeName);
}
else
{
// check that attribute is not duplicated
for (const AstAttr* attr : attributes)
{
if (attr->type == type)
{
report(lexer.current().location, "Cannot duplicate attribute '%s'", attributeName);
}
}
}
return {found, type};
}
// attribute ::= '@' NAME
void Parser::parseAttribute(TempVector<AstAttr*>& attributes)
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
LUAU_ASSERT(lexer.current().type == Lexeme::Type::Attribute);
Location loc = lexer.current().location;
const char* name = lexer.current().name;
const auto [found, type] = validateAttribute(name, attributes);
nextLexeme();
if (found)
attributes.push_back(allocator.alloc<AstAttr>(loc, type));
}
// attributes ::= {attribute}
AstArray<AstAttr*> Parser::parseAttributes()
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
Lexeme::Type type = lexer.current().type;
LUAU_ASSERT(type == Lexeme::Attribute);
TempVector<AstAttr*> attributes(scratchAttr);
while (lexer.current().type == Lexeme::Attribute)
parseAttribute(attributes);
return copy(attributes);
}
// attributes local function Name funcbody
// attributes function funcname funcbody
// attributes `declare function' Name`(' [parlist] `)' [`:` Type]
// declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}'
AstStat* Parser::parseAttributeStat()
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
AstArray<AstAttr*> attributes = parseAttributes();
Lexeme::Type type = lexer.current().type;
switch (type)
{
case Lexeme::Type::ReservedFunction:
return parseFunctionStat(attributes);
case Lexeme::Type::ReservedLocal:
return parseLocal(attributes);
case Lexeme::Type::Name:
if (options.allowDeclarationSyntax && !strcmp("declare", lexer.current().data))
{
AstExpr* expr = parsePrimaryExpr(/* asStatement= */ true);
return parseDeclaration(expr->location, attributes);
}
default:
return reportStatError(lexer.current().location, {}, {},
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got %s intead",
lexer.current().toString().c_str());
}
}
// local function Name funcbody |
// local bindinglist [`=' explist]
AstStat* Parser::parseLocal()
AstStat* Parser::parseLocal(const AstArray<AstAttr*>& attributes)
{
Location start = lexer.current().location;
@ -695,7 +824,7 @@ AstStat* Parser::parseLocal()
matchRecoveryStopOnToken[Lexeme::ReservedEnd]++;
auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name);
auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name, attributes);
matchRecoveryStopOnToken[Lexeme::ReservedEnd]--;
@ -705,6 +834,12 @@ AstStat* Parser::parseLocal()
}
else
{
if (FFlag::LuauAttributeSyntax && attributes.size != 0)
{
return reportStatError(lexer.current().location, {}, {}, "Expected 'function' after local declaration with attribute, but got %s intead",
lexer.current().toString().c_str());
}
matchRecoveryStopOnToken['=']++;
TempVector<Binding> names(scratchBinding);
@ -775,8 +910,16 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported)
AstDeclaredClassProp Parser::parseDeclaredClassMethod()
{
Location start;
if (FFlag::LuauDeclarationExtraPropData)
start = lexer.current().location;
nextLexeme();
Location start = lexer.current().location;
if (!FFlag::LuauDeclarationExtraPropData)
start = lexer.current().location;
Name fnName = parseName("function name");
// TODO: generic method declarations CLI-39909
@ -801,15 +944,15 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod()
expectMatchAndConsume(')', matchParen);
AstTypeList retTypes = parseOptionalReturnType().value_or(AstTypeList{copy<AstType*>(nullptr, 0), nullptr});
Location end = lexer.current().location;
Location end = FFlag::LuauDeclarationExtraPropData ? lexer.previousLocation() : lexer.current().location;
TempVector<AstType*> vars(scratchType);
TempVector<std::optional<AstArgumentName>> varNames(scratchOptArgName);
if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr)
{
return AstDeclaredClassProp{
fnName.name, reportTypeError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true};
return AstDeclaredClassProp{fnName.name, FFlag::LuauDeclarationExtraPropData ? fnName.location : Location{},
reportTypeError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true};
}
// Skip the first index.
@ -829,21 +972,21 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod()
AstType* fnType = allocator.alloc<AstTypeFunction>(
Location(start, end), generics, genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes);
return AstDeclaredClassProp{fnName.name, fnType, true};
return AstDeclaredClassProp{fnName.name, FFlag::LuauDeclarationExtraPropData ? fnName.location : Location{}, fnType, true,
FFlag::LuauDeclarationExtraPropData ? Location(start, end) : Location{}};
}
AstStat* Parser::parseDeclaration(const Location& start)
AstStat* Parser::parseDeclaration(const Location& start, const AstArray<AstAttr*>& attributes)
{
// `declare` token is already parsed at this point
if (FFlag::LuauAttributeSyntax && (attributes.size != 0) && (lexer.current().type != Lexeme::ReservedFunction))
return reportStatError(lexer.current().location, {}, {}, "Expected a function type declaration after attribute, but got %s intead",
lexer.current().toString().c_str());
if (lexer.current().type == Lexeme::ReservedFunction)
{
nextLexeme();
bool checkedFunction = false;
if (FFlag::LuauCheckedFunctionSyntax && lexer.current().type == Lexeme::ReservedChecked)
{
checkedFunction = true;
nextLexeme();
}
Name globalName = parseName("global function name");
auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false);
@ -881,8 +1024,12 @@ AstStat* Parser::parseDeclaration(const Location& start)
if (vararg && !varargAnnotation)
return reportStatError(Location(start, end), {}, {}, "All declaration parameters must be annotated");
return allocator.alloc<AstStatDeclareFunction>(Location(start, end), globalName.name, generics, genericPacks,
AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes, checkedFunction);
if (FFlag::LuauDeclarationExtraPropData)
return allocator.alloc<AstStatDeclareFunction>(Location(start, end), attributes, globalName.name, globalName.location, generics,
genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), vararg, varargLocation, retTypes);
else
return allocator.alloc<AstStatDeclareFunction>(Location(start, end), attributes, globalName.name, Location{}, generics, genericPacks,
AstTypeList{copy(vars), varargAnnotation}, copy(varNames), false, Location{}, retTypes);
}
else if (AstName(lexer.current().name) == "class")
{
@ -912,19 +1059,42 @@ AstStat* Parser::parseDeclaration(const Location& start)
const Lexeme begin = lexer.current();
nextLexeme(); // [
std::optional<AstArray<char>> chars = parseCharArray();
if (FFlag::LuauDeclarationExtraPropData)
{
const Location nameBegin = lexer.current().location;
std::optional<AstArray<char>> chars = parseCharArray();
expectMatchAndConsume(']', begin);
expectAndConsume(':', "property type annotation");
AstType* type = parseType();
const Location nameEnd = lexer.previousLocation();
// since AstName contains a char*, it can't contain null
bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size);
expectMatchAndConsume(']', begin);
expectAndConsume(':', "property type annotation");
AstType* type = parseType();
if (chars && !containsNull)
props.push_back(AstDeclaredClassProp{AstName(chars->data), type, false});
// since AstName contains a char*, it can't contain null
bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size);
if (chars && !containsNull)
props.push_back(AstDeclaredClassProp{
AstName(chars->data), Location(nameBegin, nameEnd), type, false, Location(begin.location, lexer.previousLocation())});
else
report(begin.location, "String literal contains malformed escape sequence or \\0");
}
else
report(begin.location, "String literal contains malformed escape sequence or \\0");
{
std::optional<AstArray<char>> chars = parseCharArray();
expectMatchAndConsume(']', begin);
expectAndConsume(':', "property type annotation");
AstType* type = parseType();
// since AstName contains a char*, it can't contain null
bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size);
if (chars && !containsNull)
props.push_back(AstDeclaredClassProp{AstName(chars->data), Location{}, type, false});
else
report(begin.location, "String literal contains malformed escape sequence or \\0");
}
}
else if (lexer.current().type == '[')
{
@ -942,12 +1112,21 @@ AstStat* Parser::parseDeclaration(const Location& start)
indexer = parseTableIndexer(AstTableAccess::ReadWrite, std::nullopt);
}
}
else if (FFlag::LuauDeclarationExtraPropData)
{
Location propStart = lexer.current().location;
Name propName = parseName("property name");
expectAndConsume(':', "property type annotation");
AstType* propType = parseType();
props.push_back(
AstDeclaredClassProp{propName.name, propName.location, propType, false, Location(propStart, lexer.previousLocation())});
}
else
{
Name propName = parseName("property name");
expectAndConsume(':', "property type annotation");
AstType* propType = parseType();
props.push_back(AstDeclaredClassProp{propName.name, propType, false});
props.push_back(AstDeclaredClassProp{propName.name, Location{}, propType, false});
}
}
@ -961,7 +1140,8 @@ AstStat* Parser::parseDeclaration(const Location& start)
expectAndConsume(':', "global variable declaration");
AstType* type = parseType(/* in declaration context */ true);
return allocator.alloc<AstStatDeclareGlobal>(Location(start, type->location), globalName->name, type);
return allocator.alloc<AstStatDeclareGlobal>(
Location(start, type->location), globalName->name, FFlag::LuauDeclarationExtraPropData ? globalName->location : Location{}, type);
}
else
{
@ -1036,7 +1216,7 @@ std::pair<AstLocal*, AstArray<AstLocal*>> Parser::prepareFunctionArguments(const
// funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end
// parlist ::= bindinglist [`,' `...'] | `...'
std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody(
bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName)
bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray<AstAttr*>& attributes)
{
Location start = matchFunction.location;
@ -1088,7 +1268,7 @@ std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody(
bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction);
body->hasEnd = hasEnd;
return {allocator.alloc<AstExprFunction>(Location(start, end), generics, genericPacks, self, vars, vararg, varargLocation, body,
return {allocator.alloc<AstExprFunction>(Location(start, end), attributes, generics, genericPacks, self, vars, vararg, varargLocation, body,
functionStack.size(), debugname, typelist, varargAnnotation, argLocation),
funLocal};
}
@ -1297,7 +1477,7 @@ std::pair<Location, AstTypeList> Parser::parseReturnType()
return {location, AstTypeList{copy(result), varargAnnotation}};
}
AstType* tail = parseFunctionTypeTail(begin, {}, {}, copy(result), copy(resultNames), varargAnnotation);
AstType* tail = parseFunctionTypeTail(begin, {nullptr, 0}, {}, {}, copy(result), copy(resultNames), varargAnnotation);
return {Location{location, tail->location}, AstTypeList{copy(&tail, 1), varargAnnotation}};
}
@ -1340,22 +1520,19 @@ AstType* Parser::parseTableType(bool inDeclarationContext)
AstTableAccess access = AstTableAccess::ReadWrite;
std::optional<Location> accessLocation;
if (FFlag::LuauReadWritePropertySyntax || FFlag::DebugLuauDeferredConstraintResolution)
if (lexer.current().type == Lexeme::Name && lexer.lookahead().type != ':')
{
if (lexer.current().type == Lexeme::Name && lexer.lookahead().type != ':')
if (AstName(lexer.current().name) == "read")
{
if (AstName(lexer.current().name) == "read")
{
accessLocation = lexer.current().location;
access = AstTableAccess::Read;
lexer.next();
}
else if (AstName(lexer.current().name) == "write")
{
accessLocation = lexer.current().location;
access = AstTableAccess::Write;
lexer.next();
}
accessLocation = lexer.current().location;
access = AstTableAccess::Read;
lexer.next();
}
else if (AstName(lexer.current().name) == "write")
{
accessLocation = lexer.current().location;
access = AstTableAccess::Write;
lexer.next();
}
}
@ -1439,7 +1616,7 @@ AstType* Parser::parseTableType(bool inDeclarationContext)
// ReturnType ::= Type | `(' TypeList `)'
// FunctionType ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType
AstTypeOrPack Parser::parseFunctionType(bool allowPack, bool isCheckedFunction)
AstTypeOrPack Parser::parseFunctionType(bool allowPack, const AstArray<AstAttr*>& attributes)
{
incrementRecursionCounter("type annotation");
@ -1487,11 +1664,12 @@ AstTypeOrPack Parser::parseFunctionType(bool allowPack, bool isCheckedFunction)
AstArray<std::optional<AstArgumentName>> paramNames = copy(names);
return {parseFunctionTypeTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation, isCheckedFunction), {}};
return {parseFunctionTypeTail(begin, attributes, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}};
}
AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray<AstGenericType> generics, AstArray<AstGenericTypePack> genericPacks,
AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames, AstTypePack* varargAnnotation, bool isCheckedFunction)
AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, const AstArray<AstAttr*>& attributes, AstArray<AstGenericType> generics,
AstArray<AstGenericTypePack> genericPacks, AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames,
AstTypePack* varargAnnotation)
{
incrementRecursionCounter("type annotation");
@ -1516,7 +1694,7 @@ AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray<AstGenericT
AstTypeList paramTypes = AstTypeList{params, varargAnnotation};
return allocator.alloc<AstTypeFunction>(
Location(begin.location, endLocation), generics, genericPacks, paramTypes, paramNames, returnTypeList, isCheckedFunction);
Location(begin.location, endLocation), attributes, generics, genericPacks, paramTypes, paramNames, returnTypeList);
}
// Type ::=
@ -1528,7 +1706,11 @@ AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray<AstGenericT
AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin)
{
TempVector<AstType*> parts(scratchType);
parts.push_back(type);
if (!FFlag::LuauLeadingBarAndAmpersand2 || type != nullptr)
{
parts.push_back(type);
}
incrementRecursionCounter("type annotation");
@ -1553,6 +1735,8 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin)
}
else if (c == '?')
{
LUAU_ASSERT(parts.size() >= 1);
Location loc = lexer.current().location;
nextLexeme();
@ -1585,7 +1769,7 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin)
}
if (parts.size() == 1)
return type;
return FFlag::LuauLeadingBarAndAmpersand2 ? parts[0] : type;
if (isUnion && isIntersection)
{
@ -1628,15 +1812,34 @@ AstTypeOrPack Parser::parseTypeOrPack()
AstType* Parser::parseType(bool inDeclarationContext)
{
unsigned int oldRecursionCount = recursionCounter;
// recursion counter is incremented in parseSimpleType
// recursion counter is incremented in parseSimpleType and/or parseTypeSuffix
Location begin = lexer.current().location;
AstType* type = parseSimpleType(/* allowPack= */ false, /* in declaration context */ inDeclarationContext).type;
if (FFlag::LuauLeadingBarAndAmpersand2)
{
AstType* type = nullptr;
recursionCounter = oldRecursionCount;
Lexeme::Type c = lexer.current().type;
if (c != '|' && c != '&')
{
type = parseSimpleType(/* allowPack= */ false, /* in declaration context */ inDeclarationContext).type;
recursionCounter = oldRecursionCount;
}
return parseTypeSuffix(type, begin);
AstType* typeWithSuffix = parseTypeSuffix(type, begin);
recursionCounter = oldRecursionCount;
return typeWithSuffix;
}
else
{
AstType* type = parseSimpleType(/* allowPack= */ false, /* in declaration context */ inDeclarationContext).type;
recursionCounter = oldRecursionCount;
return parseTypeSuffix(type, begin);
}
}
// Type ::= nil | Name[`.' Name] [ `<' Type [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}'
@ -1647,7 +1850,21 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext)
Location start = lexer.current().location;
if (lexer.current().type == Lexeme::ReservedNil)
AstArray<AstAttr*> attributes{nullptr, 0};
if (lexer.current().type == Lexeme::Attribute)
{
if (!inDeclarationContext || !FFlag::LuauAttributeSyntax)
{
return {reportTypeError(start, {}, "attributes are not allowed in declaration context")};
}
else
{
attributes = Parser::parseAttributes();
return parseFunctionType(allowPack, attributes);
}
}
else if (lexer.current().type == Lexeme::ReservedNil)
{
nextLexeme();
return {allocator.alloc<AstTypeReference>(start, std::nullopt, nameNil, std::nullopt, start), {}};
@ -1735,15 +1952,9 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext)
{
return {parseTableType(/* inDeclarationContext */ inDeclarationContext), {}};
}
else if (FFlag::LuauCheckedFunctionSyntax && inDeclarationContext && lexer.current().type == Lexeme::ReservedChecked)
{
LUAU_ASSERT(FFlag::LuauCheckedFunctionSyntax);
nextLexeme();
return parseFunctionType(allowPack, /* isCheckedFunction */ true);
}
else if (lexer.current().type == '(' || lexer.current().type == '<')
{
return parseFunctionType(allowPack);
return parseFunctionType(allowPack, AstArray<AstAttr*>({nullptr, 0}));
}
else if (lexer.current().type == Lexeme::ReservedFunction)
{
@ -2213,11 +2424,24 @@ static ConstantNumberParseResult parseDouble(double& result, const char* data)
return ConstantNumberParseResult::Ok;
}
// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp
// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | [attributes] FUNCTION body | primaryexp
AstExpr* Parser::parseSimpleExpr()
{
Location start = lexer.current().location;
AstArray<AstAttr*> attributes{nullptr, 0};
if (FFlag::LuauAttributeSyntax && FFlag::LuauAttributeSyntaxFunExpr && lexer.current().type == Lexeme::Attribute)
{
attributes = parseAttributes();
if (lexer.current().type != Lexeme::ReservedFunction)
{
return reportExprError(
start, {}, "Expected 'function' declaration after attribute, but got %s intead", lexer.current().toString().c_str());
}
}
if (lexer.current().type == Lexeme::ReservedNil)
{
nextLexeme();
@ -2241,7 +2465,7 @@ AstExpr* Parser::parseSimpleExpr()
Lexeme matchFunction = lexer.current();
nextLexeme();
return parseFunctionBody(false, matchFunction, AstName(), nullptr).first;
return parseFunctionBody(false, matchFunction, AstName(), nullptr, attributes).first;
}
else if (lexer.current().type == Lexeme::Number)
{
@ -2671,7 +2895,7 @@ std::optional<AstArray<char>> Parser::parseCharArray()
LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString ||
lexer.current().type == Lexeme::InterpStringSimple);
scratchData.assign(lexer.current().data, lexer.current().length);
scratchData.assign(lexer.current().data, lexer.current().getLength());
if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple)
{
@ -2716,7 +2940,7 @@ AstExpr* Parser::parseInterpString()
endLocation = currentLexeme.location;
scratchData.assign(currentLexeme.data, currentLexeme.length);
scratchData.assign(currentLexeme.data, currentLexeme.getLength());
if (!Lexer::fixupQuotedString(scratchData))
{
@ -2789,7 +3013,7 @@ AstExpr* Parser::parseNumber()
{
Location start = lexer.current().location;
scratchData.assign(lexer.current().data, lexer.current().length);
scratchData.assign(lexer.current().data, lexer.current().getLength());
// Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al
if (scratchData.find('_') != std::string::npos)
@ -3144,11 +3368,11 @@ void Parser::nextLexeme()
return;
// Comments starting with ! are called "hot comments" and contain directives for type checking / linting / compiling
if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!')
if (lexeme.type == Lexeme::Comment && lexeme.getLength() && lexeme.data[0] == '!')
{
const char* text = lexeme.data;
unsigned int end = lexeme.length;
unsigned int end = lexeme.getLength();
while (end > 0 && isSpace(text[end - 1]))
--end;

View file

@ -250,6 +250,10 @@ void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector<Ev
ThreadContext& getThreadContext()
{
// Check custom provider that which might implement a custom TLS
if (auto provider = threadContextProvider())
return provider();
thread_local ThreadContext context;
return context;
}

View file

@ -317,6 +317,7 @@ static bool compileFile(const char* name, CompileFormat format, Luau::CodeGen::A
{
options.includeAssembly = format != CompileFormat::CodegenIr;
options.includeIr = format != CompileFormat::CodegenAsm;
options.includeIrTypes = format != CompileFormat::CodegenAsm;
options.includeOutlinedCode = format == CompileFormat::CodegenVerbose;
}

View file

@ -144,7 +144,10 @@ static int lua_require(lua_State* L)
if (luau_load(ML, resolvedRequire.chunkName.c_str(), bytecode.data(), bytecode.size(), 0) == 0)
{
if (codegen)
Luau::CodeGen::compile(ML, -1);
{
Luau::CodeGen::CompilationOptions nativeOptions;
Luau::CodeGen::compile(ML, -1, nativeOptions);
}
if (coverageActive())
coverageTrack(ML, -1);
@ -253,12 +256,16 @@ void setupState(lua_State* L)
void setupArguments(lua_State* L, int argc, char** argv)
{
lua_checkstack(L, argc);
for (int i = 0; i < argc; ++i)
lua_pushstring(L, argv[i]);
}
std::string runCode(lua_State* L, const std::string& source)
{
lua_checkstack(L, LUA_MINSTACK);
std::string bytecode = Luau::compile(source, copts());
if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0)
@ -429,6 +436,8 @@ static void completeIndexer(lua_State* L, const std::string& editBuffer, const A
std::string_view lookup = editBuffer;
bool completeOnlyFunctions = false;
lua_checkstack(L, LUA_MINSTACK);
// Push the global variable table to begin the search
lua_pushvalue(L, LUA_GLOBALSINDEX);
@ -602,7 +611,10 @@ static bool runFile(const char* name, lua_State* GL, bool repl)
if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0)
{
if (codegen)
Luau::CodeGen::compile(L, -1);
{
Luau::CodeGen::CompilationOptions nativeOptions;
Luau::CodeGen::compile(L, -1, nativeOptions);
}
if (coverageActive())
coverageTrack(L, -1);

View file

@ -229,6 +229,7 @@ if(LUAU_BUILD_TESTS)
target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen)
target_compile_options(Luau.Conformance PRIVATE ${LUAU_OPTIONS})
target_compile_definitions(Luau.Conformance PRIVATE DOCTEST_CONFIG_DOUBLE_STRINGIFY)
target_include_directories(Luau.Conformance PRIVATE extern)
target_link_libraries(Luau.Conformance PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen Luau.VM)
if(CMAKE_SYSTEM_NAME MATCHES "Android|iOS")

View file

@ -13,10 +13,11 @@ namespace CodeGen
{
struct IrFunction;
struct HostIrHooks;
void loadBytecodeTypeInfo(IrFunction& function);
void buildBytecodeBlocks(IrFunction& function, const std::vector<uint8_t>& jumpTargets);
void analyzeBytecodeTypes(IrFunction& function);
void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks);
} // namespace CodeGen
} // namespace Luau

View file

@ -12,6 +12,12 @@
struct lua_State;
#if defined(__x86_64__) || defined(_M_X64)
#define CODEGEN_TARGET_X64
#elif defined(__aarch64__) || defined(_M_ARM64)
#define CODEGEN_TARGET_A64
#endif
namespace Luau
{
namespace CodeGen
@ -40,8 +46,12 @@ enum class CodeGenCompilationResult
CodeGenAssemblerFinalizationFailure = 7, // Failure during assembler finalization
CodeGenLoweringFailure = 8, // Lowering failed
AllocationFailed = 9, // Native codegen failed due to an allocation error
Count = 10,
};
std::string toString(const CodeGenCompilationResult& result);
struct ProtoCompilationFailure
{
CodeGenCompilationResult result = CodeGenCompilationResult::Success;
@ -62,6 +72,97 @@ struct CompilationResult
}
};
struct IrBuilder;
struct IrOp;
using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength);
using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos);
using HostVectorNamecallHandler = bool (*)(
IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos);
enum class HostMetamethod
{
Add,
Sub,
Mul,
Div,
Idiv,
Mod,
Pow,
Minus,
Equal,
LessThan,
LessEqual,
Length,
Concat,
};
using HostUserdataOperationBytecodeType = uint8_t (*)(uint8_t type, const char* member, size_t memberLength);
using HostUserdataMetamethodBytecodeType = uint8_t (*)(uint8_t lhsTy, uint8_t rhsTy, HostMetamethod method);
using HostUserdataAccessHandler = bool (*)(
IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos);
using HostUserdataMetamethodHandler = bool (*)(
IrBuilder& builder, uint8_t lhsTy, uint8_t rhsTy, int resultReg, IrOp lhs, IrOp rhs, HostMetamethod method, int pcpos);
using HostUserdataNamecallHandler = bool (*)(
IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos);
struct HostIrHooks
{
// Suggest result type of a vector field access
HostVectorOperationBytecodeType vectorAccessBytecodeType = nullptr;
// Suggest result type of a vector function namecall
HostVectorOperationBytecodeType vectorNamecallBytecodeType = nullptr;
// Handle vector value field access
// 'sourceReg' is guaranteed to be a vector
// Guards should take a VM exit to 'pcpos'
HostVectorAccessHandler vectorAccess = nullptr;
// Handle namecall performed on a vector value
// 'sourceReg' (self argument) is guaranteed to be a vector
// All other arguments can be of any type
// Guards should take a VM exit to 'pcpos'
HostVectorNamecallHandler vectorNamecall = nullptr;
// Suggest result type of a userdata field access
HostUserdataOperationBytecodeType userdataAccessBytecodeType = nullptr;
// Suggest result type of a metamethod call
HostUserdataMetamethodBytecodeType userdataMetamethodBytecodeType = nullptr;
// Suggest result type of a userdata namecall
HostUserdataOperationBytecodeType userdataNamecallBytecodeType = nullptr;
// Handle userdata value field access
// 'sourceReg' is guaranteed to be a userdata, but tag has to be checked
// Write to 'resultReg' might invalidate 'sourceReg'
// Guards should take a VM exit to 'pcpos'
HostUserdataAccessHandler userdataAccess = nullptr;
// Handle metamethod operation on a userdata value
// 'lhs' and 'rhs' operands can be VM registers of constants
// Operand types have to be checked and userdata operand tags have to be checked
// Write to 'resultReg' might invalidate source operands
// Guards should take a VM exit to 'pcpos'
HostUserdataMetamethodHandler userdataMetamethod = nullptr;
// Handle namecall performed on a userdata value
// 'sourceReg' (self argument) is guaranteed to be a userdata, but tag has to be checked
// All other arguments can be of any type
// Guards should take a VM exit to 'pcpos'
HostUserdataNamecallHandler userdataNamecall = nullptr;
};
struct CompilationOptions
{
unsigned int flags = 0;
HostIrHooks hooks;
// null-terminated array of userdata types names that might have custom lowering
const char* const* userdataTypes = nullptr;
};
struct CompilationStats
{
size_t bytecodeSizeBytes = 0;
@ -101,8 +202,17 @@ using UniqueSharedCodeGenContext = std::unique_ptr<SharedCodeGenContext, SharedC
// SharedCodeGenContext must be destroyed before this function is called.
void destroySharedCodeGenContext(const SharedCodeGenContext* codeGenContext) noexcept;
void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext);
// Initializes native code-gen on the provided Luau VM, using a VM-specific
// code-gen context and either the default allocator parameters or custom
// allocator parameters.
void create(lua_State* L);
void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext);
void create(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext);
// Initializes native code-gen on the provided Luau VM, using the provided
// SharedCodeGenContext. Note that after this function is called, the
// SharedCodeGenContext must not be destroyed until after the Luau VM L is
// destroyed via lua_close.
void create(lua_State* L, SharedCodeGenContext* codeGenContext);
// Check if native execution is enabled
@ -111,11 +221,20 @@ void create(lua_State* L, SharedCodeGenContext* codeGenContext);
// Enable or disable native execution according to `enabled` argument
void setNativeExecutionEnabled(lua_State* L, bool enabled);
// Given a name, this function must return the index of the type which matches the type array used all CompilationOptions and AssemblyOptions
// If the type is unknown, 0xff has to be returned
using UserdataRemapperCallback = uint8_t(void* context, const char* name, size_t nameLength);
void setUserdataRemapper(lua_State* L, void* context, UserdataRemapperCallback cb);
using ModuleId = std::array<uint8_t, 16>;
// Builds target function and all inner functions
CompilationResult compile(lua_State* L, int idx, unsigned int flags = 0, CompilationStats* stats = nullptr);
CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags = 0, CompilationStats* stats = nullptr);
CompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats = nullptr);
CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats = nullptr);
CompilationResult compile(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats = nullptr);
CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats = nullptr);
using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int instpos);
@ -160,7 +279,7 @@ struct AssemblyOptions
Target target = Host;
unsigned int flags = 0;
CompilationOptions compilationOptions;
bool outputBinary = false;

View file

@ -16,11 +16,11 @@ namespace Luau
namespace CodeGen
{
struct AssemblyOptions;
struct HostIrHooks;
struct IrBuilder
{
IrBuilder();
IrBuilder(const HostIrHooks& hostHooks);
void buildFunctionIr(Proto* proto);
@ -54,6 +54,7 @@ struct IrBuilder
IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d);
IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e);
IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f);
IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f, IrOp g);
IrOp block(IrBlockKind kind); // Requested kind can be ignored if we are in an outlined sequence
IrOp blockAtInst(uint32_t index);
@ -64,13 +65,17 @@ struct IrBuilder
IrOp vmExit(uint32_t pcpos);
const HostIrHooks& hostHooks;
bool inTerminatedBlock = false;
bool interruptRequested = false;
bool activeFastcallFallback = false;
IrOp fastcallFallbackReturn;
int fastcallSkipTarget = -1;
// Force builder to skip source commands
int cmdSkipTarget = -1;
IrFunction function;

View file

@ -31,7 +31,7 @@ enum
// * Rn - VM stack register slot, n in 0..254
// * Kn - VM proto constant slot, n in 0..2^23-1
// * UPn - VM function upvalue slot, n in 0..199
// * A, B, C, D, E are instruction arguments
// * A, B, C, D, E, F, G are instruction arguments
enum class IrCmd : uint8_t
{
NOP,
@ -179,6 +179,10 @@ enum class IrCmd : uint8_t
// A: double
ABS_NUM,
// Get the sign of the argument (math.sign)
// A: double
SIGN_NUM,
// Add/Sub/Mul/Div/Idiv two vectors
// A, B: TValue
ADD_VEC,
@ -290,6 +294,11 @@ enum class IrCmd : uint8_t
// C: block
TRY_CALL_FASTGETTM,
// Create new tagged userdata
// A: int (size)
// B: int (tag)
NEW_USERDATA,
// Convert integer into a double number
// A: int
INT_TO_NUM,
@ -321,13 +330,12 @@ enum class IrCmd : uint8_t
// This is used to recover after calling a variadic function
ADJUST_STACK_TO_TOP,
// Execute fastcall builtin function in-place
// Execute fastcall builtin function with 1 argument in-place
// This is used for a few builtins that can have more than 1 result and cannot be represented as a regular instruction
// A: unsigned int (builtin id)
// B: Rn (result start)
// C: Rn (argument start)
// D: Rn or Kn or undef (optional second argument)
// E: int (argument count)
// F: int (result count)
// C: Rn (first argument)
// D: int (result count)
FASTCALL,
// Call the fastcall builtin function
@ -335,8 +343,9 @@ enum class IrCmd : uint8_t
// B: Rn (result start)
// C: Rn (argument start)
// D: Rn or Kn or undef (optional second argument)
// E: int (argument count or -1 to use all arguments up to stack top)
// F: int (result count or -1 to preserve all results and adjust stack top)
// E: Rn or Kn or undef (optional third argument)
// F: int (argument count or -1 to use all arguments up to stack top)
// G: int (result count or -1 to preserve all results and adjust stack top)
INVOKE_FASTCALL,
// Check that fastcall builtin function invocation was successful (negative result count jumps to fallback)
@ -460,6 +469,13 @@ enum class IrCmd : uint8_t
// When undef is specified instead of a block, execution is aborted on check failure
CHECK_BUFFER_LEN,
// Guard against userdata tag mismatch
// A: pointer (userdata)
// B: int (tag)
// C: block/vmexit/undef
// When undef is specified instead of a block, execution is aborted on check failure
CHECK_USERDATA_TAG,
// Special operations
// Check interrupt handler
@ -857,6 +873,7 @@ struct IrInst
IrOp d;
IrOp e;
IrOp f;
IrOp g;
uint32_t lastUse = 0;
uint16_t useCount = 0;
@ -911,6 +928,7 @@ struct IrInstHash
h = mix(h, key.d);
h = mix(h, key.e);
h = mix(h, key.f);
h = mix(h, key.g);
// MurmurHash2 tail
h ^= h >> 13;
@ -925,7 +943,7 @@ struct IrInstEq
{
bool operator()(const IrInst& a, const IrInst& b) const
{
return a.cmd == b.cmd && a.a == b.a && a.b == b.b && a.c == b.c && a.d == b.d && a.e == b.e && a.f == b.f;
return a.cmd == b.cmd && a.a == b.a && a.b == b.b && a.c == b.c && a.d == b.d && a.e == b.e && a.f == b.f && a.g == b.g;
}
};

View file

@ -31,9 +31,11 @@ void toString(IrToStringContext& ctx, IrOp op);
void toString(std::string& result, IrConst constant);
const char* getBytecodeTypeName(uint8_t type);
const char* getBytecodeTypeName_DEPRECATED(uint8_t type);
const char* getBytecodeTypeName(uint8_t type, const char* const* userdataTypes);
void toString(std::string& result, const BytecodeTypes& bcTypes);
void toString_DEPRECATED(std::string& result, const BytecodeTypes& bcTypes);
void toString(std::string& result, const BytecodeTypes& bcTypes, const char* const* userdataTypes);
void toStringDetailed(
IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, const IrInst& inst, uint32_t instIdx, IncludeUseInfo includeUseInfo);

View file

@ -11,6 +11,7 @@ namespace CodeGen
{
struct IrBuilder;
enum class HostMetamethod;
inline bool isJumpD(LuauOpcode op)
{
@ -63,6 +64,7 @@ inline bool isFastCall(LuauOpcode op)
case LOP_FASTCALL1:
case LOP_FASTCALL2:
case LOP_FASTCALL2K:
case LOP_FASTCALL3:
return true;
default:
@ -129,6 +131,7 @@ inline bool isNonTerminatingJump(IrCmd cmd)
case IrCmd::CHECK_NODE_NO_NEXT:
case IrCmd::CHECK_NODE_VALUE:
case IrCmd::CHECK_BUFFER_LEN:
case IrCmd::CHECK_USERDATA_TAG:
return true;
default:
break;
@ -168,6 +171,7 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::ROUND_NUM:
case IrCmd::SQRT_NUM:
case IrCmd::ABS_NUM:
case IrCmd::SIGN_NUM:
case IrCmd::ADD_VEC:
case IrCmd::SUB_VEC:
case IrCmd::MUL_VEC:
@ -182,6 +186,7 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::DUP_TABLE:
case IrCmd::TRY_NUM_TO_INDEX:
case IrCmd::TRY_CALL_FASTGETTM:
case IrCmd::NEW_USERDATA:
case IrCmd::INT_TO_NUM:
case IrCmd::UINT_TO_NUM:
case IrCmd::NUM_TO_INT:
@ -241,6 +246,12 @@ IrValueKind getCmdValueKind(IrCmd cmd);
bool isGCO(uint8_t tag);
// Optional bit has to be cleared at call site, otherwise, this will return 'false' for 'userdata?'
bool isUserdataBytecodeType(uint8_t ty);
bool isCustomUserdataBytecodeType(uint8_t ty);
HostMetamethod tmToHostMetamethod(int tm);
// Manually add or remove use of an operand
void addUse(IrFunction& function, IrOp op);
void removeUse(IrFunction& function, IrOp op);

View file

@ -4,7 +4,7 @@
#include "Luau/Common.h"
#include "Luau/IrData.h"
LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5)
LUAU_FASTFLAG(LuauCodegenFastcall3)
namespace Luau
{
@ -112,12 +112,48 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i
visitor.useRange(vmRegOp(inst.a), function.intOp(inst.b));
break;
// TODO: FASTCALL is more restrictive than INVOKE_FASTCALL; we should either determine the exact semantics, or rework it
case IrCmd::FASTCALL:
case IrCmd::INVOKE_FASTCALL:
if (int count = function.intOp(inst.e); count != -1)
if (FFlag::LuauCodegenFastcall3)
{
if (count >= 3)
visitor.use(inst.c);
if (int nresults = function.intOp(inst.d); nresults != -1)
visitor.defRange(vmRegOp(inst.b), nresults);
}
else
{
if (int count = function.intOp(inst.e); count != -1)
{
if (count >= 3)
{
CODEGEN_ASSERT(inst.d.kind == IrOpKind::VmReg && vmRegOp(inst.d) == vmRegOp(inst.c) + 1);
visitor.useRange(vmRegOp(inst.c), count);
}
else
{
if (count >= 1)
visitor.use(inst.c);
if (count >= 2)
visitor.maybeUse(inst.d); // Argument can also be a VmConst
}
}
else
{
visitor.useVarargs(vmRegOp(inst.c));
}
// Multiple return sequences (count == -1) are defined by ADJUST_STACK_TO_REG
if (int count = function.intOp(inst.f); count != -1)
visitor.defRange(vmRegOp(inst.b), count);
}
break;
case IrCmd::INVOKE_FASTCALL:
if (int count = function.intOp(FFlag::LuauCodegenFastcall3 ? inst.f : inst.e); count != -1)
{
// Only LOP_FASTCALL3 lowering is allowed to have third optional argument
if (count >= 3 && (!FFlag::LuauCodegenFastcall3 || inst.e.kind == IrOpKind::Undef))
{
CODEGEN_ASSERT(inst.d.kind == IrOpKind::VmReg && vmRegOp(inst.d) == vmRegOp(inst.c) + 1);
@ -130,6 +166,9 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i
if (count >= 2)
visitor.maybeUse(inst.d); // Argument can also be a VmConst
if (FFlag::LuauCodegenFastcall3 && count >= 3)
visitor.maybeUse(inst.e); // Argument can also be a VmConst
}
}
else
@ -138,7 +177,7 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i
}
// Multiple return sequences (count == -1) are defined by ADJUST_STACK_TO_REG
if (int count = function.intOp(inst.f); count != -1)
if (int count = function.intOp(FFlag::LuauCodegenFastcall3 ? inst.g : inst.f); count != -1)
visitor.defRange(vmRegOp(inst.b), count);
break;
case IrCmd::FORGLOOP:
@ -188,15 +227,8 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i
visitor.def(inst.b);
break;
case IrCmd::FALLBACK_FORGPREP:
if (FFlag::LuauCodegenRemoveDeadStores5)
{
// This instruction doesn't always redefine Rn, Rn+1, Rn+2, so we have to mark it as implicit use
visitor.useRange(vmRegOp(inst.b), 3);
}
else
{
visitor.use(inst.b);
}
// This instruction doesn't always redefine Rn, Rn+1, Rn+2, so we have to mark it as implicit use
visitor.useRange(vmRegOp(inst.b), 3);
visitor.defRange(vmRegOp(inst.b), 3);
break;
@ -214,12 +246,6 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i
visitor.use(inst.a);
break;
// After optimizations with DebugLuauAbortingChecks enabled, CHECK_TAG Rn, tag, block instructions are generated
case IrCmd::CHECK_TAG:
if (!FFlag::LuauCodegenRemoveDeadStores5)
visitor.maybeUse(inst.a);
break;
default:
// All instructions which reference registers have to be handled explicitly
CODEGEN_ASSERT(inst.a.kind != IrOpKind::VmReg);
@ -228,6 +254,7 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i
CODEGEN_ASSERT(inst.d.kind != IrOpKind::VmReg);
CODEGEN_ASSERT(inst.e.kind != IrOpKind::VmReg);
CODEGEN_ASSERT(inst.f.kind != IrOpKind::VmReg);
CODEGEN_ASSERT(inst.g.kind != IrOpKind::VmReg);
break;
}
}

View file

@ -16,7 +16,7 @@ namespace CodeGen
{
// This value is used in 'finishFunction' to mark the function that spans to the end of the whole code block
static uint32_t kFullBlockFuncton = ~0u;
static uint32_t kFullBlockFunction = ~0u;
class UnwindBuilder
{
@ -52,11 +52,10 @@ public:
virtual void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> gpr,
const std::vector<X64::RegisterX64>& simd) = 0;
virtual size_t getSize() const = 0;
virtual size_t getFunctionCount() const = 0;
virtual size_t getUnwindInfoSize(size_t blockSize) const = 0;
// This will place the unwinding data at the target address and might update values of some fields
virtual void finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const = 0;
virtual size_t finalize(char* target, size_t offset, void* funcAddress, size_t blockSize) const = 0;
};
} // namespace CodeGen

View file

@ -33,10 +33,9 @@ public:
void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> gpr,
const std::vector<X64::RegisterX64>& simd) override;
size_t getSize() const override;
size_t getFunctionCount() const override;
size_t getUnwindInfoSize(size_t blockSize = 0) const override;
void finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const override;
size_t finalize(char* target, size_t offset, void* funcAddress, size_t blockSize) const override;
private:
size_t beginOffset = 0;

View file

@ -53,10 +53,9 @@ public:
void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> gpr,
const std::vector<X64::RegisterX64>& simd) override;
size_t getSize() const override;
size_t getFunctionCount() const override;
size_t getUnwindInfoSize(size_t blockSize = 0) const override;
void finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const override;
size_t finalize(char* target, size_t offset, void* funcAddress, size_t blockSize) const override;
private:
size_t beginOffset = 0;

View file

@ -826,7 +826,7 @@ void AssemblyBuilderX64::vcvtss2sd(OperandX64 dst, OperandX64 src1, OperandX64 s
else
CODEGEN_ASSERT(src2.memSize == SizeX64::dword);
placeAvx("vcvtsd2ss", dst, src1, src2, 0x5a, false, AVX_0F, AVX_F3);
placeAvx("vcvtss2sd", dst, src1, src2, 0x5a, false, AVX_0F, AVX_F3);
}
void AssemblyBuilderX64::vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode)

View file

@ -2,35 +2,26 @@
#include "Luau/BytecodeAnalysis.h"
#include "Luau/BytecodeUtils.h"
#include "Luau/CodeGen.h"
#include "Luau/IrData.h"
#include "Luau/IrUtils.h"
#include "lobject.h"
#include "lstate.h"
#include <algorithm>
LUAU_FASTFLAG(LuauCodegenDirectUserdataFlow)
LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo load changes the format used by Codegen, same flag is used
LUAU_FASTFLAGVARIABLE(LuauCodegenTypeInfo, false) // New analysis is flagged separately
LUAU_FASTFLAG(LuauTypeInfoLookupImprovement)
LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOps, false)
LUAU_FASTFLAGVARIABLE(LuauCodegenFastcall3, false)
namespace Luau
{
namespace CodeGen
{
static bool hasTypedParameters(Proto* proto)
{
CODEGEN_ASSERT(!FFlag::LuauLoadTypeInfo);
return proto->typeinfo && proto->numparams != 0;
}
template<typename T>
static T read(uint8_t* data, size_t& offset)
{
CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo);
T result;
memcpy(&result, data + offset, sizeof(T));
offset += sizeof(T);
@ -40,8 +31,6 @@ static T read(uint8_t* data, size_t& offset)
static uint32_t readVarInt(uint8_t* data, size_t& offset)
{
CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo);
uint32_t result = 0;
uint32_t shift = 0;
@ -59,25 +48,15 @@ static uint32_t readVarInt(uint8_t* data, size_t& offset)
void loadBytecodeTypeInfo(IrFunction& function)
{
CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo);
Proto* proto = function.proto;
if (FFlag::LuauTypeInfoLookupImprovement)
{
if (!proto)
return;
}
else
{
if (!proto || !proto->typeinfo)
return;
}
if (!proto)
return;
BytecodeTypeInfo& typeInfo = function.bcTypeInfo;
// If there is no typeinfo, we generate default values for arguments and upvalues
if (FFlag::LuauTypeInfoLookupImprovement && !proto->typeinfo)
if (!proto->typeinfo)
{
typeInfo.argumentTypes.resize(proto->numparams, LBC_TYPE_ANY);
typeInfo.upvalueTypes.resize(proto->nups, LBC_TYPE_ANY);
@ -91,8 +70,6 @@ void loadBytecodeTypeInfo(IrFunction& function)
uint32_t upvalCount = readVarInt(data, offset);
uint32_t localCount = readVarInt(data, offset);
CODEGEN_ASSERT(upvalCount == unsigned(proto->nups));
if (typeSize != 0)
{
uint8_t* types = (uint8_t*)data + offset;
@ -110,6 +87,8 @@ void loadBytecodeTypeInfo(IrFunction& function)
if (upvalCount != 0)
{
CODEGEN_ASSERT(upvalCount == unsigned(proto->nups));
typeInfo.upvalueTypes.resize(upvalCount);
uint8_t* types = (uint8_t*)data + offset;
@ -137,8 +116,6 @@ void loadBytecodeTypeInfo(IrFunction& function)
static void prepareRegTypeInfoLookups(BytecodeTypeInfo& typeInfo)
{
CODEGEN_ASSERT(FFlag::LuauTypeInfoLookupImprovement);
// Sort by register first, then by end PC
std::sort(typeInfo.regTypes.begin(), typeInfo.regTypes.end(), [](const BytecodeRegTypeInfo& a, const BytecodeRegTypeInfo& b) {
if (a.reg != b.reg)
@ -171,47 +148,30 @@ static void prepareRegTypeInfoLookups(BytecodeTypeInfo& typeInfo)
static BytecodeRegTypeInfo* findRegType(BytecodeTypeInfo& info, uint8_t reg, int pc)
{
CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo);
if (FFlag::LuauTypeInfoLookupImprovement)
{
auto b = info.regTypes.begin() + info.regTypeOffsets[reg];
auto e = info.regTypes.begin() + info.regTypeOffsets[reg + 1];
// Doen't have info
if (b == e)
return nullptr;
// No info after the last live range
if (pc >= (e - 1)->endpc)
return nullptr;
for (auto it = b; it != e; ++it)
{
CODEGEN_ASSERT(it->reg == reg);
if (pc >= it->startpc && pc < it->endpc)
return &*it;
}
auto b = info.regTypes.begin() + info.regTypeOffsets[reg];
auto e = info.regTypes.begin() + info.regTypeOffsets[reg + 1];
// Doen't have info
if (b == e)
return nullptr;
}
else
{
for (BytecodeRegTypeInfo& el : info.regTypes)
{
if (reg == el.reg && pc >= el.startpc && pc < el.endpc)
return &el;
}
// No info after the last live range
if (pc >= (e - 1)->endpc)
return nullptr;
for (auto it = b; it != e; ++it)
{
CODEGEN_ASSERT(it->reg == reg);
if (pc >= it->startpc && pc < it->endpc)
return &*it;
}
return nullptr;
}
static void refineRegType(BytecodeTypeInfo& info, uint8_t reg, int pc, uint8_t ty)
{
CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo);
if (ty != LBC_TYPE_ANY)
{
if (BytecodeRegTypeInfo* regType = findRegType(info, reg, pc))
@ -220,7 +180,7 @@ static void refineRegType(BytecodeTypeInfo& info, uint8_t reg, int pc, uint8_t t
if (regType->type == LBC_TYPE_ANY)
regType->type = ty;
}
else if (FFlag::LuauTypeInfoLookupImprovement && reg < info.argumentTypes.size())
else if (reg < info.argumentTypes.size())
{
if (info.argumentTypes[reg] == LBC_TYPE_ANY)
info.argumentTypes[reg] = ty;
@ -230,8 +190,6 @@ static void refineRegType(BytecodeTypeInfo& info, uint8_t reg, int pc, uint8_t t
static void refineUpvalueType(BytecodeTypeInfo& info, int up, uint8_t ty)
{
CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo);
if (ty != LBC_TYPE_ANY)
{
if (size_t(up) < info.upvalueTypes.size())
@ -558,6 +516,49 @@ static void applyBuiltinCall(int bfid, BytecodeTypes& types)
}
}
static HostMetamethod opcodeToHostMetamethod(LuauOpcode op)
{
switch (op)
{
case LOP_ADD:
return HostMetamethod::Add;
case LOP_SUB:
return HostMetamethod::Sub;
case LOP_MUL:
return HostMetamethod::Mul;
case LOP_DIV:
return HostMetamethod::Div;
case LOP_IDIV:
return HostMetamethod::Idiv;
case LOP_MOD:
return HostMetamethod::Mod;
case LOP_POW:
return HostMetamethod::Pow;
case LOP_ADDK:
return HostMetamethod::Add;
case LOP_SUBK:
return HostMetamethod::Sub;
case LOP_MULK:
return HostMetamethod::Mul;
case LOP_DIVK:
return HostMetamethod::Div;
case LOP_IDIVK:
return HostMetamethod::Idiv;
case LOP_MODK:
return HostMetamethod::Mod;
case LOP_POWK:
return HostMetamethod::Pow;
case LOP_SUBRK:
return HostMetamethod::Sub;
case LOP_DIVRK:
return HostMetamethod::Div;
default:
CODEGEN_ASSERT(!"opcode is not assigned to a host metamethod");
}
return HostMetamethod::Add;
}
void buildBytecodeBlocks(IrFunction& function, const std::vector<uint8_t>& jumpTargets)
{
Proto* proto = function.proto;
@ -607,15 +608,14 @@ void buildBytecodeBlocks(IrFunction& function, const std::vector<uint8_t>& jumpT
}
}
void analyzeBytecodeTypes(IrFunction& function)
void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
{
Proto* proto = function.proto;
CODEGEN_ASSERT(proto);
BytecodeTypeInfo& bcTypeInfo = function.bcTypeInfo;
if (FFlag::LuauTypeInfoLookupImprovement)
prepareRegTypeInfoLookups(bcTypeInfo);
prepareRegTypeInfoLookups(bcTypeInfo);
// Setup our current knowledge of type tags based on arguments
uint8_t regTags[256];
@ -631,48 +631,31 @@ void analyzeBytecodeTypes(IrFunction& function)
// At the block start, reset or knowledge to the starting state
// In the future we might be able to propagate some info between the blocks as well
if (FFlag::LuauLoadTypeInfo)
for (size_t i = 0; i < bcTypeInfo.argumentTypes.size(); i++)
{
for (size_t i = 0; i < bcTypeInfo.argumentTypes.size(); i++)
{
uint8_t et = bcTypeInfo.argumentTypes[i];
uint8_t et = bcTypeInfo.argumentTypes[i];
// TODO: if argument is optional, this might force a VM exit unnecessarily
regTags[i] = et & ~LBC_TYPE_OPTIONAL_BIT;
}
}
else
{
if (hasTypedParameters(proto))
{
for (int i = 0; i < proto->numparams; ++i)
{
uint8_t et = proto->typeinfo[2 + i];
// TODO: if argument is optional, this might force a VM exit unnecessarily
regTags[i] = et & ~LBC_TYPE_OPTIONAL_BIT;
}
}
// TODO: if argument is optional, this might force a VM exit unnecessarily
regTags[i] = et & ~LBC_TYPE_OPTIONAL_BIT;
}
for (int i = proto->numparams; i < proto->maxstacksize; ++i)
regTags[i] = LBC_TYPE_ANY;
LuauBytecodeType knownNextCallResult = LBC_TYPE_ANY;
for (int i = block.startpc; i <= block.finishpc;)
{
const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
if (FFlag::LuauCodegenTypeInfo)
// Assign known register types from local type information
// TODO: this is an expensive walk for each instruction
// TODO: it's best to lookup when register is actually used in the instruction
for (BytecodeRegTypeInfo& el : bcTypeInfo.regTypes)
{
// Assign known register types from local type information
// TODO: this is an expensive walk for each instruction
// TODO: it's best to lookup when register is actually used in the instruction
for (BytecodeRegTypeInfo& el : bcTypeInfo.regTypes)
{
if (el.type != LBC_TYPE_ANY && i >= el.startpc && i < el.endpc)
regTags[el.reg] = el.type;
}
if (el.type != LBC_TYPE_ANY && i >= el.startpc && i < el.endpc)
regTags[el.reg] = el.type;
}
BytecodeTypes& bcType = function.bcTypes[i];
@ -694,8 +677,7 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = LBC_TYPE_BOOLEAN;
bcType.result = regTags[ra];
if (FFlag::LuauCodegenTypeInfo)
refineRegType(bcTypeInfo, ra, i, bcType.result);
refineRegType(bcTypeInfo, ra, i, bcType.result);
break;
}
case LOP_LOADN:
@ -704,8 +686,7 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = LBC_TYPE_NUMBER;
bcType.result = regTags[ra];
if (FFlag::LuauCodegenTypeInfo)
refineRegType(bcTypeInfo, ra, i, bcType.result);
refineRegType(bcTypeInfo, ra, i, bcType.result);
break;
}
case LOP_LOADK:
@ -716,8 +697,7 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = bcType.a;
bcType.result = regTags[ra];
if (FFlag::LuauCodegenTypeInfo)
refineRegType(bcTypeInfo, ra, i, bcType.result);
refineRegType(bcTypeInfo, ra, i, bcType.result);
break;
}
case LOP_LOADKX:
@ -728,8 +708,7 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = bcType.a;
bcType.result = regTags[ra];
if (FFlag::LuauCodegenTypeInfo)
refineRegType(bcTypeInfo, ra, i, bcType.result);
refineRegType(bcTypeInfo, ra, i, bcType.result);
break;
}
case LOP_MOVE:
@ -740,8 +719,7 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = regTags[rb];
bcType.result = regTags[ra];
if (FFlag::LuauCodegenTypeInfo)
refineRegType(bcTypeInfo, ra, i, bcType.result);
refineRegType(bcTypeInfo, ra, i, bcType.result);
break;
}
case LOP_GETTABLE:
@ -771,10 +749,51 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = LBC_TYPE_ANY;
// Assuming that vector component is being indexed
// TODO: check what key is used
if (bcType.a == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_NUMBER;
if (FFlag::LuauCodegenUserdataOps)
{
TString* str = gco2ts(function.proto->k[kc].value.gc);
const char* field = getstr(str);
if (bcType.a == LBC_TYPE_VECTOR)
{
if (str->len == 1)
{
// Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z"
char ch = field[0] | ' ';
if (ch == 'x' || ch == 'y' || ch == 'z')
regTags[ra] = LBC_TYPE_NUMBER;
}
if (regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType)
regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len);
}
else if (isCustomUserdataBytecodeType(bcType.a))
{
if (regTags[ra] == LBC_TYPE_ANY && hostHooks.userdataAccessBytecodeType)
regTags[ra] = hostHooks.userdataAccessBytecodeType(bcType.a, field, str->len);
}
}
else
{
if (bcType.a == LBC_TYPE_VECTOR)
{
TString* str = gco2ts(function.proto->k[kc].value.gc);
const char* field = getstr(str);
if (str->len == 1)
{
// Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z"
char ch = field[0] | ' ';
if (ch == 'x' || ch == 'y' || ch == 'z')
regTags[ra] = LBC_TYPE_NUMBER;
}
if (regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType)
regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len);
}
}
bcType.result = regTags[ra];
break;
@ -810,6 +829,9 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra];
break;
@ -839,6 +861,11 @@ void analyzeBytecodeTypes(IrFunction& function)
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
}
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
{
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
}
bcType.result = regTags[ra];
break;
@ -857,6 +884,9 @@ void analyzeBytecodeTypes(IrFunction& function)
if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra];
break;
@ -877,6 +907,9 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra];
break;
@ -906,6 +939,11 @@ void analyzeBytecodeTypes(IrFunction& function)
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
}
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
{
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
}
bcType.result = regTags[ra];
break;
@ -924,6 +962,9 @@ void analyzeBytecodeTypes(IrFunction& function)
if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra];
break;
@ -943,6 +984,9 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra];
break;
@ -970,6 +1014,11 @@ void analyzeBytecodeTypes(IrFunction& function)
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
}
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
{
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
}
bcType.result = regTags[ra];
break;
@ -998,6 +1047,8 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && isCustomUserdataBytecodeType(bcType.a))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, LBC_TYPE_ANY, HostMetamethod::Minus);
bcType.result = regTags[ra];
break;
@ -1036,8 +1087,7 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra + 3] = bcType.c;
regTags[ra] = bcType.result;
if (FFlag::LuauCodegenTypeInfo)
refineRegType(bcTypeInfo, ra, i, bcType.result);
refineRegType(bcTypeInfo, ra, i, bcType.result);
break;
}
case LOP_FASTCALL1:
@ -1055,8 +1105,7 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[LUAU_INSN_B(*pc)] = bcType.a;
regTags[ra] = bcType.result;
if (FFlag::LuauCodegenTypeInfo)
refineRegType(bcTypeInfo, ra, i, bcType.result);
refineRegType(bcTypeInfo, ra, i, bcType.result);
break;
}
case LOP_FASTCALL2:
@ -1074,8 +1123,29 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[int(pc[1])] = bcType.b;
regTags[ra] = bcType.result;
if (FFlag::LuauCodegenTypeInfo)
refineRegType(bcTypeInfo, ra, i, bcType.result);
refineRegType(bcTypeInfo, ra, i, bcType.result);
break;
}
case LOP_FASTCALL3:
{
CODEGEN_ASSERT(FFlag::LuauCodegenFastcall3);
int bfid = LUAU_INSN_A(*pc);
int skip = LUAU_INSN_C(*pc);
int aux = pc[1];
Instruction call = pc[skip + 1];
CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
int ra = LUAU_INSN_A(call);
applyBuiltinCall(bfid, bcType);
regTags[LUAU_INSN_B(*pc)] = bcType.a;
regTags[aux & 0xff] = bcType.b;
regTags[(aux >> 8) & 0xff] = bcType.c;
regTags[ra] = bcType.result;
refineRegType(bcTypeInfo, ra, i, bcType.result);
break;
}
case LOP_FORNPREP:
@ -1086,12 +1156,9 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra + 1] = LBC_TYPE_NUMBER;
regTags[ra + 2] = LBC_TYPE_NUMBER;
if (FFlag::LuauCodegenTypeInfo)
{
refineRegType(bcTypeInfo, ra, i, regTags[ra]);
refineRegType(bcTypeInfo, ra + 1, i, regTags[ra + 1]);
refineRegType(bcTypeInfo, ra + 2, i, regTags[ra + 2]);
}
refineRegType(bcTypeInfo, ra, i, regTags[ra]);
refineRegType(bcTypeInfo, ra + 1, i, regTags[ra + 1]);
refineRegType(bcTypeInfo, ra + 2, i, regTags[ra + 2]);
break;
}
case LOP_FORNLOOP:
@ -1121,61 +1188,88 @@ void analyzeBytecodeTypes(IrFunction& function)
}
case LOP_NAMECALL:
{
if (FFlag::LuauCodegenDirectUserdataFlow)
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
uint32_t kc = pc[1];
bcType.a = regTags[rb];
bcType.b = getBytecodeConstantTag(proto, kc);
// While namecall might result in a callable table, we assume the function fast path
regTags[ra] = LBC_TYPE_FUNCTION;
// Namecall places source register into target + 1
regTags[ra + 1] = bcType.a;
bcType.result = LBC_TYPE_FUNCTION;
if (FFlag::LuauCodegenUserdataOps)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
uint32_t kc = pc[1];
TString* str = gco2ts(function.proto->k[kc].value.gc);
const char* field = getstr(str);
bcType.a = regTags[rb];
bcType.b = getBytecodeConstantTag(proto, kc);
// While namecall might result in a callable table, we assume the function fast path
regTags[ra] = LBC_TYPE_FUNCTION;
// Namecall places source register into target + 1
regTags[ra + 1] = bcType.a;
bcType.result = LBC_TYPE_FUNCTION;
if (bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType)
knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len));
else if (isCustomUserdataBytecodeType(bcType.a) && hostHooks.userdataNamecallBytecodeType)
knownNextCallResult = LuauBytecodeType(hostHooks.userdataNamecallBytecodeType(bcType.a, field, str->len));
}
else
{
if (bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType)
{
TString* str = gco2ts(function.proto->k[kc].value.gc);
const char* field = getstr(str);
knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len));
}
}
break;
}
case LOP_CALL:
{
int ra = LUAU_INSN_A(*pc);
if (knownNextCallResult != LBC_TYPE_ANY)
{
bcType.result = knownNextCallResult;
knownNextCallResult = LBC_TYPE_ANY;
regTags[ra] = bcType.result;
}
refineRegType(bcTypeInfo, ra, i, bcType.result);
break;
}
case LOP_GETUPVAL:
{
if (FFlag::LuauCodegenTypeInfo)
int ra = LUAU_INSN_A(*pc);
int up = LUAU_INSN_B(*pc);
bcType.a = LBC_TYPE_ANY;
if (size_t(up) < bcTypeInfo.upvalueTypes.size())
{
int ra = LUAU_INSN_A(*pc);
int up = LUAU_INSN_B(*pc);
uint8_t et = bcTypeInfo.upvalueTypes[up];
bcType.a = LBC_TYPE_ANY;
if (size_t(up) < bcTypeInfo.upvalueTypes.size())
{
uint8_t et = bcTypeInfo.upvalueTypes[up];
// TODO: if argument is optional, this might force a VM exit unnecessarily
bcType.a = et & ~LBC_TYPE_OPTIONAL_BIT;
}
regTags[ra] = bcType.a;
bcType.result = regTags[ra];
// TODO: if argument is optional, this might force a VM exit unnecessarily
bcType.a = et & ~LBC_TYPE_OPTIONAL_BIT;
}
regTags[ra] = bcType.a;
bcType.result = regTags[ra];
break;
}
case LOP_SETUPVAL:
{
if (FFlag::LuauCodegenTypeInfo)
{
int ra = LUAU_INSN_A(*pc);
int up = LUAU_INSN_B(*pc);
int ra = LUAU_INSN_A(*pc);
int up = LUAU_INSN_B(*pc);
refineUpvalueType(bcTypeInfo, up, regTags[ra]);
}
refineUpvalueType(bcTypeInfo, up, regTags[ra]);
break;
}
case LOP_GETGLOBAL:
case LOP_SETGLOBAL:
case LOP_CALL:
case LOP_RETURN:
case LOP_JUMP:
case LOP_JUMPBACK:

View file

@ -8,6 +8,8 @@
#include "lobject.h"
#include "lstate.h"
LUAU_FASTFLAG(LuauNativeAttribute)
namespace Luau
{
namespace CodeGen
@ -56,7 +58,10 @@ std::vector<FunctionBytecodeSummary> summarizeBytecode(lua_State* L, int idx, un
Proto* root = clvalue(func)->l.p;
std::vector<Proto*> protos;
gatherFunctions(protos, root, CodeGen_ColdFunctions);
if (FFlag::LuauNativeAttribute)
gatherFunctions(protos, root, CodeGen_ColdFunctions, root->flags & LPF_NATIVE_FUNCTION);
else
gatherFunctions_DEPRECATED(protos, root, CodeGen_ColdFunctions);
std::vector<FunctionBytecodeSummary> summaries;
summaries.reserve(protos.size());

View file

@ -7,7 +7,7 @@
#include <string.h>
#include <stdlib.h>
#if defined(_WIN32) && defined(_M_X64)
#if defined(_WIN32) && defined(CODEGEN_TARGET_X64)
#ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN
@ -26,7 +26,7 @@ extern "C" void __deregister_frame(const void*) __attribute__((weak));
extern "C" void __unw_add_dynamic_fde() __attribute__((weak));
#endif
#if defined(__APPLE__) && defined(__aarch64__)
#if defined(__APPLE__) && defined(CODEGEN_TARGET_A64)
#include <sys/sysctl.h>
#include <mach-o/loader.h>
#include <dlfcn.h>
@ -48,7 +48,7 @@ namespace Luau
namespace CodeGen
{
#if defined(__APPLE__) && defined(__aarch64__)
#if defined(__APPLE__) && defined(CODEGEN_TARGET_A64)
static int findDynamicUnwindSections(uintptr_t addr, unw_dynamic_unwind_sections_t* info)
{
// Define a minimal mach header for JIT'd code.
@ -102,17 +102,17 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz
UnwindBuilder* unwind = (UnwindBuilder*)context;
// All unwinding related data is placed together at the start of the block
size_t unwindSize = unwind->getSize();
size_t unwindSize = unwind->getUnwindInfoSize(blockSize);
unwindSize = (unwindSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1); // Match code allocator alignment
CODEGEN_ASSERT(blockSize >= unwindSize);
char* unwindData = (char*)block;
unwind->finalize(unwindData, unwindSize, block, blockSize);
[[maybe_unused]] size_t functionCount = unwind->finalize(unwindData, unwindSize, block, blockSize);
#if defined(_WIN32) && defined(_M_X64)
#if defined(_WIN32) && defined(CODEGEN_TARGET_X64)
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM)
if (!RtlAddFunctionTable((RUNTIME_FUNCTION*)block, uint32_t(unwind->getFunctionCount()), uintptr_t(block)))
if (!RtlAddFunctionTable((RUNTIME_FUNCTION*)block, uint32_t(functionCount), uintptr_t(block)))
{
CODEGEN_ASSERT(!"Failed to allocate function table");
return nullptr;
@ -126,7 +126,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz
visitFdeEntries(unwindData, __register_frame);
#endif
#if defined(__APPLE__) && defined(__aarch64__)
#if defined(__APPLE__) && defined(CODEGEN_TARGET_A64)
// Starting from macOS 14, we need to register unwind section callback to state that our ABI doesn't require pointer authentication
// This might conflict with other JITs that do the same; unfortunately this is the best we can do for now.
static unw_add_find_dynamic_unwind_sections_t unw_add_find_dynamic_unwind_sections =
@ -141,7 +141,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz
void destroyBlockUnwindInfo(void* context, void* unwindData)
{
#if defined(_WIN32) && defined(_M_X64)
#if defined(_WIN32) && defined(CODEGEN_TARGET_X64)
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM)
if (!RtlDeleteFunctionTable((RUNTIME_FUNCTION*)unwindData))
@ -161,12 +161,12 @@ void destroyBlockUnwindInfo(void* context, void* unwindData)
bool isUnwindSupported()
{
#if defined(_WIN32) && defined(_M_X64)
#if defined(_WIN32) && defined(CODEGEN_TARGET_X64)
return true;
#elif defined(__ANDROID__)
// Current unwind information is not compatible with Android
return false;
#elif defined(__APPLE__) && defined(__aarch64__)
#elif defined(__APPLE__) && defined(CODEGEN_TARGET_A64)
char ver[256];
size_t verLength = sizeof(ver);
// libunwind on macOS 12 and earlier (which maps to osrelease 21) assumes JIT frames use pointer authentication without a way to override that

View file

@ -27,7 +27,7 @@
#include <memory>
#include <optional>
#if defined(__x86_64__) || defined(_M_X64)
#if defined(CODEGEN_TARGET_X64)
#ifdef _MSC_VER
#include <intrin.h> // __cpuid
#else
@ -35,7 +35,7 @@
#endif
#endif
#if defined(__aarch64__)
#if defined(CODEGEN_TARGET_A64)
#ifdef __APPLE__
#include <sys/sysctl.h>
#endif
@ -58,186 +58,41 @@ LUAU_FASTINTVARIABLE(CodegenHeuristicsBlockLimit, 32'768) // 32 K
// Current value is based on some member variables being limited to 16 bits
LUAU_FASTINTVARIABLE(CodegenHeuristicsBlockInstructionLimit, 65'536) // 64 K
LUAU_FASTFLAG(LuauCodegenContext)
namespace Luau
{
namespace CodeGen
{
static const Instruction kCodeEntryInsn = LOP_NATIVECALL;
void* gPerfLogContext = nullptr;
PerfLogFn gPerfLogFn = nullptr;
struct OldNativeProto
std::string toString(const CodeGenCompilationResult& result)
{
Proto* p;
void* execdata;
uintptr_t exectarget;
};
// Additional data attached to Proto::execdata
// Guaranteed to be aligned to 16 bytes
struct ExtraExecData
{
size_t execDataSize;
size_t codeSize;
};
static int alignTo(int value, int align)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
CODEGEN_ASSERT(align > 0 && (align & (align - 1)) == 0);
return (value + (align - 1)) & ~(align - 1);
}
// Returns the size of execdata required to store all code offsets and ExtraExecData structure at proper alignment
// Always a multiple of 4 bytes
static int calculateExecDataSize(Proto* proto)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
int size = proto->sizecode * sizeof(uint32_t);
size = alignTo(size, 16);
size += sizeof(ExtraExecData);
return size;
}
// Returns pointer to the ExtraExecData inside the Proto::execdata
// Even though 'execdata' is a field in Proto, we require it to support cases where it's not attached to Proto during construction
ExtraExecData* getExtraExecData(Proto* proto, void* execdata)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
int size = proto->sizecode * sizeof(uint32_t);
size = alignTo(size, 16);
return reinterpret_cast<ExtraExecData*>(reinterpret_cast<char*>(execdata) + size);
}
static OldNativeProto createOldNativeProto(Proto* proto, const IrBuilder& ir)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
int execDataSize = calculateExecDataSize(proto);
CODEGEN_ASSERT(execDataSize % 4 == 0);
uint32_t* execData = new uint32_t[execDataSize / 4];
uint32_t instTarget = ir.function.entryLocation;
for (int i = 0; i < proto->sizecode; i++)
switch (result)
{
CODEGEN_ASSERT(ir.function.bcMapping[i].asmLocation >= instTarget);
execData[i] = ir.function.bcMapping[i].asmLocation - instTarget;
case CodeGenCompilationResult::Success:
return "Success";
case CodeGenCompilationResult::NothingToCompile:
return "NothingToCompile";
case CodeGenCompilationResult::NotNativeModule:
return "NotNativeModule";
case CodeGenCompilationResult::CodeGenNotInitialized:
return "CodeGenNotInitialized";
case CodeGenCompilationResult::CodeGenOverflowInstructionLimit:
return "CodeGenOverflowInstructionLimit";
case CodeGenCompilationResult::CodeGenOverflowBlockLimit:
return "CodeGenOverflowBlockLimit";
case CodeGenCompilationResult::CodeGenOverflowBlockInstructionLimit:
return "CodeGenOverflowBlockInstructionLimit";
case CodeGenCompilationResult::CodeGenAssemblerFinalizationFailure:
return "CodeGenAssemblerFinalizationFailure";
case CodeGenCompilationResult::CodeGenLoweringFailure:
return "CodeGenLoweringFailure";
case CodeGenCompilationResult::AllocationFailed:
return "AllocationFailed";
case CodeGenCompilationResult::Count:
return "Count";
}
// Set first instruction offset to 0 so that entering this function still executes any generated entry code.
execData[0] = 0;
ExtraExecData* extra = getExtraExecData(proto, execData);
memset(extra, 0, sizeof(ExtraExecData));
extra->execDataSize = execDataSize;
// entry target will be relocated when assembly is finalized
return {proto, execData, instTarget};
}
static void destroyExecData(void* execdata)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
delete[] static_cast<uint32_t*>(execdata);
}
static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
CODEGEN_ASSERT(p->source);
const char* source = getstr(p->source);
source = (source[0] == '=' || source[0] == '@') ? source + 1 : "[string]";
char name[256];
snprintf(name, sizeof(name), "<luau> %s:%d %s", source, p->linedefined, p->debugname ? getstr(p->debugname) : "");
if (gPerfLogFn)
gPerfLogFn(gPerfLogContext, addr, size, name);
}
template<typename AssemblyBuilder>
static std::optional<OldNativeProto> createNativeFunction(
AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, uint32_t& totalIrInstCount, CodeGenCompilationResult& result)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
IrBuilder ir;
ir.buildFunctionIr(proto);
unsigned instCount = unsigned(ir.function.instructions.size());
if (totalIrInstCount + instCount >= unsigned(FInt::CodegenHeuristicsInstructionLimit.value))
{
result = CodeGenCompilationResult::CodeGenOverflowInstructionLimit;
return std::nullopt;
}
totalIrInstCount += instCount;
if (!lowerFunction(ir, build, helpers, proto, {}, /* stats */ nullptr, result))
return std::nullopt;
return createOldNativeProto(proto, ir);
}
static NativeState* getNativeState(lua_State* L)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
return static_cast<NativeState*>(L->global->ecb.context);
}
static void onCloseState(lua_State* L)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
delete getNativeState(L);
L->global->ecb = lua_ExecutionCallbacks();
}
static void onDestroyFunction(lua_State* L, Proto* proto)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
destroyExecData(proto->execdata);
proto->execdata = nullptr;
proto->exectarget = 0;
proto->codeentry = proto->code;
}
static int onEnter(lua_State* L, Proto* proto)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
NativeState* data = getNativeState(L);
CODEGEN_ASSERT(proto->execdata);
CODEGEN_ASSERT(L->ci->savedpc >= proto->code && L->ci->savedpc < proto->code + proto->sizecode);
uintptr_t target = proto->exectarget + static_cast<uint32_t*>(proto->execdata)[L->ci->savedpc - proto->code];
// Returns 1 to finish the function in the VM
return GateFn(data->context.gateEntry)(L, proto, target, &data->context);
}
// used to disable native execution, unconditionally
static int onEnterDisabled(lua_State* L, Proto* proto)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
return 1;
CODEGEN_ASSERT(false);
return "";
}
void onDisable(lua_State* L, Proto* proto)
@ -279,18 +134,7 @@ void onDisable(lua_State* L, Proto* proto)
});
}
static size_t getMemorySize(lua_State* L, Proto* proto)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
ExtraExecData* extra = getExtraExecData(proto, proto->execdata);
// While execDataSize is exactly the size of the allocation we made and hold for 'execdata' field, the code size is approximate
// This is because code+data page is shared and owned by all Proto from a single module and each one can keep the whole region alive
// So individual Proto being freed by GC will not reflect memory use by native code correctly
return extra->execDataSize + extra->codeSize;
}
#if defined(__aarch64__)
#if defined(CODEGEN_TARGET_A64)
unsigned int getCpuFeaturesA64()
{
unsigned int result = 0;
@ -326,7 +170,7 @@ bool isSupported()
return false;
#endif
#if defined(__x86_64__) || defined(_M_X64)
#if defined(CODEGEN_TARGET_X64)
int cpuinfo[4] = {};
#ifdef _MSC_VER
__cpuid(cpuinfo, 1);
@ -341,273 +185,12 @@ bool isSupported()
return false;
return true;
#elif defined(__aarch64__)
#elif defined(CODEGEN_TARGET_A64)
return true;
#else
return false;
#endif
}
static void create_OLD(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
CODEGEN_ASSERT(isSupported());
std::unique_ptr<NativeState> data = std::make_unique<NativeState>(allocationCallback, allocationCallbackContext);
#if defined(_WIN32)
data->unwindBuilder = std::make_unique<UnwindBuilderWin>();
#else
data->unwindBuilder = std::make_unique<UnwindBuilderDwarf2>();
#endif
data->codeAllocator.context = data->unwindBuilder.get();
data->codeAllocator.createBlockUnwindInfo = createBlockUnwindInfo;
data->codeAllocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo;
initFunctions(*data);
#if defined(__x86_64__) || defined(_M_X64)
if (!X64::initHeaderFunctions(*data))
return;
#elif defined(__aarch64__)
if (!A64::initHeaderFunctions(*data))
return;
#endif
if (gPerfLogFn)
gPerfLogFn(gPerfLogContext, uintptr_t(data->context.gateEntry), 4096, "<luau gate>");
lua_ExecutionCallbacks* ecb = &L->global->ecb;
ecb->context = data.release();
ecb->close = onCloseState;
ecb->destroy = onDestroyFunction;
ecb->enter = onEnter;
ecb->disable = onDisable;
ecb->getmemorysize = getMemorySize;
}
void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext)
{
if (FFlag::LuauCodegenContext)
{
create_NEW(L, allocationCallback, allocationCallbackContext);
}
else
{
create_OLD(L, allocationCallback, allocationCallbackContext);
}
}
void create(lua_State* L)
{
if (FFlag::LuauCodegenContext)
{
create_NEW(L);
}
else
{
create(L, nullptr, nullptr);
}
}
void create(lua_State* L, SharedCodeGenContext* codeGenContext)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
create_NEW(L, codeGenContext);
}
[[nodiscard]] bool isNativeExecutionEnabled(lua_State* L)
{
if (FFlag::LuauCodegenContext)
{
return isNativeExecutionEnabled_NEW(L);
}
else
{
return getNativeState(L) ? (L->global->ecb.enter == onEnter) : false;
}
}
void setNativeExecutionEnabled(lua_State* L, bool enabled)
{
if (FFlag::LuauCodegenContext)
{
setNativeExecutionEnabled_NEW(L, enabled);
}
else
{
if (getNativeState(L))
L->global->ecb.enter = enabled ? onEnter : onEnterDisabled;
}
}
static CompilationResult compile_OLD(lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
{
CompilationResult compilationResult;
CODEGEN_ASSERT(lua_isLfunction(L, idx));
const TValue* func = luaA_toobject(L, idx);
Proto* root = clvalue(func)->l.p;
if ((flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0)
{
compilationResult.result = CodeGenCompilationResult::NotNativeModule;
return compilationResult;
}
// If initialization has failed, do not compile any functions
NativeState* data = getNativeState(L);
if (!data)
{
compilationResult.result = CodeGenCompilationResult::CodeGenNotInitialized;
return compilationResult;
}
std::vector<Proto*> protos;
gatherFunctions(protos, root, flags);
// Skip protos that have been compiled during previous invocations of CodeGen::compile
protos.erase(std::remove_if(protos.begin(), protos.end(),
[](Proto* p) {
return p == nullptr || p->execdata != nullptr;
}),
protos.end());
if (protos.empty())
{
compilationResult.result = CodeGenCompilationResult::NothingToCompile;
return compilationResult;
}
if (stats != nullptr)
stats->functionsTotal = uint32_t(protos.size());
#if defined(__aarch64__)
static unsigned int cpuFeatures = getCpuFeaturesA64();
A64::AssemblyBuilderA64 build(/* logText= */ false, cpuFeatures);
#else
X64::AssemblyBuilderX64 build(/* logText= */ false);
#endif
ModuleHelpers helpers;
#if defined(__aarch64__)
A64::assembleHelpers(build, helpers);
#else
X64::assembleHelpers(build, helpers);
#endif
std::vector<OldNativeProto> results;
results.reserve(protos.size());
uint32_t totalIrInstCount = 0;
for (Proto* p : protos)
{
CodeGenCompilationResult protoResult = CodeGenCompilationResult::Success;
if (std::optional<OldNativeProto> np = createNativeFunction(build, helpers, p, totalIrInstCount, protoResult))
results.push_back(*np);
else
compilationResult.protoFailures.push_back({protoResult, p->debugname ? getstr(p->debugname) : "", p->linedefined});
}
// Very large modules might result in overflowing a jump offset; in this case we currently abandon the entire module
if (!build.finalize())
{
for (OldNativeProto result : results)
destroyExecData(result.execdata);
compilationResult.result = CodeGenCompilationResult::CodeGenAssemblerFinalizationFailure;
return compilationResult;
}
// If no functions were assembled, we don't need to allocate/copy executable pages for helpers
if (results.empty())
return compilationResult;
uint8_t* nativeData = nullptr;
size_t sizeNativeData = 0;
uint8_t* codeStart = nullptr;
if (!data->codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast<const uint8_t*>(build.code.data()),
int(build.code.size() * sizeof(build.code[0])), nativeData, sizeNativeData, codeStart))
{
for (OldNativeProto result : results)
destroyExecData(result.execdata);
compilationResult.result = CodeGenCompilationResult::AllocationFailed;
return compilationResult;
}
if (gPerfLogFn && results.size() > 0)
gPerfLogFn(gPerfLogContext, uintptr_t(codeStart), uint32_t(results[0].exectarget), "<luau helpers>");
for (size_t i = 0; i < results.size(); ++i)
{
uint32_t begin = uint32_t(results[i].exectarget);
uint32_t end = i + 1 < results.size() ? uint32_t(results[i + 1].exectarget) : uint32_t(build.code.size() * sizeof(build.code[0]));
CODEGEN_ASSERT(begin < end);
if (gPerfLogFn)
logPerfFunction(results[i].p, uintptr_t(codeStart) + begin, end - begin);
ExtraExecData* extra = getExtraExecData(results[i].p, results[i].execdata);
extra->codeSize = end - begin;
}
for (const OldNativeProto& result : results)
{
// the memory is now managed by VM and will be freed via onDestroyFunction
result.p->execdata = result.execdata;
result.p->exectarget = uintptr_t(codeStart) + result.exectarget;
result.p->codeentry = &kCodeEntryInsn;
}
if (stats != nullptr)
{
for (const OldNativeProto& result : results)
{
stats->bytecodeSizeBytes += result.p->sizecode * sizeof(Instruction);
// Account for the native -> bytecode instruction offsets mapping:
stats->nativeMetadataSizeBytes += result.p->sizecode * sizeof(uint32_t);
}
stats->functionsCompiled += uint32_t(results.size());
stats->nativeCodeSizeBytes += build.code.size();
stats->nativeDataSizeBytes += build.data.size();
}
return compilationResult;
}
CompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
{
if (FFlag::LuauCodegenContext)
{
return compile_NEW(L, idx, flags, stats);
}
else
{
return compile_OLD(L, idx, flags, stats);
}
}
CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return compile_NEW(moduleId, L, idx, flags, stats);
}
void setPerfLog(void* context, PerfLogFn logFn)
{
gPerfLogContext = context;
gPerfLogFn = logFn;
}
} // namespace CodeGen
} // namespace Luau

View file

@ -253,44 +253,11 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde
// Our entry function is special, it spans the whole remaining code area
unwind.startFunction();
unwind.prologueA64(prologueSize, kStackSize, {x29, x30, x19, x20, x21, x22, x23, x24, x25});
unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFuncton);
unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFunction);
return locations;
}
bool initHeaderFunctions(NativeState& data)
{
AssemblyBuilderA64 build(/* logText= */ false);
UnwindBuilder& unwind = *data.unwindBuilder.get();
unwind.startInfo(UnwindBuilder::A64);
EntryLocations entryLocations = buildEntryFunction(build, unwind);
build.finalize();
unwind.finishInfo();
CODEGEN_ASSERT(build.data.empty());
uint8_t* codeStart = nullptr;
if (!data.codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast<const uint8_t*>(build.code.data()),
int(build.code.size() * sizeof(build.code[0])), data.gateData, data.gateDataSize, codeStart))
{
CODEGEN_ASSERT(!"Failed to create entry function");
return false;
}
// Set the offset at the begining so that functions in new blocks will not overlay the locations
// specified by the unwind information of the entry function
unwind.setBeginOffset(build.getLabelOffset(entryLocations.prologueEnd));
data.context.gateEntry = codeStart + build.getLabelOffset(entryLocations.start);
data.context.gateExit = codeStart + build.getLabelOffset(entryLocations.epilogueStart);
return true;
}
bool initHeaderFunctions(BaseCodeGenContext& codeGenContext)
{
AssemblyBuilderA64 build(/* logText= */ false);

View file

@ -7,7 +7,6 @@ namespace CodeGen
{
class BaseCodeGenContext;
struct NativeState;
struct ModuleHelpers;
namespace A64
@ -15,7 +14,6 @@ namespace A64
class AssemblyBuilderA64;
bool initHeaderFunctions(NativeState& data);
bool initHeaderFunctions(BaseCodeGenContext& codeGenContext);
void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers);

View file

@ -12,13 +12,50 @@
#include "lapi.h"
LUAU_FASTFLAG(LuauCodegenTypeInfo)
LUAU_FASTFLAG(LuauLoadUserdataInfo)
LUAU_FASTFLAG(LuauNativeAttribute)
namespace Luau
{
namespace CodeGen
{
static const LocVar* tryFindLocal(const Proto* proto, int reg, int pcpos)
{
for (int i = 0; i < proto->sizelocvars; i++)
{
const LocVar& local = proto->locvars[i];
if (reg == local.reg && pcpos >= local.startpc && pcpos < local.endpc)
return &local;
}
return nullptr;
}
const char* tryFindLocalName(const Proto* proto, int reg, int pcpos)
{
const LocVar* var = tryFindLocal(proto, reg, pcpos);
if (var && var->varname)
return getstr(var->varname);
return nullptr;
}
const char* tryFindUpvalueName(const Proto* proto, int upval)
{
if (proto->upvalues)
{
CODEGEN_ASSERT(upval < proto->sizeupvalues);
if (proto->upvalues[upval])
return getstr(proto->upvalues[upval]);
}
return nullptr;
}
template<typename AssemblyBuilder>
static void logFunctionHeader(AssemblyBuilder& build, Proto* proto)
{
@ -29,10 +66,8 @@ static void logFunctionHeader(AssemblyBuilder& build, Proto* proto)
for (int i = 0; i < proto->numparams; i++)
{
LocVar* var = proto->locvars ? &proto->locvars[proto->sizelocvars - proto->numparams + i] : nullptr;
if (var && var->varname)
build.logAppend("%s%s", i == 0 ? "" : ", ", getstr(var->varname));
if (const char* name = tryFindLocalName(proto, i, 0))
build.logAppend("%s%s", i == 0 ? "" : ", ", name);
else
build.logAppend("%s$arg%d", i == 0 ? "" : ", ", i);
}
@ -49,9 +84,9 @@ static void logFunctionHeader(AssemblyBuilder& build, Proto* proto)
}
template<typename AssemblyBuilder>
static void logFunctionTypes(AssemblyBuilder& build, const IrFunction& function)
static void logFunctionTypes_DEPRECATED(AssemblyBuilder& build, const IrFunction& function)
{
CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo);
CODEGEN_ASSERT(!FFlag::LuauLoadUserdataInfo);
const BytecodeTypeInfo& typeInfo = function.bcTypeInfo;
@ -60,7 +95,12 @@ static void logFunctionTypes(AssemblyBuilder& build, const IrFunction& function)
uint8_t ty = typeInfo.argumentTypes[i];
if (ty != LBC_TYPE_ANY)
build.logAppend("; R%d: %s [argument]\n", int(i), getBytecodeTypeName(ty));
{
if (const char* name = tryFindLocalName(function.proto, int(i), 0))
build.logAppend("; R%d: %s [argument '%s']\n", int(i), getBytecodeTypeName_DEPRECATED(ty), name);
else
build.logAppend("; R%d: %s [argument]\n", int(i), getBytecodeTypeName_DEPRECATED(ty));
}
}
for (size_t i = 0; i < typeInfo.upvalueTypes.size(); i++)
@ -68,12 +108,73 @@ static void logFunctionTypes(AssemblyBuilder& build, const IrFunction& function)
uint8_t ty = typeInfo.upvalueTypes[i];
if (ty != LBC_TYPE_ANY)
build.logAppend("; U%d: %s\n", int(i), getBytecodeTypeName(ty));
{
if (const char* name = tryFindUpvalueName(function.proto, int(i)))
build.logAppend("; U%d: %s ['%s']\n", int(i), getBytecodeTypeName_DEPRECATED(ty), name);
else
build.logAppend("; U%d: %s\n", int(i), getBytecodeTypeName_DEPRECATED(ty));
}
}
for (const BytecodeRegTypeInfo& el : typeInfo.regTypes)
{
build.logAppend("; R%d: %s from %d to %d\n", el.reg, getBytecodeTypeName(el.type), el.startpc, el.endpc);
// Using last active position as the PC because 'startpc' for type info is before local is initialized
if (const char* name = tryFindLocalName(function.proto, el.reg, el.endpc - 1))
build.logAppend("; R%d: %s from %d to %d [local '%s']\n", el.reg, getBytecodeTypeName_DEPRECATED(el.type), el.startpc, el.endpc, name);
else
build.logAppend("; R%d: %s from %d to %d\n", el.reg, getBytecodeTypeName_DEPRECATED(el.type), el.startpc, el.endpc);
}
}
template<typename AssemblyBuilder>
static void logFunctionTypes(AssemblyBuilder& build, const IrFunction& function, const char* const* userdataTypes)
{
CODEGEN_ASSERT(FFlag::LuauLoadUserdataInfo);
const BytecodeTypeInfo& typeInfo = function.bcTypeInfo;
for (size_t i = 0; i < typeInfo.argumentTypes.size(); i++)
{
uint8_t ty = typeInfo.argumentTypes[i];
const char* type = getBytecodeTypeName(ty, userdataTypes);
const char* optional = (ty & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : "";
if (ty != LBC_TYPE_ANY)
{
if (const char* name = tryFindLocalName(function.proto, int(i), 0))
build.logAppend("; R%d: %s%s [argument '%s']\n", int(i), type, optional, name);
else
build.logAppend("; R%d: %s%s [argument]\n", int(i), type, optional);
}
}
for (size_t i = 0; i < typeInfo.upvalueTypes.size(); i++)
{
uint8_t ty = typeInfo.upvalueTypes[i];
const char* type = getBytecodeTypeName(ty, userdataTypes);
const char* optional = (ty & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : "";
if (ty != LBC_TYPE_ANY)
{
if (const char* name = tryFindUpvalueName(function.proto, int(i)))
build.logAppend("; U%d: %s%s ['%s']\n", int(i), type, optional, name);
else
build.logAppend("; U%d: %s%s\n", int(i), type, optional);
}
}
for (const BytecodeRegTypeInfo& el : typeInfo.regTypes)
{
const char* type = getBytecodeTypeName(el.type, userdataTypes);
const char* optional = (el.type & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : "";
// Using last active position as the PC because 'startpc' for type info is before local is initialized
if (const char* name = tryFindLocalName(function.proto, el.reg, el.endpc - 1))
build.logAppend("; R%d: %s%s from %d to %d [local '%s']\n", el.reg, type, optional, el.startpc, el.endpc, name);
else
build.logAppend("; R%d: %s%s from %d to %d\n", el.reg, type, optional, el.startpc, el.endpc);
}
}
@ -93,11 +194,14 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A
{
Proto* root = clvalue(func)->l.p;
if ((options.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0)
if ((options.compilationOptions.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0)
return std::string();
std::vector<Proto*> protos;
gatherFunctions(protos, root, options.flags);
if (FFlag::LuauNativeAttribute)
gatherFunctions(protos, root, options.compilationOptions.flags, root->flags & LPF_NATIVE_FUNCTION);
else
gatherFunctions_DEPRECATED(protos, root, options.compilationOptions.flags);
protos.erase(std::remove_if(protos.begin(), protos.end(),
[](Proto* p) {
@ -125,7 +229,7 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A
for (Proto* p : protos)
{
IrBuilder ir;
IrBuilder ir(options.compilationOptions.hooks);
ir.buildFunctionIr(p);
unsigned asmSize = build.getCodeSize();
unsigned asmCount = build.getInstructionCount();
@ -133,8 +237,13 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A
if (options.includeAssembly || options.includeIr)
logFunctionHeader(build, p);
if (FFlag::LuauCodegenTypeInfo && options.includeIrTypes)
logFunctionTypes(build, ir.function);
if (options.includeIrTypes)
{
if (FFlag::LuauLoadUserdataInfo)
logFunctionTypes(build, ir.function, options.compilationOptions.userdataTypes);
else
logFunctionTypes_DEPRECATED(build, ir.function);
}
CodeGenCompilationResult result = CodeGenCompilationResult::Success;
@ -189,7 +298,7 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A
return build.text;
}
#if defined(__aarch64__)
#if defined(CODEGEN_TARGET_A64)
unsigned int getCpuFeaturesA64();
#endif
@ -202,7 +311,7 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options, Lowering
{
case AssemblyOptions::Host:
{
#if defined(__aarch64__)
#if defined(CODEGEN_TARGET_A64)
static unsigned int cpuFeatures = getCpuFeaturesA64();
A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, cpuFeatures);
#else

View file

@ -12,12 +12,9 @@
#include "lapi.h"
LUAU_FASTFLAGVARIABLE(LuauCodegenContext, false)
LUAU_FASTFLAGVARIABLE(LuauCodegenCheckNullContext, false)
LUAU_FASTINT(LuauCodeGenBlockSize)
LUAU_FASTINT(LuauCodeGenMaxTotalSize)
LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024)
LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024)
LUAU_FASTFLAG(LuauNativeAttribute)
namespace Luau
{
@ -27,14 +24,19 @@ namespace CodeGen
static const Instruction kCodeEntryInsn = LOP_NATIVECALL;
// From CodeGen.cpp
extern void* gPerfLogContext;
extern PerfLogFn gPerfLogFn;
static void* gPerfLogContext = nullptr;
static PerfLogFn gPerfLogFn = nullptr;
unsigned int getCpuFeaturesA64();
void setPerfLog(void* context, PerfLogFn logFn)
{
gPerfLogContext = context;
gPerfLogFn = logFn;
}
static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
CODEGEN_ASSERT(p->source);
const char* source = getstr(p->source);
@ -50,8 +52,6 @@ static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size)
static void logPerfFunctions(
const std::vector<Proto*>& moduleProtos, const uint8_t* nativeModuleBaseAddress, const std::vector<NativeProtoExecDataPtr>& nativeProtos)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
if (gPerfLogFn == nullptr)
return;
@ -83,8 +83,6 @@ static void logPerfFunctions(
template<bool Release, typename NativeProtosVector>
[[nodiscard]] static uint32_t bindNativeProtos(const std::vector<Proto*>& moduleProtos, NativeProtosVector& nativeProtos)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
uint32_t protosBound = 0;
auto protoIt = moduleProtos.begin();
@ -125,7 +123,6 @@ template<bool Release, typename NativeProtosVector>
BaseCodeGenContext::BaseCodeGenContext(size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext)
: codeAllocator{blockSize, maxTotalSize, allocationCallback, allocationCallbackContext}
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
CODEGEN_ASSERT(isSupported());
#if defined(_WIN32)
@ -143,12 +140,10 @@ BaseCodeGenContext::BaseCodeGenContext(size_t blockSize, size_t maxTotalSize, Al
[[nodiscard]] bool BaseCodeGenContext::initHeaderFunctions()
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
#if defined(__x86_64__) || defined(_M_X64)
#if defined(CODEGEN_TARGET_X64)
if (!X64::initHeaderFunctions(*this))
return false;
#elif defined(__aarch64__)
#elif defined(CODEGEN_TARGET_A64)
if (!A64::initHeaderFunctions(*this))
return false;
#endif
@ -164,13 +159,10 @@ StandaloneCodeGenContext::StandaloneCodeGenContext(
size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext)
: BaseCodeGenContext{blockSize, maxTotalSize, allocationCallback, allocationCallbackContext}
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
}
[[nodiscard]] std::optional<ModuleBindResult> StandaloneCodeGenContext::tryBindExistingModule(const ModuleId&, const std::vector<Proto*>&)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
// The StandaloneCodeGenContext does not support sharing of native code
return {};
}
@ -178,8 +170,6 @@ StandaloneCodeGenContext::StandaloneCodeGenContext(
[[nodiscard]] ModuleBindResult StandaloneCodeGenContext::bindModule(const std::optional<ModuleId>&, const std::vector<Proto*>& moduleProtos,
std::vector<NativeProtoExecDataPtr> nativeProtos, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
uint8_t* nativeData = nullptr;
size_t sizeNativeData = 0;
uint8_t* codeStart = nullptr;
@ -205,8 +195,6 @@ StandaloneCodeGenContext::StandaloneCodeGenContext(
void StandaloneCodeGenContext::onCloseState() noexcept
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
// The StandaloneCodeGenContext is owned by the one VM that owns it, so when
// that VM is destroyed, we destroy *this as well:
delete this;
@ -214,8 +202,6 @@ void StandaloneCodeGenContext::onCloseState() noexcept
void StandaloneCodeGenContext::onDestroyFunction(void* execdata) noexcept
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
destroyNativeProtoExecData(static_cast<uint32_t*>(execdata));
}
@ -225,14 +211,11 @@ SharedCodeGenContext::SharedCodeGenContext(
: BaseCodeGenContext{blockSize, maxTotalSize, allocationCallback, allocationCallbackContext}
, sharedAllocator{&codeAllocator}
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
}
[[nodiscard]] std::optional<ModuleBindResult> SharedCodeGenContext::tryBindExistingModule(
const ModuleId& moduleId, const std::vector<Proto*>& moduleProtos)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
NativeModuleRef nativeModule = sharedAllocator.tryGetNativeModule(moduleId);
if (nativeModule.empty())
{
@ -249,8 +232,6 @@ SharedCodeGenContext::SharedCodeGenContext(
[[nodiscard]] ModuleBindResult SharedCodeGenContext::bindModule(const std::optional<ModuleId>& moduleId, const std::vector<Proto*>& moduleProtos,
std::vector<NativeProtoExecDataPtr> nativeProtos, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
const std::pair<NativeModuleRef, bool> insertionResult = [&]() -> std::pair<NativeModuleRef, bool> {
if (moduleId.has_value())
{
@ -279,8 +260,6 @@ SharedCodeGenContext::SharedCodeGenContext(
void SharedCodeGenContext::onCloseState() noexcept
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
// The lifetime of the SharedCodeGenContext is managed separately from the
// VMs that use it. When a VM is destroyed, we don't need to do anything
// here.
@ -288,23 +267,17 @@ void SharedCodeGenContext::onCloseState() noexcept
void SharedCodeGenContext::onDestroyFunction(void* execdata) noexcept
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
getNativeProtoExecDataHeader(static_cast<const uint32_t*>(execdata)).nativeModule->release();
}
[[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext()
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return createSharedCodeGenContext(size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), nullptr, nullptr);
}
[[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext(AllocationCallback* allocationCallback, void* allocationCallbackContext)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return createSharedCodeGenContext(
size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext);
}
@ -312,8 +285,6 @@ void SharedCodeGenContext::onDestroyFunction(void* execdata) noexcept
[[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext(
size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
UniqueSharedCodeGenContext codeGenContext{new SharedCodeGenContext{blockSize, maxTotalSize, nullptr, nullptr}};
if (!codeGenContext->initHeaderFunctions())
@ -324,38 +295,28 @@ void SharedCodeGenContext::onDestroyFunction(void* execdata) noexcept
void destroySharedCodeGenContext(const SharedCodeGenContext* codeGenContext) noexcept
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
delete codeGenContext;
}
void SharedCodeGenContextDeleter::operator()(const SharedCodeGenContext* codeGenContext) const noexcept
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
destroySharedCodeGenContext(codeGenContext);
}
[[nodiscard]] static BaseCodeGenContext* getCodeGenContext(lua_State* L) noexcept
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return static_cast<BaseCodeGenContext*>(L->global->ecb.context);
}
static void onCloseState(lua_State* L) noexcept
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
getCodeGenContext(L)->onCloseState();
L->global->ecb = lua_ExecutionCallbacks{};
}
static void onDestroyFunction(lua_State* L, Proto* proto) noexcept
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
getCodeGenContext(L)->onDestroyFunction(proto->execdata);
proto->execdata = nullptr;
proto->exectarget = 0;
@ -364,8 +325,6 @@ static void onDestroyFunction(lua_State* L, Proto* proto) noexcept
static int onEnter(lua_State* L, Proto* proto)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
BaseCodeGenContext* codeGenContext = getCodeGenContext(L);
CODEGEN_ASSERT(proto->execdata);
@ -379,8 +338,6 @@ static int onEnter(lua_State* L, Proto* proto)
static int onEnterDisabled(lua_State* L, Proto* proto)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return 1;
}
@ -389,8 +346,6 @@ void onDisable(lua_State* L, Proto* proto);
static size_t getMemorySize(lua_State* L, Proto* proto)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
const NativeProtoExecDataHeader& execDataHeader = getNativeProtoExecDataHeader(static_cast<const uint32_t*>(proto->execdata));
const size_t execDataSize = sizeof(NativeProtoExecDataHeader) + execDataHeader.bytecodeInstructionCount * sizeof(Instruction);
@ -403,8 +358,7 @@ static size_t getMemorySize(lua_State* L, Proto* proto)
static void initializeExecutionCallbacks(lua_State* L, BaseCodeGenContext* codeGenContext) noexcept
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
CODEGEN_ASSERT(!FFlag::LuauCodegenCheckNullContext || codeGenContext != nullptr);
CODEGEN_ASSERT(codeGenContext != nullptr);
lua_ExecutionCallbacks* ecb = &L->global->ecb;
@ -416,24 +370,18 @@ static void initializeExecutionCallbacks(lua_State* L, BaseCodeGenContext* codeG
ecb->getmemorysize = getMemorySize;
}
void create_NEW(lua_State* L)
void create(lua_State* L)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return create_NEW(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), nullptr, nullptr);
return create(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), nullptr, nullptr);
}
void create_NEW(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext)
void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return create_NEW(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext);
return create(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext);
}
void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext)
void create(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
std::unique_ptr<StandaloneCodeGenContext> codeGenContext =
std::make_unique<StandaloneCodeGenContext>(blockSize, maxTotalSize, allocationCallback, allocationCallbackContext);
@ -443,17 +391,13 @@ void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationC
initializeExecutionCallbacks(L, codeGenContext.release());
}
void create_NEW(lua_State* L, SharedCodeGenContext* codeGenContext)
void create(lua_State* L, SharedCodeGenContext* codeGenContext)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
initializeExecutionCallbacks(L, codeGenContext);
}
[[nodiscard]] static NativeProtoExecDataPtr createNativeProtoExecData(Proto* proto, const IrBuilder& ir)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
NativeProtoExecDataPtr nativeExecData = createNativeProtoExecData(proto->sizecode);
uint32_t instTarget = ir.function.entryLocation;
@ -478,12 +422,10 @@ void create_NEW(lua_State* L, SharedCodeGenContext* codeGenContext)
}
template<typename AssemblyBuilder>
[[nodiscard]] static NativeProtoExecDataPtr createNativeFunction(
AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, uint32_t& totalIrInstCount, CodeGenCompilationResult& result)
[[nodiscard]] static NativeProtoExecDataPtr createNativeFunction(AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto,
uint32_t& totalIrInstCount, const HostIrHooks& hooks, CodeGenCompilationResult& result)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
IrBuilder ir;
IrBuilder ir(hooks);
ir.buildFunctionIr(proto);
unsigned instCount = unsigned(ir.function.instructions.size());
@ -505,15 +447,14 @@ template<typename AssemblyBuilder>
}
[[nodiscard]] static CompilationResult compileInternal(
const std::optional<ModuleId>& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
const std::optional<ModuleId>& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
CODEGEN_ASSERT(lua_isLfunction(L, idx));
const TValue* func = luaA_toobject(L, idx);
Proto* root = clvalue(func)->l.p;
if ((flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0)
if ((options.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0 && (root->flags & LPF_NATIVE_FUNCTION) == 0)
return CompilationResult{CodeGenCompilationResult::NotNativeModule};
BaseCodeGenContext* codeGenContext = getCodeGenContext(L);
@ -521,7 +462,10 @@ template<typename AssemblyBuilder>
return CompilationResult{CodeGenCompilationResult::CodeGenNotInitialized};
std::vector<Proto*> protos;
gatherFunctions(protos, root, flags);
if (FFlag::LuauNativeAttribute)
gatherFunctions(protos, root, options.flags, root->flags & LPF_NATIVE_FUNCTION);
else
gatherFunctions_DEPRECATED(protos, root, options.flags);
// Skip protos that have been compiled during previous invocations of CodeGen::compile
protos.erase(std::remove_if(protos.begin(), protos.end(),
@ -547,7 +491,7 @@ template<typename AssemblyBuilder>
}
}
#if defined(__aarch64__)
#if defined(CODEGEN_TARGET_A64)
static unsigned int cpuFeatures = getCpuFeaturesA64();
A64::AssemblyBuilderA64 build(/* logText= */ false, cpuFeatures);
#else
@ -555,7 +499,7 @@ template<typename AssemblyBuilder>
#endif
ModuleHelpers helpers;
#if defined(__aarch64__)
#if defined(CODEGEN_TARGET_A64)
A64::assembleHelpers(build, helpers);
#else
X64::assembleHelpers(build, helpers);
@ -572,7 +516,7 @@ template<typename AssemblyBuilder>
{
CodeGenCompilationResult protoResult = CodeGenCompilationResult::Success;
NativeProtoExecDataPtr nativeExecData = createNativeFunction(build, helpers, protos[i], totalIrInstCount, protoResult);
NativeProtoExecDataPtr nativeExecData = createNativeFunction(build, helpers, protos[i], totalIrInstCount, options.hooks, protoResult);
if (nativeExecData != nullptr)
{
nativeProtos.push_back(std::move(nativeExecData));
@ -639,34 +583,60 @@ template<typename AssemblyBuilder>
return compilationResult;
}
CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return compileInternal(moduleId, L, idx, flags, stats);
return compileInternal(moduleId, L, idx, options, stats);
}
CompilationResult compile_NEW(lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
CompilationResult compile(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return compileInternal({}, L, idx, flags, stats);
return compileInternal({}, L, idx, options, stats);
}
[[nodiscard]] bool isNativeExecutionEnabled_NEW(lua_State* L)
CompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return compileInternal({}, L, idx, CompilationOptions{flags}, stats);
}
CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
{
return compileInternal(moduleId, L, idx, CompilationOptions{flags}, stats);
}
[[nodiscard]] bool isNativeExecutionEnabled(lua_State* L)
{
return getCodeGenContext(L) != nullptr && L->global->ecb.enter == onEnter;
}
void setNativeExecutionEnabled_NEW(lua_State* L, bool enabled)
void setNativeExecutionEnabled(lua_State* L, bool enabled)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
if (getCodeGenContext(L) != nullptr)
L->global->ecb.enter = enabled ? onEnter : onEnterDisabled;
}
static uint8_t userdataRemapperWrap(lua_State* L, const char* str, size_t len)
{
if (BaseCodeGenContext* codegenCtx = getCodeGenContext(L))
{
uint8_t index = codegenCtx->userdataRemapper(codegenCtx->userdataRemappingContext, str, len);
if (index < (LBC_TYPE_TAGGED_USERDATA_END - LBC_TYPE_TAGGED_USERDATA_BASE))
return LBC_TYPE_TAGGED_USERDATA_BASE + index;
}
return LBC_TYPE_USERDATA;
}
void setUserdataRemapper(lua_State* L, void* context, UserdataRemapperCallback cb)
{
if (BaseCodeGenContext* codegenCtx = getCodeGenContext(L))
{
codegenCtx->userdataRemappingContext = context;
codegenCtx->userdataRemapper = cb;
L->global->ecb.gettypemapping = cb ? userdataRemapperWrap : nullptr;
}
}
} // namespace CodeGen
} // namespace Luau

View file

@ -50,6 +50,9 @@ public:
uint8_t* gateData = nullptr;
size_t gateDataSize = 0;
void* userdataRemappingContext = nullptr;
UserdataRemapperCallback* userdataRemapper = nullptr;
NativeContext context;
};
@ -88,33 +91,5 @@ private:
SharedCodeAllocator sharedAllocator;
};
// The following will become the public interface, and can be moved into
// CodeGen.h after the shared allocator work is complete. When the old
// implementation is removed, the _NEW suffix can be dropped from these
// functions.
// Initializes native code-gen on the provided Luau VM, using a VM-specific
// code-gen context and either the default allocator parameters or custom
// allocator parameters.
void create_NEW(lua_State* L);
void create_NEW(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext);
void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext);
// Initializes native code-gen on the provided Luau VM, using the provided
// SharedCodeGenContext. Note that after this function is called, the
// SharedCodeGenContext must not be destroyed until after the Luau VM L is
// destroyed via lua_close.
void create_NEW(lua_State* L, SharedCodeGenContext* codeGenContext);
CompilationResult compile_NEW(lua_State* L, int idx, unsigned int flags, CompilationStats* stats);
CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats);
// Returns true if native execution is currently enabled for this VM
[[nodiscard]] bool isNativeExecutionEnabled_NEW(lua_State* L);
// Enables or disables native excution for this VM
void setNativeExecutionEnabled_NEW(lua_State* L, bool enabled);
} // namespace CodeGen
} // namespace Luau

View file

@ -27,14 +27,15 @@ LUAU_FASTFLAG(DebugCodegenSkipNumbering)
LUAU_FASTINT(CodegenHeuristicsInstructionLimit)
LUAU_FASTINT(CodegenHeuristicsBlockLimit)
LUAU_FASTINT(CodegenHeuristicsBlockInstructionLimit)
LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5)
LUAU_FASTFLAG(LuauLoadUserdataInfo)
LUAU_FASTFLAG(LuauNativeAttribute)
namespace Luau
{
namespace CodeGen
{
inline void gatherFunctions(std::vector<Proto*>& results, Proto* proto, unsigned int flags)
inline void gatherFunctions_DEPRECATED(std::vector<Proto*>& results, Proto* proto, unsigned int flags)
{
if (results.size() <= size_t(proto->bytecodeid))
results.resize(proto->bytecodeid + 1);
@ -49,7 +50,36 @@ inline void gatherFunctions(std::vector<Proto*>& results, Proto* proto, unsigned
// Recursively traverse child protos even if we aren't compiling this one
for (int i = 0; i < proto->sizep; i++)
gatherFunctions(results, proto->p[i], flags);
gatherFunctions_DEPRECATED(results, proto->p[i], flags);
}
inline void gatherFunctionsHelper(
std::vector<Proto*>& results, Proto* proto, const unsigned int flags, const bool hasNativeFunctions, const bool root)
{
if (results.size() <= size_t(proto->bytecodeid))
results.resize(proto->bytecodeid + 1);
// Skip protos that we've already compiled in this run: this happens because at -O2, inlined functions get their protos reused
if (results[proto->bytecodeid])
return;
// if native module, compile cold functions if requested
// if not native module, compile function if it has native attribute and is not root
bool shouldGather = hasNativeFunctions ? (!root && (proto->flags & LPF_NATIVE_FUNCTION) != 0)
: ((proto->flags & LPF_NATIVE_COLD) == 0 || (flags & CodeGen_ColdFunctions) != 0);
if (shouldGather)
results[proto->bytecodeid] = proto;
// Recursively traverse child protos even if we aren't compiling this one
for (int i = 0; i < proto->sizep; i++)
gatherFunctionsHelper(results, proto->p[i], flags, hasNativeFunctions, false);
}
inline void gatherFunctions(std::vector<Proto*>& results, Proto* root, const unsigned int flags, const bool hasNativeFunctions = false)
{
LUAU_ASSERT(FFlag::LuauNativeAttribute);
gatherFunctionsHelper(results, root, flags, hasNativeFunctions, true);
}
inline unsigned getInstructionCount(const std::vector<IrInst>& instructions, IrCmd cmd)
@ -149,7 +179,11 @@ inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction&
if (bcTypes.result != LBC_TYPE_ANY || bcTypes.a != LBC_TYPE_ANY || bcTypes.b != LBC_TYPE_ANY || bcTypes.c != LBC_TYPE_ANY)
{
toString(ctx.result, bcTypes);
if (FFlag::LuauLoadUserdataInfo)
toString(ctx.result, bcTypes, options.compilationOptions.userdataTypes);
else
toString_DEPRECATED(ctx.result, bcTypes);
build.logAppend("\n");
}
}
@ -312,8 +346,7 @@ inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers&
}
}
if (FFlag::LuauCodegenRemoveDeadStores5)
markDeadStoresInBlockChains(ir);
markDeadStoresInBlockChains(ir);
}
std::vector<uint32_t> sortedBlocks = getSortedBlockOrder(ir.function);

View file

@ -14,6 +14,7 @@
#include "lstate.h"
#include "lstring.h"
#include "ltable.h"
#include "ludata.h"
#include <string.h>
@ -219,6 +220,20 @@ void callEpilogC(lua_State* L, int nresults, int n)
L->top = (nresults == LUA_MULTRET) ? res : cip->top;
}
Udata* newUserdata(lua_State* L, size_t s, int tag)
{
Udata* u = luaU_newudata(L, s, tag);
if (Table* h = L->global->udatamt[tag])
{
u->metatable = h;
luaC_objbarrier(L, u, h);
}
return u;
}
// Extracted as-is from lvmexecute.cpp with the exception of control flow (reentry) and removed interrupts/savedpc
Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults)
{

View file

@ -17,6 +17,8 @@ void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc);
Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults);
void callEpilogC(lua_State* L, int nresults, int n);
Udata* newUserdata(lua_State* L, size_t s, int tag);
#define CALL_FALLBACK_YIELD 1
Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults);

View file

@ -181,44 +181,11 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde
build.ret();
// Our entry function is special, it spans the whole remaining code area
unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFuncton);
unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFunction);
return locations;
}
bool initHeaderFunctions(NativeState& data)
{
AssemblyBuilderX64 build(/* logText= */ false);
UnwindBuilder& unwind = *data.unwindBuilder.get();
unwind.startInfo(UnwindBuilder::X64);
EntryLocations entryLocations = buildEntryFunction(build, unwind);
build.finalize();
unwind.finishInfo();
CODEGEN_ASSERT(build.data.empty());
uint8_t* codeStart = nullptr;
if (!data.codeAllocator.allocate(
build.data.data(), int(build.data.size()), build.code.data(), int(build.code.size()), data.gateData, data.gateDataSize, codeStart))
{
CODEGEN_ASSERT(!"Failed to create entry function");
return false;
}
// Set the offset at the begining so that functions in new blocks will not overlay the locations
// specified by the unwind information of the entry function
unwind.setBeginOffset(build.getLabelOffset(entryLocations.prologueEnd));
data.context.gateEntry = codeStart + build.getLabelOffset(entryLocations.start);
data.context.gateExit = codeStart + build.getLabelOffset(entryLocations.epilogueStart);
return true;
}
bool initHeaderFunctions(BaseCodeGenContext& codeGenContext)
{
AssemblyBuilderX64 build(/* logText= */ false);

View file

@ -7,7 +7,6 @@ namespace CodeGen
{
class BaseCodeGenContext;
struct NativeState;
struct ModuleHelpers;
namespace X64
@ -15,7 +14,6 @@ namespace X64
class AssemblyBuilderX64;
bool initHeaderFunctions(NativeState& data);
bool initHeaderFunctions(BaseCodeGenContext& codeGenContext);
void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers);

View file

@ -12,7 +12,7 @@
#include "lstate.h"
LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5)
LUAU_FASTFLAG(LuauCodegenMathSign)
namespace Luau
{
@ -29,17 +29,13 @@ static void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build,
callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_frexp)]);
build.vmovsd(luauRegValue(ra), xmm0);
if (FFlag::LuauCodegenRemoveDeadStores5)
build.mov(luauRegTag(ra), LUA_TNUMBER);
build.mov(luauRegTag(ra), LUA_TNUMBER);
if (nresults > 1)
{
build.vcvtsi2sd(xmm0, xmm0, dword[sTemporarySlot + 0]);
build.vmovsd(luauRegValue(ra + 1), xmm0);
if (FFlag::LuauCodegenRemoveDeadStores5)
build.mov(luauRegTag(ra + 1), LUA_TNUMBER);
build.mov(luauRegTag(ra + 1), LUA_TNUMBER);
}
}
@ -52,21 +48,19 @@ static void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build,
build.vmovsd(xmm1, qword[sTemporarySlot + 0]);
build.vmovsd(luauRegValue(ra), xmm1);
if (FFlag::LuauCodegenRemoveDeadStores5)
build.mov(luauRegTag(ra), LUA_TNUMBER);
build.mov(luauRegTag(ra), LUA_TNUMBER);
if (nresults > 1)
{
build.vmovsd(luauRegValue(ra + 1), xmm0);
if (FFlag::LuauCodegenRemoveDeadStores5)
build.mov(luauRegTag(ra + 1), LUA_TNUMBER);
build.mov(luauRegTag(ra + 1), LUA_TNUMBER);
}
}
static void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenMathSign);
ScopedRegX64 tmp0{regs, SizeX64::xmmword};
ScopedRegX64 tmp1{regs, SizeX64::xmmword};
ScopedRegX64 tmp2{regs, SizeX64::xmmword};
@ -90,23 +84,22 @@ static void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build,
build.vblendvpd(tmp0.reg, tmp2.reg, build.f64x2(1, 1), tmp0.reg);
build.vmovsd(luauRegValue(ra), tmp0.reg);
if (FFlag::LuauCodegenRemoveDeadStores5)
build.mov(luauRegTag(ra), LUA_TNUMBER);
build.mov(luauRegTag(ra), LUA_TNUMBER);
}
void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, OperandX64 arg2, int nparams, int nresults)
void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, int nresults)
{
switch (bfid)
{
case LBF_MATH_FREXP:
CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2));
CODEGEN_ASSERT(nresults == 1 || nresults == 2);
return emitBuiltinMathFrexp(regs, build, ra, arg, nresults);
case LBF_MATH_MODF:
CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2));
CODEGEN_ASSERT(nresults == 1 || nresults == 2);
return emitBuiltinMathModf(regs, build, ra, arg, nresults);
case LBF_MATH_SIGN:
CODEGEN_ASSERT(nparams == 1 && nresults == 1);
CODEGEN_ASSERT(!FFlag::LuauCodegenMathSign);
CODEGEN_ASSERT(nresults == 1);
return emitBuiltinMathSign(regs, build, ra, arg);
default:
CODEGEN_ASSERT(!"Missing x64 lowering");

View file

@ -16,7 +16,7 @@ class AssemblyBuilderX64;
struct OperandX64;
struct IrRegAllocX64;
void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, OperandX64 arg2, int nparams, int nresults);
void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, int nresults);
} // namespace X64
} // namespace CodeGen

View file

@ -22,8 +22,6 @@ namespace Luau
namespace CodeGen
{
struct NativeState;
namespace A64
{

View file

@ -155,8 +155,37 @@ void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, Ope
callWrap.addArgument(SizeX64::qword, luauRegAddress(ra));
callWrap.addArgument(SizeX64::qword, b);
callWrap.addArgument(SizeX64::qword, c);
callWrap.addArgument(SizeX64::dword, tm);
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]);
switch (tm)
{
case TM_ADD:
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithadd)]);
break;
case TM_SUB:
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithsub)]);
break;
case TM_MUL:
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithmul)]);
break;
case TM_DIV:
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithdiv)]);
break;
case TM_IDIV:
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithidiv)]);
break;
case TM_MOD:
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithmod)]);
break;
case TM_POW:
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithpow)]);
break;
case TM_UNM:
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithunm)]);
break;
default:
CODEGEN_ASSERT(!"Invalid doarith helper operation tag");
break;
}
emitUpdateBase(build);
}

View file

@ -26,7 +26,6 @@ namespace CodeGen
{
enum class IrCondition : uint8_t;
struct NativeState;
struct IrOp;
namespace X64

View file

@ -13,6 +13,8 @@
#include <stddef.h>
LUAU_FASTFLAGVARIABLE(LuauCodegenInstG, false)
namespace Luau
{
namespace CodeGen
@ -52,6 +54,9 @@ void updateUseCounts(IrFunction& function)
checkOp(inst.d);
checkOp(inst.e);
checkOp(inst.f);
if (FFlag::LuauCodegenInstG)
checkOp(inst.g);
}
}
@ -95,6 +100,9 @@ void updateLastUseLocations(IrFunction& function, const std::vector<uint32_t>& s
checkOp(inst.d);
checkOp(inst.e);
checkOp(inst.f);
if (FFlag::LuauCodegenInstG)
checkOp(inst.g);
}
}
}
@ -128,6 +136,12 @@ uint32_t getNextInstUse(IrFunction& function, uint32_t targetInstIdx, uint32_t s
if (inst.f.kind == IrOpKind::Inst && inst.f.index == targetInstIdx)
return i;
if (FFlag::LuauCodegenInstG)
{
if (inst.g.kind == IrOpKind::Inst && inst.g.index == targetInstIdx)
return i;
}
}
// There must be a next use since there is the last use location
@ -165,6 +179,9 @@ std::pair<uint32_t, uint32_t> getLiveInOutValueCount(IrFunction& function, IrBlo
checkOp(inst.d);
checkOp(inst.e);
checkOp(inst.f);
if (FFlag::LuauCodegenInstG)
checkOp(inst.g);
}
return std::make_pair(liveIns, liveOuts);
@ -488,6 +505,9 @@ static void computeCfgBlockEdges(IrFunction& function)
checkOp(inst.d);
checkOp(inst.e);
checkOp(inst.f);
if (FFlag::LuauCodegenInstG)
checkOp(inst.g);
}
}

View file

@ -13,8 +13,9 @@
#include <string.h>
LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo load changes the format used by Codegen, same flag is used
LUAU_FASTFLAG(LuauTypeInfoLookupImprovement)
LUAU_FASTFLAG(LuauLoadUserdataInfo)
LUAU_FASTFLAG(LuauCodegenInstG)
LUAU_FASTFLAG(LuauCodegenFastcall3)
namespace Luau
{
@ -23,120 +24,25 @@ namespace CodeGen
constexpr unsigned kNoAssociatedBlockIndex = ~0u;
IrBuilder::IrBuilder()
: constantMap({IrConstKind::Tag, ~0ull})
IrBuilder::IrBuilder(const HostIrHooks& hostHooks)
: hostHooks(hostHooks)
, constantMap({IrConstKind::Tag, ~0ull})
{
}
static bool hasTypedParameters_DEPRECATED(Proto* proto)
{
CODEGEN_ASSERT(!FFlag::LuauLoadTypeInfo);
return proto->typeinfo && proto->numparams != 0;
}
static void buildArgumentTypeChecks_DEPRECATED(IrBuilder& build, Proto* proto)
{
CODEGEN_ASSERT(!FFlag::LuauLoadTypeInfo);
CODEGEN_ASSERT(hasTypedParameters_DEPRECATED(proto));
for (int i = 0; i < proto->numparams; ++i)
{
uint8_t et = proto->typeinfo[2 + i];
uint8_t tag = et & ~LBC_TYPE_OPTIONAL_BIT;
uint8_t optional = et & LBC_TYPE_OPTIONAL_BIT;
if (tag == LBC_TYPE_ANY)
continue;
IrOp load = build.inst(IrCmd::LOAD_TAG, build.vmReg(i));
IrOp nextCheck;
if (optional)
{
nextCheck = build.block(IrBlockKind::Internal);
IrOp fallbackCheck = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_EQ_TAG, load, build.constTag(LUA_TNIL), nextCheck, fallbackCheck);
build.beginBlock(fallbackCheck);
}
switch (tag)
{
case LBC_TYPE_NIL:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNIL), build.vmExit(kVmExitEntryGuardPc));
break;
case LBC_TYPE_BOOLEAN:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBOOLEAN), build.vmExit(kVmExitEntryGuardPc));
break;
case LBC_TYPE_NUMBER:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNUMBER), build.vmExit(kVmExitEntryGuardPc));
break;
case LBC_TYPE_STRING:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TSTRING), build.vmExit(kVmExitEntryGuardPc));
break;
case LBC_TYPE_TABLE:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTABLE), build.vmExit(kVmExitEntryGuardPc));
break;
case LBC_TYPE_FUNCTION:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TFUNCTION), build.vmExit(kVmExitEntryGuardPc));
break;
case LBC_TYPE_THREAD:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTHREAD), build.vmExit(kVmExitEntryGuardPc));
break;
case LBC_TYPE_USERDATA:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TUSERDATA), build.vmExit(kVmExitEntryGuardPc));
break;
case LBC_TYPE_VECTOR:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TVECTOR), build.vmExit(kVmExitEntryGuardPc));
break;
case LBC_TYPE_BUFFER:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBUFFER), build.vmExit(kVmExitEntryGuardPc));
break;
}
if (optional)
{
build.inst(IrCmd::JUMP, nextCheck);
build.beginBlock(nextCheck);
}
}
// If the last argument is optional, we can skip creating a new internal block since one will already have been created.
if (!(proto->typeinfo[2 + proto->numparams - 1] & LBC_TYPE_OPTIONAL_BIT))
{
IrOp next = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP, next);
build.beginBlock(next);
}
}
static bool hasTypedParameters(const BytecodeTypeInfo& typeInfo)
{
CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo);
if (FFlag::LuauTypeInfoLookupImprovement)
for (auto el : typeInfo.argumentTypes)
{
for (auto el : typeInfo.argumentTypes)
{
if (el != LBC_TYPE_ANY)
return true;
}
if (el != LBC_TYPE_ANY)
return true;
}
return false;
}
else
{
return !typeInfo.argumentTypes.empty();
}
return false;
}
static void buildArgumentTypeChecks(IrBuilder& build)
{
CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo);
const BytecodeTypeInfo& typeInfo = build.function.bcTypeInfo;
CODEGEN_ASSERT(hasTypedParameters(typeInfo));
@ -195,6 +101,19 @@ static void buildArgumentTypeChecks(IrBuilder& build)
case LBC_TYPE_BUFFER:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBUFFER), build.vmExit(kVmExitEntryGuardPc));
break;
default:
if (FFlag::LuauLoadUserdataInfo)
{
if (tag >= LBC_TYPE_TAGGED_USERDATA_BASE && tag < LBC_TYPE_TAGGED_USERDATA_END)
{
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TUSERDATA), build.vmExit(kVmExitEntryGuardPc));
}
else
{
CODEGEN_ASSERT(!"unknown argument type tag");
}
}
break;
}
if (optional)
@ -219,18 +138,17 @@ void IrBuilder::buildFunctionIr(Proto* proto)
function.proto = proto;
function.variadic = proto->is_vararg != 0;
if (FFlag::LuauLoadTypeInfo)
loadBytecodeTypeInfo(function);
loadBytecodeTypeInfo(function);
// Reserve entry block
bool generateTypeChecks = FFlag::LuauLoadTypeInfo ? hasTypedParameters(function.bcTypeInfo) : hasTypedParameters_DEPRECATED(proto);
bool generateTypeChecks = hasTypedParameters(function.bcTypeInfo);
IrOp entry = generateTypeChecks ? block(IrBlockKind::Internal) : IrOp{};
// Rebuild original control flow blocks
rebuildBytecodeBasicBlocks(proto);
// Infer register tags in bytecode
analyzeBytecodeTypes(function);
analyzeBytecodeTypes(function, hostHooks);
function.bcMapping.resize(proto->sizecode, {~0u, ~0u});
@ -238,10 +156,7 @@ void IrBuilder::buildFunctionIr(Proto* proto)
{
beginBlock(entry);
if (FFlag::LuauLoadTypeInfo)
buildArgumentTypeChecks(*this);
else
buildArgumentTypeChecks_DEPRECATED(*this, proto);
buildArgumentTypeChecks(*this);
inst(IrCmd::JUMP, blockAtInst(0));
}
@ -283,10 +198,10 @@ void IrBuilder::buildFunctionIr(Proto* proto)
translateInst(op, pc, i);
if (fastcallSkipTarget != -1)
if (cmdSkipTarget != -1)
{
nexti = fastcallSkipTarget;
fastcallSkipTarget = -1;
nexti = cmdSkipTarget;
cmdSkipTarget = -1;
}
}
@ -535,16 +450,21 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
translateInstCloseUpvals(*this, pc);
break;
case LOP_FASTCALL:
handleFastcallFallback(translateFastCallN(*this, pc, i, false, 0, {}), pc, i);
handleFastcallFallback(translateFastCallN(*this, pc, i, false, 0, {}, {}), pc, i);
break;
case LOP_FASTCALL1:
handleFastcallFallback(translateFastCallN(*this, pc, i, true, 1, undef()), pc, i);
handleFastcallFallback(translateFastCallN(*this, pc, i, true, 1, undef(), undef()), pc, i);
break;
case LOP_FASTCALL2:
handleFastcallFallback(translateFastCallN(*this, pc, i, true, 2, vmReg(pc[1])), pc, i);
handleFastcallFallback(translateFastCallN(*this, pc, i, true, 2, vmReg(pc[1]), undef()), pc, i);
break;
case LOP_FASTCALL2K:
handleFastcallFallback(translateFastCallN(*this, pc, i, true, 2, vmConst(pc[1])), pc, i);
handleFastcallFallback(translateFastCallN(*this, pc, i, true, 2, vmConst(pc[1]), undef()), pc, i);
break;
case LOP_FASTCALL3:
CODEGEN_ASSERT(FFlag::LuauCodegenFastcall3);
handleFastcallFallback(translateFastCallN(*this, pc, i, true, 3, vmReg(pc[1] & 0xff), vmReg((pc[1] >> 8) & 0xff)), pc, i);
break;
case LOP_FORNPREP:
translateInstForNPrep(*this, pc, i);
@ -613,7 +533,8 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
translateInstCapture(*this, pc, i);
break;
case LOP_NAMECALL:
translateInstNamecall(*this, pc, i);
if (translateInstNamecall(*this, pc, i))
cmdSkipTarget = i + 3;
break;
case LOP_PREPVARARGS:
inst(IrCmd::FALLBACK_PREPVARARGS, constUint(i), constInt(LUAU_INSN_A(*pc)));
@ -654,7 +575,7 @@ void IrBuilder::handleFastcallFallback(IrOp fallbackOrUndef, const Instruction*
}
else
{
fastcallSkipTarget = i + skip + 2;
cmdSkipTarget = i + skip + 2;
}
}
@ -725,6 +646,9 @@ void IrBuilder::clone(const IrBlock& source, bool removeCurrentTerminator)
redirect(clone.e);
redirect(clone.f);
if (FFlag::LuauCodegenInstG)
redirect(clone.g);
addUse(function, clone.a);
addUse(function, clone.b);
addUse(function, clone.c);
@ -732,11 +656,17 @@ void IrBuilder::clone(const IrBlock& source, bool removeCurrentTerminator)
addUse(function, clone.e);
addUse(function, clone.f);
if (FFlag::LuauCodegenInstG)
addUse(function, clone.g);
// Instructions that referenced the original will have to be adjusted to use the clone
instRedir[index] = uint32_t(function.instructions.size());
// Reconstruct the fresh clone
inst(clone.cmd, clone.a, clone.b, clone.c, clone.d, clone.e, clone.f);
if (FFlag::LuauCodegenInstG)
inst(clone.cmd, clone.a, clone.b, clone.c, clone.d, clone.e, clone.f, clone.g);
else
inst(clone.cmd, clone.a, clone.b, clone.c, clone.d, clone.e, clone.f);
}
}
@ -834,8 +764,33 @@ IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e)
IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f)
{
if (FFlag::LuauCodegenInstG)
{
return inst(cmd, a, b, c, d, e, f, {});
}
else
{
uint32_t index = uint32_t(function.instructions.size());
function.instructions.push_back({cmd, a, b, c, d, e, f});
CODEGEN_ASSERT(!inTerminatedBlock);
if (isBlockTerminator(cmd))
{
function.blocks[activeBlockIdx].finish = index;
inTerminatedBlock = true;
}
return {IrOpKind::Inst, index};
}
}
IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f, IrOp g)
{
CODEGEN_ASSERT(FFlag::LuauCodegenInstG);
uint32_t index = uint32_t(function.instructions.size());
function.instructions.push_back({cmd, a, b, c, d, e, f});
function.instructions.push_back({cmd, a, b, c, d, e, f, g});
CODEGEN_ASSERT(!inTerminatedBlock);

View file

@ -7,6 +7,9 @@
#include <stdarg.h>
LUAU_FASTFLAG(LuauLoadUserdataInfo)
LUAU_FASTFLAG(LuauCodegenInstG)
namespace Luau
{
namespace CodeGen
@ -151,6 +154,8 @@ const char* getCmdName(IrCmd cmd)
return "SQRT_NUM";
case IrCmd::ABS_NUM:
return "ABS_NUM";
case IrCmd::SIGN_NUM:
return "SIGN_NUM";
case IrCmd::ADD_VEC:
return "ADD_VEC";
case IrCmd::SUB_VEC:
@ -197,6 +202,8 @@ const char* getCmdName(IrCmd cmd)
return "TRY_NUM_TO_INDEX";
case IrCmd::TRY_CALL_FASTGETTM:
return "TRY_CALL_FASTGETTM";
case IrCmd::NEW_USERDATA:
return "NEW_USERDATA";
case IrCmd::INT_TO_NUM:
return "INT_TO_NUM";
case IrCmd::UINT_TO_NUM:
@ -255,6 +262,8 @@ const char* getCmdName(IrCmd cmd)
return "CHECK_NODE_VALUE";
case IrCmd::CHECK_BUFFER_LEN:
return "CHECK_BUFFER_LEN";
case IrCmd::CHECK_USERDATA_TAG:
return "CHECK_USERDATA_TAG";
case IrCmd::INTERRUPT:
return "INTERRUPT";
case IrCmd::CHECK_GC:
@ -411,6 +420,9 @@ void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index)
checkOp(inst.d, ", ");
checkOp(inst.e, ", ");
checkOp(inst.f, ", ");
if (FFlag::LuauCodegenInstG)
checkOp(inst.g, ", ");
}
void toString(IrToStringContext& ctx, const IrBlock& block, uint32_t index)
@ -480,8 +492,10 @@ void toString(std::string& result, IrConst constant)
}
}
const char* getBytecodeTypeName(uint8_t type)
const char* getBytecodeTypeName_DEPRECATED(uint8_t type)
{
CODEGEN_ASSERT(!FFlag::LuauLoadUserdataInfo);
switch (type & ~LBC_TYPE_OPTIONAL_BIT)
{
case LBC_TYPE_NIL:
@ -512,13 +526,78 @@ const char* getBytecodeTypeName(uint8_t type)
return nullptr;
}
void toString(std::string& result, const BytecodeTypes& bcTypes)
const char* getBytecodeTypeName(uint8_t type, const char* const* userdataTypes)
{
CODEGEN_ASSERT(FFlag::LuauLoadUserdataInfo);
// Optional bit should be handled externally
type = type & ~LBC_TYPE_OPTIONAL_BIT;
if (type >= LBC_TYPE_TAGGED_USERDATA_BASE && type < LBC_TYPE_TAGGED_USERDATA_END)
{
if (userdataTypes)
return userdataTypes[type - LBC_TYPE_TAGGED_USERDATA_BASE];
return "userdata";
}
switch (type)
{
case LBC_TYPE_NIL:
return "nil";
case LBC_TYPE_BOOLEAN:
return "boolean";
case LBC_TYPE_NUMBER:
return "number";
case LBC_TYPE_STRING:
return "string";
case LBC_TYPE_TABLE:
return "table";
case LBC_TYPE_FUNCTION:
return "function";
case LBC_TYPE_THREAD:
return "thread";
case LBC_TYPE_USERDATA:
return "userdata";
case LBC_TYPE_VECTOR:
return "vector";
case LBC_TYPE_BUFFER:
return "buffer";
case LBC_TYPE_ANY:
return "any";
}
CODEGEN_ASSERT(!"Unhandled type in getBytecodeTypeName");
return nullptr;
}
void toString_DEPRECATED(std::string& result, const BytecodeTypes& bcTypes)
{
CODEGEN_ASSERT(!FFlag::LuauLoadUserdataInfo);
if (bcTypes.c != LBC_TYPE_ANY)
append(result, "%s <- %s, %s, %s", getBytecodeTypeName(bcTypes.result), getBytecodeTypeName(bcTypes.a), getBytecodeTypeName(bcTypes.b),
getBytecodeTypeName(bcTypes.c));
append(result, "%s <- %s, %s, %s", getBytecodeTypeName_DEPRECATED(bcTypes.result), getBytecodeTypeName_DEPRECATED(bcTypes.a),
getBytecodeTypeName_DEPRECATED(bcTypes.b), getBytecodeTypeName_DEPRECATED(bcTypes.c));
else
append(result, "%s <- %s, %s", getBytecodeTypeName(bcTypes.result), getBytecodeTypeName(bcTypes.a), getBytecodeTypeName(bcTypes.b));
append(result, "%s <- %s, %s", getBytecodeTypeName_DEPRECATED(bcTypes.result), getBytecodeTypeName_DEPRECATED(bcTypes.a),
getBytecodeTypeName_DEPRECATED(bcTypes.b));
}
void toString(std::string& result, const BytecodeTypes& bcTypes, const char* const* userdataTypes)
{
CODEGEN_ASSERT(FFlag::LuauLoadUserdataInfo);
append(result, "%s%s", getBytecodeTypeName(bcTypes.result, userdataTypes), (bcTypes.result & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : "");
append(result, " <- ");
append(result, "%s%s", getBytecodeTypeName(bcTypes.a, userdataTypes), (bcTypes.a & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : "");
append(result, ", ");
append(result, "%s%s", getBytecodeTypeName(bcTypes.b, userdataTypes), (bcTypes.b & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : "");
if (bcTypes.c != LBC_TYPE_ANY)
{
append(result, ", ");
append(result, "%s%s", getBytecodeTypeName(bcTypes.c, userdataTypes), (bcTypes.c & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : "");
}
}
static void appendBlockSet(IrToStringContext& ctx, BlockIteratorWrapper blocks)
@ -583,6 +662,8 @@ static RegisterSet getJumpTargetExtraLiveIn(IrToStringContext& ctx, const IrBloc
op = inst.e;
else if (inst.f.kind == IrOpKind::Block)
op = inst.f;
else if (FFlag::LuauCodegenInstG && inst.g.kind == IrOpKind::Block)
op = inst.g;
if (op.kind == IrOpKind::Block && op.index < ctx.cfg.in.size())
{
@ -867,6 +948,9 @@ std::string toDot(const IrFunction& function, bool includeInst)
checkOp(inst.d);
checkOp(inst.e);
checkOp(inst.f);
if (FFlag::LuauCodegenInstG)
checkOp(inst.g);
}
}

View file

@ -11,7 +11,11 @@
#include "lstate.h"
#include "lgc.h"
LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5)
LUAU_FASTFLAG(LuauCodegenUserdataOps)
LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataAlloc, false)
LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOpsFixA64, false)
LUAU_FASTFLAG(LuauCodegenFastcall3)
LUAU_FASTFLAG(LuauCodegenMathSign)
namespace Luau
{
@ -193,78 +197,51 @@ static void emitInvokeLibm1P(AssemblyBuilderA64& build, size_t func, int arg)
build.blr(x1);
}
static bool emitBuiltin(
AssemblyBuilderA64& build, IrFunction& function, IrRegAllocA64& regs, int bfid, int res, int arg, IrOp args, int nparams, int nresults)
static bool emitBuiltin(AssemblyBuilderA64& build, IrFunction& function, IrRegAllocA64& regs, int bfid, int res, int arg, int nresults)
{
switch (bfid)
{
case LBF_MATH_FREXP:
{
if (FFlag::LuauCodegenRemoveDeadStores5)
{
CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2));
emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg);
build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n)));
CODEGEN_ASSERT(nresults == 1 || nresults == 2);
emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg);
build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n)));
RegisterA64 temp = regs.allocTemp(KindA64::w);
build.mov(temp, LUA_TNUMBER);
build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt)));
RegisterA64 temp = regs.allocTemp(KindA64::w);
build.mov(temp, LUA_TNUMBER);
build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt)));
if (nresults == 2)
{
build.ldr(w0, sTemporary);
build.scvtf(d1, w0);
build.str(d1, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n)));
build.str(temp, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, tt)));
}
}
else
if (nresults == 2)
{
CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2));
emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg);
build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n)));
if (nresults == 2)
{
build.ldr(w0, sTemporary);
build.scvtf(d1, w0);
build.str(d1, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n)));
}
build.ldr(w0, sTemporary);
build.scvtf(d1, w0);
build.str(d1, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n)));
build.str(temp, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, tt)));
}
return true;
}
case LBF_MATH_MODF:
{
if (FFlag::LuauCodegenRemoveDeadStores5)
{
CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2));
emitInvokeLibm1P(build, offsetof(NativeContext, libm_modf), arg);
build.ldr(d1, sTemporary);
build.str(d1, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n)));
CODEGEN_ASSERT(nresults == 1 || nresults == 2);
emitInvokeLibm1P(build, offsetof(NativeContext, libm_modf), arg);
build.ldr(d1, sTemporary);
build.str(d1, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n)));
RegisterA64 temp = regs.allocTemp(KindA64::w);
build.mov(temp, LUA_TNUMBER);
build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt)));
RegisterA64 temp = regs.allocTemp(KindA64::w);
build.mov(temp, LUA_TNUMBER);
build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt)));
if (nresults == 2)
{
build.str(d0, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n)));
build.str(temp, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, tt)));
}
}
else
if (nresults == 2)
{
CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2));
emitInvokeLibm1P(build, offsetof(NativeContext, libm_modf), arg);
build.ldr(d1, sTemporary);
build.str(d1, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n)));
if (nresults == 2)
build.str(d0, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n)));
build.str(d0, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n)));
build.str(temp, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, tt)));
}
return true;
}
case LBF_MATH_SIGN:
{
CODEGEN_ASSERT(nparams == 1 && nresults == 1);
CODEGEN_ASSERT(!FFlag::LuauCodegenMathSign);
CODEGEN_ASSERT(nresults == 1);
build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n)));
build.fcmpz(d0);
build.fmov(d0, 0.0);
@ -274,12 +251,10 @@ static bool emitBuiltin(
build.fcsel(d0, d1, d0, getConditionFP(IrCondition::Less));
build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n)));
if (FFlag::LuauCodegenRemoveDeadStores5)
{
RegisterA64 temp = regs.allocTemp(KindA64::w);
build.mov(temp, LUA_TNUMBER);
build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt)));
}
RegisterA64 temp = regs.allocTemp(KindA64::w);
build.mov(temp, LUA_TNUMBER);
build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt)));
return true;
}
@ -723,6 +698,24 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.fabs(inst.regA64, temp);
break;
}
case IrCmd::SIGN_NUM:
{
CODEGEN_ASSERT(FFlag::LuauCodegenMathSign);
inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a});
RegisterA64 temp = tempDouble(inst.a);
RegisterA64 temp0 = regs.allocTemp(KindA64::d);
RegisterA64 temp1 = regs.allocTemp(KindA64::d);
build.fcmpz(temp);
build.fmov(temp0, 0.0);
build.fmov(temp1, 1.0);
build.fcsel(inst.regA64, temp1, temp0, getConditionFP(IrCondition::Greater));
build.fmov(temp1, -1.0);
build.fcsel(inst.regA64, temp1, inst.regA64, getConditionFP(IrCondition::Less));
break;
}
case IrCmd::ADD_VEC:
{
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b});
@ -1082,6 +1075,19 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
inst.regA64 = regs.takeReg(x0, index);
break;
}
case IrCmd::NEW_USERDATA:
{
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc);
regs.spill(build, index);
build.mov(x0, rState);
build.mov(x1, intOp(inst.a));
build.mov(x2, intOp(inst.b));
build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, newUserdata)));
build.blr(x3);
inst.regA64 = regs.takeReg(x0, index);
break;
}
case IrCmd::INT_TO_NUM:
{
inst.regA64 = regs.allocReg(KindA64::d, index);
@ -1188,34 +1194,88 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
}
case IrCmd::FASTCALL:
regs.spill(build, index);
error |= !emitBuiltin(build, function, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), inst.d, intOp(inst.e), intOp(inst.f));
if (FFlag::LuauCodegenFastcall3)
error |= !emitBuiltin(build, function, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d));
else
error |= !emitBuiltin(build, function, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.f));
break;
case IrCmd::INVOKE_FASTCALL:
{
regs.spill(build, index);
build.mov(x0, rState);
build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue)));
build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue)));
build.mov(w3, intOp(inst.f)); // nresults
if (inst.d.kind == IrOpKind::VmReg)
build.add(x4, rBase, uint16_t(vmRegOp(inst.d) * sizeof(TValue)));
else if (inst.d.kind == IrOpKind::VmConst)
emitAddOffset(build, x4, rConstants, vmConstOp(inst.d) * sizeof(TValue));
else
CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef);
// nparams
if (intOp(inst.e) == LUA_MULTRET)
if (FFlag::LuauCodegenFastcall3)
{
// L->top - (ra + 1)
build.ldr(x5, mem(rState, offsetof(lua_State, top)));
build.sub(x5, x5, rBase);
build.sub(x5, x5, uint16_t((vmRegOp(inst.b) + 1) * sizeof(TValue)));
build.lsr(x5, x5, kTValueSizeLog2);
// We might need a temporary and we have to preserve it over the spill
RegisterA64 temp = regs.allocTemp(KindA64::q);
regs.spill(build, index, {temp});
build.mov(x0, rState);
build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue)));
build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue)));
build.mov(w3, intOp(inst.g)); // nresults
// 'E' argument can only be produced by LOP_FASTCALL3 lowering
if (inst.e.kind != IrOpKind::Undef)
{
CODEGEN_ASSERT(intOp(inst.f) == 3);
build.ldr(x4, mem(rState, offsetof(lua_State, top)));
build.ldr(temp, mem(rBase, vmRegOp(inst.d) * sizeof(TValue)));
build.str(temp, mem(x4, 0));
build.ldr(temp, mem(rBase, vmRegOp(inst.e) * sizeof(TValue)));
build.str(temp, mem(x4, sizeof(TValue)));
}
else
{
if (inst.d.kind == IrOpKind::VmReg)
build.add(x4, rBase, uint16_t(vmRegOp(inst.d) * sizeof(TValue)));
else if (inst.d.kind == IrOpKind::VmConst)
emitAddOffset(build, x4, rConstants, vmConstOp(inst.d) * sizeof(TValue));
else
CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef);
}
// nparams
if (intOp(inst.f) == LUA_MULTRET)
{
// L->top - (ra + 1)
build.ldr(x5, mem(rState, offsetof(lua_State, top)));
build.sub(x5, x5, rBase);
build.sub(x5, x5, uint16_t((vmRegOp(inst.b) + 1) * sizeof(TValue)));
build.lsr(x5, x5, kTValueSizeLog2);
}
else
build.mov(w5, intOp(inst.f));
}
else
build.mov(w5, intOp(inst.e));
{
regs.spill(build, index);
build.mov(x0, rState);
build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue)));
build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue)));
build.mov(w3, intOp(inst.f)); // nresults
if (inst.d.kind == IrOpKind::VmReg)
build.add(x4, rBase, uint16_t(vmRegOp(inst.d) * sizeof(TValue)));
else if (inst.d.kind == IrOpKind::VmConst)
emitAddOffset(build, x4, rConstants, vmConstOp(inst.d) * sizeof(TValue));
else
CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef);
// nparams
if (intOp(inst.e) == LUA_MULTRET)
{
// L->top - (ra + 1)
build.ldr(x5, mem(rState, offsetof(lua_State, top)));
build.sub(x5, x5, rBase);
build.sub(x5, x5, uint16_t((vmRegOp(inst.b) + 1) * sizeof(TValue)));
build.lsr(x5, x5, kTValueSizeLog2);
}
else
build.mov(w5, intOp(inst.e));
}
build.ldr(x6, mem(rNativeContext, offsetof(NativeContext, luauF_table) + uintOp(inst.a) * sizeof(luau_FastFunction)));
build.blr(x6);
@ -1242,9 +1302,38 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
else
build.add(x3, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue)));
build.mov(w4, TMS(intOp(inst.d)));
build.ldr(x5, mem(rNativeContext, offsetof(NativeContext, luaV_doarith)));
build.blr(x5);
switch (TMS(intOp(inst.d)))
{
case TM_ADD:
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithadd)));
break;
case TM_SUB:
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithsub)));
break;
case TM_MUL:
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithmul)));
break;
case TM_DIV:
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithdiv)));
break;
case TM_IDIV:
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithidiv)));
break;
case TM_MOD:
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithmod)));
break;
case TM_POW:
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithpow)));
break;
case TM_UNM:
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithunm)));
break;
default:
CODEGEN_ASSERT(!"Invalid doarith helper operation tag");
break;
}
build.blr(x4);
emitUpdateBase(build);
break;
@ -1388,35 +1477,14 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
Label fresh; // used when guard aborts execution or jumps to a VM exit
Label& fail = getTargetLabel(inst.c, fresh);
if (FFlag::LuauCodegenRemoveDeadStores5)
if (tagOp(inst.b) == 0)
{
if (tagOp(inst.b) == 0)
{
build.cbnz(regOp(inst.a), fail);
}
else
{
build.cmp(regOp(inst.a), tagOp(inst.b));
build.b(ConditionA64::NotEqual, fail);
}
build.cbnz(regOp(inst.a), fail);
}
else
{
// To support DebugLuauAbortingChecks, CHECK_TAG with VmReg has to be handled
RegisterA64 tag = inst.a.kind == IrOpKind::VmReg ? regs.allocTemp(KindA64::w) : regOp(inst.a);
if (inst.a.kind == IrOpKind::VmReg)
build.ldr(tag, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, tt)));
if (tagOp(inst.b) == 0)
{
build.cbnz(tag, fail);
}
else
{
build.cmp(tag, tagOp(inst.b));
build.b(ConditionA64::NotEqual, fail);
}
build.cmp(regOp(inst.a), tagOp(inst.b));
build.b(ConditionA64::NotEqual, fail);
}
finalizeTargetLabel(inst.c, fresh);
@ -1638,6 +1706,24 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
finalizeTargetLabel(inst.d, fresh);
break;
}
case IrCmd::CHECK_USERDATA_TAG:
{
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps);
Label fresh; // used when guard aborts execution or jumps to a VM exit
Label& fail = getTargetLabel(inst.c, fresh);
RegisterA64 temp = regs.allocTemp(KindA64::w);
build.ldrb(temp, mem(regOp(inst.a), offsetof(Udata, tag)));
if (FFlag::LuauCodegenUserdataOpsFixA64)
build.cmp(temp, intOp(inst.b));
else
build.cmp(temp, tagOp(inst.b));
build.b(ConditionA64::NotEqual, fail);
finalizeTargetLabel(inst.c, fresh);
break;
}
case IrCmd::INTERRUPT:
{
regs.spill(build, index);
@ -2269,7 +2355,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READI8:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldrsb(inst.regA64, addr);
break;
@ -2278,7 +2364,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READU8:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldrb(inst.regA64, addr);
break;
@ -2287,7 +2373,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_WRITEI8:
{
RegisterA64 temp = tempInt(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.strb(temp, addr);
break;
@ -2296,7 +2382,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READI16:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldrsh(inst.regA64, addr);
break;
@ -2305,7 +2391,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READU16:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldrh(inst.regA64, addr);
break;
@ -2314,7 +2400,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_WRITEI16:
{
RegisterA64 temp = tempInt(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.strh(temp, addr);
break;
@ -2323,7 +2409,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READI32:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldr(inst.regA64, addr);
break;
@ -2332,7 +2418,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_WRITEI32:
{
RegisterA64 temp = tempInt(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.str(temp, addr);
break;
@ -2342,7 +2428,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
inst.regA64 = regs.allocReg(KindA64::d, index);
RegisterA64 temp = castReg(KindA64::s, inst.regA64); // safe to alias a fresh register
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldr(temp, addr);
build.fcvt(inst.regA64, temp);
@ -2353,7 +2439,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
RegisterA64 temp1 = tempDouble(inst.c);
RegisterA64 temp2 = regs.allocTemp(KindA64::s);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.fcvt(temp2, temp1);
build.str(temp2, addr);
@ -2363,7 +2449,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READF64:
{
inst.regA64 = regs.allocReg(KindA64::d, index);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldr(inst.regA64, addr);
break;
@ -2372,7 +2458,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_WRITEF64:
{
RegisterA64 temp = tempDouble(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.str(temp, addr);
break;
@ -2600,32 +2686,68 @@ AddressA64 IrLoweringA64::tempAddr(IrOp op, int offset)
}
}
AddressA64 IrLoweringA64::tempAddrBuffer(IrOp bufferOp, IrOp indexOp)
AddressA64 IrLoweringA64::tempAddrBuffer(IrOp bufferOp, IrOp indexOp, uint8_t tag)
{
if (indexOp.kind == IrOpKind::Inst)
if (FFlag::LuauCodegenUserdataOps)
{
RegisterA64 temp = regs.allocTemp(KindA64::x);
build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw
return mem(temp, offsetof(Buffer, data));
}
else if (indexOp.kind == IrOpKind::Constant)
{
// Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled encoding
if (unsigned(intOp(indexOp)) + offsetof(Buffer, data) <= 255)
return mem(regOp(bufferOp), int(intOp(indexOp) + offsetof(Buffer, data)));
CODEGEN_ASSERT(tag == LUA_TUSERDATA || tag == LUA_TBUFFER);
int dataOffset = tag == LUA_TBUFFER ? offsetof(Buffer, data) : offsetof(Udata, data);
// indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset
if (intOp(indexOp) < 0)
return mem(regOp(bufferOp), offsetof(Buffer, data));
if (indexOp.kind == IrOpKind::Inst)
{
RegisterA64 temp = regs.allocTemp(KindA64::x);
build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw
return mem(temp, dataOffset);
}
else if (indexOp.kind == IrOpKind::Constant)
{
// Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled
// encoding
if (unsigned(intOp(indexOp)) + dataOffset <= 255)
return mem(regOp(bufferOp), int(intOp(indexOp) + dataOffset));
RegisterA64 temp = regs.allocTemp(KindA64::x);
emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp)));
return mem(temp, offsetof(Buffer, data));
// indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset
if (intOp(indexOp) < 0)
return mem(regOp(bufferOp), dataOffset);
RegisterA64 temp = regs.allocTemp(KindA64::x);
emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp)));
return mem(temp, dataOffset);
}
else
{
CODEGEN_ASSERT(!"Unsupported instruction form");
return noreg;
}
}
else
{
CODEGEN_ASSERT(!"Unsupported instruction form");
return noreg;
if (indexOp.kind == IrOpKind::Inst)
{
RegisterA64 temp = regs.allocTemp(KindA64::x);
build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw
return mem(temp, offsetof(Buffer, data));
}
else if (indexOp.kind == IrOpKind::Constant)
{
// Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled
// encoding
if (unsigned(intOp(indexOp)) + offsetof(Buffer, data) <= 255)
return mem(regOp(bufferOp), int(intOp(indexOp) + offsetof(Buffer, data)));
// indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset
if (intOp(indexOp) < 0)
return mem(regOp(bufferOp), offsetof(Buffer, data));
RegisterA64 temp = regs.allocTemp(KindA64::x);
emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp)));
return mem(temp, offsetof(Buffer, data));
}
else
{
CODEGEN_ASSERT(!"Unsupported instruction form");
return noreg;
}
}
}

View file

@ -44,7 +44,7 @@ struct IrLoweringA64
RegisterA64 tempInt(IrOp op);
RegisterA64 tempUint(IrOp op);
AddressA64 tempAddr(IrOp op, int offset);
AddressA64 tempAddrBuffer(IrOp bufferOp, IrOp indexOp);
AddressA64 tempAddrBuffer(IrOp bufferOp, IrOp indexOp, uint8_t tag);
// May emit restore instructions
RegisterA64 regOp(IrOp op);

View file

@ -15,6 +15,11 @@
#include "lstate.h"
#include "lgc.h"
LUAU_FASTFLAG(LuauCodegenUserdataOps)
LUAU_FASTFLAG(LuauCodegenUserdataAlloc)
LUAU_FASTFLAG(LuauCodegenFastcall3)
LUAU_FASTFLAG(LuauCodegenMathSign)
namespace Luau
{
namespace CodeGen
@ -586,6 +591,33 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.vandpd(inst.regX64, inst.regX64, build.i64(~(1LL << 63)));
break;
case IrCmd::SIGN_NUM:
{
CODEGEN_ASSERT(FFlag::LuauCodegenMathSign);
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a});
ScopedRegX64 tmp0{regs, SizeX64::xmmword};
ScopedRegX64 tmp1{regs, SizeX64::xmmword};
ScopedRegX64 tmp2{regs, SizeX64::xmmword};
build.vxorpd(tmp0.reg, tmp0.reg, tmp0.reg);
// Set tmp1 to -1 if arg < 0, else 0
build.vcmpltsd(tmp1.reg, regOp(inst.a), tmp0.reg);
build.vmovsd(tmp2.reg, build.f64(-1));
build.vandpd(tmp1.reg, tmp1.reg, tmp2.reg);
// Set mask bit to 1 if 0 < arg, else 0
build.vcmpltsd(inst.regX64, tmp0.reg, regOp(inst.a));
// Result = (mask-bit == 1) ? 1.0 : tmp1
// If arg < 0 then tmp1 is -1 and mask-bit is 0, result is -1
// If arg == 0 then tmp1 is 0 and mask-bit is 0, result is 0
// If arg > 0 then tmp1 is 0 and mask-bit is 1, result is 1
build.vblendvpd(inst.regX64, tmp1.reg, build.f64x2(1, 1), inst.regX64);
break;
}
case IrCmd::ADD_VEC:
{
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b});
@ -905,6 +937,18 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
inst.regX64 = regs.takeReg(rax, index);
break;
}
case IrCmd::NEW_USERDATA:
{
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc);
IrCallWrapperX64 callWrap(regs, build, index);
callWrap.addArgument(SizeX64::qword, rState);
callWrap.addArgument(SizeX64::qword, intOp(inst.a));
callWrap.addArgument(SizeX64::dword, intOp(inst.b));
callWrap.call(qword[rNativeContext + offsetof(NativeContext, newUserdata)]);
inst.regX64 = regs.takeReg(rax, index);
break;
}
case IrCmd::INT_TO_NUM:
inst.regX64 = regs.allocReg(SizeX64::xmmword, index);
@ -993,9 +1037,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::FASTCALL:
{
OperandX64 arg2 = inst.d.kind != IrOpKind::Undef ? memRegDoubleOp(inst.d) : OperandX64{0};
emitBuiltin(regs, build, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), arg2, intOp(inst.e), intOp(inst.f));
if (FFlag::LuauCodegenFastcall3)
emitBuiltin(regs, build, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d));
else
emitBuiltin(regs, build, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.f));
break;
}
case IrCmd::INVOKE_FASTCALL:
@ -1003,25 +1048,49 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
unsigned bfid = uintOp(inst.a);
OperandX64 args = 0;
ScopedRegX64 argsAlt{regs};
if (inst.d.kind == IrOpKind::VmReg)
args = luauRegAddress(vmRegOp(inst.d));
else if (inst.d.kind == IrOpKind::VmConst)
args = luauConstantAddress(vmConstOp(inst.d));
// 'E' argument can only be produced by LOP_FASTCALL3
if (FFlag::LuauCodegenFastcall3 && inst.e.kind != IrOpKind::Undef)
{
CODEGEN_ASSERT(intOp(inst.f) == 3);
ScopedRegX64 tmp{regs, SizeX64::xmmword};
argsAlt.alloc(SizeX64::qword);
build.mov(argsAlt.reg, qword[rState + offsetof(lua_State, top)]);
build.vmovups(tmp.reg, luauReg(vmRegOp(inst.d)));
build.vmovups(xmmword[argsAlt.reg], tmp.reg);
build.vmovups(tmp.reg, luauReg(vmRegOp(inst.e)));
build.vmovups(xmmword[argsAlt.reg + sizeof(TValue)], tmp.reg);
}
else
CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef);
{
if (inst.d.kind == IrOpKind::VmReg)
args = luauRegAddress(vmRegOp(inst.d));
else if (inst.d.kind == IrOpKind::VmConst)
args = luauConstantAddress(vmConstOp(inst.d));
else
CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef);
}
int ra = vmRegOp(inst.b);
int arg = vmRegOp(inst.c);
int nparams = intOp(inst.e);
int nresults = intOp(inst.f);
int nparams = intOp(FFlag::LuauCodegenFastcall3 ? inst.f : inst.e);
int nresults = intOp(FFlag::LuauCodegenFastcall3 ? inst.g : inst.f);
IrCallWrapperX64 callWrap(regs, build, index);
callWrap.addArgument(SizeX64::qword, rState);
callWrap.addArgument(SizeX64::qword, luauRegAddress(ra));
callWrap.addArgument(SizeX64::qword, luauRegAddress(arg));
callWrap.addArgument(SizeX64::dword, nresults);
callWrap.addArgument(SizeX64::qword, args);
if (FFlag::LuauCodegenFastcall3 && inst.e.kind != IrOpKind::Undef)
callWrap.addArgument(SizeX64::qword, argsAlt);
else
callWrap.addArgument(SizeX64::qword, args);
if (nparams == LUA_MULTRET)
{
@ -1350,6 +1419,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
}
break;
}
case IrCmd::CHECK_USERDATA_TAG:
{
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps);
build.cmp(byte[regOp(inst.a) + offsetof(Udata, tag)], intOp(inst.b));
jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.c, next);
break;
}
case IrCmd::INTERRUPT:
{
unsigned pcpos = uintOp(inst.a);
@ -1895,71 +1972,71 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READI8:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movsx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b)]);
build.movsx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break;
case IrCmd::BUFFER_READU8:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movzx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b)]);
build.movzx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break;
case IrCmd::BUFFER_WRITEI8:
{
OperandX64 value = inst.c.kind == IrOpKind::Inst ? byteReg(regOp(inst.c)) : OperandX64(int8_t(intOp(inst.c)));
build.mov(byte[bufferAddrOp(inst.a, inst.b)], value);
build.mov(byte[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value);
break;
}
case IrCmd::BUFFER_READI16:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movsx(inst.regX64, word[bufferAddrOp(inst.a, inst.b)]);
build.movsx(inst.regX64, word[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break;
case IrCmd::BUFFER_READU16:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movzx(inst.regX64, word[bufferAddrOp(inst.a, inst.b)]);
build.movzx(inst.regX64, word[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break;
case IrCmd::BUFFER_WRITEI16:
{
OperandX64 value = inst.c.kind == IrOpKind::Inst ? wordReg(regOp(inst.c)) : OperandX64(int16_t(intOp(inst.c)));
build.mov(word[bufferAddrOp(inst.a, inst.b)], value);
build.mov(word[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value);
break;
}
case IrCmd::BUFFER_READI32:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.mov(inst.regX64, dword[bufferAddrOp(inst.a, inst.b)]);
build.mov(inst.regX64, dword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break;
case IrCmd::BUFFER_WRITEI32:
{
OperandX64 value = inst.c.kind == IrOpKind::Inst ? regOp(inst.c) : OperandX64(intOp(inst.c));
build.mov(dword[bufferAddrOp(inst.a, inst.b)], value);
build.mov(dword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value);
break;
}
case IrCmd::BUFFER_READF32:
inst.regX64 = regs.allocReg(SizeX64::xmmword, index);
build.vcvtss2sd(inst.regX64, inst.regX64, dword[bufferAddrOp(inst.a, inst.b)]);
build.vcvtss2sd(inst.regX64, inst.regX64, dword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break;
case IrCmd::BUFFER_WRITEF32:
storeDoubleAsFloat(dword[bufferAddrOp(inst.a, inst.b)], inst.c);
storeDoubleAsFloat(dword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], inst.c);
break;
case IrCmd::BUFFER_READF64:
inst.regX64 = regs.allocReg(SizeX64::xmmword, index);
build.vmovsd(inst.regX64, qword[bufferAddrOp(inst.a, inst.b)]);
build.vmovsd(inst.regX64, qword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break;
case IrCmd::BUFFER_WRITEF64:
@ -1967,11 +2044,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, build.f64(doubleOp(inst.c)));
build.vmovsd(qword[bufferAddrOp(inst.a, inst.b)], tmp.reg);
build.vmovsd(qword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], tmp.reg);
}
else if (inst.c.kind == IrOpKind::Inst)
{
build.vmovsd(qword[bufferAddrOp(inst.a, inst.b)], regOp(inst.c));
build.vmovsd(qword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], regOp(inst.c));
}
else
{
@ -2190,12 +2267,25 @@ RegisterX64 IrLoweringX64::regOp(IrOp op)
return inst.regX64;
}
OperandX64 IrLoweringX64::bufferAddrOp(IrOp bufferOp, IrOp indexOp)
OperandX64 IrLoweringX64::bufferAddrOp(IrOp bufferOp, IrOp indexOp, uint8_t tag)
{
if (indexOp.kind == IrOpKind::Inst)
return regOp(bufferOp) + qwordReg(regOp(indexOp)) + offsetof(Buffer, data);
else if (indexOp.kind == IrOpKind::Constant)
return regOp(bufferOp) + intOp(indexOp) + offsetof(Buffer, data);
if (FFlag::LuauCodegenUserdataOps)
{
CODEGEN_ASSERT(tag == LUA_TUSERDATA || tag == LUA_TBUFFER);
int dataOffset = tag == LUA_TBUFFER ? offsetof(Buffer, data) : offsetof(Udata, data);
if (indexOp.kind == IrOpKind::Inst)
return regOp(bufferOp) + qwordReg(regOp(indexOp)) + dataOffset;
else if (indexOp.kind == IrOpKind::Constant)
return regOp(bufferOp) + intOp(indexOp) + dataOffset;
}
else
{
if (indexOp.kind == IrOpKind::Inst)
return regOp(bufferOp) + qwordReg(regOp(indexOp)) + offsetof(Buffer, data);
else if (indexOp.kind == IrOpKind::Constant)
return regOp(bufferOp) + intOp(indexOp) + offsetof(Buffer, data);
}
CODEGEN_ASSERT(!"Unsupported instruction form");
return noreg;

View file

@ -50,7 +50,7 @@ struct IrLoweringX64
OperandX64 memRegUintOp(IrOp op);
OperandX64 memRegTagOp(IrOp op);
RegisterX64 regOp(IrOp op);
OperandX64 bufferAddrOp(IrOp bufferOp, IrOp indexOp);
OperandX64 bufferAddrOp(IrOp bufferOp, IrOp indexOp, uint8_t tag);
RegisterX64 vecOp(IrOp op, ScopedRegX64& tmp);
IrConst constOp(IrOp op) const;

Some files were not shown because too many files have changed in this diff Show more