Merge branch 'master' of https://github.com/Roblox/luau into store-class-type-location

This commit is contained in:
JohnnyMorganz 2024-07-07 13:37:33 +02:00
commit 43e31b1f5c
198 changed files with 12446 additions and 4857 deletions

View file

@ -77,10 +77,12 @@ jobs:
valgrind --tool=callgrind ./luau-compile --null -O1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O1 | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --null -O1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O1 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --null -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2 | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --null -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --codegennull -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-codegen | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --codegennull -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-codegen | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --codegennull -O2 -t1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-t1-codegen | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --null -O0 bench/other/regex.lua 2>&1 | filter regex-O0 | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --null -O0 bench/other/regex.lua 2>&1 | filter regex-O0 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --null -O1 bench/other/regex.lua 2>&1 | filter regex-O1 | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --null -O1 bench/other/regex.lua 2>&1 | filter regex-O1 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --null -O2 bench/other/regex.lua 2>&1 | filter regex-O2 | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --null -O2 bench/other/regex.lua 2>&1 | filter regex-O2 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --codegennull -O2 bench/other/regex.lua 2>&1 | filter regex-O2-codegen | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --codegennull -O2 bench/other/regex.lua 2>&1 | filter regex-O2-codegen | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --codegennull -O2 -t1 bench/other/regex.lua 2>&1 | filter regex-O2-t1-codegen | tee -a compile-output.txt
- name: Checkout benchmark results - name: Checkout benchmark results
uses: actions/checkout@v3 uses: actions/checkout@v3

2
.gitignore vendored
View file

@ -1,5 +1,7 @@
/build/ /build/
/build[.-]*/ /build[.-]*/
/cmake/
/cmake[.-]*/
/coverage/ /coverage/
/.vs/ /.vs/
/.vscode/ /.vscode/

View file

@ -57,7 +57,7 @@ struct GeneralizationConstraint
struct IterableConstraint struct IterableConstraint
{ {
TypePackId iterator; TypePackId iterator;
TypePackId variables; std::vector<TypeId> variables;
const AstNode* nextAstFragment; const AstNode* nextAstFragment;
DenseHashMap<const AstNode*, TypeId>* astForInNextTypes; DenseHashMap<const AstNode*, TypeId>* astForInNextTypes;
@ -179,23 +179,6 @@ struct HasPropConstraint
bool suppressSimplification = false; bool suppressSimplification = false;
}; };
// result ~ setProp subjectType ["prop", "prop2", ...] propType
//
// If the subject is a table or table-like thing that already has the named
// property chain, we unify propType with that existing property type.
//
// If the subject is a free table, we augment it in place.
//
// If the subject is an unsealed table, result is an augmented table that
// includes that new prop.
struct SetPropConstraint
{
TypeId resultType;
TypeId subjectType;
std::vector<std::string> path;
TypeId propType;
};
// resultType ~ hasIndexer subjectType indexType // resultType ~ hasIndexer subjectType indexType
// //
// If the subject type is a table or table-like thing that supports indexing, // If the subject type is a table or table-like thing that supports indexing,
@ -209,46 +192,48 @@ struct HasIndexerConstraint
TypeId indexType; TypeId indexType;
}; };
// result ~ setIndexer subjectType indexType propType // assignProp lhsType propName rhsType
// //
// If the subject is a table or table-like thing that already has an indexer, // Assign a value of type rhsType into the named property of lhsType.
// unify its indexType and propType with those from this constraint.
// struct AssignPropConstraint
// If the table is a free or unsealed table, we augment it with a new indexer.
struct SetIndexerConstraint
{ {
TypeId subjectType; TypeId lhsType;
std::string propName;
TypeId rhsType;
/// The canonical write type of the property. It is _solely_ used to
/// populate astTypes during constraint resolution. Nothing should ever
/// block on it.
TypeId propType;
// When we generate constraints, we increment the remaining prop count on
// the table if we are able. This flag informs the solver as to whether or
// not it should in turn decrement the prop count when this constraint is
// dispatched.
bool decrementPropCount = false;
};
struct AssignIndexConstraint
{
TypeId lhsType;
TypeId indexType; TypeId indexType;
TypeId rhsType;
/// The canonical write type of the property. It is _solely_ used to
/// populate astTypes during constraint resolution. Nothing should ever
/// block on it.
TypeId propType; TypeId propType;
}; };
// resultType ~ unpack sourceTypePack // resultTypes ~ unpack sourceTypePack
// //
// Similar to PackSubtypeConstraint, but with one important difference: If the // Similar to PackSubtypeConstraint, but with one important difference: If the
// sourcePack is blocked, this constraint blocks. // sourcePack is blocked, this constraint blocks.
struct UnpackConstraint struct UnpackConstraint
{ {
TypePackId resultPack; std::vector<TypeId> resultPack;
TypePackId sourcePack; TypePackId sourcePack;
// UnpackConstraint is sometimes used to resolve the types of assignments.
// When this is the case, any LocalTypes in resultPack can have their
// domains extended by the corresponding type from sourcePack.
bool resultIsLValue = false;
};
// resultType ~ unpack sourceType
//
// The same as UnpackConstraint, but specialized for a pair of types as opposed to packs.
struct Unpack1Constraint
{
TypeId resultType;
TypeId sourceType;
// UnpackConstraint is sometimes used to resolve the types of assignments.
// When this is the case, any LocalTypes in resultPack can have their
// domains extended by the corresponding type from sourcePack.
bool resultIsLValue = false;
}; };
// ty ~ reduce ty // ty ~ reduce ty
@ -268,8 +253,8 @@ struct ReducePackConstraint
}; };
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, IterableConstraint, NameConstraint, using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, IterableConstraint, NameConstraint,
TypeAliasExpansionConstraint, FunctionCallConstraint, FunctionCheckConstraint, PrimitiveTypeConstraint, HasPropConstraint, SetPropConstraint, TypeAliasExpansionConstraint, FunctionCallConstraint, FunctionCheckConstraint, PrimitiveTypeConstraint, HasPropConstraint, HasIndexerConstraint,
HasIndexerConstraint, SetIndexerConstraint, UnpackConstraint, Unpack1Constraint, ReduceConstraint, ReducePackConstraint, EqualityConstraint>; AssignPropConstraint, AssignIndexConstraint, UnpackConstraint, ReduceConstraint, ReducePackConstraint, EqualityConstraint>;
struct Constraint struct Constraint
{ {
@ -284,11 +269,13 @@ struct Constraint
std::vector<NotNull<Constraint>> dependencies; std::vector<NotNull<Constraint>> dependencies;
DenseHashSet<TypeId> getFreeTypes() const; DenseHashSet<TypeId> getMaybeMutatedFreeTypes() const;
}; };
using ConstraintPtr = std::unique_ptr<Constraint>; using ConstraintPtr = std::unique_ptr<Constraint>;
bool isReferenceCountedType(const TypeId typ);
inline Constraint& asMutable(const Constraint& c) inline Constraint& asMutable(const Constraint& c)
{ {
return const_cast<Constraint&>(c); return const_cast<Constraint&>(c);

View file

@ -118,6 +118,8 @@ struct ConstraintGenerator
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope; std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope;
std::vector<RequireCycle> requireCycles; std::vector<RequireCycle> requireCycles;
DenseHashMap<TypeId, TypeIds> localTypes{nullptr};
DcrLogger* logger; DcrLogger* logger;
ConstraintGenerator(ModulePtr module, NotNull<Normalizer> normalizer, NotNull<ModuleResolver> moduleResolver, NotNull<BuiltinTypes> builtinTypes, ConstraintGenerator(ModulePtr module, NotNull<Normalizer> normalizer, NotNull<ModuleResolver> moduleResolver, NotNull<BuiltinTypes> builtinTypes,
@ -254,18 +256,11 @@ private:
Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType); Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
std::tuple<TypeId, TypeId, RefinementId> checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType); std::tuple<TypeId, TypeId, RefinementId> checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
struct LValueBounds void visitLValue(const ScopePtr& scope, AstExpr* expr, TypeId rhsType);
{ void visitLValue(const ScopePtr& scope, AstExprLocal* local, TypeId rhsType);
std::optional<TypeId> annotationTy; void visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType);
std::optional<TypeId> assignedTy; void visitLValue(const ScopePtr& scope, AstExprIndexName* indexName, TypeId rhsType);
}; void visitLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr, TypeId rhsType);
LValueBounds checkLValue(const ScopePtr& scope, AstExpr* expr);
LValueBounds checkLValue(const ScopePtr& scope, AstExprLocal* local);
LValueBounds checkLValue(const ScopePtr& scope, AstExprGlobal* global);
LValueBounds checkLValue(const ScopePtr& scope, AstExprIndexName* indexName);
LValueBounds checkLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr);
LValueBounds updateProperty(const ScopePtr& scope, AstExpr* expr);
struct FunctionSignature struct FunctionSignature
{ {
@ -361,6 +356,8 @@ private:
*/ */
void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program); void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program);
bool recordPropertyAssignment(TypeId ty);
// Record the fact that a particular local has a particular type in at least // Record the fact that a particular local has a particular type in at least
// one of its states. // one of its states.
void recordInferredBinding(AstLocal* local, TypeId ty); void recordInferredBinding(AstLocal* local, TypeId ty);
@ -373,7 +370,8 @@ private:
*/ */
std::vector<std::optional<TypeId>> getExpectedCallTypesForFunctionOverloads(const TypeId fnType); std::vector<std::optional<TypeId>> getExpectedCallTypesForFunctionOverloads(const TypeId fnType);
TypeId createFamilyInstance(TypeFamilyInstanceType instance, const ScopePtr& scope, Location location); TypeId createTypeFamilyInstance(
const TypeFamily& family, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments, const ScopePtr& scope, Location location);
}; };
/** Borrow a vector of pointers from a vector of owning pointers to constraints. /** Borrow a vector of pointers from a vector of owning pointers to constraints.

View file

@ -94,6 +94,10 @@ struct ConstraintSolver
// Irreducible/uninhabited type families or type pack families. // Irreducible/uninhabited type families or type pack families.
DenseHashSet<const void*> uninhabitedTypeFamilies{{}}; DenseHashSet<const void*> uninhabitedTypeFamilies{{}};
// The set of types that will definitely be unchanged by generalization.
DenseHashSet<TypeId> generalizedTypes_{nullptr};
const NotNull<DenseHashSet<TypeId>> generalizedTypes{&generalizedTypes_};
// Recorded errors that take place within the solver. // Recorded errors that take place within the solver.
ErrorVec errors; ErrorVec errors;
@ -103,6 +107,8 @@ struct ConstraintSolver
DcrLogger* logger; DcrLogger* logger;
TypeCheckLimits limits; TypeCheckLimits limits;
DenseHashMap<TypeId, const Constraint*> typeFamiliesToFinalize{nullptr};
explicit ConstraintSolver(NotNull<Normalizer> normalizer, NotNull<Scope> rootScope, std::vector<NotNull<Constraint>> constraints, explicit ConstraintSolver(NotNull<Normalizer> normalizer, NotNull<Scope> rootScope, std::vector<NotNull<Constraint>> constraints,
ModuleName moduleName, NotNull<ModuleResolver> moduleResolver, std::vector<RequireCycle> requireCycles, DcrLogger* logger, ModuleName moduleName, NotNull<ModuleResolver> moduleResolver, std::vector<RequireCycle> requireCycles, DcrLogger* logger,
TypeCheckLimits limits); TypeCheckLimits limits);
@ -116,8 +122,35 @@ struct ConstraintSolver
**/ **/
void run(); void run();
/**
* Attempts to perform one final reduction on type families after every constraint has been completed
*
**/
void finalizeTypeFamilies();
bool isDone(); bool isDone();
private:
/**
* Bind a type variable to another type.
*
* A constraint is required and will validate that blockedTy is owned by this
* constraint. This prevents one constraint from interfering with another's
* blocked types.
*
* Bind will also unblock the type variable for you.
*/
void bind(NotNull<const Constraint> constraint, TypeId ty, TypeId boundTo);
void bind(NotNull<const Constraint> constraint, TypePackId tp, TypePackId boundTo);
template<typename T, typename... Args>
void emplace(NotNull<const Constraint> constraint, TypeId ty, Args&&... args);
template<typename T, typename... Args>
void emplace(NotNull<const Constraint> constraint, TypePackId tp, Args&&... args);
public:
/** Attempt to dispatch a constraint. Returns true if it was successful. If /** Attempt to dispatch a constraint. Returns true if it was successful. If
* tryDispatch() returns false, the constraint remains in the unsolved set * tryDispatch() returns false, the constraint remains in the unsolved set
* and will be retried later. * and will be retried later.
@ -134,20 +167,15 @@ struct ConstraintSolver
bool tryDispatch(const FunctionCheckConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const FunctionCheckConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const SetPropConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatchHasIndexer( bool tryDispatchHasIndexer(
int& recursionDepth, NotNull<const Constraint> constraint, TypeId subjectType, TypeId indexType, TypeId resultType, Set<TypeId>& seen); int& recursionDepth, NotNull<const Constraint> constraint, TypeId subjectType, TypeId indexType, TypeId resultType, Set<TypeId>& seen);
bool tryDispatch(const HasIndexerConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const HasIndexerConstraint& c, NotNull<const Constraint> constraint);
std::pair<bool, std::optional<TypeId>> tryDispatchSetIndexer( bool tryDispatch(const AssignPropConstraint& c, NotNull<const Constraint> constraint);
NotNull<const Constraint> constraint, TypeId subjectType, TypeId indexType, TypeId propType, bool expandFreeTypeBounds); bool tryDispatch(const AssignIndexConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const SetIndexerConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatchUnpack1(NotNull<const Constraint> constraint, TypeId resultType, TypeId sourceType, bool resultIsLValue);
bool tryDispatch(const UnpackConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const UnpackConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const Unpack1Constraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const ReduceConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const ReduceConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const ReducePackConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const ReducePackConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const EqualityConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const EqualityConstraint& c, NotNull<const Constraint> constraint, bool force);
@ -157,14 +185,28 @@ struct ConstraintSolver
bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
// for a, ... in next_function, t, ... do // for a, ... in next_function, t, ... do
bool tryDispatchIterableFunction( bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType, std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType,
const std::string& propName, ValueContext context, bool inConditional = false, bool suppressSimplification = false); const std::string& propName, ValueContext context, bool inConditional = false, bool suppressSimplification = false);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType, std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType,
const std::string& propName, ValueContext context, bool inConditional, bool suppressSimplification, DenseHashSet<TypeId>& seen); const std::string& propName, ValueContext context, bool inConditional, bool suppressSimplification, DenseHashSet<TypeId>& seen);
/**
* Generate constraints to unpack the types of srcTypes and assign each
* value to the corresponding BlockedType in destTypes.
*
* This function also overwrites the owners of each BlockedType. This is
* okay because this function is only used to decompose IterableConstraint
* into an UnpackConstraint.
*
* @param destTypes A vector of types comprised of BlockedTypes.
* @param srcTypes A TypePack that represents rvalues to be assigned.
* @returns The underlying UnpackConstraint. There's a bit of code in
* iteration that needs to pass blocks on to this constraint.
*/
NotNull<const Constraint> unpackAndAssign(const std::vector<TypeId> destTypes, TypePackId srcTypes, NotNull<const Constraint> constraint);
void block(NotNull<const Constraint> target, NotNull<const Constraint> constraint); void block(NotNull<const Constraint> target, NotNull<const Constraint> constraint);
/** /**
* Block a constraint on the resolution of a Type. * Block a constraint on the resolution of a Type.
@ -242,6 +284,24 @@ struct ConstraintSolver
void reportError(TypeErrorData&& data, const Location& location); void reportError(TypeErrorData&& data, const Location& location);
void reportError(TypeError e); void reportError(TypeError e);
/**
* Shifts the count of references from `source` to `target`. This should be paired
* with any instance of binding a free type in order to maintain accurate refcounts.
* If `target` is not a free type, this is a noop.
* @param source the free type which is being bound
* @param target the type which the free type is being bound to
*/
void shiftReferences(TypeId source, TypeId target);
/**
* Generalizes the given free type if the reference counting allows it.
* @param the scope to generalize in
* @param type the free type we want to generalize
* @returns a non-free type that generalizes the argument, or `std::nullopt` if one
* does not exist
*/
std::optional<TypeId> generalizeFreeType(NotNull<Scope> scope, TypeId type, bool avoidSealingTables = false);
/** /**
* Checks the existing set of constraints to see if there exist any that contain * Checks the existing set of constraints to see if there exist any that contain
* the provided free type, indicating that it is not yet ready to be replaced by * the provided free type, indicating that it is not yet ready to be replaced by
@ -266,22 +326,6 @@ struct ConstraintSolver
template<typename TID> template<typename TID>
bool unify(NotNull<const Constraint> constraint, TID subTy, TID superTy); bool unify(NotNull<const Constraint> constraint, TID subTy, TID superTy);
private:
/**
* Bind a BlockedType to another type while taking care not to bind it to
* itself in the case that resultTy == blockedTy. This can happen if we
* have a tautological constraint. When it does, we must instead bind
* blockedTy to a fresh type belonging to an appropriate scope.
*
* To determine which scope is appropriate, we also accept rootTy, which is
* to be the type that contains blockedTy.
*
* A constraint is required and will validate that blockedTy is owned by this
* constraint. This prevents one constraint from interfering with another's
* blocked types.
*/
void bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId rootTy, NotNull<const Constraint> constraint);
/** /**
* Marks a constraint as being blocked on a type or type pack. The constraint * Marks a constraint as being blocked on a type or type pack. The constraint
* solver will not attempt to dispatch blocked constraints until their * solver will not attempt to dispatch blocked constraints until their

View file

@ -191,7 +191,7 @@ struct Frontend
void queueModuleCheck(const std::vector<ModuleName>& names); void queueModuleCheck(const std::vector<ModuleName>& names);
void queueModuleCheck(const ModuleName& name); void queueModuleCheck(const ModuleName& name);
std::vector<ModuleName> checkQueuedModules(std::optional<FrontendOptions> optionOverride = {}, std::vector<ModuleName> checkQueuedModules(std::optional<FrontendOptions> optionOverride = {},
std::function<void(std::function<void()> task)> executeTask = {}, std::function<void(size_t done, size_t total)> progress = {}); std::function<void(std::function<void()> task)> executeTask = {}, std::function<bool(size_t done, size_t total)> progress = {});
std::optional<CheckResult> getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false); std::optional<CheckResult> getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false);

View file

@ -0,0 +1,13 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Scope.h"
#include "Luau/NotNull.h"
#include "Luau/TypeFwd.h"
namespace Luau
{
std::optional<TypeId> generalize(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> bakedTypes, TypeId ty, /* avoid sealing tables*/ bool avoidSealingTables = false);
}

View file

@ -27,12 +27,16 @@ struct ReplaceGenerics : Substitution
{ {
} }
void resetState(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope,
const std::vector<TypeId>& generics, const std::vector<TypePackId>& genericPacks);
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
TypeLevel level; TypeLevel level;
Scope* scope; Scope* scope;
std::vector<TypeId> generics; std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks; std::vector<TypePackId> genericPacks;
bool ignoreChildren(TypeId ty) override; bool ignoreChildren(TypeId ty) override;
bool isDirty(TypeId ty) override; bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override; bool isDirty(TypePackId tp) override;
@ -48,13 +52,19 @@ struct Instantiation : Substitution
, builtinTypes(builtinTypes) , builtinTypes(builtinTypes)
, level(level) , level(level)
, scope(scope) , scope(scope)
, reusableReplaceGenerics(log, arena, builtinTypes, level, scope, {}, {})
{ {
} }
void resetState(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope);
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
TypeLevel level; TypeLevel level;
Scope* scope; Scope* scope;
ReplaceGenerics reusableReplaceGenerics;
bool ignoreChildren(TypeId ty) override; bool ignoreChildren(TypeId ty) override;
bool isDirty(TypeId ty) override; bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override; bool isDirty(TypePackId tp) override;

View file

@ -102,6 +102,12 @@ struct Module
DenseHashMap<const AstType*, TypeId> astResolvedTypes{nullptr}; DenseHashMap<const AstType*, TypeId> astResolvedTypes{nullptr};
DenseHashMap<const AstTypePack*, TypePackId> astResolvedTypePacks{nullptr}; DenseHashMap<const AstTypePack*, TypePackId> astResolvedTypePacks{nullptr};
// The computed result type of a compound assignment. (eg foo += 1)
//
// Type checking uses this to check that the result of such an operation is
// actually compatible with the left-side operand.
DenseHashMap<const AstStat*, TypeId> astCompoundAssignResultTypes{nullptr};
DenseHashMap<TypeId, std::vector<std::pair<Location, TypeId>>> upperBoundContributors{nullptr}; DenseHashMap<TypeId, std::vector<std::pair<Location, TypeId>>> upperBoundContributors{nullptr};
// Map AST nodes to the scope they create. Cannot be NotNull<Scope> because // Map AST nodes to the scope they create. Cannot be NotNull<Scope> because

View file

@ -307,6 +307,9 @@ struct NormalizedType
/// Returns true if the type is a subtype of string(it could be a singleton). Behaves like Type::isString() /// Returns true if the type is a subtype of string(it could be a singleton). Behaves like Type::isString()
bool isSubtypeOfString() const; bool isSubtypeOfString() const;
/// Returns true if the type is a subtype of boolean(it could be a singleton). Behaves like Type::isBoolean()
bool isSubtypeOfBooleans() const;
/// Returns true if this type should result in error suppressing behavior. /// Returns true if this type should result in error suppressing behavior.
bool shouldSuppressErrors() const; bool shouldSuppressErrors() const;
@ -360,7 +363,6 @@ public:
Normalizer& operator=(Normalizer&) = delete; Normalizer& operator=(Normalizer&) = delete;
// If this returns null, the typechecker should emit a "too complex" error // If this returns null, the typechecker should emit a "too complex" error
const NormalizedType* DEPRECATED_normalize(TypeId ty);
std::shared_ptr<const NormalizedType> normalize(TypeId ty); std::shared_ptr<const NormalizedType> normalize(TypeId ty);
void clearNormal(NormalizedType& norm); void clearNormal(NormalizedType& norm);
@ -395,6 +397,7 @@ public:
TypeId negate(TypeId there); TypeId negate(TypeId there);
void subtractPrimitive(NormalizedType& here, TypeId ty); void subtractPrimitive(NormalizedType& here, TypeId ty);
void subtractSingleton(NormalizedType& here, TypeId ty); void subtractSingleton(NormalizedType& here, TypeId ty);
NormalizationResult intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect);
// ------- Normalizing intersections // ------- Normalizing intersections
TypeId intersectionOfTops(TypeId here, TypeId there); TypeId intersectionOfTops(TypeId here, TypeId there);
@ -403,8 +406,8 @@ public:
void intersectClassesWithClass(NormalizedClassType& heres, TypeId there); void intersectClassesWithClass(NormalizedClassType& heres, TypeId there);
void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there); void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there);
std::optional<TypePackId> intersectionOfTypePacks(TypePackId here, TypePackId there); std::optional<TypePackId> intersectionOfTypePacks(TypePackId here, TypePackId there);
std::optional<TypeId> intersectionOfTables(TypeId here, TypeId there); std::optional<TypeId> intersectionOfTables(TypeId here, TypeId there, Set<TypeId>& seenSet);
void intersectTablesWithTable(TypeIds& heres, TypeId there); void intersectTablesWithTable(TypeIds& heres, TypeId there, Set<TypeId>& seenSetTypes);
void intersectTables(TypeIds& heres, const TypeIds& theres); void intersectTables(TypeIds& heres, const TypeIds& theres);
std::optional<TypeId> intersectionOfFunctions(TypeId here, TypeId there); std::optional<TypeId> intersectionOfFunctions(TypeId here, TypeId there);
void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there); void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there);
@ -412,7 +415,7 @@ public:
NormalizationResult intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set<TypeId>& seenSetTypes); NormalizationResult intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set<TypeId>& seenSetTypes);
NormalizationResult intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); NormalizationResult intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
NormalizationResult intersectNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes); NormalizationResult intersectNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes);
NormalizationResult normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType); NormalizationResult normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType, Set<TypeId>& seenSet);
// Check for inhabitance // Check for inhabitance
NormalizationResult isInhabited(TypeId ty); NormalizationResult isInhabited(TypeId ty);
@ -422,6 +425,7 @@ public:
// Check for intersections being inhabited // Check for intersections being inhabited
NormalizationResult isIntersectionInhabited(TypeId left, TypeId right); NormalizationResult isIntersectionInhabited(TypeId left, TypeId right);
NormalizationResult isIntersectionInhabited(TypeId left, TypeId right, Set<TypeId>& seenSet);
// -------- Convert back from a normalized type to a type // -------- Convert back from a normalized type to a type
TypeId typeFromNormal(const NormalizedType& norm); TypeId typeFromNormal(const NormalizedType& norm);

View file

@ -102,4 +102,12 @@ bool subsumesStrict(Scope* left, Scope* right);
// outermost-possible scope. // outermost-possible scope.
bool subsumes(Scope* left, Scope* right); bool subsumes(Scope* left, Scope* right);
inline Scope* max(Scope* left, Scope* right)
{
if (subsumes(left, right))
return right;
else
return left;
}
} // namespace Luau } // namespace Luau

View file

@ -4,7 +4,6 @@
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
LUAU_FASTFLAG(LuauFixSetIter)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
namespace Luau namespace Luau
@ -143,11 +142,8 @@ public:
: impl(impl_) : impl(impl_)
, end(end_) , end(end_)
{ {
if (FFlag::LuauFixSetIter || FFlag::DebugLuauDeferredConstraintResolution) while (impl != end && impl->second == false)
{ ++impl;
while (impl != end && impl->second == false)
++impl;
}
} }
const T& operator*() const const T& operator*() const

View file

@ -5,6 +5,7 @@
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
#include <set>
namespace Luau namespace Luau
{ {
@ -19,6 +20,8 @@ struct SimplifyResult
}; };
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty, TypeId discriminant); SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty, TypeId discriminant);
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, std::set<TypeId> parts);
SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty, TypeId discriminant); SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty, TypeId discriminant);
enum class Relation enum class Relation

View file

@ -134,7 +134,8 @@ struct Tarjan
TarjanResult visitRoot(TypeId ty); TarjanResult visitRoot(TypeId ty);
TarjanResult visitRoot(TypePackId ty); TarjanResult visitRoot(TypePackId ty);
void clearTarjan(); // Used to reuse the object for a new operation
void clearTarjan(const TxnLog* log);
// Get/set the dirty bit for an index (grows the vector if needed) // Get/set the dirty bit for an index (grows the vector if needed)
bool getDirty(int index); bool getDirty(int index);
@ -212,6 +213,8 @@ public:
std::optional<TypeId> substitute(TypeId ty); std::optional<TypeId> substitute(TypeId ty);
std::optional<TypePackId> substitute(TypePackId tp); std::optional<TypePackId> substitute(TypePackId tp);
void resetState(const TxnLog* log, TypeArena* arena);
TypeId replace(TypeId ty); TypeId replace(TypeId ty);
TypePackId replace(TypePackId tp); TypePackId replace(TypePackId tp);

View file

@ -86,24 +86,6 @@ struct FreeType
TypeId upperBound = nullptr; TypeId upperBound = nullptr;
}; };
/** A type that tracks the domain of a local variable.
*
* We consider each local's domain to be the union of all types assigned to it.
* We accomplish this with LocalType. Each time we dispatch an assignment to a
* local, we accumulate this union and decrement blockCount.
*
* When blockCount reaches 0, we can consider the LocalType to be "fully baked"
* and replace it with the union we've built.
*/
struct LocalType
{
TypeId domain;
int blockCount = 0;
// Used for debugging
std::string name;
};
struct GenericType struct GenericType
{ {
// By default, generics are global, with a synthetic name // By default, generics are global, with a synthetic name
@ -148,6 +130,7 @@ struct BlockedType
Constraint* getOwner() const; Constraint* getOwner() const;
void setOwner(Constraint* newOwner); void setOwner(Constraint* newOwner);
void replaceOwner(Constraint* newOwner);
private: private:
// The constraint that is intended to unblock this type. Other constraints // The constraint that is intended to unblock this type. Other constraints
@ -471,6 +454,11 @@ struct TableType
// Methods of this table that have an untyped self will use the same shared self type. // Methods of this table that have an untyped self will use the same shared self type.
std::optional<TypeId> selfTy; std::optional<TypeId> selfTy;
// We track the number of as-yet-unadded properties to unsealed tables.
// Some constraints will use this information to decide whether or not they
// are able to dispatch.
size_t remainingProps = 0;
}; };
// Represents a metatable attached to a table type. Somewhat analogous to a bound type. // Represents a metatable attached to a table type. Somewhat analogous to a bound type.
@ -672,9 +660,9 @@ struct NegationType
using ErrorType = Unifiable::Error; using ErrorType = Unifiable::Error;
using TypeVariant = Unifiable::Variant<TypeId, FreeType, LocalType, GenericType, PrimitiveType, BlockedType, PendingExpansionType, SingletonType, using TypeVariant =
FunctionType, TableType, MetatableType, ClassType, AnyType, UnionType, IntersectionType, LazyType, UnknownType, NeverType, NegationType, Unifiable::Variant<TypeId, FreeType, GenericType, PrimitiveType, SingletonType, BlockedType, PendingExpansionType, FunctionType, TableType,
TypeFamilyInstanceType>; MetatableType, ClassType, AnyType, UnionType, IntersectionType, LazyType, UnknownType, NeverType, NegationType, TypeFamilyInstanceType>;
struct Type final struct Type final
{ {

View file

@ -6,7 +6,6 @@
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/TypeCheckLimits.h" #include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
#include "Luau/Variant.h"
#include <functional> #include <functional>
#include <string> #include <string>
@ -19,22 +18,6 @@ struct TypeArena;
struct TxnLog; struct TxnLog;
class Normalizer; class Normalizer;
struct TypeFamilyQueue
{
NotNull<VecDeque<TypeId>> queuedTys;
NotNull<VecDeque<TypePackId>> queuedTps;
void add(TypeId instanceTy);
void add(TypePackId instanceTp);
template<typename T>
void add(const std::vector<T>& ts)
{
for (const T& t : ts)
enqueue(t);
}
};
struct TypeFamilyContext struct TypeFamilyContext
{ {
NotNull<TypeArena> arena; NotNull<TypeArena> arena;
@ -99,8 +82,8 @@ struct TypeFamilyReductionResult
}; };
template<typename T> template<typename T>
using ReducerFunction = std::function<TypeFamilyReductionResult<T>( using ReducerFunction =
T, NotNull<TypeFamilyQueue>, const std::vector<TypeId>&, const std::vector<TypePackId>&, NotNull<TypeFamilyContext>)>; std::function<TypeFamilyReductionResult<T>(T, const std::vector<TypeId>&, const std::vector<TypePackId>&, NotNull<TypeFamilyContext>)>;
/// Represents a type function that may be applied to map a series of types and /// Represents a type function that may be applied to map a series of types and
/// type packs to a single output type. /// type packs to a single output type.
@ -196,11 +179,12 @@ struct BuiltinTypeFamilies
TypeFamily keyofFamily; TypeFamily keyofFamily;
TypeFamily rawkeyofFamily; TypeFamily rawkeyofFamily;
TypeFamily indexFamily;
TypeFamily rawgetFamily;
void addToScope(NotNull<TypeArena> arena, NotNull<Scope> scope) const; void addToScope(NotNull<TypeArena> arena, NotNull<Scope> scope) const;
}; };
const BuiltinTypeFamilies& builtinTypeFunctions();
const BuiltinTypeFamilies kBuiltinTypeFamilies{};
} // namespace Luau } // namespace Luau

View file

@ -4,6 +4,7 @@
#include "Luau/Anyification.h" #include "Luau/Anyification.h"
#include "Luau/ControlFlow.h" #include "Luau/ControlFlow.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/Instantiation.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/Predicate.h" #include "Luau/Predicate.h"
#include "Luau/Substitution.h" #include "Luau/Substitution.h"
@ -362,6 +363,8 @@ public:
UnifierSharedState unifierState; UnifierSharedState unifierState;
Normalizer normalizer; Normalizer normalizer;
Instantiation reusableInstantiation;
std::vector<RequireCycle> requireCycles; std::vector<RequireCycle> requireCycles;
// Type inference limits // Type inference limits

View file

@ -55,6 +55,9 @@ struct InConditionalContext
using ScopePtr = std::shared_ptr<struct Scope>; using ScopePtr = std::shared_ptr<struct Scope>;
std::optional<Property> findTableProperty(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location);
std::optional<TypeId> findMetatableEntry( std::optional<TypeId> findMetatableEntry(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location); NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location);
std::optional<TypeId> findTablePropertyRespectingMeta( std::optional<TypeId> findTablePropertyRespectingMeta(

View file

@ -69,7 +69,6 @@ struct Unifier2
*/ */
bool unify(TypeId subTy, TypeId superTy); bool unify(TypeId subTy, TypeId superTy);
bool unifyFreeWithType(TypeId subTy, TypeId superTy); bool unifyFreeWithType(TypeId subTy, TypeId superTy);
bool unify(const LocalType* subTy, TypeId superFn);
bool unify(TypeId subTy, const FunctionType* superFn); bool unify(TypeId subTy, const FunctionType* superFn);
bool unify(const UnionType* subUnion, TypeId superTy); bool unify(const UnionType* subUnion, TypeId superTy);
bool unify(TypeId subTy, const UnionType* superUnion); bool unify(TypeId subTy, const UnionType* superUnion);
@ -78,6 +77,11 @@ struct Unifier2
bool unify(TableType* subTable, const TableType* superTable); bool unify(TableType* subTable, const TableType* superTable);
bool unify(const MetatableType* subMetatable, const MetatableType* superMetatable); bool unify(const MetatableType* subMetatable, const MetatableType* superMetatable);
bool unify(const AnyType* subAny, const FunctionType* superFn);
bool unify(const FunctionType* subFn, const AnyType* superAny);
bool unify(const AnyType* subAny, const TableType* superTable);
bool unify(const TableType* subTable, const AnyType* superAny);
// TODO think about this one carefully. We don't do unions or intersections of type packs // TODO think about this one carefully. We don't do unions or intersections of type packs
bool unify(TypePackId subTp, TypePackId superTp); bool unify(TypePackId subTp, TypePackId superTp);

View file

@ -100,10 +100,6 @@ struct GenericTypeVisitor
{ {
return visit(ty); return visit(ty);
} }
virtual bool visit(TypeId ty, const LocalType& ftv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const GenericType& gtv) virtual bool visit(TypeId ty, const GenericType& gtv)
{ {
return visit(ty); return visit(ty);
@ -248,11 +244,6 @@ struct GenericTypeVisitor
else else
visit(ty, *ftv); visit(ty, *ftv);
} }
else if (auto lt = get<LocalType>(ty))
{
if (visit(ty, *lt))
traverse(lt->domain);
}
else if (auto gtv = get<GenericType>(ty)) else if (auto gtv = get<GenericType>(ty))
visit(ty, *gtv); visit(ty, *gtv);
else if (auto etv = get<ErrorType>(ty)) else if (auto etv = get<ErrorType>(ty))
@ -357,16 +348,38 @@ struct GenericTypeVisitor
{ {
if (visit(ty, *utv)) if (visit(ty, *utv))
{ {
bool unionChanged = false;
for (TypeId optTy : utv->options) for (TypeId optTy : utv->options)
{
traverse(optTy); traverse(optTy);
if (!get<UnionType>(follow(ty)))
{
unionChanged = true;
break;
}
}
if (unionChanged)
traverse(ty);
} }
} }
else if (auto itv = get<IntersectionType>(ty)) else if (auto itv = get<IntersectionType>(ty))
{ {
if (visit(ty, *itv)) if (visit(ty, *itv))
{ {
bool intersectionChanged = false;
for (TypeId partTy : itv->parts) for (TypeId partTy : itv->parts)
{
traverse(partTy); traverse(partTy);
if (!get<IntersectionType>(follow(ty)))
{
intersectionChanged = true;
break;
}
}
if (intersectionChanged)
traverse(ty);
} }
} }
else if (auto ltv = get<LazyType>(ty)) else if (auto ltv = get<LazyType>(ty))

View file

@ -8,6 +8,8 @@
#include <math.h> #include <math.h>
LUAU_FASTFLAG(LuauDeclarationExtraPropData)
namespace Luau namespace Luau
{ {
@ -735,8 +737,21 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstStatDeclareFunction* node) void write(class AstStatDeclareFunction* node)
{ {
writeNode(node, "AstStatDeclareFunction", [&]() { writeNode(node, "AstStatDeclareFunction", [&]() {
// TODO: attributes
PROP(name); PROP(name);
if (FFlag::LuauDeclarationExtraPropData)
PROP(nameLocation);
PROP(params); PROP(params);
if (FFlag::LuauDeclarationExtraPropData)
{
PROP(paramNames);
PROP(vararg);
PROP(varargLocation);
}
PROP(retTypes); PROP(retTypes);
PROP(generics); PROP(generics);
PROP(genericPacks); PROP(genericPacks);
@ -747,6 +762,10 @@ struct AstJsonEncoder : public AstVisitor
{ {
writeNode(node, "AstStatDeclareGlobal", [&]() { writeNode(node, "AstStatDeclareGlobal", [&]() {
PROP(name); PROP(name);
if (FFlag::LuauDeclarationExtraPropData)
PROP(nameLocation);
PROP(type); PROP(type);
}); });
} }
@ -756,8 +775,16 @@ struct AstJsonEncoder : public AstVisitor
writeRaw("{"); writeRaw("{");
bool c = pushComma(); bool c = pushComma();
write("name", prop.name); write("name", prop.name);
if (FFlag::LuauDeclarationExtraPropData)
write("nameLocation", prop.nameLocation);
writeType("AstDeclaredClassProp"); writeType("AstDeclaredClassProp");
write("luauType", prop.ty); write("luauType", prop.ty);
if (FFlag::LuauDeclarationExtraPropData)
write("location", prop.location);
popComma(c); popComma(c);
writeRaw("}"); writeRaw("}");
} }

View file

@ -12,6 +12,7 @@
#include <algorithm> #include <algorithm>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAGVARIABLE(LuauFixBindingForGlobalPos, false);
namespace Luau namespace Luau
{ {
@ -332,6 +333,11 @@ std::optional<TypeId> findExpectedTypeAtPosition(const Module& module, const Sou
static std::optional<AstStatLocal*> findBindingLocalStatement(const SourceModule& source, const Binding& binding) static std::optional<AstStatLocal*> findBindingLocalStatement(const SourceModule& source, const Binding& binding)
{ {
// Bindings coming from global sources (e.g., definition files) have a zero position.
// They cannot be defined from a local statement
if (FFlag::LuauFixBindingForGlobalPos && binding.location == Location{{0, 0}, {0, 0}})
return std::nullopt;
std::vector<AstNode*> nodes = findAstAncestryOfPosition(source, binding.location.begin); std::vector<AstNode*> nodes = findAstAncestryOfPosition(source, binding.location.begin);
auto iter = std::find_if(nodes.rbegin(), nodes.rend(), [](AstNode* node) { auto iter = std::find_if(nodes.rbegin(), nodes.rend(), [](AstNode* node) {
return node->is<AstStatLocal>(); return node->is<AstStatLocal>();

View file

@ -1830,12 +1830,21 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName
if (!sourceModule) if (!sourceModule)
return {}; return {};
ModulePtr module = frontend.moduleResolverForAutocomplete.getModule(moduleName); ModulePtr module;
if (FFlag::DebugLuauDeferredConstraintResolution)
module = frontend.moduleResolver.getModule(moduleName);
else
module = frontend.moduleResolverForAutocomplete.getModule(moduleName);
if (!module) if (!module)
return {}; return {};
NotNull<BuiltinTypes> builtinTypes = frontend.builtinTypes; NotNull<BuiltinTypes> builtinTypes = frontend.builtinTypes;
Scope* globalScope = frontend.globalsForAutocomplete.globalScope.get(); Scope* globalScope;
if (FFlag::DebugLuauDeferredConstraintResolution)
globalScope = frontend.globals.globalScope.get();
else
globalScope = frontend.globalsForAutocomplete.globalScope.get();
TypeArena typeArena; TypeArena typeArena;
return autocomplete(*sourceModule, module, builtinTypes, &typeArena, globalScope, position, callback); return autocomplete(*sourceModule, module, builtinTypes, &typeArena, globalScope, position, callback);

View file

@ -24,7 +24,6 @@
*/ */
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAGVARIABLE(LuauMakeStringMethodsChecked, false);
namespace Luau namespace Luau
{ {
@ -217,7 +216,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
NotNull<BuiltinTypes> builtinTypes = globals.builtinTypes; NotNull<BuiltinTypes> builtinTypes = globals.builtinTypes;
if (FFlag::DebugLuauDeferredConstraintResolution) if (FFlag::DebugLuauDeferredConstraintResolution)
kBuiltinTypeFamilies.addToScope(NotNull{&arena}, NotNull{globals.globalScope.get()}); builtinTypeFunctions().addToScope(NotNull{&arena}, NotNull{globals.globalScope.get()});
LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile( LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile(
globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false, typeCheckForAutocomplete); globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false, typeCheckForAutocomplete);
@ -257,21 +256,44 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
TypeId tableMetaMT = arena.addType(MetatableType{tabTy, genericMT}); TypeId tableMetaMT = arena.addType(MetatableType{tabTy, genericMT});
// getmetatable : <MT>({ @metatable MT, {+ +} }) -> MT
addGlobalBinding(globals, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); addGlobalBinding(globals, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau");
// clang-format off if (FFlag::DebugLuauDeferredConstraintResolution)
// setmetatable<T: {}, MT>(T, MT) -> { @metatable MT, T } {
addGlobalBinding(globals, "setmetatable", TypeId genericT = arena.addType(GenericType{"T"});
arena.addType( TypeId tMetaMT = arena.addType(MetatableType{genericT, genericMT});
FunctionType{
{genericMT}, // clang-format off
{}, // setmetatable<T: {}, MT>(T, MT) -> { @metatable MT, T }
arena.addTypePack(TypePack{{tabTy, genericMT}}), addGlobalBinding(globals, "setmetatable",
arena.addTypePack(TypePack{{tableMetaMT}}) arena.addType(
} FunctionType{
), "@luau" {genericT, genericMT},
); {},
// clang-format on arena.addTypePack(TypePack{{genericT, genericMT}}),
arena.addTypePack(TypePack{{tMetaMT}})
}
), "@luau"
);
// clang-format on
}
else
{
// clang-format off
// setmetatable<T: {}, MT>(T, MT) -> { @metatable MT, T }
addGlobalBinding(globals, "setmetatable",
arena.addType(
FunctionType{
{genericMT},
{},
arena.addTypePack(TypePack{{tabTy, genericMT}}),
arena.addTypePack(TypePack{{tableMetaMT}})
}
), "@luau"
);
// clang-format on
}
for (const auto& pair : globals.globalScope->bindings) for (const auto& pair : globals.globalScope->bindings)
{ {
@ -291,7 +313,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
// declare function assert<T>(value: T, errorMessage: string?): intersect<T, ~(false?)> // declare function assert<T>(value: T, errorMessage: string?): intersect<T, ~(false?)>
TypeId genericT = arena.addType(GenericType{"T"}); TypeId genericT = arena.addType(GenericType{"T"});
TypeId refinedTy = arena.addType(TypeFamilyInstanceType{ TypeId refinedTy = arena.addType(TypeFamilyInstanceType{
NotNull{&kBuiltinTypeFamilies.intersectFamily}, {genericT, arena.addType(NegationType{builtinTypes->falsyType})}, {}}); NotNull{&builtinTypeFunctions().intersectFamily}, {genericT, arena.addType(NegationType{builtinTypes->falsyType})}, {}});
TypeId assertTy = arena.addType(FunctionType{ TypeId assertTy = arena.addType(FunctionType{
{genericT}, {}, arena.addTypePack(TypePack{{genericT, builtinTypes->optionalStringType}}), arena.addTypePack(TypePack{{refinedTy}})}); {genericT}, {}, arena.addTypePack(TypePack{{genericT, builtinTypes->optionalStringType}}), arena.addTypePack(TypePack{{refinedTy}})});
@ -773,153 +795,87 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}}); const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}});
if (FFlag::LuauMakeStringMethodsChecked) FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack};
{ formatFTV.magicFunction = &magicFunctionFormat;
FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack}; formatFTV.isCheckedFunction = true;
formatFTV.magicFunction = &magicFunctionFormat; const TypeId formatFn = arena->addType(formatFTV);
formatFTV.isCheckedFunction = true; attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat);
const TypeId formatFn = arena->addType(formatFTV);
attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat);
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true); const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true);
const TypeId replArgType = arena->addType( const TypeId replArgType =
UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)),
makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ false)}}); makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ false)}});
const TypeId gsubFunc = const TypeId gsubFunc =
makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false); makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false);
const TypeId gmatchFunc = makeFunction( const TypeId gmatchFunc =
*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true); makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true);
attachMagicFunction(gmatchFunc, magicFunctionGmatch); attachMagicFunction(gmatchFunc, magicFunctionGmatch);
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
FunctionType matchFuncTy{ FunctionType matchFuncTy{
arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}; arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})};
matchFuncTy.isCheckedFunction = true; matchFuncTy.isCheckedFunction = true;
const TypeId matchFunc = arena->addType(matchFuncTy); const TypeId matchFunc = arena->addType(matchFuncTy);
attachMagicFunction(matchFunc, magicFunctionMatch); attachMagicFunction(matchFunc, magicFunctionMatch);
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
FunctionType findFuncTy{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), FunctionType findFuncTy{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}; arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})};
findFuncTy.isCheckedFunction = true; findFuncTy.isCheckedFunction = true;
const TypeId findFunc = arena->addType(findFuncTy); const TypeId findFunc = arena->addType(findFuncTy);
attachMagicFunction(findFunc, magicFunctionFind); attachMagicFunction(findFunc, magicFunctionFind);
attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); attachDcrMagicFunction(findFunc, dcrMagicFunctionFind);
// string.byte : string -> number? -> number? -> ...number // string.byte : string -> number? -> number? -> ...number
FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList}; FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList};
stringDotByte.isCheckedFunction = true; stringDotByte.isCheckedFunction = true;
// string.char : .... number -> string // string.char : .... number -> string
FunctionType stringDotChar{numberVariadicList, arena->addTypePack({stringType})}; FunctionType stringDotChar{numberVariadicList, arena->addTypePack({stringType})};
stringDotChar.isCheckedFunction = true; stringDotChar.isCheckedFunction = true;
// string.unpack : string -> string -> number? -> ...any // string.unpack : string -> string -> number? -> ...any
FunctionType stringDotUnpack{ FunctionType stringDotUnpack{
arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}),
variadicTailPack, variadicTailPack,
}; };
stringDotUnpack.isCheckedFunction = true; stringDotUnpack.isCheckedFunction = true;
TableType::Props stringLib = { TableType::Props stringLib = {
{"byte", {arena->addType(stringDotByte)}}, {"byte", {arena->addType(stringDotByte)}},
{"char", {arena->addType(stringDotChar)}}, {"char", {arena->addType(stringDotChar)}},
{"find", {findFunc}}, {"find", {findFunc}},
{"format", {formatFn}}, // FIXME {"format", {formatFn}}, // FIXME
{"gmatch", {gmatchFunc}}, {"gmatch", {gmatchFunc}},
{"gsub", {gsubFunc}}, {"gsub", {gsubFunc}},
{"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}}, {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}},
{"lower", {stringToStringType}}, {"lower", {stringToStringType}},
{"match", {matchFunc}}, {"match", {matchFunc}},
{"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType}, /* checked */ true)}}, {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType}, /* checked */ true)}},
{"reverse", {stringToStringType}}, {"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType}, /* checked */ true)}}, {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType}, /* checked */ true)}},
{"upper", {stringToStringType}}, {"upper", {stringToStringType}},
{"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {},
{arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})}, {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})},
/* checked */ true)}}, /* checked */ true)}},
{"pack", {arena->addType(FunctionType{ {"pack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType}, variadicTailPack}), arena->addTypePack(TypePack{{stringType}, variadicTailPack}),
oneStringPack, oneStringPack,
})}}, })}},
{"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}}, {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}},
{"unpack", {arena->addType(stringDotUnpack)}}, {"unpack", {arena->addType(stringDotUnpack)}},
}; };
assignPropDocumentationSymbols(stringLib, "@luau/global/string");
TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); assignPropDocumentationSymbols(stringLib, "@luau/global/string");
if (TableType* ttv = getMutable<TableType>(tableType)) TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed});
ttv->name = "typeof(string)";
return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); if (TableType* ttv = getMutable<TableType>(tableType))
} ttv->name = "typeof(string)";
else
{
FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack};
formatFTV.magicFunction = &magicFunctionFormat;
const TypeId formatFn = arena->addType(formatFTV);
attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat);
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}); return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
const TypeId replArgType = arena->addType(
UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)),
makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}});
const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType});
const TypeId gmatchFunc =
makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})});
attachMagicFunction(gmatchFunc, magicFunctionGmatch);
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
const TypeId matchFunc = arena->addType(FunctionType{
arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})});
attachMagicFunction(matchFunc, magicFunctionMatch);
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})});
attachMagicFunction(findFunc, magicFunctionFind);
attachDcrMagicFunction(findFunc, dcrMagicFunctionFind);
TableType::Props stringLib = {
{"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}},
{"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}},
{"find", {findFunc}},
{"format", {formatFn}}, // FIXME
{"gmatch", {gmatchFunc}},
{"gsub", {gsubFunc}},
{"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"lower", {stringToStringType}},
{"match", {matchFunc}},
{"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}},
{"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}},
{"upper", {stringToStringType}},
{"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {},
{arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}},
{"pack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType}, variadicTailPack}),
oneStringPack,
})}},
{"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"unpack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}),
variadicTailPack,
})}},
};
assignPropDocumentationSymbols(stringLib, "@luau/global/string");
TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed});
if (TableType* ttv = getMutable<TableType>(tableType))
ttv->name = "typeof(string)";
return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
}
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect( static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(

View file

@ -271,11 +271,6 @@ private:
t->upperBound = shallowClone(t->upperBound); t->upperBound = shallowClone(t->upperBound);
} }
void cloneChildren(LocalType* t)
{
t->domain = shallowClone(t->domain);
}
void cloneChildren(GenericType* t) void cloneChildren(GenericType* t)
{ {
// TOOD: clone upper bounds. // TOOD: clone upper bounds.

View file

@ -13,12 +13,12 @@ Constraint::Constraint(NotNull<Scope> scope, const Location& location, Constrain
{ {
} }
struct FreeTypeCollector : TypeOnceVisitor struct ReferenceCountInitializer : TypeOnceVisitor
{ {
DenseHashSet<TypeId>* result; DenseHashSet<TypeId>* result;
FreeTypeCollector(DenseHashSet<TypeId>* result) ReferenceCountInitializer(DenseHashSet<TypeId>* result)
: result(result) : result(result)
{ {
} }
@ -29,6 +29,18 @@ struct FreeTypeCollector : TypeOnceVisitor
return false; return false;
} }
bool visit(TypeId ty, const BlockedType&) override
{
result->insert(ty);
return false;
}
bool visit(TypeId ty, const PendingExpansionType&) override
{
result->insert(ty);
return false;
}
bool visit(TypeId ty, const ClassType&) override bool visit(TypeId ty, const ClassType&) override
{ {
// ClassTypes never contain free types. // ClassTypes never contain free types.
@ -36,26 +48,89 @@ struct FreeTypeCollector : TypeOnceVisitor
} }
}; };
DenseHashSet<TypeId> Constraint::getFreeTypes() const bool isReferenceCountedType(const TypeId typ)
{ {
DenseHashSet<TypeId> types{{}}; // n.b. this should match whatever `ReferenceCountInitializer` includes.
FreeTypeCollector ftc{&types}; return get<FreeType>(typ) || get<BlockedType>(typ) || get<PendingExpansionType>(typ);
}
if (auto sc = get<SubtypeConstraint>(*this)) DenseHashSet<TypeId> Constraint::getMaybeMutatedFreeTypes() const
{
// For the purpose of this function and reference counting in general, we are only considering
// mutations that affect the _bounds_ of the free type, and not something that may bind the free
// type itself to a new type. As such, `ReduceConstraint` and `GeneralizationConstraint` have no
// contribution to the output set here.
DenseHashSet<TypeId> types{{}};
ReferenceCountInitializer rci{&types};
if (auto ec = get<EqualityConstraint>(*this))
{ {
ftc.traverse(sc->subType); rci.traverse(ec->resultType);
ftc.traverse(sc->superType); // `EqualityConstraints` should not mutate `assignmentType`.
}
else if (auto sc = get<SubtypeConstraint>(*this))
{
rci.traverse(sc->subType);
rci.traverse(sc->superType);
} }
else if (auto psc = get<PackSubtypeConstraint>(*this)) else if (auto psc = get<PackSubtypeConstraint>(*this))
{ {
ftc.traverse(psc->subPack); rci.traverse(psc->subPack);
ftc.traverse(psc->superPack); rci.traverse(psc->superPack);
}
else if (auto itc = get<IterableConstraint>(*this))
{
for (TypeId ty : itc->variables)
rci.traverse(ty);
// `IterableConstraints` should not mutate `iterator`.
}
else if (auto nc = get<NameConstraint>(*this))
{
rci.traverse(nc->namedType);
}
else if (auto taec = get<TypeAliasExpansionConstraint>(*this))
{
rci.traverse(taec->target);
}
else if (auto fchc = get<FunctionCheckConstraint>(*this))
{
rci.traverse(fchc->argsPack);
} }
else if (auto ptc = get<PrimitiveTypeConstraint>(*this)) else if (auto ptc = get<PrimitiveTypeConstraint>(*this))
{ {
// we need to take into account primitive type constraints to prevent type families from reducing on rci.traverse(ptc->freeType);
// primitive whose types we have not yet selected to be singleton or not. }
ftc.traverse(ptc->freeType); else if (auto hpc = get<HasPropConstraint>(*this))
{
rci.traverse(hpc->resultType);
// `HasPropConstraints` should not mutate `subjectType`.
}
else if (auto hic = get<HasIndexerConstraint>(*this))
{
rci.traverse(hic->resultType);
// `HasIndexerConstraint` should not mutate `subjectType` or `indexType`.
}
else if (auto apc = get<AssignPropConstraint>(*this))
{
rci.traverse(apc->lhsType);
rci.traverse(apc->rhsType);
}
else if (auto aic = get<AssignIndexConstraint>(*this))
{
rci.traverse(aic->lhsType);
rci.traverse(aic->indexType);
rci.traverse(aic->rhsType);
}
else if (auto uc = get<UnpackConstraint>(*this))
{
for (TypeId ty : uc->resultPack)
rci.traverse(ty);
// `UnpackConstraint` should not mutate `sourcePack`.
}
else if (auto rpc = get<ReducePackConstraint>(*this))
{
rci.traverse(rpc->tp);
} }
return types; return types;

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -763,7 +763,8 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c)
for (AstExpr* arg : c->args) for (AstExpr* arg : c->args)
visitExpr(scope, arg); visitExpr(scope, arg);
return {defArena->freshCell(), nullptr}; // calls should be treated as subscripted.
return {defArena->freshCell(/* subscripted */ true), nullptr};
} }
DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i) DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i)

View file

@ -2,7 +2,7 @@
#include "Luau/BuiltinDefinitions.h" #include "Luau/BuiltinDefinitions.h"
LUAU_FASTFLAGVARIABLE(LuauCheckedEmbeddedDefinitions2, false); LUAU_FASTFLAGVARIABLE(LuauCheckedEmbeddedDefinitions2, false);
LUAU_FASTFLAG(LuauCheckedFunctionSyntax); LUAU_FASTFLAG(LuauAttributeSyntax);
namespace Luau namespace Luau
{ {
@ -320,9 +320,9 @@ declare os: {
clock: () -> number, clock: () -> number,
} }
declare function @checked require(target: any): any @checked declare function require(target: any): any
declare function @checked getfenv(target: any): { [string]: any } @checked declare function getfenv(target: any): { [string]: any }
declare _G: any declare _G: any
declare _VERSION: string declare _VERSION: string
@ -364,7 +364,7 @@ declare function select<A...>(i: string | number, ...: A...): ...any
-- (nil, string). -- (nil, string).
declare function loadstring<A...>(src: string, chunkname: string?): (((A...) -> any)?, string?) declare function loadstring<A...>(src: string, chunkname: string?): (((A...) -> any)?, string?)
declare function @checked newproxy(mt: boolean?): any @checked declare function newproxy(mt: boolean?): any
declare coroutine: { declare coroutine: {
create: <A..., R...>(f: (A...) -> R...) -> thread, create: <A..., R...>(f: (A...) -> R...) -> thread,
@ -452,7 +452,7 @@ std::string getBuiltinDefinitionSource()
std::string result = kBuiltinDefinitionLuaSrc; std::string result = kBuiltinDefinitionLuaSrc;
// Annotates each non generic function as checked // Annotates each non generic function as checked
if (FFlag::LuauCheckedEmbeddedDefinitions2 && FFlag::LuauCheckedFunctionSyntax) if (FFlag::LuauCheckedEmbeddedDefinitions2 && FFlag::LuauAttributeSyntax)
result = kBuiltinDefinitionLuaSrcChecked; result = kBuiltinDefinitionLuaSrcChecked;
return result; return result;

View file

@ -7,11 +7,14 @@
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/StringUtils.h" #include "Luau/StringUtils.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/Type.h"
#include "Luau/TypeFamily.h"
#include <optional> #include <optional>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <unordered_set>
LUAU_FASTINTVARIABLE(LuauIndentTypeMismatchMaxTypeLength, 10) LUAU_FASTINTVARIABLE(LuauIndentTypeMismatchMaxTypeLength, 10)
@ -61,6 +64,17 @@ static std::string wrongNumberOfArgsString(
namespace Luau namespace Luau
{ {
// this list of binary operator type families is used for better stringification of type families errors
static const std::unordered_map<std::string, const char*> kBinaryOps{{"add", "+"}, {"sub", "-"}, {"mul", "*"}, {"div", "/"}, {"idiv", "//"},
{"pow", "^"}, {"mod", "%"}, {"concat", ".."}, {"and", "and"}, {"or", "or"}, {"lt", "< or >="}, {"le", "<= or >"}, {"eq", "== or ~="}};
// this list of unary operator type families is used for better stringification of type families errors
static const std::unordered_map<std::string, const char*> kUnaryOps{{"unm", "-"}, {"len", "#"}, {"not", "not"}};
// this list of type families will receive a special error indicating that the user should file a bug on the GitHub repository
// putting a type family in this list indicates that it is expected to _always_ reduce
static const std::unordered_set<std::string> kUnreachableTypeFamilies{"refine", "singleton", "union", "intersect"};
struct ErrorConverter struct ErrorConverter
{ {
FileResolver* fileResolver = nullptr; FileResolver* fileResolver = nullptr;
@ -565,6 +579,108 @@ struct ErrorConverter
std::string operator()(const UninhabitedTypeFamily& e) const std::string operator()(const UninhabitedTypeFamily& e) const
{ {
auto tfit = get<TypeFamilyInstanceType>(e.ty);
LUAU_ASSERT(tfit); // Luau analysis has actually done something wrong if this type is not a type family.
if (!tfit)
return "Unexpected type " + Luau::toString(e.ty) + " flagged as an uninhabited type family.";
// unary operators
if (auto unaryString = kUnaryOps.find(tfit->family->name); unaryString != kUnaryOps.end())
{
std::string result = "Operator '" + std::string(unaryString->second) + "' could not be applied to ";
if (tfit->typeArguments.size() == 1 && tfit->packArguments.empty())
{
result += "operand of type " + Luau::toString(tfit->typeArguments[0]);
if (tfit->family->name != "not")
result += "; there is no corresponding overload for __" + tfit->family->name;
}
else
{
// if it's not the expected case, we ought to add a specialization later, but this is a sane default.
result += "operands of types ";
bool isFirst = true;
for (auto arg : tfit->typeArguments)
{
if (!isFirst)
result += ", ";
result += Luau::toString(arg);
isFirst = false;
}
for (auto packArg : tfit->packArguments)
result += ", " + Luau::toString(packArg);
}
return result;
}
// binary operators
if (auto binaryString = kBinaryOps.find(tfit->family->name); binaryString != kBinaryOps.end())
{
std::string result = "Operator '" + std::string(binaryString->second) + "' could not be applied to operands of types ";
if (tfit->typeArguments.size() == 2 && tfit->packArguments.empty())
{
// this is the expected case.
result += Luau::toString(tfit->typeArguments[0]) + " and " + Luau::toString(tfit->typeArguments[1]);
}
else
{
// if it's not the expected case, we ought to add a specialization later, but this is a sane default.
bool isFirst = true;
for (auto arg : tfit->typeArguments)
{
if (!isFirst)
result += ", ";
result += Luau::toString(arg);
isFirst = false;
}
for (auto packArg : tfit->packArguments)
result += ", " + Luau::toString(packArg);
}
result += "; there is no corresponding overload for __" + tfit->family->name;
return result;
}
// miscellaneous
if ("keyof" == tfit->family->name || "rawkeyof" == tfit->family->name)
{
if (tfit->typeArguments.size() == 1 && tfit->packArguments.empty())
return "Type '" + toString(tfit->typeArguments[0]) + "' does not have keys, so '" + Luau::toString(e.ty) + "' is invalid";
else
return "Type family instance " + Luau::toString(e.ty) + " is ill-formed, and thus invalid";
}
if ("index" == tfit->family->name || "rawget" == tfit->family->name)
{
if (tfit->typeArguments.size() != 2)
return "Type family instance " + Luau::toString(e.ty) + " is ill-formed, and thus invalid";
if (auto errType = get<ErrorType>(tfit->typeArguments[1])) // Second argument to (index | rawget)<_,_> is not a type
return "Second argument to " + tfit->family->name + "<" + Luau::toString(tfit->typeArguments[0]) + ", _> is not a valid index type";
else // Property `indexer` does not exist on type `indexee`
return "Property '" + Luau::toString(tfit->typeArguments[1]) + "' does not exist on type '" + Luau::toString(tfit->typeArguments[0]) +
"'";
}
if (kUnreachableTypeFamilies.count(tfit->family->name))
{
return "Type family instance " + Luau::toString(e.ty) + " is uninhabited\n" +
"This is likely to be a bug, please report it at https://github.com/luau-lang/luau/issues";
}
// Everything should be specialized above to report a more descriptive error that hopefully does not mention "type families" explicitly.
// If we produce this message, it's an indication that we've missed a specialization and it should be fixed!
return "Type family instance " + Luau::toString(e.ty) + " is uninhabited"; return "Type family instance " + Luau::toString(e.ty) + " is uninhabited";
} }
@ -1205,7 +1321,7 @@ void copyError(T& e, TypeArena& destArena, CloneState& cloneState)
else if constexpr (std::is_same_v<T, ExplicitFunctionAnnotationRecommended>) else if constexpr (std::is_same_v<T, ExplicitFunctionAnnotationRecommended>)
{ {
e.recommendedReturn = clone(e.recommendedReturn); e.recommendedReturn = clone(e.recommendedReturn);
for (auto [_, t] : e.recommendedArgs) for (auto& [_, t] : e.recommendedArgs)
t = clone(t); t = clone(t);
} }
else if constexpr (std::is_same_v<T, UninhabitedTypePackFamily>) else if constexpr (std::is_same_v<T, UninhabitedTypePackFamily>)

View file

@ -34,6 +34,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAG(LuauInferInNoCheckMode)
LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false)
LUAU_FASTFLAGVARIABLE(LuauCancelFromProgress, false)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJsonFile, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJsonFile, false)
@ -440,6 +441,8 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
FrontendOptions frontendOptions = optionOverride.value_or(options); FrontendOptions frontendOptions = optionOverride.value_or(options);
if (FFlag::DebugLuauDeferredConstraintResolution)
frontendOptions.forAutocomplete = false;
if (std::optional<CheckResult> result = getCheckResult(name, true, frontendOptions.forAutocomplete)) if (std::optional<CheckResult> result = getCheckResult(name, true, frontendOptions.forAutocomplete))
return std::move(*result); return std::move(*result);
@ -492,9 +495,11 @@ void Frontend::queueModuleCheck(const ModuleName& name)
} }
std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptions> optionOverride, std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptions> optionOverride,
std::function<void(std::function<void()> task)> executeTask, std::function<void(size_t done, size_t total)> progress) std::function<void(std::function<void()> task)> executeTask, std::function<bool(size_t done, size_t total)> progress)
{ {
FrontendOptions frontendOptions = optionOverride.value_or(options); FrontendOptions frontendOptions = optionOverride.value_or(options);
if (FFlag::DebugLuauDeferredConstraintResolution)
frontendOptions.forAutocomplete = false;
// By taking data into locals, we make sure queue is cleared at the end, even if an ICE or a different exception is thrown // By taking data into locals, we make sure queue is cleared at the end, even if an ICE or a different exception is thrown
std::vector<ModuleName> currModuleQueue; std::vector<ModuleName> currModuleQueue;
@ -673,7 +678,17 @@ std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptio
} }
if (progress) if (progress)
progress(buildQueueItems.size() - remaining, buildQueueItems.size()); {
if (FFlag::LuauCancelFromProgress)
{
if (!progress(buildQueueItems.size() - remaining, buildQueueItems.size()))
cancelled = true;
}
else
{
progress(buildQueueItems.size() - remaining, buildQueueItems.size());
}
}
// Items cannot be submitted while holding the lock // Items cannot be submitted while holding the lock
for (size_t i : nextItems) for (size_t i : nextItems)
@ -707,6 +722,9 @@ std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptio
std::optional<CheckResult> Frontend::getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete) std::optional<CheckResult> Frontend::getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete)
{ {
if (FFlag::DebugLuauDeferredConstraintResolution)
forAutocomplete = false;
auto it = sourceNodes.find(name); auto it = sourceNodes.find(name);
if (it == sourceNodes.end() || it->second->hasDirtyModule(forAutocomplete)) if (it == sourceNodes.end() || it->second->hasDirtyModule(forAutocomplete))
@ -1003,11 +1021,10 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item)
module->astForInNextTypes.clear(); module->astForInNextTypes.clear();
module->astResolvedTypes.clear(); module->astResolvedTypes.clear();
module->astResolvedTypePacks.clear(); module->astResolvedTypePacks.clear();
module->astCompoundAssignResultTypes.clear();
module->astScopes.clear(); module->astScopes.clear();
module->upperBoundContributors.clear(); module->upperBoundContributors.clear();
module->scopes.clear();
if (!FFlag::DebugLuauDeferredConstraintResolution)
module->scopes.clear();
} }
if (mode != Mode::NoCheck) if (mode != Mode::NoCheck)
@ -1196,12 +1213,6 @@ struct InternalTypeFinder : TypeOnceVisitor
return false; return false;
} }
bool visit(TypeId, const LocalType&) override
{
LUAU_ASSERT(false);
return false;
}
bool visit(TypePackId, const BlockedTypePack&) override bool visit(TypePackId, const BlockedTypePack&) override
{ {
LUAU_ASSERT(false); LUAU_ASSERT(false);
@ -1297,6 +1308,30 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<R
result->type = sourceModule.type; result->type = sourceModule.type;
result->upperBoundContributors = std::move(cs.upperBoundContributors); result->upperBoundContributors = std::move(cs.upperBoundContributors);
if (result->timeout || result->cancelled)
{
// If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending
// types
ScopePtr moduleScope = result->getModuleScope();
moduleScope->returnType = builtinTypes->errorRecoveryTypePack();
for (auto& [name, ty] : result->declaredGlobals)
ty = builtinTypes->errorRecoveryType();
for (auto& [name, tf] : result->exportedTypeBindings)
tf.type = builtinTypes->errorRecoveryType();
}
else
{
if (mode == Mode::Nonstrict)
Luau::checkNonStrict(builtinTypes, iceHandler, NotNull{&unifierState}, NotNull{&dfg}, NotNull{&limits}, sourceModule, result.get());
else
Luau::check(builtinTypes, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get());
}
unfreeze(result->interfaceTypes);
result->clonePublicInterface(builtinTypes, *iceHandler);
if (FFlag::DebugLuauForbidInternalTypes) if (FFlag::DebugLuauForbidInternalTypes)
{ {
InternalTypeFinder finder; InternalTypeFinder finder;
@ -1325,30 +1360,6 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<R
finder.traverse(tp); finder.traverse(tp);
} }
if (result->timeout || result->cancelled)
{
// If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending
// types
ScopePtr moduleScope = result->getModuleScope();
moduleScope->returnType = builtinTypes->errorRecoveryTypePack();
for (auto& [name, ty] : result->declaredGlobals)
ty = builtinTypes->errorRecoveryType();
for (auto& [name, tf] : result->exportedTypeBindings)
tf.type = builtinTypes->errorRecoveryType();
}
else
{
if (mode == Mode::Nonstrict)
Luau::checkNonStrict(builtinTypes, iceHandler, NotNull{&unifierState}, NotNull{&dfg}, NotNull{&limits}, sourceModule, result.get());
else
Luau::check(builtinTypes, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get());
}
unfreeze(result->interfaceTypes);
result->clonePublicInterface(builtinTypes, *iceHandler);
// It would be nice if we could freeze the arenas before doing type // It would be nice if we could freeze the arenas before doing type
// checking, but we'll have to do some work to get there. // checking, but we'll have to do some work to get there.
// //

View file

@ -0,0 +1,910 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Generalization.h"
#include "Luau/Scope.h"
#include "Luau/Type.h"
#include "Luau/ToString.h"
#include "Luau/TypeArena.h"
#include "Luau/TypePack.h"
#include "Luau/VisitType.h"
namespace Luau
{
struct MutatingGeneralizer : TypeOnceVisitor
{
NotNull<BuiltinTypes> builtinTypes;
NotNull<Scope> scope;
NotNull<DenseHashSet<TypeId>> cachedTypes;
DenseHashMap<const void*, size_t> positiveTypes;
DenseHashMap<const void*, size_t> negativeTypes;
std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks;
bool isWithinFunction = false;
bool avoidSealingTables = false;
MutatingGeneralizer(NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope, NotNull<DenseHashSet<TypeId>> cachedTypes,
DenseHashMap<const void*, size_t> positiveTypes, DenseHashMap<const void*, size_t> negativeTypes, bool avoidSealingTables)
: TypeOnceVisitor(/* skipBoundTypes */ true)
, builtinTypes(builtinTypes)
, scope(scope)
, cachedTypes(cachedTypes)
, positiveTypes(std::move(positiveTypes))
, negativeTypes(std::move(negativeTypes))
, avoidSealingTables(avoidSealingTables)
{
}
static void replace(DenseHashSet<TypeId>& seen, TypeId haystack, TypeId needle, TypeId replacement)
{
haystack = follow(haystack);
if (seen.find(haystack))
return;
seen.insert(haystack);
if (UnionType* ut = getMutable<UnionType>(haystack))
{
for (auto iter = ut->options.begin(); iter != ut->options.end();)
{
// FIXME: I bet this function has reentrancy problems
TypeId option = follow(*iter);
if (option == needle && get<NeverType>(replacement))
{
iter = ut->options.erase(iter);
continue;
}
if (option == needle)
{
*iter = replacement;
iter++;
continue;
}
// advance the iterator, nothing after this can use it.
iter++;
if (seen.find(option))
continue;
seen.insert(option);
if (get<UnionType>(option))
replace(seen, option, needle, haystack);
else if (get<IntersectionType>(option))
replace(seen, option, needle, haystack);
}
if (ut->options.size() == 1)
{
TypeId onlyType = ut->options[0];
LUAU_ASSERT(onlyType != haystack);
emplaceType<BoundType>(asMutable(haystack), onlyType);
}
return;
}
if (IntersectionType* it = getMutable<IntersectionType>(needle))
{
for (auto iter = it->parts.begin(); iter != it->parts.end();)
{
// FIXME: I bet this function has reentrancy problems
TypeId part = follow(*iter);
if (part == needle && get<UnknownType>(replacement))
{
iter = it->parts.erase(iter);
continue;
}
if (part == needle)
{
*iter = replacement;
iter++;
continue;
}
// advance the iterator, nothing after this can use it.
iter++;
if (seen.find(part))
continue;
seen.insert(part);
if (get<UnionType>(part))
replace(seen, part, needle, haystack);
else if (get<IntersectionType>(part))
replace(seen, part, needle, haystack);
}
if (it->parts.size() == 1)
{
TypeId onlyType = it->parts[0];
LUAU_ASSERT(onlyType != needle);
emplaceType<BoundType>(asMutable(needle), onlyType);
}
return;
}
}
bool visit(TypeId ty, const FunctionType& ft) override
{
if (cachedTypes->contains(ty))
return false;
const bool oldValue = isWithinFunction;
isWithinFunction = true;
traverse(ft.argTypes);
traverse(ft.retTypes);
isWithinFunction = oldValue;
return false;
}
bool visit(TypeId ty, const FreeType&) override
{
LUAU_ASSERT(!cachedTypes->contains(ty));
const FreeType* ft = get<FreeType>(ty);
LUAU_ASSERT(ft);
traverse(ft->lowerBound);
traverse(ft->upperBound);
// It is possible for the above traverse() calls to cause ty to be
// transmuted. We must reacquire ft if this happens.
ty = follow(ty);
ft = get<FreeType>(ty);
if (!ft)
return false;
const size_t positiveCount = getCount(positiveTypes, ty);
const size_t negativeCount = getCount(negativeTypes, ty);
if (!positiveCount && !negativeCount)
return false;
const bool hasLowerBound = !get<NeverType>(follow(ft->lowerBound));
const bool hasUpperBound = !get<UnknownType>(follow(ft->upperBound));
DenseHashSet<TypeId> seen{nullptr};
seen.insert(ty);
if (!hasLowerBound && !hasUpperBound)
{
if (!isWithinFunction || (positiveCount + negativeCount == 1))
emplaceType<BoundType>(asMutable(ty), builtinTypes->unknownType);
else
{
emplaceType<GenericType>(asMutable(ty), scope);
generics.push_back(ty);
}
}
// It is possible that this free type has other free types in its upper
// or lower bounds. If this is the case, we must replace those
// references with never (for the lower bound) or unknown (for the upper
// bound).
//
// If we do not do this, we get tautological bounds like a <: a <: unknown.
else if (positiveCount && !hasUpperBound)
{
TypeId lb = follow(ft->lowerBound);
if (FreeType* lowerFree = getMutable<FreeType>(lb); lowerFree && lowerFree->upperBound == ty)
lowerFree->upperBound = builtinTypes->unknownType;
else
{
DenseHashSet<TypeId> replaceSeen{nullptr};
replace(replaceSeen, lb, ty, builtinTypes->unknownType);
}
if (lb != ty)
emplaceType<BoundType>(asMutable(ty), lb);
else if (!isWithinFunction || (positiveCount + negativeCount == 1))
emplaceType<BoundType>(asMutable(ty), builtinTypes->unknownType);
else
{
// if the lower bound is the type in question, we don't actually have a lower bound.
emplaceType<GenericType>(asMutable(ty), scope);
generics.push_back(ty);
}
}
else
{
TypeId ub = follow(ft->upperBound);
if (FreeType* upperFree = getMutable<FreeType>(ub); upperFree && upperFree->lowerBound == ty)
upperFree->lowerBound = builtinTypes->neverType;
else
{
DenseHashSet<TypeId> replaceSeen{nullptr};
replace(replaceSeen, ub, ty, builtinTypes->neverType);
}
if (ub != ty)
emplaceType<BoundType>(asMutable(ty), ub);
else if (!isWithinFunction || (positiveCount + negativeCount == 1))
emplaceType<BoundType>(asMutable(ty), builtinTypes->unknownType);
else
{
// if the upper bound is the type in question, we don't actually have an upper bound.
emplaceType<GenericType>(asMutable(ty), scope);
generics.push_back(ty);
}
}
return false;
}
size_t getCount(const DenseHashMap<const void*, size_t>& map, const void* ty)
{
if (const size_t* count = map.find(ty))
return *count;
else
return 0;
}
bool visit(TypeId ty, const TableType&) override
{
if (cachedTypes->contains(ty))
return false;
const size_t positiveCount = getCount(positiveTypes, ty);
const size_t negativeCount = getCount(negativeTypes, ty);
// FIXME: Free tables should probably just be replaced by upper bounds on free types.
//
// eg never <: 'a <: {x: number} & {z: boolean}
if (!positiveCount && !negativeCount)
return true;
TableType* tt = getMutable<TableType>(ty);
LUAU_ASSERT(tt);
if (!avoidSealingTables)
tt->state = TableState::Sealed;
return true;
}
bool visit(TypePackId tp, const FreeTypePack& ftp) override
{
if (!subsumes(scope, ftp.scope))
return true;
tp = follow(tp);
const size_t positiveCount = getCount(positiveTypes, tp);
const size_t negativeCount = getCount(negativeTypes, tp);
if (1 == positiveCount + negativeCount)
emplaceTypePack<BoundTypePack>(asMutable(tp), builtinTypes->unknownTypePack);
else
{
emplaceTypePack<GenericTypePack>(asMutable(tp), scope);
genericPacks.push_back(tp);
}
return true;
}
};
struct FreeTypeSearcher : TypeVisitor
{
NotNull<Scope> scope;
NotNull<DenseHashSet<TypeId>> cachedTypes;
explicit FreeTypeSearcher(NotNull<Scope> scope, NotNull<DenseHashSet<TypeId>> cachedTypes)
: TypeVisitor(/*skipBoundTypes*/ true)
, scope(scope)
, cachedTypes(cachedTypes)
{
}
enum Polarity
{
Positive,
Negative,
Both,
};
Polarity polarity = Positive;
void flip()
{
switch (polarity)
{
case Positive:
polarity = Negative;
break;
case Negative:
polarity = Positive;
break;
case Both:
break;
}
}
DenseHashSet<const void*> seenPositive{nullptr};
DenseHashSet<const void*> seenNegative{nullptr};
bool seenWithPolarity(const void* ty)
{
switch (polarity)
{
case Positive:
{
if (seenPositive.contains(ty))
return true;
seenPositive.insert(ty);
return false;
}
case Negative:
{
if (seenNegative.contains(ty))
return true;
seenNegative.insert(ty);
return false;
}
case Both:
{
if (seenPositive.contains(ty) && seenNegative.contains(ty))
return true;
seenPositive.insert(ty);
seenNegative.insert(ty);
return false;
}
}
return false;
}
// The keys in these maps are either TypeIds or TypePackIds. It's safe to
// mix them because we only use these pointers as unique keys. We never
// indirect them.
DenseHashMap<const void*, size_t> negativeTypes{0};
DenseHashMap<const void*, size_t> positiveTypes{0};
bool visit(TypeId ty) override
{
if (cachedTypes->contains(ty) || seenWithPolarity(ty))
return false;
LUAU_ASSERT(ty);
return true;
}
bool visit(TypeId ty, const FreeType& ft) override
{
if (cachedTypes->contains(ty) || seenWithPolarity(ty))
return false;
if (!subsumes(scope, ft.scope))
return true;
switch (polarity)
{
case Positive:
positiveTypes[ty]++;
break;
case Negative:
negativeTypes[ty]++;
break;
case Both:
positiveTypes[ty]++;
negativeTypes[ty]++;
break;
}
return true;
}
bool visit(TypeId ty, const TableType& tt) override
{
if (cachedTypes->contains(ty) || seenWithPolarity(ty))
return false;
if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope))
{
switch (polarity)
{
case Positive:
positiveTypes[ty]++;
break;
case Negative:
negativeTypes[ty]++;
break;
case Both:
positiveTypes[ty]++;
negativeTypes[ty]++;
break;
}
}
for (const auto& [_name, prop] : tt.props)
{
if (prop.isReadOnly())
traverse(*prop.readTy);
else
{
LUAU_ASSERT(prop.isShared());
Polarity p = polarity;
polarity = Both;
traverse(prop.type());
polarity = p;
}
}
if (tt.indexer)
{
traverse(tt.indexer->indexType);
traverse(tt.indexer->indexResultType);
}
return false;
}
bool visit(TypeId ty, const FunctionType& ft) override
{
if (cachedTypes->contains(ty) || seenWithPolarity(ty))
return false;
flip();
traverse(ft.argTypes);
flip();
traverse(ft.retTypes);
return false;
}
bool visit(TypeId, const ClassType&) override
{
return false;
}
bool visit(TypePackId tp, const FreeTypePack& ftp) override
{
if (seenWithPolarity(tp))
return false;
if (!subsumes(scope, ftp.scope))
return true;
switch (polarity)
{
case Positive:
positiveTypes[tp]++;
break;
case Negative:
negativeTypes[tp]++;
break;
case Both:
positiveTypes[tp]++;
negativeTypes[tp]++;
break;
}
return true;
}
};
// We keep a running set of types that will not change under generalization and
// only have outgoing references to types that are the same. We use this to
// short circuit generalization. It improves performance quite a lot.
//
// We do this by tracing through the type and searching for types that are
// uncacheable. If a type has a reference to an uncacheable type, it is itself
// uncacheable.
//
// If a type has no outbound references to uncacheable types, we add it to the
// cache.
struct TypeCacher : TypeOnceVisitor
{
NotNull<DenseHashSet<TypeId>> cachedTypes;
DenseHashSet<TypeId> uncacheable{nullptr};
DenseHashSet<TypePackId> uncacheablePacks{nullptr};
explicit TypeCacher(NotNull<DenseHashSet<TypeId>> cachedTypes)
: TypeOnceVisitor(/* skipBoundTypes */ true)
, cachedTypes(cachedTypes)
{
}
void cache(TypeId ty)
{
cachedTypes->insert(ty);
}
bool isCached(TypeId ty) const
{
return cachedTypes->contains(ty);
}
void markUncacheable(TypeId ty)
{
uncacheable.insert(ty);
}
void markUncacheable(TypePackId tp)
{
uncacheablePacks.insert(tp);
}
bool isUncacheable(TypeId ty) const
{
return uncacheable.contains(ty);
}
bool isUncacheable(TypePackId tp) const
{
return uncacheablePacks.contains(tp);
}
bool visit(TypeId ty) override
{
if (isUncacheable(ty) || isCached(ty))
return false;
return true;
}
bool visit(TypeId ty, const FreeType& ft) override
{
// Free types are never cacheable.
LUAU_ASSERT(!isCached(ty));
if (!isUncacheable(ty))
{
traverse(ft.lowerBound);
traverse(ft.upperBound);
markUncacheable(ty);
}
return false;
}
bool visit(TypeId ty, const GenericType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const PrimitiveType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const SingletonType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const BlockedType&) override
{
markUncacheable(ty);
return false;
}
bool visit(TypeId ty, const PendingExpansionType&) override
{
markUncacheable(ty);
return false;
}
bool visit(TypeId ty, const FunctionType& ft) override
{
if (isCached(ty) || isUncacheable(ty))
return false;
traverse(ft.argTypes);
traverse(ft.retTypes);
for (TypeId gen : ft.generics)
traverse(gen);
bool uncacheable = false;
if (isUncacheable(ft.argTypes))
uncacheable = true;
else if (isUncacheable(ft.retTypes))
uncacheable = true;
for (TypeId argTy : ft.argTypes)
{
if (isUncacheable(argTy))
{
uncacheable = true;
break;
}
}
for (TypeId retTy : ft.retTypes)
{
if (isUncacheable(retTy))
{
uncacheable = true;
break;
}
}
for (TypeId g : ft.generics)
{
if (isUncacheable(g))
{
uncacheable = true;
break;
}
}
if (uncacheable)
markUncacheable(ty);
else
cache(ty);
return false;
}
bool visit(TypeId ty, const TableType& tt) override
{
if (isCached(ty) || isUncacheable(ty))
return false;
if (tt.boundTo)
{
traverse(*tt.boundTo);
if (isUncacheable(*tt.boundTo))
{
markUncacheable(ty);
return false;
}
}
bool uncacheable = false;
// This logic runs immediately after generalization, so any remaining
// unsealed tables are assuredly not cacheable. They may yet have
// properties added to them.
if (tt.state == TableState::Free || tt.state == TableState::Unsealed)
uncacheable = true;
for (const auto& [_name, prop] : tt.props)
{
if (prop.readTy)
{
traverse(*prop.readTy);
if (isUncacheable(*prop.readTy))
uncacheable = true;
}
if (prop.writeTy && prop.writeTy != prop.readTy)
{
traverse(*prop.writeTy);
if (isUncacheable(*prop.writeTy))
uncacheable = true;
}
}
if (tt.indexer)
{
traverse(tt.indexer->indexType);
if (isUncacheable(tt.indexer->indexType))
uncacheable = true;
traverse(tt.indexer->indexResultType);
if (isUncacheable(tt.indexer->indexResultType))
uncacheable = true;
}
if (uncacheable)
markUncacheable(ty);
else
cache(ty);
return false;
}
bool visit(TypeId ty, const ClassType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const AnyType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const UnionType& ut) override
{
if (isUncacheable(ty) || isCached(ty))
return false;
bool uncacheable = false;
for (TypeId partTy : ut.options)
{
traverse(partTy);
uncacheable |= isUncacheable(partTy);
}
if (uncacheable)
markUncacheable(ty);
else
cache(ty);
return false;
}
bool visit(TypeId ty, const IntersectionType& it) override
{
if (isUncacheable(ty) || isCached(ty))
return false;
bool uncacheable = false;
for (TypeId partTy : it.parts)
{
traverse(partTy);
uncacheable |= isUncacheable(partTy);
}
if (uncacheable)
markUncacheable(ty);
else
cache(ty);
return false;
}
bool visit(TypeId ty, const UnknownType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const NeverType&) override
{
cache(ty);
return false;
}
bool visit(TypeId ty, const NegationType& nt) override
{
if (!isCached(ty) && !isUncacheable(ty))
{
traverse(nt.ty);
if (isUncacheable(nt.ty))
markUncacheable(ty);
else
cache(ty);
}
return false;
}
bool visit(TypeId ty, const TypeFamilyInstanceType& tfit) override
{
if (isCached(ty) || isUncacheable(ty))
return false;
bool uncacheable = false;
for (TypeId argTy : tfit.typeArguments)
{
traverse(argTy);
if (isUncacheable(argTy))
uncacheable = true;
}
for (TypePackId argPack : tfit.packArguments)
{
traverse(argPack);
if (isUncacheable(argPack))
uncacheable = true;
}
if (uncacheable)
markUncacheable(ty);
else
cache(ty);
return false;
}
bool visit(TypePackId tp, const FreeTypePack&) override
{
markUncacheable(tp);
return false;
}
bool visit(TypePackId tp, const VariadicTypePack& vtp) override
{
if (isUncacheable(tp))
return false;
traverse(vtp.ty);
if (isUncacheable(vtp.ty))
markUncacheable(tp);
return false;
}
bool visit(TypePackId tp, const BlockedTypePack&) override
{
markUncacheable(tp);
return false;
}
bool visit(TypePackId tp, const TypeFamilyInstanceTypePack&) override
{
markUncacheable(tp);
return false;
}
};
std::optional<TypeId> generalize(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> cachedTypes, TypeId ty, bool avoidSealingTables)
{
ty = follow(ty);
if (ty->owningArena != arena || ty->persistent)
return ty;
if (const FunctionType* ft = get<FunctionType>(ty); ft && (!ft->generics.empty() || !ft->genericPacks.empty()))
return ty;
FreeTypeSearcher fts{scope, cachedTypes};
fts.traverse(ty);
MutatingGeneralizer gen{builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables};
gen.traverse(ty);
/* MutatingGeneralizer mutates types in place, so it is possible that ty has
* been transmuted to a BoundType. We must follow it again and verify that
* we are allowed to mutate it before we attach generics to it.
*/
ty = follow(ty);
if (ty->owningArena != arena || ty->persistent)
return ty;
TypeCacher cacher{cachedTypes};
cacher.traverse(ty);
FunctionType* ftv = getMutable<FunctionType>(ty);
if (ftv)
{
ftv->generics = std::move(gen.generics);
ftv->genericPacks = std::move(gen.genericPacks);
}
return ty;
}
} // namespace Luau

View file

@ -11,10 +11,23 @@
#include <algorithm> #include <algorithm>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauReusableSubstitutions)
namespace Luau namespace Luau
{ {
void Instantiation::resetState(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope)
{
LUAU_ASSERT(FFlag::LuauReusableSubstitutions);
Substitution::resetState(log, arena);
this->builtinTypes = builtinTypes;
this->level = level;
this->scope = scope;
}
bool Instantiation::isDirty(TypeId ty) bool Instantiation::isDirty(TypeId ty)
{ {
if (const FunctionType* ftv = log->getMutable<FunctionType>(ty)) if (const FunctionType* ftv = log->getMutable<FunctionType>(ty))
@ -58,13 +71,26 @@ TypeId Instantiation::clean(TypeId ty)
clone.argNames = ftv->argNames; clone.argNames = ftv->argNames;
TypeId result = addType(std::move(clone)); TypeId result = addType(std::move(clone));
// Annoyingly, we have to do this even if there are no generics, if (FFlag::LuauReusableSubstitutions)
// to replace any generic tables. {
ReplaceGenerics replaceGenerics{log, arena, builtinTypes, level, scope, ftv->generics, ftv->genericPacks}; // Annoyingly, we have to do this even if there are no generics,
// to replace any generic tables.
reusableReplaceGenerics.resetState(log, arena, builtinTypes, level, scope, ftv->generics, ftv->genericPacks);
// TODO: What to do if this returns nullopt? // TODO: What to do if this returns nullopt?
// We don't have access to the error-reporting machinery // We don't have access to the error-reporting machinery
result = replaceGenerics.substitute(result).value_or(result); result = reusableReplaceGenerics.substitute(result).value_or(result);
}
else
{
// Annoyingly, we have to do this even if there are no generics,
// to replace any generic tables.
ReplaceGenerics replaceGenerics{log, arena, builtinTypes, level, scope, ftv->generics, ftv->genericPacks};
// TODO: What to do if this returns nullopt?
// We don't have access to the error-reporting machinery
result = replaceGenerics.substitute(result).value_or(result);
}
asMutable(result)->documentationSymbol = ty->documentationSymbol; asMutable(result)->documentationSymbol = ty->documentationSymbol;
return result; return result;
@ -76,6 +102,22 @@ TypePackId Instantiation::clean(TypePackId tp)
return tp; return tp;
} }
void ReplaceGenerics::resetState(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope,
const std::vector<TypeId>& generics, const std::vector<TypePackId>& genericPacks)
{
LUAU_ASSERT(FFlag::LuauReusableSubstitutions);
Substitution::resetState(log, arena);
this->builtinTypes = builtinTypes;
this->level = level;
this->scope = scope;
this->generics = generics;
this->genericPacks = genericPacks;
}
bool ReplaceGenerics::ignoreChildren(TypeId ty) bool ReplaceGenerics::ignoreChildren(TypeId ty)
{ {
if (const FunctionType* ftv = log->getMutable<FunctionType>(ty)) if (const FunctionType* ftv = log->getMutable<FunctionType>(ty))

View file

@ -16,6 +16,11 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauAttributeSyntax)
LUAU_FASTFLAG(LuauAttribute)
LUAU_FASTFLAG(LuauNativeAttribute)
LUAU_FASTFLAGVARIABLE(LintRedundantNativeAttribute, false)
namespace Luau namespace Luau
{ {
@ -2922,6 +2927,64 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
} }
} }
static bool hasNativeCommentDirective(const std::vector<HotComment>& hotcomments)
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
LUAU_ASSERT(FFlag::LuauNativeAttribute);
LUAU_ASSERT(FFlag::LintRedundantNativeAttribute);
for (const HotComment& hc : hotcomments)
{
if (hc.content.empty() || hc.content[0] == ' ' || hc.content[0] == '\t')
continue;
if (hc.header)
{
size_t space = hc.content.find_first_of(" \t");
std::string_view first = std::string_view(hc.content).substr(0, space);
if (first == "native")
return true;
}
}
return false;
}
struct LintRedundantNativeAttribute : AstVisitor
{
public:
LUAU_NOINLINE static void process(LintContext& context)
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
LUAU_ASSERT(FFlag::LuauNativeAttribute);
LUAU_ASSERT(FFlag::LintRedundantNativeAttribute);
LintRedundantNativeAttribute pass;
pass.context = &context;
context.root->visit(&pass);
}
private:
LintContext* context;
bool visit(AstExprFunction* node) override
{
node->body->visit(this);
for (const auto attribute : node->attributes)
{
if (attribute->type == AstAttr::Type::Native)
{
emitWarning(*context, LintWarning::Code_RedundantNativeAttribute, attribute->location,
"native attribute on a function is redundant in a native module; consider removing it");
}
}
return false;
}
};
std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module,
const std::vector<HotComment>& hotcomments, const LintOptions& options) const std::vector<HotComment>& hotcomments, const LintOptions& options)
{ {
@ -3008,6 +3071,13 @@ std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const Sc
if (context.warningEnabled(LintWarning::Code_ComparisonPrecedence)) if (context.warningEnabled(LintWarning::Code_ComparisonPrecedence))
LintComparisonPrecedence::process(context); LintComparisonPrecedence::process(context);
if (FFlag::LuauAttributeSyntax && FFlag::LuauNativeAttribute && FFlag::LintRedundantNativeAttribute &&
context.warningEnabled(LintWarning::Code_RedundantNativeAttribute))
{
if (hasNativeCommentDirective(hotcomments))
LintRedundantNativeAttribute::process(context);
}
std::sort(context.result.begin(), context.result.end(), WarningComparator()); std::sort(context.result.begin(), context.result.end(), WarningComparator());
return context.result; return context.result;

View file

@ -17,23 +17,23 @@
LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false)
LUAU_FASTFLAGVARIABLE(LuauNormalizeAwayUninhabitableTables, false) LUAU_FASTFLAGVARIABLE(LuauNormalizeAwayUninhabitableTables, false)
LUAU_FASTFLAGVARIABLE(LuauFixNormalizeCaching, false);
LUAU_FASTFLAGVARIABLE(LuauNormalizeNotUnknownIntersection, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeNotUnknownIntersection, false);
LUAU_FASTFLAGVARIABLE(LuauFixCyclicUnionsOfIntersections, false); LUAU_FASTFLAGVARIABLE(LuauFixReduceStackPressure, false);
LUAU_FASTFLAGVARIABLE(LuauFixCyclicTablesBlowingStack, false);
// This could theoretically be 2000 on amd64, but x86 requires this. // This could theoretically be 2000 on amd64, but x86 requires this.
LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200);
LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
static bool fixNormalizeCaching() static bool fixReduceStackPressure()
{ {
return FFlag::LuauFixNormalizeCaching || FFlag::DebugLuauDeferredConstraintResolution; return FFlag::LuauFixReduceStackPressure || FFlag::DebugLuauDeferredConstraintResolution;
} }
static bool fixCyclicUnionsOfIntersections() static bool fixCyclicTablesBlowingStack()
{ {
return FFlag::LuauFixCyclicUnionsOfIntersections || FFlag::DebugLuauDeferredConstraintResolution; return FFlag::LuauFixCyclicTablesBlowingStack || FFlag::DebugLuauDeferredConstraintResolution;
} }
namespace Luau namespace Luau
@ -45,6 +45,14 @@ static bool normalizeAwayUninhabitableTables()
return FFlag::LuauNormalizeAwayUninhabitableTables || FFlag::DebugLuauDeferredConstraintResolution; return FFlag::LuauNormalizeAwayUninhabitableTables || FFlag::DebugLuauDeferredConstraintResolution;
} }
static bool shouldEarlyExit(NormalizationResult res)
{
// if res is hit limits, return control flow
if (res == NormalizationResult::HitLimits || res == NormalizationResult::False)
return true;
return false;
}
TypeIds::TypeIds(std::initializer_list<TypeId> tys) TypeIds::TypeIds(std::initializer_list<TypeId> tys)
{ {
for (TypeId ty : tys) for (TypeId ty : tys)
@ -339,6 +347,12 @@ bool NormalizedType::isSubtypeOfString() const
!hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars(); !hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars();
} }
bool NormalizedType::isSubtypeOfBooleans() const
{
return hasBooleans() && !hasTops() && !hasClasses() && !hasErrors() && !hasNils() && !hasNumbers() && !hasStrings() && !hasThreads() &&
!hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars();
}
bool NormalizedType::shouldSuppressErrors() const bool NormalizedType::shouldSuppressErrors() const
{ {
return hasErrors() || get<AnyType>(tops); return hasErrors() || get<AnyType>(tops);
@ -547,22 +561,21 @@ NormalizationResult Normalizer::isInhabited(TypeId ty, Set<TypeId>& seen)
return isInhabited(mtv->metatable, seen); return isInhabited(mtv->metatable, seen);
} }
if (fixNormalizeCaching()) std::shared_ptr<const NormalizedType> norm = normalize(ty);
{ return isInhabited(norm.get(), seen);
std::shared_ptr<const NormalizedType> norm = normalize(ty);
return isInhabited(norm.get(), seen);
}
else
{
const NormalizedType* norm = DEPRECATED_normalize(ty);
return isInhabited(norm, seen);
}
} }
NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right) NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right)
{
Set<TypeId> seen{nullptr};
return isIntersectionInhabited(left, right, seen);
}
NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right, Set<TypeId>& seenSet)
{ {
left = follow(left); left = follow(left);
right = follow(right); right = follow(right);
// We're asking if intersection is inahbited between left and right but we've already seen them ....
if (cacheInhabitance) if (cacheInhabitance)
{ {
@ -570,12 +583,8 @@ NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId righ
return *result ? NormalizationResult::True : NormalizationResult::False; return *result ? NormalizationResult::True : NormalizationResult::False;
} }
Set<TypeId> seen{nullptr};
seen.insert(left);
seen.insert(right);
NormalizedType norm{builtinTypes}; NormalizedType norm{builtinTypes};
NormalizationResult res = normalizeIntersections({left, right}, norm); NormalizationResult res = normalizeIntersections({left, right}, norm, seenSet);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
{ {
if (cacheInhabitance && res == NormalizationResult::False) if (cacheInhabitance && res == NormalizationResult::False)
@ -584,7 +593,7 @@ NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId righ
return res; return res;
} }
NormalizationResult result = isInhabited(&norm, seen); NormalizationResult result = isInhabited(&norm, seenSet);
if (cacheInhabitance && result == NormalizationResult::True) if (cacheInhabitance && result == NormalizationResult::True)
cachedIsInhabitedIntersection[{left, right}] = true; cachedIsInhabitedIntersection[{left, right}] = true;
@ -856,31 +865,6 @@ Normalizer::Normalizer(TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, Not
{ {
} }
const NormalizedType* Normalizer::DEPRECATED_normalize(TypeId ty)
{
if (!arena)
sharedState->iceHandler->ice("Normalizing types outside a module");
auto found = cachedNormals.find(ty);
if (found != cachedNormals.end())
return found->second.get();
NormalizedType norm{builtinTypes};
Set<TypeId> seenSetTypes{nullptr};
NormalizationResult res = unionNormalWithTy(norm, ty, seenSetTypes);
if (res != NormalizationResult::True)
return nullptr;
if (norm.isUnknown())
{
clearNormal(norm);
norm.tops = builtinTypes->unknownType;
}
std::shared_ptr<NormalizedType> shared = std::make_shared<NormalizedType>(std::move(norm));
const NormalizedType* result = shared.get();
cachedNormals[ty] = std::move(shared);
return result;
}
static bool isCacheable(TypeId ty, Set<TypeId>& seen); static bool isCacheable(TypeId ty, Set<TypeId>& seen);
static bool isCacheable(TypePackId tp, Set<TypeId>& seen) static bool isCacheable(TypePackId tp, Set<TypeId>& seen)
@ -935,9 +919,6 @@ static bool isCacheable(TypeId ty, Set<TypeId>& seen)
static bool isCacheable(TypeId ty) static bool isCacheable(TypeId ty)
{ {
if (!fixNormalizeCaching())
return true;
Set<TypeId> seen{nullptr}; Set<TypeId> seen{nullptr};
return isCacheable(ty, seen); return isCacheable(ty, seen);
} }
@ -971,7 +952,7 @@ std::shared_ptr<const NormalizedType> Normalizer::normalize(TypeId ty)
return shared; return shared;
} }
NormalizationResult Normalizer::normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType) NormalizationResult Normalizer::normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType, Set<TypeId>& seenSet)
{ {
if (!arena) if (!arena)
sharedState->iceHandler->ice("Normalizing types outside a module"); sharedState->iceHandler->ice("Normalizing types outside a module");
@ -981,7 +962,7 @@ NormalizationResult Normalizer::normalizeIntersections(const std::vector<TypeId>
Set<TypeId> seenSetTypes{nullptr}; Set<TypeId> seenSetTypes{nullptr};
for (auto ty : intersections) for (auto ty : intersections)
{ {
NormalizationResult res = intersectNormalWithTy(norm, ty, seenSetTypes); NormalizationResult res = intersectNormalWithTy(norm, ty, seenSet);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
return res; return res;
} }
@ -1729,6 +1710,20 @@ bool Normalizer::withinResourceLimits()
return true; return true;
} }
NormalizationResult Normalizer::intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect)
{
std::optional<NormalizedType> negated;
std::shared_ptr<const NormalizedType> normal = normalize(toNegate);
negated = negateNormal(*normal);
if (!negated)
return NormalizationResult::False;
intersectNormals(intersect, *negated);
return NormalizationResult::True;
}
// See above for an explaination of `ignoreSmallerTyvars`. // See above for an explaination of `ignoreSmallerTyvars`.
NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes, int ignoreSmallerTyvars) NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes, int ignoreSmallerTyvars)
{ {
@ -1775,12 +1770,9 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
} }
else if (const IntersectionType* itv = get<IntersectionType>(there)) else if (const IntersectionType* itv = get<IntersectionType>(there))
{ {
if (fixCyclicUnionsOfIntersections()) if (seenSetTypes.count(there))
{ return NormalizationResult::True;
if (seenSetTypes.count(there)) seenSetTypes.insert(there);
return NormalizationResult::True;
seenSetTypes.insert(there);
}
NormalizedType norm{builtinTypes}; NormalizedType norm{builtinTypes};
norm.tops = builtinTypes->anyType; norm.tops = builtinTypes->anyType;
@ -1789,14 +1781,12 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
NormalizationResult res = intersectNormalWithTy(norm, *it, seenSetTypes); NormalizationResult res = intersectNormalWithTy(norm, *it, seenSetTypes);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
{ {
if (fixCyclicUnionsOfIntersections()) seenSetTypes.erase(there);
seenSetTypes.erase(there);
return res; return res;
} }
} }
if (fixCyclicUnionsOfIntersections()) seenSetTypes.erase(there);
seenSetTypes.erase(there);
return unionNormals(here, norm); return unionNormals(here, norm);
} }
@ -1814,12 +1804,6 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
if (!isCacheable(there)) if (!isCacheable(there))
here.isCacheable = false; here.isCacheable = false;
} }
else if (auto lt = get<LocalType>(there))
{
// FIXME? This is somewhat questionable.
// Maybe we should assert because this should never happen?
unionNormalWithTy(here, lt->domain, seenSetTypes, ignoreSmallerTyvars);
}
else if (get<FunctionType>(there)) else if (get<FunctionType>(there))
unionFunctionsWithFunction(here.functions, there); unionFunctionsWithFunction(here.functions, there);
else if (get<TableType>(there) || get<MetatableType>(there)) else if (get<TableType>(there) || get<MetatableType>(there))
@ -1876,16 +1860,8 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
{ {
std::optional<NormalizedType> tn; std::optional<NormalizedType> tn;
if (fixNormalizeCaching()) std::shared_ptr<const NormalizedType> thereNormal = normalize(ntv->ty);
{ tn = negateNormal(*thereNormal);
std::shared_ptr<const NormalizedType> thereNormal = normalize(ntv->ty);
tn = negateNormal(*thereNormal);
}
else
{
const NormalizedType* thereNormal = DEPRECATED_normalize(ntv->ty);
tn = negateNormal(*thereNormal);
}
if (!tn) if (!tn)
return NormalizationResult::False; return NormalizationResult::False;
@ -2484,7 +2460,7 @@ std::optional<TypePackId> Normalizer::intersectionOfTypePacks(TypePackId here, T
return arena->addTypePack({}); return arena->addTypePack({});
} }
std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there) std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there, Set<TypeId>& seenSet)
{ {
if (here == there) if (here == there)
return here; return here;
@ -2541,8 +2517,9 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
state = tttv->state; state = tttv->state;
TypeLevel level = max(httv->level, tttv->level); TypeLevel level = max(httv->level, tttv->level);
TableType result{state, level}; Scope* scope = max(httv->scope, tttv->scope);
std::unique_ptr<TableType> result = nullptr;
bool hereSubThere = true; bool hereSubThere = true;
bool thereSubHere = true; bool thereSubHere = true;
@ -2563,8 +2540,43 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
if (tprop.readTy.has_value()) if (tprop.readTy.has_value())
{ {
// if the intersection of the read types of a property is uninhabited, the whole table is `never`. // if the intersection of the read types of a property is uninhabited, the whole table is `never`.
if (normalizeAwayUninhabitableTables() && NormalizationResult::False == isIntersectionInhabited(*hprop.readTy, *tprop.readTy)) if (fixReduceStackPressure())
return {builtinTypes->neverType}; {
// We've seen these table prop elements before and we're about to ask if their intersection
// is inhabited
if (fixCyclicTablesBlowingStack())
{
if (seenSet.contains(*hprop.readTy) && seenSet.contains(*tprop.readTy))
{
seenSet.erase(*hprop.readTy);
seenSet.erase(*tprop.readTy);
return {builtinTypes->neverType};
}
else
{
seenSet.insert(*hprop.readTy);
seenSet.insert(*tprop.readTy);
}
}
NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy, seenSet);
// Cleanup
if (fixCyclicTablesBlowingStack())
{
seenSet.erase(*hprop.readTy);
seenSet.erase(*tprop.readTy);
}
if (normalizeAwayUninhabitableTables() && NormalizationResult::True != res)
return {builtinTypes->neverType};
}
else
{
if (normalizeAwayUninhabitableTables() &&
NormalizationResult::False == isIntersectionInhabited(*hprop.readTy, *tprop.readTy))
return {builtinTypes->neverType};
}
TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result; TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result;
prop.readTy = ty; prop.readTy = ty;
@ -2614,14 +2626,21 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
// TODO: string indexers // TODO: string indexers
if (prop.readTy || prop.writeTy) if (prop.readTy || prop.writeTy)
result.props[name] = prop; {
if (!result.get())
result = std::make_unique<TableType>(TableType{state, level, scope});
result->props[name] = prop;
}
} }
for (const auto& [name, tprop] : tttv->props) for (const auto& [name, tprop] : tttv->props)
{ {
if (httv->props.count(name) == 0) if (httv->props.count(name) == 0)
{ {
result.props[name] = tprop; if (!result.get())
result = std::make_unique<TableType>(TableType{state, level, scope});
result->props[name] = tprop;
hereSubThere = false; hereSubThere = false;
} }
} }
@ -2631,18 +2650,24 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
// TODO: What should intersection of indexes be? // TODO: What should intersection of indexes be?
TypeId index = unionType(httv->indexer->indexType, tttv->indexer->indexType); TypeId index = unionType(httv->indexer->indexType, tttv->indexer->indexType);
TypeId indexResult = intersectionType(httv->indexer->indexResultType, tttv->indexer->indexResultType); TypeId indexResult = intersectionType(httv->indexer->indexResultType, tttv->indexer->indexResultType);
result.indexer = {index, indexResult}; if (!result.get())
result = std::make_unique<TableType>(TableType{state, level, scope});
result->indexer = {index, indexResult};
hereSubThere &= (httv->indexer->indexType == index) && (httv->indexer->indexResultType == indexResult); hereSubThere &= (httv->indexer->indexType == index) && (httv->indexer->indexResultType == indexResult);
thereSubHere &= (tttv->indexer->indexType == index) && (tttv->indexer->indexResultType == indexResult); thereSubHere &= (tttv->indexer->indexType == index) && (tttv->indexer->indexResultType == indexResult);
} }
else if (httv->indexer) else if (httv->indexer)
{ {
result.indexer = httv->indexer; if (!result.get())
result = std::make_unique<TableType>(TableType{state, level, scope});
result->indexer = httv->indexer;
thereSubHere = false; thereSubHere = false;
} }
else if (tttv->indexer) else if (tttv->indexer)
{ {
result.indexer = tttv->indexer; if (!result.get())
result = std::make_unique<TableType>(TableType{state, level, scope});
result->indexer = tttv->indexer;
hereSubThere = false; hereSubThere = false;
} }
@ -2652,12 +2677,17 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
else if (thereSubHere) else if (thereSubHere)
table = ttable; table = ttable;
else else
table = arena->addType(std::move(result)); {
if (result.get())
table = arena->addType(std::move(*result));
else
table = arena->addType(TableType{state, level, scope});
}
if (tmtable && hmtable) if (tmtable && hmtable)
{ {
// NOTE: this assumes metatables are ivariant // NOTE: this assumes metatables are ivariant
if (std::optional<TypeId> mtable = intersectionOfTables(hmtable, tmtable)) if (std::optional<TypeId> mtable = intersectionOfTables(hmtable, tmtable, seenSet))
{ {
if (table == htable && *mtable == hmtable) if (table == htable && *mtable == hmtable)
return here; return here;
@ -2687,12 +2717,12 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
return table; return table;
} }
void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there) void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there, Set<TypeId>& seenSetTypes)
{ {
TypeIds tmp; TypeIds tmp;
for (TypeId here : heres) for (TypeId here : heres)
{ {
if (std::optional<TypeId> inter = intersectionOfTables(here, there)) if (std::optional<TypeId> inter = intersectionOfTables(here, there, seenSetTypes))
tmp.insert(*inter); tmp.insert(*inter);
} }
heres.retain(tmp); heres.retain(tmp);
@ -2706,7 +2736,8 @@ void Normalizer::intersectTables(TypeIds& heres, const TypeIds& theres)
{ {
for (TypeId there : theres) for (TypeId there : theres)
{ {
if (std::optional<TypeId> inter = intersectionOfTables(here, there)) Set<TypeId> seenSetTypes{nullptr};
if (std::optional<TypeId> inter = intersectionOfTables(here, there, seenSetTypes))
tmp.insert(*inter); tmp.insert(*inter);
} }
} }
@ -3047,7 +3078,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
return NormalizationResult::True; return NormalizationResult::True;
} }
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) || else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) ||
get<TypeFamilyInstanceType>(there) || get<LocalType>(there)) get<TypeFamilyInstanceType>(there))
{ {
NormalizedType thereNorm{builtinTypes}; NormalizedType thereNorm{builtinTypes};
NormalizedType topNorm{builtinTypes}; NormalizedType topNorm{builtinTypes};
@ -3056,10 +3087,6 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
here.isCacheable = false; here.isCacheable = false;
return intersectNormals(here, thereNorm); return intersectNormals(here, thereNorm);
} }
else if (auto lt = get<LocalType>(there))
{
return intersectNormalWithTy(here, lt->domain, seenSetTypes);
}
NormalizedTyvars tyvars = std::move(here.tyvars); NormalizedTyvars tyvars = std::move(here.tyvars);
@ -3074,7 +3101,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
{ {
TypeIds tables = std::move(here.tables); TypeIds tables = std::move(here.tables);
clearNormal(here); clearNormal(here);
intersectTablesWithTable(tables, there); intersectTablesWithTable(tables, there, seenSetTypes);
here.tables = std::move(tables); here.tables = std::move(tables);
} }
else if (get<ClassType>(there)) else if (get<ClassType>(there))
@ -3148,60 +3175,17 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
subtractSingleton(here, follow(ntv->ty)); subtractSingleton(here, follow(ntv->ty));
else if (get<ClassType>(t)) else if (get<ClassType>(t))
{ {
if (fixNormalizeCaching()) NormalizationResult res = intersectNormalWithNegationTy(t, here);
{ if (shouldEarlyExit(res))
std::shared_ptr<const NormalizedType> normal = normalize(t); return res;
std::optional<NormalizedType> negated = negateNormal(*normal);
if (!negated)
return NormalizationResult::False;
intersectNormals(here, *negated);
}
else
{
const NormalizedType* normal = DEPRECATED_normalize(t);
std::optional<NormalizedType> negated = negateNormal(*normal);
if (!negated)
return NormalizationResult::False;
intersectNormals(here, *negated);
}
} }
else if (const UnionType* itv = get<UnionType>(t)) else if (const UnionType* itv = get<UnionType>(t))
{ {
if (fixNormalizeCaching()) for (TypeId part : itv->options)
{ {
for (TypeId part : itv->options) NormalizationResult res = intersectNormalWithNegationTy(part, here);
{ if (shouldEarlyExit(res))
std::shared_ptr<const NormalizedType> normalPart = normalize(part); return res;
std::optional<NormalizedType> negated = negateNormal(*normalPart);
if (!negated)
return NormalizationResult::False;
intersectNormals(here, *negated);
}
}
else
{
if (fixNormalizeCaching())
{
for (TypeId part : itv->options)
{
std::shared_ptr<const NormalizedType> normalPart = normalize(part);
std::optional<NormalizedType> negated = negateNormal(*normalPart);
if (!negated)
return NormalizationResult::False;
intersectNormals(here, *negated);
}
}
else
{
for (TypeId part : itv->options)
{
const NormalizedType* normalPart = DEPRECATED_normalize(part);
std::optional<NormalizedType> negated = negateNormal(*normalPart);
if (!negated)
return NormalizationResult::False;
intersectNormals(here, *negated);
}
}
} }
} }
else if (get<AnyType>(t)) else if (get<AnyType>(t))

View file

@ -1,5 +0,0 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Common.h"
LUAU_FASTFLAGVARIABLE(LuauFixSetIter, false)

View file

@ -1255,6 +1255,10 @@ TypeId TypeSimplifier::union_(TypeId left, TypeId right)
case Relation::Coincident: case Relation::Coincident:
case Relation::Superset: case Relation::Superset:
return left; return left;
case Relation::Subset:
newParts.insert(right);
changed = true;
break;
default: default:
newParts.insert(part); newParts.insert(part);
newParts.insert(right); newParts.insert(right);
@ -1364,6 +1368,17 @@ SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<
return SimplifyResult{res, std::move(s.blockedTypes)}; return SimplifyResult{res, std::move(s.blockedTypes)};
} }
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, std::set<TypeId> parts)
{
LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution);
TypeSimplifier s{builtinTypes, arena};
TypeId res = s.intersectFromParts(std::move(parts));
return SimplifyResult{res, std::move(s.blockedTypes)};
}
SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right) SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right)
{ {
LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution);

View file

@ -11,6 +11,7 @@
LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTINTVARIABLE(LuauTarjanPreallocationSize, 256); LUAU_FASTINTVARIABLE(LuauTarjanPreallocationSize, 256);
LUAU_FASTFLAG(LuauReusableSubstitutions)
namespace Luau namespace Luau
{ {
@ -24,8 +25,6 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
// We decline to copy them. // We decline to copy them.
if constexpr (std::is_same_v<T, FreeType>) if constexpr (std::is_same_v<T, FreeType>)
return ty; return ty;
else if constexpr (std::is_same_v<T, LocalType>)
return ty;
else if constexpr (std::is_same_v<T, BoundType>) else if constexpr (std::is_same_v<T, BoundType>)
{ {
// This should never happen, but visit() cannot see it. // This should never happen, but visit() cannot see it.
@ -148,6 +147,8 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
} }
Tarjan::Tarjan() Tarjan::Tarjan()
: typeToIndex(nullptr, FFlag::LuauReusableSubstitutions ? FInt::LuauTarjanPreallocationSize : 0)
, packToIndex(nullptr, FFlag::LuauReusableSubstitutions ? FInt::LuauTarjanPreallocationSize : 0)
{ {
nodes.reserve(FInt::LuauTarjanPreallocationSize); nodes.reserve(FInt::LuauTarjanPreallocationSize);
stack.reserve(FInt::LuauTarjanPreallocationSize); stack.reserve(FInt::LuauTarjanPreallocationSize);
@ -448,14 +449,31 @@ TarjanResult Tarjan::visitRoot(TypePackId tp)
return loop(); return loop();
} }
void Tarjan::clearTarjan() void Tarjan::clearTarjan(const TxnLog* log)
{ {
typeToIndex.clear(); if (FFlag::LuauReusableSubstitutions)
packToIndex.clear(); {
typeToIndex.clear(~0u);
packToIndex.clear(~0u);
}
else
{
typeToIndex.clear();
packToIndex.clear();
}
nodes.clear(); nodes.clear();
stack.clear(); stack.clear();
if (FFlag::LuauReusableSubstitutions)
{
childCount = 0;
// childLimit setting stays the same
this->log = log;
}
edgesTy.clear(); edgesTy.clear();
edgesTp.clear(); edgesTp.clear();
worklist.clear(); worklist.clear();
@ -530,7 +548,6 @@ Substitution::Substitution(const TxnLog* log_, TypeArena* arena)
{ {
log = log_; log = log_;
LUAU_ASSERT(log); LUAU_ASSERT(log);
LUAU_ASSERT(arena);
} }
void Substitution::dontTraverseInto(TypeId ty) void Substitution::dontTraverseInto(TypeId ty)
@ -548,7 +565,7 @@ std::optional<TypeId> Substitution::substitute(TypeId ty)
ty = log->follow(ty); ty = log->follow(ty);
// clear algorithm state for reentrancy // clear algorithm state for reentrancy
clearTarjan(); clearTarjan(log);
auto result = findDirty(ty); auto result = findDirty(ty);
if (result != TarjanResult::Ok) if (result != TarjanResult::Ok)
@ -581,7 +598,7 @@ std::optional<TypePackId> Substitution::substitute(TypePackId tp)
tp = log->follow(tp); tp = log->follow(tp);
// clear algorithm state for reentrancy // clear algorithm state for reentrancy
clearTarjan(); clearTarjan(log);
auto result = findDirty(tp); auto result = findDirty(tp);
if (result != TarjanResult::Ok) if (result != TarjanResult::Ok)
@ -609,6 +626,23 @@ std::optional<TypePackId> Substitution::substitute(TypePackId tp)
return newTp; return newTp;
} }
void Substitution::resetState(const TxnLog* log, TypeArena* arena)
{
LUAU_ASSERT(FFlag::LuauReusableSubstitutions);
clearTarjan(log);
this->arena = arena;
newTypes.clear();
newPacks.clear();
replacedTypes.clear();
replacedTypePacks.clear();
noTraverseTypes.clear();
noTraverseTypePacks.clear();
}
TypeId Substitution::clone(TypeId ty) TypeId Substitution::clone(TypeId ty)
{ {
return shallowClone(ty, *arena, log, /* alwaysClone */ true); return shallowClone(ty, *arena, log, /* alwaysClone */ true);

View file

@ -1438,6 +1438,7 @@ SubtypingResult Subtyping::isCovariantWith(
result.andAlso(isCovariantWith(env, subNorm->strings, superNorm->strings)); result.andAlso(isCovariantWith(env, subNorm->strings, superNorm->strings));
result.andAlso(isCovariantWith(env, subNorm->strings, superNorm->tables)); result.andAlso(isCovariantWith(env, subNorm->strings, superNorm->tables));
result.andAlso(isCovariantWith(env, subNorm->threads, superNorm->threads)); result.andAlso(isCovariantWith(env, subNorm->threads, superNorm->threads));
result.andAlso(isCovariantWith(env, subNorm->buffers, superNorm->buffers));
result.andAlso(isCovariantWith(env, subNorm->tables, superNorm->tables)); result.andAlso(isCovariantWith(env, subNorm->tables, superNorm->tables));
result.andAlso(isCovariantWith(env, subNorm->functions, superNorm->functions)); result.andAlso(isCovariantWith(env, subNorm->functions, superNorm->functions));
// isCovariantWith(subNorm->tyvars, superNorm->tyvars); // isCovariantWith(subNorm->tyvars, superNorm->tyvars);

View file

@ -337,7 +337,9 @@ TypeId matchLiteralType(NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes,
TypeId matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, TypeId matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier,
expectedTableTy->indexer->indexResultType, *propTy, item.value, toBlock); expectedTableTy->indexer->indexResultType, *propTy, item.value, toBlock);
tableTy->indexer->indexResultType = matchedType; // if the index result type is the prop type, we can replace it with the matched type here.
if (tableTy->indexer->indexResultType == *propTy)
tableTy->indexer->indexResultType = matchedType;
} }
} }
else if (item.kind == AstExprTable::Item::General) else if (item.kind == AstExprTable::Item::General)

View file

@ -262,14 +262,6 @@ void StateDot::visitChildren(TypeId ty, int index)
visitChild(t.upperBound, index, "[upperBound]"); visitChild(t.upperBound, index, "[upperBound]");
} }
} }
else if constexpr (std::is_same_v<T, LocalType>)
{
formatAppend(result, "LocalType");
finishNodeLabel(ty);
finishNode();
visitChild(t.domain, 1, "[domain]");
}
else if constexpr (std::is_same_v<T, AnyType>) else if constexpr (std::is_same_v<T, AnyType>)
{ {
formatAppend(result, "AnyType %d", index); formatAppend(result, "AnyType %d", index);

View file

@ -20,7 +20,6 @@
#include <string> #include <string>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAGVARIABLE(LuauToStringiteTypesSingleLine, false)
/* /*
* Enables increasing levels of verbosity for Luau type names when stringifying. * Enables increasing levels of verbosity for Luau type names when stringifying.
@ -101,16 +100,6 @@ struct FindCyclicTypes final : TypeVisitor
return false; return false;
} }
bool visit(TypeId ty, const LocalType& lt) override
{
if (!visited.insert(ty))
return false;
traverse(lt.domain);
return false;
}
bool visit(TypeId ty, const TableType& ttv) override bool visit(TypeId ty, const TableType& ttv) override
{ {
if (!visited.insert(ty)) if (!visited.insert(ty))
@ -526,21 +515,6 @@ struct TypeStringifier
} }
} }
void operator()(TypeId ty, const LocalType& lt)
{
state.emit("l-");
state.emit(lt.name);
if (FInt::DebugLuauVerboseTypeNames >= 1)
{
state.emit("[");
state.emit(lt.blockCount);
state.emit("]");
}
state.emit("=[");
stringify(lt.domain);
state.emit("]");
}
void operator()(TypeId, const BoundType& btv) void operator()(TypeId, const BoundType& btv)
{ {
stringify(btv.boundTo); stringify(btv.boundTo);
@ -1725,6 +1699,18 @@ std::string generateName(size_t i)
return n; return n;
} }
std::string toStringVector(const std::vector<TypeId>& types, ToStringOptions& opts)
{
std::string s;
for (TypeId ty : types)
{
if (!s.empty())
s += ", ";
s += toString(ty, opts);
}
return s;
}
std::string toString(const Constraint& constraint, ToStringOptions& opts) std::string toString(const Constraint& constraint, ToStringOptions& opts)
{ {
auto go = [&opts](auto&& c) -> std::string { auto go = [&opts](auto&& c) -> std::string {
@ -1755,7 +1741,7 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
else if constexpr (std::is_same_v<T, IterableConstraint>) else if constexpr (std::is_same_v<T, IterableConstraint>)
{ {
std::string iteratorStr = tos(c.iterator); std::string iteratorStr = tos(c.iterator);
std::string variableStr = tos(c.variables); std::string variableStr = toStringVector(c.variables, opts);
return variableStr + " ~ iterate " + iteratorStr; return variableStr + " ~ iterate " + iteratorStr;
} }
@ -1788,23 +1774,16 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
{ {
return tos(c.resultType) + " ~ hasProp " + tos(c.subjectType) + ", \"" + c.prop + "\" ctx=" + std::to_string(int(c.context)); return tos(c.resultType) + " ~ hasProp " + tos(c.subjectType) + ", \"" + c.prop + "\" ctx=" + std::to_string(int(c.context));
} }
else if constexpr (std::is_same_v<T, SetPropConstraint>)
{
const std::string pathStr = c.path.size() == 1 ? "\"" + c.path[0] + "\"" : "[\"" + join(c.path, "\", \"") + "\"]";
return tos(c.resultType) + " ~ setProp " + tos(c.subjectType) + ", " + pathStr + " " + tos(c.propType);
}
else if constexpr (std::is_same_v<T, HasIndexerConstraint>) else if constexpr (std::is_same_v<T, HasIndexerConstraint>)
{ {
return tos(c.resultType) + " ~ hasIndexer " + tos(c.subjectType) + " " + tos(c.indexType); return tos(c.resultType) + " ~ hasIndexer " + tos(c.subjectType) + " " + tos(c.indexType);
} }
else if constexpr (std::is_same_v<T, SetIndexerConstraint>) else if constexpr (std::is_same_v<T, AssignPropConstraint>)
{ return "assignProp " + tos(c.lhsType) + " " + c.propName + " " + tos(c.rhsType);
return "setIndexer " + tos(c.subjectType) + " [ " + tos(c.indexType) + " ] " + tos(c.propType); else if constexpr (std::is_same_v<T, AssignIndexConstraint>)
} return "assignIndex " + tos(c.lhsType) + " " + tos(c.indexType) + " " + tos(c.rhsType);
else if constexpr (std::is_same_v<T, UnpackConstraint>) else if constexpr (std::is_same_v<T, UnpackConstraint>)
return tos(c.resultPack) + " ~ ...unpack " + tos(c.sourcePack); return toStringVector(c.resultPack, opts) + " ~ ...unpack " + tos(c.sourcePack);
else if constexpr (std::is_same_v<T, Unpack1Constraint>)
return tos(c.resultType) + " ~ unpack " + tos(c.sourceType);
else if constexpr (std::is_same_v<T, ReduceConstraint>) else if constexpr (std::is_same_v<T, ReduceConstraint>)
return "reduce " + tos(c.ty); return "reduce " + tos(c.ty);
else if constexpr (std::is_same_v<T, ReducePackConstraint>) else if constexpr (std::is_same_v<T, ReducePackConstraint>)

View file

@ -1182,11 +1182,11 @@ std::string toString(AstNode* node)
Printer printer(writer); Printer printer(writer);
printer.writeTypes = true; printer.writeTypes = true;
if (auto statNode = dynamic_cast<AstStat*>(node)) if (auto statNode = node->asStat())
printer.visualize(*statNode); printer.visualize(*statNode);
else if (auto exprNode = dynamic_cast<AstExpr*>(node)) else if (auto exprNode = node->asExpr())
printer.visualize(*exprNode); printer.visualize(*exprNode);
else if (auto typeNode = dynamic_cast<AstType*>(node)) else if (auto typeNode = node->asType())
printer.visualizeTypeAnnotation(*typeNode); printer.visualizeTypeAnnotation(*typeNode);
return writer.str(); return writer.str();

View file

@ -561,6 +561,11 @@ void BlockedType::setOwner(Constraint* newOwner)
owner = newOwner; owner = newOwner;
} }
void BlockedType::replaceOwner(Constraint* newOwner)
{
owner = newOwner;
}
PendingExpansionType::PendingExpansionType( PendingExpansionType::PendingExpansionType(
std::optional<AstName> prefix, AstName name, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments) std::optional<AstName> prefix, AstName name, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments)
: prefix(prefix) : prefix(prefix)

View file

@ -338,10 +338,6 @@ public:
{ {
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("free"), std::nullopt, Location()); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("free"), std::nullopt, Location());
} }
AstType* operator()(const LocalType& lt)
{
return Luau::visit(*this, lt.domain->ty);
}
AstType* operator()(const UnionType& uv) AstType* operator()(const UnionType& uv)
{ {
AstArray<AstType*> unionTypes; AstArray<AstType*> unionTypes;

View file

@ -446,7 +446,6 @@ struct TypeChecker2
.errors; .errors;
if (!isErrorSuppressing(location, instance)) if (!isErrorSuppressing(location, instance))
reportErrors(std::move(errors)); reportErrors(std::move(errors));
return instance; return instance;
} }
@ -1108,10 +1107,13 @@ struct TypeChecker2
void visit(AstStatCompoundAssign* stat) void visit(AstStatCompoundAssign* stat)
{ {
AstExprBinary fake{stat->location, stat->op, stat->var, stat->value}; AstExprBinary fake{stat->location, stat->op, stat->var, stat->value};
TypeId resultTy = visit(&fake, stat); visit(&fake, stat);
TypeId* resultTy = module->astCompoundAssignResultTypes.find(stat);
LUAU_ASSERT(resultTy);
TypeId varTy = lookupType(stat->var); TypeId varTy = lookupType(stat->var);
testIsSubtype(resultTy, varTy, stat->location); testIsSubtype(*resultTy, varTy, stat->location);
} }
void visit(AstStatFunction* stat) void visit(AstStatFunction* stat)
@ -1242,13 +1244,14 @@ struct TypeChecker2
void visit(AstExprConstantBool* expr) void visit(AstExprConstantBool* expr)
{ {
#if defined(LUAU_ENABLE_ASSERT) // booleans use specialized inference logic for singleton types, which can lead to real type errors here.
const TypeId bestType = expr->value ? builtinTypes->trueType : builtinTypes->falseType; const TypeId bestType = expr->value ? builtinTypes->trueType : builtinTypes->falseType;
const TypeId inferredType = lookupType(expr); const TypeId inferredType = lookupType(expr);
const SubtypingResult r = subtyping->isSubtype(bestType, inferredType); const SubtypingResult r = subtyping->isSubtype(bestType, inferredType);
LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, inferredType)); if (!r.isSubtype && !isErrorSuppressing(expr->location, inferredType))
#endif reportError(TypeMismatch{inferredType, bestType}, expr->location);
} }
void visit(AstExprConstantNumber* expr) void visit(AstExprConstantNumber* expr)
@ -1264,13 +1267,14 @@ struct TypeChecker2
void visit(AstExprConstantString* expr) void visit(AstExprConstantString* expr)
{ {
#if defined(LUAU_ENABLE_ASSERT) // strings use specialized inference logic for singleton types, which can lead to real type errors here.
const TypeId bestType = module->internalTypes.addType(SingletonType{StringSingleton{std::string{expr->value.data, expr->value.size}}}); const TypeId bestType = module->internalTypes.addType(SingletonType{StringSingleton{std::string{expr->value.data, expr->value.size}}});
const TypeId inferredType = lookupType(expr); const TypeId inferredType = lookupType(expr);
const SubtypingResult r = subtyping->isSubtype(bestType, inferredType); const SubtypingResult r = subtyping->isSubtype(bestType, inferredType);
LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, inferredType)); if (!r.isSubtype && !isErrorSuppressing(expr->location, inferredType))
#endif reportError(TypeMismatch{inferredType, bestType}, expr->location);
} }
void visit(AstExprLocal* expr) void visit(AstExprLocal* expr)
@ -1280,7 +1284,9 @@ struct TypeChecker2
void visit(AstExprGlobal* expr) void visit(AstExprGlobal* expr)
{ {
// TODO! NotNull<Scope> scope = stack.back();
if (!scope->lookup(expr->name))
reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location);
} }
void visit(AstExprVarargs* expr) void visit(AstExprVarargs* expr)
@ -1534,6 +1540,24 @@ struct TypeChecker2
visitExprName(indexName->expr, indexName->location, indexName->index.value, context, builtinTypes->stringType); visitExprName(indexName->expr, indexName->location, indexName->index.value, context, builtinTypes->stringType);
} }
void indexExprMetatableHelper(AstExprIndexExpr* indexExpr, const MetatableType* metaTable, TypeId exprType, TypeId indexType)
{
if (auto tt = get<TableType>(follow(metaTable->table)); tt && tt->indexer)
testIsSubtype(indexType, tt->indexer->indexType, indexExpr->index->location);
else if (auto mt = get<MetatableType>(follow(metaTable->table)))
indexExprMetatableHelper(indexExpr, mt, exprType, indexType);
else if (auto tmt = get<TableType>(follow(metaTable->metatable)); tmt && tmt->indexer)
testIsSubtype(indexType, tmt->indexer->indexType, indexExpr->index->location);
else if (auto mtmt = get<MetatableType>(follow(metaTable->metatable)))
indexExprMetatableHelper(indexExpr, mtmt, exprType, indexType);
else
{
LUAU_ASSERT(tt || get<PrimitiveType>(follow(metaTable->table)));
reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location);
}
}
void visit(AstExprIndexExpr* indexExpr, ValueContext context) void visit(AstExprIndexExpr* indexExpr, ValueContext context)
{ {
if (auto str = indexExpr->index->as<AstExprConstantString>()) if (auto str = indexExpr->index->as<AstExprConstantString>())
@ -1557,6 +1581,10 @@ struct TypeChecker2
else else
reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location);
} }
else if (auto mt = get<MetatableType>(exprType))
{
return indexExprMetatableHelper(indexExpr, mt, exprType, indexType);
}
else if (auto cls = get<ClassType>(exprType)) else if (auto cls = get<ClassType>(exprType))
{ {
if (cls->indexer) if (cls->indexer)
@ -1577,6 +1605,19 @@ struct TypeChecker2
reportError(OptionalValueAccess{exprType}, indexExpr->location); reportError(OptionalValueAccess{exprType}, indexExpr->location);
} }
} }
else if (auto exprIntersection = get<IntersectionType>(exprType))
{
for (TypeId part : exprIntersection)
{
(void)part;
}
}
else if (get<NeverType>(exprType) || isErrorSuppressing(indexExpr->location, exprType))
{
// Nothing
}
else
reportError(NotATable{exprType}, indexExpr->location);
} }
void visit(AstExprFunction* fn) void visit(AstExprFunction* fn)
@ -1589,7 +1630,6 @@ struct TypeChecker2
functionDeclStack.push_back(inferredFnTy); functionDeclStack.push_back(inferredFnTy);
std::shared_ptr<const NormalizedType> normalizedFnTy = normalizer.normalize(inferredFnTy); std::shared_ptr<const NormalizedType> normalizedFnTy = normalizer.normalize(inferredFnTy);
const FunctionType* inferredFtv = get<FunctionType>(normalizedFnTy->functions.parts.front());
if (!normalizedFnTy) if (!normalizedFnTy)
{ {
reportError(CodeTooComplex{}, fn->location); reportError(CodeTooComplex{}, fn->location);
@ -1684,16 +1724,23 @@ struct TypeChecker2
if (fn->returnAnnotation) if (fn->returnAnnotation)
visit(*fn->returnAnnotation); visit(*fn->returnAnnotation);
// If the function type has a family annotation, we need to see if we can suggest an annotation // If the function type has a family annotation, we need to see if we can suggest an annotation
TypeFamilyReductionGuesser guesser{NotNull{&module->internalTypes}, builtinTypes, NotNull{&normalizer}}; if (normalizedFnTy)
for (TypeId retTy : inferredFtv->retTypes)
{ {
if (get<TypeFamilyInstanceType>(follow(retTy))) const FunctionType* inferredFtv = get<FunctionType>(normalizedFnTy->functions.parts.front());
LUAU_ASSERT(inferredFtv);
TypeFamilyReductionGuesser guesser{NotNull{&module->internalTypes}, builtinTypes, NotNull{&normalizer}};
for (TypeId retTy : inferredFtv->retTypes)
{ {
TypeFamilyReductionGuessResult result = guesser.guessTypeFamilyReductionForFunction(*fn, inferredFtv, retTy); if (get<TypeFamilyInstanceType>(follow(retTy)))
if (result.shouldRecommendAnnotation) {
reportError( TypeFamilyReductionGuessResult result = guesser.guessTypeFamilyReductionForFunction(*fn, inferredFtv, retTy);
ExplicitFunctionAnnotationRecommended{std::move(result.guessedFunctionAnnotations), result.guessedReturnType}, fn->location); if (result.shouldRecommendAnnotation)
reportError(ExplicitFunctionAnnotationRecommended{std::move(result.guessedFunctionAnnotations), result.guessedReturnType},
fn->location);
}
} }
} }
@ -1822,7 +1869,7 @@ struct TypeChecker2
bool isStringOperation = bool isStringOperation =
(normLeft ? normLeft->isSubtypeOfString() : isString(leftType)) && (normRight ? normRight->isSubtypeOfString() : isString(rightType)); (normLeft ? normLeft->isSubtypeOfString() : isString(leftType)) && (normRight ? normRight->isSubtypeOfString() : isString(rightType));
leftType = follow(leftType);
if (get<AnyType>(leftType) || get<ErrorType>(leftType) || get<NeverType>(leftType)) if (get<AnyType>(leftType) || get<ErrorType>(leftType) || get<NeverType>(leftType))
return leftType; return leftType;
else if (get<AnyType>(rightType) || get<ErrorType>(rightType) || get<NeverType>(rightType)) else if (get<AnyType>(rightType) || get<ErrorType>(rightType) || get<NeverType>(rightType))
@ -2091,24 +2138,39 @@ struct TypeChecker2
TypeId annotationType = lookupAnnotation(expr->annotation); TypeId annotationType = lookupAnnotation(expr->annotation);
TypeId computedType = lookupType(expr->expr); TypeId computedType = lookupType(expr->expr);
// Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case.
if (subtyping->isSubtype(annotationType, computedType).isSubtype)
return;
if (subtyping->isSubtype(computedType, annotationType).isSubtype)
return;
switch (shouldSuppressErrors(NotNull{&normalizer}, computedType).orElse(shouldSuppressErrors(NotNull{&normalizer}, annotationType))) switch (shouldSuppressErrors(NotNull{&normalizer}, computedType).orElse(shouldSuppressErrors(NotNull{&normalizer}, annotationType)))
{ {
case ErrorSuppression::Suppress: case ErrorSuppression::Suppress:
return; return;
case ErrorSuppression::NormalizationFailed: case ErrorSuppression::NormalizationFailed:
reportError(NormalizationTooComplex{}, expr->location); reportError(NormalizationTooComplex{}, expr->location);
return;
case ErrorSuppression::DoNotSuppress: case ErrorSuppression::DoNotSuppress:
break; break;
} }
reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); switch (normalizer.isInhabited(computedType))
{
case NormalizationResult::True:
break;
case NormalizationResult::False:
return;
case NormalizationResult::HitLimits:
reportError(NormalizationTooComplex{}, expr->location);
return;
}
switch (normalizer.isIntersectionInhabited(computedType, annotationType))
{
case NormalizationResult::True:
return;
case NormalizationResult::False:
reportError(TypesAreUnrelated{computedType, annotationType}, expr->location);
break;
case NormalizationResult::HitLimits:
reportError(NormalizationTooComplex{}, expr->location);
break;
}
} }
void visit(AstExprIfElse* expr) void visit(AstExprIfElse* expr)
@ -2710,6 +2772,8 @@ struct TypeChecker2
fetch(builtinTypes->stringType); fetch(builtinTypes->stringType);
if (normValid) if (normValid)
fetch(norm->threads); fetch(norm->threads);
if (normValid)
fetch(norm->buffers);
if (normValid) if (normValid)
{ {

File diff suppressed because it is too large Load diff

View file

@ -33,13 +33,12 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false)
LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false)
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauMetatableInstantiationCloneCheck, false)
LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false)
LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false)
LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false) LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false)
LUAU_FASTFLAGVARIABLE(LuauForbidAliasNamedTypeof, false)
LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false) LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false)
LUAU_FASTFLAG(LuauFixNormalizeCaching) LUAU_FASTFLAGVARIABLE(LuauReusableSubstitutions, false)
LUAU_FASTFLAG(LuauDeclarationExtraPropData)
namespace Luau namespace Luau
{ {
@ -216,6 +215,7 @@ TypeChecker::TypeChecker(const ScopePtr& globalScope, ModuleResolver* resolver,
, iceHandler(iceHandler) , iceHandler(iceHandler)
, unifierState(iceHandler) , unifierState(iceHandler)
, normalizer(nullptr, builtinTypes, NotNull{&unifierState}) , normalizer(nullptr, builtinTypes, NotNull{&unifierState})
, reusableInstantiation(TxnLog::empty(), nullptr, builtinTypes, {}, nullptr)
, nilType(builtinTypes->nilType) , nilType(builtinTypes->nilType)
, numberType(builtinTypes->numberType) , numberType(builtinTypes->numberType)
, stringType(builtinTypes->stringType) , stringType(builtinTypes->stringType)
@ -668,7 +668,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std
{ {
if (const auto& typealias = stat->as<AstStatTypeAlias>()) if (const auto& typealias = stat->as<AstStatTypeAlias>())
{ {
if (typealias->name == kParseNameError || (FFlag::LuauForbidAliasNamedTypeof && typealias->name == "typeof")) if (typealias->name == kParseNameError || typealias->name == "typeof")
continue; continue;
auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings; auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings;
@ -1536,7 +1536,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& ty
if (name == kParseNameError) if (name == kParseNameError)
return ControlFlow::None; return ControlFlow::None;
if (FFlag::LuauForbidAliasNamedTypeof && name == "typeof") if (name == "typeof")
{ {
reportError(typealias.location, GenericError{"Type aliases cannot be named typeof"}); reportError(typealias.location, GenericError{"Type aliases cannot be named typeof"});
return ControlFlow::None; return ControlFlow::None;
@ -1657,7 +1657,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea
// If the alias is missing a name, we can't do anything with it. Ignore it. // If the alias is missing a name, we can't do anything with it. Ignore it.
// Also, typeof is not a valid type alias name. We will report an error for // Also, typeof is not a valid type alias name. We will report an error for
// this in check() // this in check()
if (name == kParseNameError || (FFlag::LuauForbidAliasNamedTypeof && name == "typeof")) if (name == kParseNameError || name == "typeof")
return; return;
std::optional<TypeFun> binding; std::optional<TypeFun> binding;
@ -1784,12 +1784,55 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass&
ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}});
ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes});
ftv->hasSelf = true; ftv->hasSelf = true;
if (FFlag::LuauDeclarationExtraPropData)
{
FunctionDefinition defn;
defn.definitionModuleName = currentModule->name;
defn.definitionLocation = prop.location;
// No data is preserved for varargLocation
defn.originalNameLocation = prop.nameLocation;
ftv->definition = defn;
}
} }
} }
if (assignTo.count(propName) == 0) if (assignTo.count(propName) == 0)
{ {
assignTo[propName] = {propTy}; if (FFlag::LuauDeclarationExtraPropData)
assignTo[propName] = {propTy, /*deprecated*/ false, /*deprecatedSuggestion*/ "", prop.location};
else
assignTo[propName] = {propTy};
}
else if (FFlag::LuauDeclarationExtraPropData)
{
Luau::Property& prop = assignTo[propName];
TypeId currentTy = prop.type();
// We special-case this logic to keep the intersection flat; otherwise we
// would create a ton of nested intersection types.
if (const IntersectionType* itv = get<IntersectionType>(currentTy))
{
std::vector<TypeId> options = itv->parts;
options.push_back(propTy);
TypeId newItv = addType(IntersectionType{std::move(options)});
prop.readTy = newItv;
prop.writeTy = newItv;
}
else if (get<FunctionType>(currentTy))
{
TypeId intersection = addType(IntersectionType{{currentTy, propTy}});
prop.readTy = intersection;
prop.writeTy = intersection;
}
else
{
reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())});
}
} }
else else
{ {
@ -1841,7 +1884,18 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFuncti
TypePackId argPack = resolveTypePack(funScope, global.params); TypePackId argPack = resolveTypePack(funScope, global.params);
TypePackId retPack = resolveTypePack(funScope, global.retTypes); TypePackId retPack = resolveTypePack(funScope, global.retTypes);
TypeId fnType = addType(FunctionType{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack});
FunctionDefinition defn;
if (FFlag::LuauDeclarationExtraPropData)
{
defn.definitionModuleName = currentModule->name;
defn.definitionLocation = global.location;
defn.varargLocation = global.vararg ? std::make_optional(global.varargLocation) : std::nullopt;
defn.originalNameLocation = global.nameLocation;
}
TypeId fnType = addType(FunctionType{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack, defn});
FunctionType* ftv = getMutable<FunctionType>(fnType); FunctionType* ftv = getMutable<FunctionType>(fnType);
ftv->argNames.reserve(global.paramNames.size); ftv->argNames.reserve(global.paramNames.size);
@ -2649,24 +2703,12 @@ static std::optional<bool> areEqComparable(NotNull<TypeArena> arena, NotNull<Nor
NormalizationResult nr; NormalizationResult nr;
if (FFlag::LuauFixNormalizeCaching) TypeId c = arena->addType(IntersectionType{{a, b}});
{ std::shared_ptr<const NormalizedType> n = normalizer->normalize(c);
TypeId c = arena->addType(IntersectionType{{a, b}}); if (!n)
std::shared_ptr<const NormalizedType> n = normalizer->normalize(c); return std::nullopt;
if (!n)
return std::nullopt;
nr = normalizer->isInhabited(n.get()); nr = normalizer->isInhabited(n.get());
}
else
{
TypeId c = arena->addType(IntersectionType{{a, b}});
const NormalizedType* n = normalizer->DEPRECATED_normalize(c);
if (!n)
return std::nullopt;
nr = normalizer->isInhabited(n);
}
switch (nr) switch (nr)
{ {
@ -4879,12 +4921,27 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat
if (ftv && ftv->hasNoFreeOrGenericTypes) if (ftv && ftv->hasNoFreeOrGenericTypes)
return ty; return ty;
Instantiation instantiation{log, &currentModule->internalTypes, builtinTypes, scope->level, /*scope*/ nullptr}; std::optional<TypeId> instantiated;
if (instantiationChildLimit) if (FFlag::LuauReusableSubstitutions)
instantiation.childLimit = *instantiationChildLimit; {
reusableInstantiation.resetState(log, &currentModule->internalTypes, builtinTypes, scope->level, /*scope*/ nullptr);
if (instantiationChildLimit)
reusableInstantiation.childLimit = *instantiationChildLimit;
instantiated = reusableInstantiation.substitute(ty);
}
else
{
Instantiation instantiation{log, &currentModule->internalTypes, builtinTypes, scope->level, /*scope*/ nullptr};
if (instantiationChildLimit)
instantiation.childLimit = *instantiationChildLimit;
instantiated = instantiation.substitute(ty);
}
std::optional<TypeId> instantiated = instantiation.substitute(ty);
if (instantiated.has_value()) if (instantiated.has_value())
return *instantiated; return *instantiated;
else else
@ -5633,8 +5690,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf,
TypeId instantiated = *maybeInstantiated; TypeId instantiated = *maybeInstantiated;
TypeId target = follow(instantiated); TypeId target = follow(instantiated);
const TableType* tfTable = FFlag::LuauMetatableInstantiationCloneCheck ? getTableType(tf.type) : nullptr; const TableType* tfTable = getTableType(tf.type);
bool needsClone = follow(tf.type) == target || (FFlag::LuauMetatableInstantiationCloneCheck && tfTable != nullptr && tfTable == getTableType(target)); bool needsClone = follow(tf.type) == target || (tfTable != nullptr && tfTable == getTableType(target));
bool shouldMutate = getTableType(tf.type); bool shouldMutate = getTableType(tf.type);
TableType* ttv = getMutableTableType(target); TableType* ttv = getMutableTableType(target);

View file

@ -38,6 +38,59 @@ bool occursCheck(TypeId needle, TypeId haystack)
return false; return false;
} }
// FIXME: Property is quite large.
//
// Returning it on the stack like this isn't great. We'd like to just return a
// const Property*, but we mint a property of type any if the subject type is
// any.
std::optional<Property> findTableProperty(NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location)
{
if (get<AnyType>(ty))
return Property::rw(ty);
if (const TableType* tableType = getTableType(ty))
{
const auto& it = tableType->props.find(name);
if (it != tableType->props.end())
return it->second;
}
std::optional<TypeId> mtIndex = findMetatableEntry(builtinTypes, errors, ty, "__index", location);
int count = 0;
while (mtIndex)
{
TypeId index = follow(*mtIndex);
if (count >= 100)
return std::nullopt;
++count;
if (const auto& itt = getTableType(index))
{
const auto& fit = itt->props.find(name);
if (fit != itt->props.end())
return fit->second.type();
}
else if (const auto& itf = get<FunctionType>(index))
{
std::optional<TypeId> r = first(follow(itf->retTypes));
if (!r)
return builtinTypes->nilType;
else
return *r;
}
else if (get<AnyType>(index))
return builtinTypes->anyType;
else
errors.push_back(TypeError{location, GenericError{"__index should either be a function or table. Got " + toString(index)}});
mtIndex = findMetatableEntry(builtinTypes, errors, *mtIndex, "__index", location);
}
return std::nullopt;
}
std::optional<TypeId> findMetatableEntry( std::optional<TypeId> findMetatableEntry(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location) NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location)
{ {

View file

@ -23,7 +23,7 @@ LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering, false) LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering, false)
LUAU_FASTFLAGVARIABLE(LuauUnifierShouldNotCopyError, false) LUAU_FASTFLAGVARIABLE(LuauUnifierShouldNotCopyError, false)
LUAU_FASTFLAG(LuauFixNormalizeCaching) LUAU_FASTFLAGVARIABLE(LuauUnifierRecursionOnRestart, false)
namespace Luau namespace Luau
{ {
@ -580,28 +580,14 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
{ {
if (normalize) if (normalize)
{ {
if (FFlag::LuauFixNormalizeCaching) // TODO: there are probably cheaper ways to check if any <: T.
{ std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
// TODO: there are probably cheaper ways to check if any <: T.
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (!superNorm) if (!superNorm)
return reportError(location, NormalizationTooComplex{}); return reportError(location, NormalizationTooComplex{});
if (!log.get<AnyType>(superNorm->tops)) if (!log.get<AnyType>(superNorm->tops))
failure = true; failure = true;
}
else
{
// TODO: there are probably cheaper ways to check if any <: T.
const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy);
if (!superNorm)
return reportError(location, NormalizationTooComplex{});
if (!log.get<AnyType>(superNorm->tops))
failure = true;
}
} }
else else
failure = true; failure = true;
@ -962,30 +948,15 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
// We deal with this by type normalization. // We deal with this by type normalization.
Unifier innerState = makeChildUnifier(); Unifier innerState = makeChildUnifier();
if (FFlag::LuauFixNormalizeCaching) std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
{ std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy); if (!subNorm || !superNorm)
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy); return reportError(location, NormalizationTooComplex{});
if (!subNorm || !superNorm) else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
return reportError(location, NormalizationTooComplex{}); innerState.tryUnifyNormalizedTypes(
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
innerState.tryUnifyNormalizedTypes(
subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
else
innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
}
else else
{ innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy);
const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy);
if (!subNorm || !superNorm)
return reportError(location, NormalizationTooComplex{});
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
innerState.tryUnifyNormalizedTypes(
subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
else
innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
}
if (!innerState.failure) if (!innerState.failure)
log.concat(std::move(innerState.log)); log.concat(std::move(innerState.log));
@ -999,30 +970,14 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
// It is possible that T <: A | B even though T </: A and T </:B // It is possible that T <: A | B even though T </: A and T </:B
// for example boolean <: true | false. // for example boolean <: true | false.
// We deal with this by type normalization. // We deal with this by type normalization.
if (FFlag::LuauFixNormalizeCaching) std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
{ std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy); if (!subNorm || !superNorm)
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy); reportError(location, NormalizationTooComplex{});
if (!subNorm || !superNorm) else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
reportError(location, NormalizationTooComplex{}); tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
tryUnifyNormalizedTypes(
subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
else
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
}
else else
{ tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy);
const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy);
if (!subNorm || !superNorm)
reportError(location, NormalizationTooComplex{});
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
tryUnifyNormalizedTypes(
subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
else
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
}
} }
else if (!found) else if (!found)
{ {
@ -1125,24 +1080,12 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType*
// It is possible that A & B <: T even though A </: T and B </: T // It is possible that A & B <: T even though A </: T and B </: T
// for example (string?) & ~nil <: string. // for example (string?) & ~nil <: string.
// We deal with this by type normalization. // We deal with this by type normalization.
if (FFlag::LuauFixNormalizeCaching) std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
{ std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy); if (subNorm && superNorm)
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy); tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
if (subNorm && superNorm)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
else
reportError(location, NormalizationTooComplex{});
}
else else
{ reportError(location, NormalizationTooComplex{});
const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy);
const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy);
if (subNorm && superNorm)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
else
reportError(location, NormalizationTooComplex{});
}
return; return;
} }
@ -1192,24 +1135,12 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType*
// for example string? & number? <: nil. // for example string? & number? <: nil.
// We deal with this by type normalization. // We deal with this by type normalization.
if (FFlag::LuauFixNormalizeCaching) std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
{ std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy); if (subNorm && superNorm)
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy); tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
if (subNorm && superNorm)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
else
reportError(location, NormalizationTooComplex{});
}
else else
{ reportError(location, NormalizationTooComplex{});
const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy);
const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy);
if (subNorm && superNorm)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
else
reportError(location, NormalizationTooComplex{});
}
} }
else if (!found) else if (!found)
{ {
@ -2249,7 +2180,18 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection,
// If one of the types stopped being a table altogether, we need to restart from the top // If one of the types stopped being a table altogether, we need to restart from the top
if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty())
return tryUnify(subTy, superTy, false, isIntersection); {
if (FFlag::LuauUnifierRecursionOnRestart)
{
RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit);
tryUnify(subTy, superTy, false, isIntersection);
return;
}
else
{
return tryUnify(subTy, superTy, false, isIntersection);
}
}
// Otherwise, restart only the table unification // Otherwise, restart only the table unification
TableType* newSuperTable = log.getMutable<TableType>(superTyNew); TableType* newSuperTable = log.getMutable<TableType>(superTyNew);
@ -2328,7 +2270,18 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection,
// If one of the types stopped being a table altogether, we need to restart from the top // If one of the types stopped being a table altogether, we need to restart from the top
if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty())
return tryUnify(subTy, superTy, false, isIntersection); {
if (FFlag::LuauUnifierRecursionOnRestart)
{
RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit);
tryUnify(subTy, superTy, false, isIntersection);
return;
}
else
{
return tryUnify(subTy, superTy, false, isIntersection);
}
}
// Recursive unification can change the txn log, and invalidate the old // Recursive unification can change the txn log, and invalidate the old
// table. If we detect that this has happened, we start over, with the updated // table. If we detect that this has happened, we start over, with the updated
@ -2712,32 +2665,16 @@ void Unifier::tryUnifyNegations(TypeId subTy, TypeId superTy)
if (!log.get<NegationType>(subTy) && !log.get<NegationType>(superTy)) if (!log.get<NegationType>(subTy) && !log.get<NegationType>(superTy))
ice("tryUnifyNegations superTy or subTy must be a negation type"); ice("tryUnifyNegations superTy or subTy must be a negation type");
if (FFlag::LuauFixNormalizeCaching) std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
{ std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy); if (!subNorm || !superNorm)
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy); return reportError(location, NormalizationTooComplex{});
if (!subNorm || !superNorm)
return reportError(location, NormalizationTooComplex{});
// T </: ~U iff T <: U // T </: ~U iff T <: U
Unifier state = makeChildUnifier(); Unifier state = makeChildUnifier();
state.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, ""); state.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "");
if (state.errors.empty()) if (state.errors.empty())
reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); reportError(location, TypeMismatch{superTy, subTy, mismatchContext()});
}
else
{
const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy);
const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy);
if (!subNorm || !superNorm)
return reportError(location, NormalizationTooComplex{});
// T </: ~U iff T <: U
Unifier state = makeChildUnifier();
state.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "");
if (state.errors.empty())
reportError(location, TypeMismatch{superTy, subTy, mismatchContext()});
}
} }
static void queueTypePack(std::vector<TypeId>& queue, DenseHashSet<TypePackId>& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) static void queueTypePack(std::vector<TypeId>& queue, DenseHashSet<TypePackId>& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack)

View file

@ -158,12 +158,6 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy)
if (subFree || superFree) if (subFree || superFree)
return true; return true;
if (auto subLocal = getMutable<LocalType>(subTy))
{
subLocal->domain = mkUnion(subLocal->domain, superTy);
expandedFreeTypes[subTy].push_back(superTy);
}
auto subFn = get<FunctionType>(subTy); auto subFn = get<FunctionType>(subTy);
auto superFn = get<FunctionType>(superTy); auto superFn = get<FunctionType>(superTy);
if (subFn && superFn) if (subFn && superFn)
@ -204,25 +198,21 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy)
auto subAny = get<AnyType>(subTy); auto subAny = get<AnyType>(subTy);
auto superAny = get<AnyType>(superTy); auto superAny = get<AnyType>(superTy);
if (subAny && superAny)
return true;
else if (subAny && superFn)
{
// If `any` is the subtype, then we can propagate that inward.
bool argResult = unify(superFn->argTypes, builtinTypes->anyTypePack);
bool retResult = unify(builtinTypes->anyTypePack, superFn->retTypes);
return argResult && retResult;
}
else if (subFn && superAny)
{
// If `any` is the supertype, then we can propagate that inward.
bool argResult = unify(builtinTypes->anyTypePack, subFn->argTypes);
bool retResult = unify(subFn->retTypes, builtinTypes->anyTypePack);
return argResult && retResult;
}
auto subTable = getMutable<TableType>(subTy); auto subTable = getMutable<TableType>(subTy);
auto superTable = get<TableType>(superTy); auto superTable = get<TableType>(superTy);
if (subAny && superAny)
return true;
else if (subAny && superFn)
return unify(subAny, superFn);
else if (subFn && superAny)
return unify(subFn, superAny);
else if (subAny && superTable)
return unify(subAny, superTable);
else if (subTable && superAny)
return unify(subTable, superAny);
if (subTable && superTable) if (subTable && superTable)
{ {
// `boundTo` works like a bound type, and therefore we'd replace it // `boundTo` works like a bound type, and therefore we'd replace it
@ -451,7 +441,16 @@ bool Unifier2::unify(TableType* subTable, const TableType* superTable)
* an indexer, we therefore conclude that the unsealed table has the * an indexer, we therefore conclude that the unsealed table has the
* same indexer. * same indexer.
*/ */
subTable->indexer = *superTable->indexer;
TypeId indexType = superTable->indexer->indexType;
if (TypeId* subst = genericSubstitutions.find(indexType))
indexType = *subst;
TypeId indexResultType = superTable->indexer->indexResultType;
if (TypeId* subst = genericSubstitutions.find(indexResultType))
indexResultType = *subst;
subTable->indexer = TableIndexer{indexType, indexResultType};
} }
return result; return result;
@ -462,6 +461,62 @@ bool Unifier2::unify(const MetatableType* subMetatable, const MetatableType* sup
return unify(subMetatable->metatable, superMetatable->metatable) && unify(subMetatable->table, superMetatable->table); return unify(subMetatable->metatable, superMetatable->metatable) && unify(subMetatable->table, superMetatable->table);
} }
bool Unifier2::unify(const AnyType* subAny, const FunctionType* superFn)
{
// If `any` is the subtype, then we can propagate that inward.
bool argResult = unify(superFn->argTypes, builtinTypes->anyTypePack);
bool retResult = unify(builtinTypes->anyTypePack, superFn->retTypes);
return argResult && retResult;
}
bool Unifier2::unify(const FunctionType* subFn, const AnyType* superAny)
{
// If `any` is the supertype, then we can propagate that inward.
bool argResult = unify(builtinTypes->anyTypePack, subFn->argTypes);
bool retResult = unify(subFn->retTypes, builtinTypes->anyTypePack);
return argResult && retResult;
}
bool Unifier2::unify(const AnyType* subAny, const TableType* superTable)
{
for (const auto& [propName, prop] : superTable->props)
{
if (prop.readTy)
unify(builtinTypes->anyType, *prop.readTy);
if (prop.writeTy)
unify(*prop.writeTy, builtinTypes->anyType);
}
if (superTable->indexer)
{
unify(builtinTypes->anyType, superTable->indexer->indexType);
unify(builtinTypes->anyType, superTable->indexer->indexResultType);
}
return true;
}
bool Unifier2::unify(const TableType* subTable, const AnyType* superAny)
{
for (const auto& [propName, prop] : subTable->props)
{
if (prop.readTy)
unify(*prop.readTy, builtinTypes->anyType);
if (prop.writeTy)
unify(builtinTypes->anyType, *prop.writeTy);
}
if (subTable->indexer)
{
unify(subTable->indexer->indexType, builtinTypes->anyType);
unify(subTable->indexer->indexResultType, builtinTypes->anyType);
}
return true;
}
// FIXME? This should probably return an ErrorVec or an optional<TypeError> // FIXME? This should probably return an ErrorVec or an optional<TypeError>
// rather than a boolean to signal an occurs check failure. // rather than a boolean to signal an occurs check failure.
bool Unifier2::unify(TypePackId subTp, TypePackId superTp) bool Unifier2::unify(TypePackId subTp, TypePackId superTp)
@ -596,6 +651,43 @@ struct FreeTypeSearcher : TypeVisitor
} }
} }
DenseHashSet<const void*> seenPositive{nullptr};
DenseHashSet<const void*> seenNegative{nullptr};
bool seenWithPolarity(const void* ty)
{
switch (polarity)
{
case Positive:
{
if (seenPositive.contains(ty))
return true;
seenPositive.insert(ty);
return false;
}
case Negative:
{
if (seenNegative.contains(ty))
return true;
seenNegative.insert(ty);
return false;
}
case Both:
{
if (seenPositive.contains(ty) && seenNegative.contains(ty))
return true;
seenPositive.insert(ty);
seenNegative.insert(ty);
return false;
}
}
return false;
}
// The keys in these maps are either TypeIds or TypePackIds. It's safe to // The keys in these maps are either TypeIds or TypePackIds. It's safe to
// mix them because we only use these pointers as unique keys. We never // mix them because we only use these pointers as unique keys. We never
// indirect them. // indirect them.
@ -604,12 +696,18 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypeId ty) override bool visit(TypeId ty) override
{ {
if (seenWithPolarity(ty))
return false;
LUAU_ASSERT(ty); LUAU_ASSERT(ty);
return true; return true;
} }
bool visit(TypeId ty, const FreeType& ft) override bool visit(TypeId ty, const FreeType& ft) override
{ {
if (seenWithPolarity(ty))
return false;
if (!subsumes(scope, ft.scope)) if (!subsumes(scope, ft.scope))
return true; return true;
@ -632,6 +730,9 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypeId ty, const TableType& tt) override bool visit(TypeId ty, const TableType& tt) override
{ {
if (seenWithPolarity(ty))
return false;
if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope)) if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope))
{ {
switch (polarity) switch (polarity)
@ -675,6 +776,9 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypeId ty, const FunctionType& ft) override bool visit(TypeId ty, const FunctionType& ft) override
{ {
if (seenWithPolarity(ty))
return false;
flip(); flip();
traverse(ft.argTypes); traverse(ft.argTypes);
flip(); flip();
@ -691,6 +795,9 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypePackId tp, const FreeTypePack& ftp) override bool visit(TypePackId tp, const FreeTypePack& ftp) override
{ {
if (seenWithPolarity(tp))
return false;
if (!subsumes(scope, ftp.scope)) if (!subsumes(scope, ftp.scope))
return true; return true;
@ -712,315 +819,6 @@ struct FreeTypeSearcher : TypeVisitor
} }
}; };
struct MutatingGeneralizer : TypeOnceVisitor
{
NotNull<BuiltinTypes> builtinTypes;
NotNull<Scope> scope;
DenseHashMap<const void*, size_t> positiveTypes;
DenseHashMap<const void*, size_t> negativeTypes;
std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks;
bool isWithinFunction = false;
MutatingGeneralizer(NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope, DenseHashMap<const void*, size_t> positiveTypes,
DenseHashMap<const void*, size_t> negativeTypes)
: TypeOnceVisitor(/* skipBoundTypes */ true)
, builtinTypes(builtinTypes)
, scope(scope)
, positiveTypes(std::move(positiveTypes))
, negativeTypes(std::move(negativeTypes))
{
}
static void replace(DenseHashSet<TypeId>& seen, TypeId haystack, TypeId needle, TypeId replacement)
{
haystack = follow(haystack);
if (seen.find(haystack))
return;
seen.insert(haystack);
if (UnionType* ut = getMutable<UnionType>(haystack))
{
for (auto iter = ut->options.begin(); iter != ut->options.end();)
{
// FIXME: I bet this function has reentrancy problems
TypeId option = follow(*iter);
if (option == needle && get<NeverType>(replacement))
{
iter = ut->options.erase(iter);
continue;
}
if (option == needle)
{
*iter = replacement;
iter++;
continue;
}
// advance the iterator, nothing after this can use it.
iter++;
if (seen.find(option))
continue;
seen.insert(option);
if (get<UnionType>(option))
replace(seen, option, needle, haystack);
else if (get<IntersectionType>(option))
replace(seen, option, needle, haystack);
}
if (ut->options.size() == 1)
{
TypeId onlyType = ut->options[0];
LUAU_ASSERT(onlyType != haystack);
emplaceType<BoundType>(asMutable(haystack), onlyType);
}
return;
}
if (IntersectionType* it = getMutable<IntersectionType>(needle))
{
for (auto iter = it->parts.begin(); iter != it->parts.end();)
{
// FIXME: I bet this function has reentrancy problems
TypeId part = follow(*iter);
if (part == needle && get<UnknownType>(replacement))
{
iter = it->parts.erase(iter);
continue;
}
if (part == needle)
{
*iter = replacement;
iter++;
continue;
}
// advance the iterator, nothing after this can use it.
iter++;
if (seen.find(part))
continue;
seen.insert(part);
if (get<UnionType>(part))
replace(seen, part, needle, haystack);
else if (get<IntersectionType>(part))
replace(seen, part, needle, haystack);
}
if (it->parts.size() == 1)
{
TypeId onlyType = it->parts[0];
LUAU_ASSERT(onlyType != needle);
emplaceType<BoundType>(asMutable(needle), onlyType);
}
return;
}
}
bool visit(TypeId ty, const FunctionType& ft) override
{
const bool oldValue = isWithinFunction;
isWithinFunction = true;
traverse(ft.argTypes);
traverse(ft.retTypes);
isWithinFunction = oldValue;
return false;
}
bool visit(TypeId ty, const FreeType&) override
{
const FreeType* ft = get<FreeType>(ty);
LUAU_ASSERT(ft);
traverse(ft->lowerBound);
traverse(ft->upperBound);
// It is possible for the above traverse() calls to cause ty to be
// transmuted. We must reacquire ft if this happens.
ty = follow(ty);
ft = get<FreeType>(ty);
if (!ft)
return false;
const size_t positiveCount = getCount(positiveTypes, ty);
const size_t negativeCount = getCount(negativeTypes, ty);
if (!positiveCount && !negativeCount)
return false;
const bool hasLowerBound = !get<NeverType>(follow(ft->lowerBound));
const bool hasUpperBound = !get<UnknownType>(follow(ft->upperBound));
DenseHashSet<TypeId> seen{nullptr};
seen.insert(ty);
if (!hasLowerBound && !hasUpperBound)
{
if (!isWithinFunction || (positiveCount + negativeCount == 1))
emplaceType<BoundType>(asMutable(ty), builtinTypes->unknownType);
else
{
emplaceType<GenericType>(asMutable(ty), scope);
generics.push_back(ty);
}
}
// It is possible that this free type has other free types in its upper
// or lower bounds. If this is the case, we must replace those
// references with never (for the lower bound) or unknown (for the upper
// bound).
//
// If we do not do this, we get tautological bounds like a <: a <: unknown.
else if (positiveCount && !hasUpperBound)
{
TypeId lb = follow(ft->lowerBound);
if (FreeType* lowerFree = getMutable<FreeType>(lb); lowerFree && lowerFree->upperBound == ty)
lowerFree->upperBound = builtinTypes->unknownType;
else
{
DenseHashSet<TypeId> replaceSeen{nullptr};
replace(replaceSeen, lb, ty, builtinTypes->unknownType);
}
if (lb != ty)
emplaceType<BoundType>(asMutable(ty), lb);
else if (!isWithinFunction || (positiveCount + negativeCount == 1))
emplaceType<BoundType>(asMutable(ty), builtinTypes->unknownType);
else
{
// if the lower bound is the type in question, we don't actually have a lower bound.
emplaceType<GenericType>(asMutable(ty), scope);
generics.push_back(ty);
}
}
else
{
TypeId ub = follow(ft->upperBound);
if (FreeType* upperFree = getMutable<FreeType>(ub); upperFree && upperFree->lowerBound == ty)
upperFree->lowerBound = builtinTypes->neverType;
else
{
DenseHashSet<TypeId> replaceSeen{nullptr};
replace(replaceSeen, ub, ty, builtinTypes->neverType);
}
if (ub != ty)
emplaceType<BoundType>(asMutable(ty), ub);
else if (!isWithinFunction || (positiveCount + negativeCount == 1))
emplaceType<BoundType>(asMutable(ty), builtinTypes->unknownType);
else
{
// if the upper bound is the type in question, we don't actually have an upper bound.
emplaceType<GenericType>(asMutable(ty), scope);
generics.push_back(ty);
}
}
return false;
}
size_t getCount(const DenseHashMap<const void*, size_t>& map, const void* ty)
{
if (const size_t* count = map.find(ty))
return *count;
else
return 0;
}
bool visit(TypeId ty, const TableType&) override
{
const size_t positiveCount = getCount(positiveTypes, ty);
const size_t negativeCount = getCount(negativeTypes, ty);
// FIXME: Free tables should probably just be replaced by upper bounds on free types.
//
// eg never <: 'a <: {x: number} & {z: boolean}
if (!positiveCount && !negativeCount)
return true;
TableType* tt = getMutable<TableType>(ty);
LUAU_ASSERT(tt);
tt->state = TableState::Sealed;
return true;
}
bool visit(TypePackId tp, const FreeTypePack& ftp) override
{
if (!subsumes(scope, ftp.scope))
return true;
tp = follow(tp);
const size_t positiveCount = getCount(positiveTypes, tp);
const size_t negativeCount = getCount(negativeTypes, tp);
if (1 == positiveCount + negativeCount)
emplaceTypePack<BoundTypePack>(asMutable(tp), builtinTypes->unknownTypePack);
else
{
emplaceTypePack<GenericTypePack>(asMutable(tp), scope);
genericPacks.push_back(tp);
}
return true;
}
};
std::optional<TypeId> Unifier2::generalize(TypeId ty)
{
ty = follow(ty);
if (ty->owningArena != arena || ty->persistent)
return ty;
if (const FunctionType* ft = get<FunctionType>(ty); ft && (!ft->generics.empty() || !ft->genericPacks.empty()))
return ty;
FreeTypeSearcher fts{scope};
fts.traverse(ty);
MutatingGeneralizer gen{builtinTypes, scope, std::move(fts.positiveTypes), std::move(fts.negativeTypes)};
gen.traverse(ty);
/* MutatingGeneralizer mutates types in place, so it is possible that ty has
* been transmuted to a BoundType. We must follow it again and verify that
* we are allowed to mutate it before we attach generics to it.
*/
ty = follow(ty);
if (ty->owningArena != arena || ty->persistent)
return ty;
FunctionType* ftv = getMutable<FunctionType>(ty);
if (ftv)
{
ftv->generics = std::move(gen.generics);
ftv->genericPacks = std::move(gen.genericPacks);
}
return ty;
}
TypeId Unifier2::mkUnion(TypeId left, TypeId right) TypeId Unifier2::mkUnion(TypeId left, TypeId right)
{ {
left = follow(left); left = follow(left);

View file

@ -60,6 +60,8 @@ class AstStat;
class AstStatBlock; class AstStatBlock;
class AstExpr; class AstExpr;
class AstTypePack; class AstTypePack;
class AstAttr;
class AstExprTable;
struct AstLocal struct AstLocal
{ {
@ -172,6 +174,10 @@ public:
{ {
return nullptr; return nullptr;
} }
virtual AstAttr* asAttr()
{
return nullptr;
}
template<typename T> template<typename T>
bool is() const bool is() const
@ -193,6 +199,29 @@ public:
Location location; Location location;
}; };
class AstAttr : public AstNode
{
public:
LUAU_RTTI(AstAttr)
enum Type
{
Checked,
Native,
};
AstAttr(const Location& location, Type type);
AstAttr* asAttr() override
{
return this;
}
void visit(AstVisitor* visitor) override;
Type type;
};
class AstExpr : public AstNode class AstExpr : public AstNode
{ {
public: public:
@ -384,13 +413,17 @@ class AstExprFunction : public AstExpr
public: public:
LUAU_RTTI(AstExprFunction) LUAU_RTTI(AstExprFunction)
AstExprFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, AstExprFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
AstLocal* self, const AstArray<AstLocal*>& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, const AstArray<AstGenericTypePack>& genericPacks, AstLocal* self, const AstArray<AstLocal*>& args, bool vararg,
const AstName& debugname, const std::optional<AstTypeList>& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, const AstName& debugname,
const std::optional<AstTypeList>& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr,
const std::optional<Location>& argLocation = std::nullopt); const std::optional<Location>& argLocation = std::nullopt);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
bool hasNativeAttribute() const;
AstArray<AstAttr*> attributes;
AstArray<AstGenericType> generics; AstArray<AstGenericType> generics;
AstArray<AstGenericTypePack> genericPacks; AstArray<AstGenericTypePack> genericPacks;
AstLocal* self; AstLocal* self;
@ -793,11 +826,12 @@ class AstStatDeclareGlobal : public AstStat
public: public:
LUAU_RTTI(AstStatDeclareGlobal) LUAU_RTTI(AstStatDeclareGlobal)
AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type); AstStatDeclareGlobal(const Location& location, const AstName& name, const Location& nameLocation, AstType* type);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
AstName name; AstName name;
Location nameLocation;
AstType* type; AstType* type;
}; };
@ -806,31 +840,38 @@ class AstStatDeclareFunction : public AstStat
public: public:
LUAU_RTTI(AstStatDeclareFunction) LUAU_RTTI(AstStatDeclareFunction)
AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray<AstGenericType>& generics, AstStatDeclareFunction(const Location& location, const AstName& name, const Location& nameLocation, const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames, const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames, bool vararg,
const AstTypeList& retTypes); const Location& varargLocation, const AstTypeList& retTypes);
AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray<AstGenericType>& generics, AstStatDeclareFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstName& name, const Location& nameLocation,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params,
const AstTypeList& retTypes, bool checkedFunction); const AstArray<AstArgumentName>& paramNames, bool vararg, const Location& varargLocation, const AstTypeList& retTypes);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
bool isCheckedFunction() const;
AstArray<AstAttr*> attributes;
AstName name; AstName name;
Location nameLocation;
AstArray<AstGenericType> generics; AstArray<AstGenericType> generics;
AstArray<AstGenericTypePack> genericPacks; AstArray<AstGenericTypePack> genericPacks;
AstTypeList params; AstTypeList params;
AstArray<AstArgumentName> paramNames; AstArray<AstArgumentName> paramNames;
bool vararg = false;
Location varargLocation;
AstTypeList retTypes; AstTypeList retTypes;
bool checkedFunction;
}; };
struct AstDeclaredClassProp struct AstDeclaredClassProp
{ {
AstName name; AstName name;
Location nameLocation;
AstType* ty = nullptr; AstType* ty = nullptr;
bool isMethod = false; bool isMethod = false;
Location location;
}; };
enum class AstTableAccess enum class AstTableAccess
@ -936,17 +977,20 @@ public:
AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes); const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes);
AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, AstTypeFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes, bool checkedFunction); const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
bool isCheckedFunction() const;
AstArray<AstAttr*> attributes;
AstArray<AstGenericType> generics; AstArray<AstGenericType> generics;
AstArray<AstGenericTypePack> genericPacks; AstArray<AstGenericTypePack> genericPacks;
AstTypeList argTypes; AstTypeList argTypes;
AstArray<std::optional<AstArgumentName>> argNames; AstArray<std::optional<AstArgumentName>> argNames;
AstTypeList returnTypes; AstTypeList returnTypes;
bool checkedFunction;
}; };
class AstTypeTypeof : public AstType class AstTypeTypeof : public AstType
@ -1105,6 +1149,11 @@ public:
return true; return true;
} }
virtual bool visit(class AstAttr* node)
{
return visit(static_cast<AstNode*>(node));
}
virtual bool visit(class AstExpr* node) virtual bool visit(class AstExpr* node)
{ {
return visit(static_cast<AstNode*>(node)); return visit(static_cast<AstNode*>(node));

View file

@ -87,6 +87,8 @@ struct Lexeme
Comment, Comment,
BlockComment, BlockComment,
Attribute,
BrokenString, BrokenString,
BrokenComment, BrokenComment,
BrokenUnicode, BrokenUnicode,
@ -115,14 +117,20 @@ struct Lexeme
ReservedTrue, ReservedTrue,
ReservedUntil, ReservedUntil,
ReservedWhile, ReservedWhile,
ReservedChecked,
Reserved_END Reserved_END
}; };
Type type; Type type;
Location location; Location location;
// Field declared here, before the union, to ensure that Lexeme size is 32 bytes.
private:
// length is used to extract a slice from the input buffer.
// This field is only valid for certain lexeme types which don't duplicate portions of input
// but instead store a pointer to a location in the input buffer and the length of lexeme.
unsigned int length; unsigned int length;
public:
union union
{ {
const char* data; // String, Number, Comment const char* data; // String, Number, Comment
@ -135,9 +143,13 @@ struct Lexeme
Lexeme(const Location& location, Type type, const char* data, size_t size); Lexeme(const Location& location, Type type, const char* data, size_t size);
Lexeme(const Location& location, Type type, const char* name); Lexeme(const Location& location, Type type, const char* name);
unsigned int getLength() const;
std::string toString() const; std::string toString() const;
}; };
static_assert(sizeof(Lexeme) <= 32, "Size of `Lexeme` struct should be up to 32 bytes.");
class AstNameTable class AstNameTable
{ {
public: public:

View file

@ -82,8 +82,8 @@ private:
// if exp then block {elseif exp then block} [else block] end | // if exp then block {elseif exp then block} [else block] end |
// for Name `=' exp `,' exp [`,' exp] do block end | // for Name `=' exp `,' exp [`,' exp] do block end |
// for namelist in explist do block end | // for namelist in explist do block end |
// function funcname funcbody | // [attributes] function funcname funcbody |
// local function Name funcbody | // [attributes] local function Name funcbody |
// local namelist [`=' explist] // local namelist [`=' explist]
// laststat ::= return [explist] | break // laststat ::= return [explist] | break
AstStat* parseStat(); AstStat* parseStat();
@ -114,11 +114,25 @@ private:
AstExpr* parseFunctionName(Location start, bool& hasself, AstName& debugname); AstExpr* parseFunctionName(Location start, bool& hasself, AstName& debugname);
// function funcname funcbody // function funcname funcbody
AstStat* parseFunctionStat(); LUAU_FORCEINLINE AstStat* parseFunctionStat(const AstArray<AstAttr*>& attributes = {nullptr, 0});
std::pair<bool, AstAttr::Type> validateAttribute(const char* attributeName, const TempVector<AstAttr*>& attributes);
// attribute ::= '@' NAME
void parseAttribute(TempVector<AstAttr*>& attribute);
// attributes ::= {attribute}
AstArray<AstAttr*> parseAttributes();
// attributes local function Name funcbody
// attributes function funcname funcbody
// attributes `declare function' Name`(' [parlist] `)' [`:` Type]
// declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}'
AstStat* parseAttributeStat();
// local function Name funcbody | // local function Name funcbody |
// local namelist [`=' explist] // local namelist [`=' explist]
AstStat* parseLocal(); AstStat* parseLocal(const AstArray<AstAttr*>& attributes);
// return [explist] // return [explist]
AstStat* parseReturn(); AstStat* parseReturn();
@ -130,7 +144,7 @@ private:
// `declare global' Name: Type | // `declare global' Name: Type |
// `declare function' Name`(' [parlist] `)' [`:` Type] // `declare function' Name`(' [parlist] `)' [`:` Type]
AstStat* parseDeclaration(const Location& start); AstStat* parseDeclaration(const Location& start, const AstArray<AstAttr*>& attributes);
// varlist `=' explist // varlist `=' explist
AstStat* parseAssignment(AstExpr* initial); AstStat* parseAssignment(AstExpr* initial);
@ -143,7 +157,7 @@ private:
// funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` Type] // funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` Type]
// funcbody ::= funcbodyhead block end // funcbody ::= funcbodyhead block end
std::pair<AstExprFunction*, AstLocal*> parseFunctionBody( std::pair<AstExprFunction*, AstLocal*> parseFunctionBody(
bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName); bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray<AstAttr*>& attributes);
// explist ::= {exp `,'} exp // explist ::= {exp `,'} exp
void parseExprList(TempVector<AstExpr*>& result); void parseExprList(TempVector<AstExpr*>& result);
@ -176,10 +190,10 @@ private:
AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional<Location> accessLocation); AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional<Location> accessLocation);
AstTypeOrPack parseFunctionType(bool allowPack, bool isCheckedFunction = false); AstTypeOrPack parseFunctionType(bool allowPack, const AstArray<AstAttr*>& attributes);
AstType* parseFunctionTypeTail(const Lexeme& begin, AstArray<AstGenericType> generics, AstArray<AstGenericTypePack> genericPacks, AstType* parseFunctionTypeTail(const Lexeme& begin, const AstArray<AstAttr*>& attributes, AstArray<AstGenericType> generics,
AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames, AstTypePack* varargAnnotation, AstArray<AstGenericTypePack> genericPacks, AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames,
bool isCheckedFunction = false); AstTypePack* varargAnnotation);
AstType* parseTableType(bool inDeclarationContext = false); AstType* parseTableType(bool inDeclarationContext = false);
AstTypeOrPack parseSimpleType(bool allowPack, bool inDeclarationContext = false); AstTypeOrPack parseSimpleType(bool allowPack, bool inDeclarationContext = false);
@ -220,7 +234,7 @@ private:
// asexp -> simpleexp [`::' Type] // asexp -> simpleexp [`::' Type]
AstExpr* parseAssertionExpr(); AstExpr* parseAssertionExpr();
// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp // simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | [attributes] FUNCTION body | primaryexp
AstExpr* parseSimpleExpr(); AstExpr* parseSimpleExpr();
// args ::= `(' [explist] `)' | tableconstructor | String // args ::= `(' [explist] `)' | tableconstructor | String
@ -393,6 +407,7 @@ private:
std::vector<unsigned int> matchRecoveryStopOnToken; std::vector<unsigned int> matchRecoveryStopOnToken;
std::vector<AstAttr*> scratchAttr;
std::vector<AstStat*> scratchStat; std::vector<AstStat*> scratchStat;
std::vector<AstArray<char>> scratchString; std::vector<AstArray<char>> scratchString;
std::vector<AstExpr*> scratchExpr; std::vector<AstExpr*> scratchExpr;

View file

@ -134,6 +134,14 @@ struct ThreadContext
static constexpr size_t kEventFlushLimit = 8192; static constexpr size_t kEventFlushLimit = 8192;
}; };
using ThreadContextProvider = ThreadContext& (*)();
inline ThreadContextProvider& threadContextProvider()
{
static ThreadContextProvider handler = nullptr;
return handler;
}
ThreadContext& getThreadContext(); ThreadContext& getThreadContext();
struct Scope struct Scope

View file

@ -3,6 +3,8 @@
#include "Luau/Common.h" #include "Luau/Common.h"
LUAU_FASTFLAG(LuauAttributeSyntax);
LUAU_FASTFLAG(LuauNativeAttribute);
namespace Luau namespace Luau
{ {
@ -16,6 +18,17 @@ static void visitTypeList(AstVisitor* visitor, const AstTypeList& list)
list.tailType->visit(visitor); list.tailType->visit(visitor);
} }
AstAttr::AstAttr(const Location& location, Type type)
: AstNode(ClassIndex(), location)
, type(type)
{
}
void AstAttr::visit(AstVisitor* visitor)
{
visitor->visit(this);
}
int gAstRttiIndex = 0; int gAstRttiIndex = 0;
AstExprGroup::AstExprGroup(const Location& location, AstExpr* expr) AstExprGroup::AstExprGroup(const Location& location, AstExpr* expr)
@ -161,11 +174,12 @@ void AstExprIndexExpr::visit(AstVisitor* visitor)
} }
} }
AstExprFunction::AstExprFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, AstExprFunction::AstExprFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
AstLocal* self, const AstArray<AstLocal*>& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, const AstArray<AstGenericTypePack>& genericPacks, AstLocal* self, const AstArray<AstLocal*>& args, bool vararg, const Location& varargLocation,
const AstName& debugname, const std::optional<AstTypeList>& returnAnnotation, AstTypePack* varargAnnotation, AstStatBlock* body, size_t functionDepth, const AstName& debugname, const std::optional<AstTypeList>& returnAnnotation,
const std::optional<Location>& argLocation) AstTypePack* varargAnnotation, const std::optional<Location>& argLocation)
: AstExpr(ClassIndex(), location) : AstExpr(ClassIndex(), location)
, attributes(attributes)
, generics(generics) , generics(generics)
, genericPacks(genericPacks) , genericPacks(genericPacks)
, self(self) , self(self)
@ -201,6 +215,18 @@ void AstExprFunction::visit(AstVisitor* visitor)
} }
} }
bool AstExprFunction::hasNativeAttribute() const
{
LUAU_ASSERT(FFlag::LuauNativeAttribute);
for (const auto attribute : attributes)
{
if (attribute->type == AstAttr::Type::Native)
return true;
}
return false;
}
AstExprTable::AstExprTable(const Location& location, const AstArray<Item>& items) AstExprTable::AstExprTable(const Location& location, const AstArray<Item>& items)
: AstExpr(ClassIndex(), location) : AstExpr(ClassIndex(), location)
, items(items) , items(items)
@ -679,9 +705,10 @@ void AstStatTypeAlias::visit(AstVisitor* visitor)
} }
} }
AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type) AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, const Location& nameLocation, AstType* type)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, name(name) , name(name)
, nameLocation(nameLocation)
, type(type) , type(type)
{ {
} }
@ -692,31 +719,37 @@ void AstStatDeclareGlobal::visit(AstVisitor* visitor)
type->visit(visitor); type->visit(visitor);
} }
AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray<AstGenericType>& generics, AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const Location& nameLocation,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params,
const AstTypeList& retTypes) const AstArray<AstArgumentName>& paramNames, bool vararg, const Location& varargLocation, const AstTypeList& retTypes)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, attributes()
, name(name) , name(name)
, nameLocation(nameLocation)
, generics(generics) , generics(generics)
, genericPacks(genericPacks) , genericPacks(genericPacks)
, params(params) , params(params)
, paramNames(paramNames) , paramNames(paramNames)
, vararg(vararg)
, varargLocation(varargLocation)
, retTypes(retTypes) , retTypes(retTypes)
, checkedFunction(false)
{ {
} }
AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray<AstGenericType>& generics, AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstName& name,
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames, const Location& nameLocation, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& retTypes, bool checkedFunction) const AstTypeList& params, const AstArray<AstArgumentName>& paramNames, bool vararg, const Location& varargLocation, const AstTypeList& retTypes)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, attributes(attributes)
, name(name) , name(name)
, nameLocation(nameLocation)
, generics(generics) , generics(generics)
, genericPacks(genericPacks) , genericPacks(genericPacks)
, params(params) , params(params)
, paramNames(paramNames) , paramNames(paramNames)
, vararg(vararg)
, varargLocation(varargLocation)
, retTypes(retTypes) , retTypes(retTypes)
, checkedFunction(checkedFunction)
{ {
} }
@ -729,6 +762,19 @@ void AstStatDeclareFunction::visit(AstVisitor* visitor)
} }
} }
bool AstStatDeclareFunction::isCheckedFunction() const
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
for (const AstAttr* attr : attributes)
{
if (attr->type == AstAttr::Type::Checked)
return true;
}
return false;
}
AstStatDeclareClass::AstStatDeclareClass(const Location& location, const AstName& name, std::optional<AstName> superName, AstStatDeclareClass::AstStatDeclareClass(const Location& location, const AstName& name, std::optional<AstName> superName,
const AstArray<AstDeclaredClassProp>& props, AstTableIndexer* indexer) const AstArray<AstDeclaredClassProp>& props, AstTableIndexer* indexer)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
@ -820,25 +866,26 @@ void AstTypeTable::visit(AstVisitor* visitor)
AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes) const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes)
: AstType(ClassIndex(), location) : AstType(ClassIndex(), location)
, attributes()
, generics(generics) , generics(generics)
, genericPacks(genericPacks) , genericPacks(genericPacks)
, argTypes(argTypes) , argTypes(argTypes)
, argNames(argNames) , argNames(argNames)
, returnTypes(returnTypes) , returnTypes(returnTypes)
, checkedFunction(false)
{ {
LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size); LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size);
} }
AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes, bool checkedFunction) const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes)
: AstType(ClassIndex(), location) : AstType(ClassIndex(), location)
, attributes(attributes)
, generics(generics) , generics(generics)
, genericPacks(genericPacks) , genericPacks(genericPacks)
, argTypes(argTypes) , argTypes(argTypes)
, argNames(argNames) , argNames(argNames)
, returnTypes(returnTypes) , returnTypes(returnTypes)
, checkedFunction(checkedFunction)
{ {
LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size); LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size);
} }
@ -852,6 +899,19 @@ void AstTypeFunction::visit(AstVisitor* visitor)
} }
} }
bool AstTypeFunction::isCheckedFunction() const
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
for (const AstAttr* attr : attributes)
{
if (attr->type == AstAttr::Type::Checked)
return true;
}
return false;
}
AstTypeTypeof::AstTypeTypeof(const Location& location, AstExpr* expr) AstTypeTypeof::AstTypeTypeof(const Location& location, AstExpr* expr)
: AstType(ClassIndex(), location) : AstType(ClassIndex(), location)
, expr(expr) , expr(expr)

View file

@ -8,7 +8,7 @@
#include <limits.h> #include <limits.h>
LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false) LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false)
LUAU_FASTFLAGVARIABLE(LuauCheckedFunctionSyntax, false) LUAU_FASTFLAGVARIABLE(LuauAttributeSyntax, false)
namespace Luau namespace Luau
{ {
@ -103,11 +103,19 @@ Lexeme::Lexeme(const Location& location, Type type, const char* name)
, length(0) , length(0)
, name(name) , name(name)
{ {
LUAU_ASSERT(type == Name || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END)); LUAU_ASSERT(type == Name || type == Attribute || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END));
}
unsigned int Lexeme::getLength() const
{
LUAU_ASSERT(type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd ||
type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment);
return length;
} }
static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", "not", "or", static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", "not", "or",
"repeat", "return", "then", "true", "until", "while", "@checked"}; "repeat", "return", "then", "true", "until", "while"};
std::string Lexeme::toString() const std::string Lexeme::toString() const
{ {
@ -192,6 +200,10 @@ std::string Lexeme::toString() const
case Comment: case Comment:
return "comment"; return "comment";
case Attribute:
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
return name ? format("'%s'", name) : "attribute";
case BrokenString: case BrokenString:
return "malformed string"; return "malformed string";
@ -279,7 +291,7 @@ std::pair<AstName, Lexeme::Type> AstNameTable::getOrAddWithType(const char* name
nameData[length] = 0; nameData[length] = 0;
const_cast<Entry&>(entry).value = AstName(nameData); const_cast<Entry&>(entry).value = AstName(nameData);
const_cast<Entry&>(entry).type = Lexeme::Name; const_cast<Entry&>(entry).type = (name[0] == '@' ? Lexeme::Attribute : Lexeme::Name);
return std::make_pair(entry.value, entry.type); return std::make_pair(entry.value, entry.type);
} }
@ -995,16 +1007,10 @@ Lexeme Lexer::readNext()
} }
case '@': case '@':
{ {
if (FFlag::LuauCheckedFunctionSyntax) if (FFlag::LuauAttributeSyntax)
{ {
// We're trying to lex the token @checked std::pair<AstName, Lexeme::Type> attribute = readName();
LUAU_ASSERT(peekch() == '@'); return Lexeme(Location(start, position()), Lexeme::Attribute, attribute.first.value);
std::pair<AstName, Lexeme::Type> maybeChecked = readName();
if (maybeChecked.second != Lexeme::ReservedChecked)
return Lexeme(Location(start, position()), Lexeme::Error);
return Lexeme(Location(start, position()), maybeChecked.second, maybeChecked.first.value);
} }
} }
default: default:

View file

@ -16,13 +16,24 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
// Warning: If you are introducing new syntax, ensure that it is behind a separate // Warning: If you are introducing new syntax, ensure that it is behind a separate
// flag so that we don't break production games by reverting syntax changes. // flag so that we don't break production games by reverting syntax changes.
// See docs/SyntaxChanges.md for an explanation. // See docs/SyntaxChanges.md for an explanation.
LUAU_FASTFLAG(LuauCheckedFunctionSyntax)
LUAU_FASTFLAGVARIABLE(LuauReadWritePropertySyntax, false)
LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false)
LUAU_FASTFLAG(LuauAttributeSyntax)
LUAU_FASTFLAGVARIABLE(LuauLeadingBarAndAmpersand2, false)
LUAU_FASTFLAGVARIABLE(LuauNativeAttribute, false)
LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr, false)
LUAU_FASTFLAGVARIABLE(LuauDeclarationExtraPropData, false)
namespace Luau namespace Luau
{ {
struct AttributeEntry
{
const char* name;
AstAttr::Type type;
};
AttributeEntry kAttributeEntries[] = {{"@checked", AstAttr::Type::Checked}, {"@native", AstAttr::Type::Native}, {nullptr, AstAttr::Type::Checked}};
ParseError::ParseError(const Location& location, const std::string& message) ParseError::ParseError(const Location& location, const std::string& message)
: location(location) : location(location)
, message(message) , message(message)
@ -281,7 +292,9 @@ AstStatBlock* Parser::parseBlockNoScope()
// for binding `=' exp `,' exp [`,' exp] do block end | // for binding `=' exp `,' exp [`,' exp] do block end |
// for namelist in explist do block end | // for namelist in explist do block end |
// function funcname funcbody | // function funcname funcbody |
// attributes function funcname funcbody |
// local function Name funcbody | // local function Name funcbody |
// local attributes function Name funcbody |
// local namelist [`=' explist] // local namelist [`=' explist]
// laststat ::= return [explist] | break // laststat ::= return [explist] | break
AstStat* Parser::parseStat() AstStat* Parser::parseStat()
@ -300,13 +313,16 @@ AstStat* Parser::parseStat()
case Lexeme::ReservedRepeat: case Lexeme::ReservedRepeat:
return parseRepeat(); return parseRepeat();
case Lexeme::ReservedFunction: case Lexeme::ReservedFunction:
return parseFunctionStat(); return parseFunctionStat(AstArray<AstAttr*>({nullptr, 0}));
case Lexeme::ReservedLocal: case Lexeme::ReservedLocal:
return parseLocal(); return parseLocal(AstArray<AstAttr*>({nullptr, 0}));
case Lexeme::ReservedReturn: case Lexeme::ReservedReturn:
return parseReturn(); return parseReturn();
case Lexeme::ReservedBreak: case Lexeme::ReservedBreak:
return parseBreak(); return parseBreak();
case Lexeme::Attribute:
if (FFlag::LuauAttributeSyntax)
return parseAttributeStat();
default:; default:;
} }
@ -344,7 +360,7 @@ AstStat* Parser::parseStat()
if (options.allowDeclarationSyntax) if (options.allowDeclarationSyntax)
{ {
if (ident == "declare") if (ident == "declare")
return parseDeclaration(expr->location); return parseDeclaration(expr->location, AstArray<AstAttr*>({nullptr, 0}));
} }
// skip unexpected symbol if lexer couldn't advance at all (statements are parsed in a loop) // skip unexpected symbol if lexer couldn't advance at all (statements are parsed in a loop)
@ -653,7 +669,7 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug
} }
// function funcname funcbody // function funcname funcbody
AstStat* Parser::parseFunctionStat() AstStat* Parser::parseFunctionStat(const AstArray<AstAttr*>& attributes)
{ {
Location start = lexer.current().location; Location start = lexer.current().location;
@ -666,16 +682,129 @@ AstStat* Parser::parseFunctionStat()
matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; matchRecoveryStopOnToken[Lexeme::ReservedEnd]++;
AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr).first; AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr, attributes).first;
matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; matchRecoveryStopOnToken[Lexeme::ReservedEnd]--;
return allocator.alloc<AstStatFunction>(Location(start, body->location), expr, body); return allocator.alloc<AstStatFunction>(Location(start, body->location), expr, body);
} }
std::pair<bool, AstAttr::Type> Parser::validateAttribute(const char* attributeName, const TempVector<AstAttr*>& attributes)
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
AstAttr::Type type;
// check if the attribute name is valid
bool found = false;
for (int i = 0; kAttributeEntries[i].name; ++i)
{
found = !strcmp(attributeName, kAttributeEntries[i].name);
if (found)
{
type = kAttributeEntries[i].type;
if (!FFlag::LuauNativeAttribute && type == AstAttr::Type::Native)
found = false;
break;
}
}
if (!found)
{
if (strlen(attributeName) == 1)
report(lexer.current().location, "Attribute name is missing");
else
report(lexer.current().location, "Invalid attribute '%s'", attributeName);
}
else
{
// check that attribute is not duplicated
for (const AstAttr* attr : attributes)
{
if (attr->type == type)
{
report(lexer.current().location, "Cannot duplicate attribute '%s'", attributeName);
}
}
}
return {found, type};
}
// attribute ::= '@' NAME
void Parser::parseAttribute(TempVector<AstAttr*>& attributes)
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
LUAU_ASSERT(lexer.current().type == Lexeme::Type::Attribute);
Location loc = lexer.current().location;
const char* name = lexer.current().name;
const auto [found, type] = validateAttribute(name, attributes);
nextLexeme();
if (found)
attributes.push_back(allocator.alloc<AstAttr>(loc, type));
}
// attributes ::= {attribute}
AstArray<AstAttr*> Parser::parseAttributes()
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
Lexeme::Type type = lexer.current().type;
LUAU_ASSERT(type == Lexeme::Attribute);
TempVector<AstAttr*> attributes(scratchAttr);
while (lexer.current().type == Lexeme::Attribute)
parseAttribute(attributes);
return copy(attributes);
}
// attributes local function Name funcbody
// attributes function funcname funcbody
// attributes `declare function' Name`(' [parlist] `)' [`:` Type]
// declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}'
AstStat* Parser::parseAttributeStat()
{
LUAU_ASSERT(FFlag::LuauAttributeSyntax);
AstArray<AstAttr*> attributes = parseAttributes();
Lexeme::Type type = lexer.current().type;
switch (type)
{
case Lexeme::Type::ReservedFunction:
return parseFunctionStat(attributes);
case Lexeme::Type::ReservedLocal:
return parseLocal(attributes);
case Lexeme::Type::Name:
if (options.allowDeclarationSyntax && !strcmp("declare", lexer.current().data))
{
AstExpr* expr = parsePrimaryExpr(/* asStatement= */ true);
return parseDeclaration(expr->location, attributes);
}
default:
return reportStatError(lexer.current().location, {}, {},
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got %s intead",
lexer.current().toString().c_str());
}
}
// local function Name funcbody | // local function Name funcbody |
// local bindinglist [`=' explist] // local bindinglist [`=' explist]
AstStat* Parser::parseLocal() AstStat* Parser::parseLocal(const AstArray<AstAttr*>& attributes)
{ {
Location start = lexer.current().location; Location start = lexer.current().location;
@ -695,7 +824,7 @@ AstStat* Parser::parseLocal()
matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; matchRecoveryStopOnToken[Lexeme::ReservedEnd]++;
auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name); auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name, attributes);
matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; matchRecoveryStopOnToken[Lexeme::ReservedEnd]--;
@ -705,6 +834,12 @@ AstStat* Parser::parseLocal()
} }
else else
{ {
if (FFlag::LuauAttributeSyntax && attributes.size != 0)
{
return reportStatError(lexer.current().location, {}, {}, "Expected 'function' after local declaration with attribute, but got %s intead",
lexer.current().toString().c_str());
}
matchRecoveryStopOnToken['=']++; matchRecoveryStopOnToken['=']++;
TempVector<Binding> names(scratchBinding); TempVector<Binding> names(scratchBinding);
@ -775,8 +910,16 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported)
AstDeclaredClassProp Parser::parseDeclaredClassMethod() AstDeclaredClassProp Parser::parseDeclaredClassMethod()
{ {
Location start;
if (FFlag::LuauDeclarationExtraPropData)
start = lexer.current().location;
nextLexeme(); nextLexeme();
Location start = lexer.current().location;
if (!FFlag::LuauDeclarationExtraPropData)
start = lexer.current().location;
Name fnName = parseName("function name"); Name fnName = parseName("function name");
// TODO: generic method declarations CLI-39909 // TODO: generic method declarations CLI-39909
@ -801,15 +944,15 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod()
expectMatchAndConsume(')', matchParen); expectMatchAndConsume(')', matchParen);
AstTypeList retTypes = parseOptionalReturnType().value_or(AstTypeList{copy<AstType*>(nullptr, 0), nullptr}); AstTypeList retTypes = parseOptionalReturnType().value_or(AstTypeList{copy<AstType*>(nullptr, 0), nullptr});
Location end = lexer.current().location; Location end = FFlag::LuauDeclarationExtraPropData ? lexer.previousLocation() : lexer.current().location;
TempVector<AstType*> vars(scratchType); TempVector<AstType*> vars(scratchType);
TempVector<std::optional<AstArgumentName>> varNames(scratchOptArgName); TempVector<std::optional<AstArgumentName>> varNames(scratchOptArgName);
if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr) if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr)
{ {
return AstDeclaredClassProp{ return AstDeclaredClassProp{fnName.name, FFlag::LuauDeclarationExtraPropData ? fnName.location : Location{},
fnName.name, reportTypeError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true}; reportTypeError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true};
} }
// Skip the first index. // Skip the first index.
@ -829,21 +972,21 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod()
AstType* fnType = allocator.alloc<AstTypeFunction>( AstType* fnType = allocator.alloc<AstTypeFunction>(
Location(start, end), generics, genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes); Location(start, end), generics, genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes);
return AstDeclaredClassProp{fnName.name, fnType, true}; return AstDeclaredClassProp{fnName.name, FFlag::LuauDeclarationExtraPropData ? fnName.location : Location{}, fnType, true,
FFlag::LuauDeclarationExtraPropData ? Location(start, end) : Location{}};
} }
AstStat* Parser::parseDeclaration(const Location& start) AstStat* Parser::parseDeclaration(const Location& start, const AstArray<AstAttr*>& attributes)
{ {
// `declare` token is already parsed at this point // `declare` token is already parsed at this point
if (FFlag::LuauAttributeSyntax && (attributes.size != 0) && (lexer.current().type != Lexeme::ReservedFunction))
return reportStatError(lexer.current().location, {}, {}, "Expected a function type declaration after attribute, but got %s intead",
lexer.current().toString().c_str());
if (lexer.current().type == Lexeme::ReservedFunction) if (lexer.current().type == Lexeme::ReservedFunction)
{ {
nextLexeme(); nextLexeme();
bool checkedFunction = false;
if (FFlag::LuauCheckedFunctionSyntax && lexer.current().type == Lexeme::ReservedChecked)
{
checkedFunction = true;
nextLexeme();
}
Name globalName = parseName("global function name"); Name globalName = parseName("global function name");
auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false);
@ -881,8 +1024,12 @@ AstStat* Parser::parseDeclaration(const Location& start)
if (vararg && !varargAnnotation) if (vararg && !varargAnnotation)
return reportStatError(Location(start, end), {}, {}, "All declaration parameters must be annotated"); return reportStatError(Location(start, end), {}, {}, "All declaration parameters must be annotated");
return allocator.alloc<AstStatDeclareFunction>(Location(start, end), globalName.name, generics, genericPacks, if (FFlag::LuauDeclarationExtraPropData)
AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes, checkedFunction); return allocator.alloc<AstStatDeclareFunction>(Location(start, end), attributes, globalName.name, globalName.location, generics,
genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), vararg, varargLocation, retTypes);
else
return allocator.alloc<AstStatDeclareFunction>(Location(start, end), attributes, globalName.name, Location{}, generics, genericPacks,
AstTypeList{copy(vars), varargAnnotation}, copy(varNames), false, Location{}, retTypes);
} }
else if (AstName(lexer.current().name) == "class") else if (AstName(lexer.current().name) == "class")
{ {
@ -912,19 +1059,42 @@ AstStat* Parser::parseDeclaration(const Location& start)
const Lexeme begin = lexer.current(); const Lexeme begin = lexer.current();
nextLexeme(); // [ nextLexeme(); // [
std::optional<AstArray<char>> chars = parseCharArray(); if (FFlag::LuauDeclarationExtraPropData)
{
const Location nameBegin = lexer.current().location;
std::optional<AstArray<char>> chars = parseCharArray();
expectMatchAndConsume(']', begin); const Location nameEnd = lexer.previousLocation();
expectAndConsume(':', "property type annotation");
AstType* type = parseType();
// since AstName contains a char*, it can't contain null expectMatchAndConsume(']', begin);
bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); expectAndConsume(':', "property type annotation");
AstType* type = parseType();
if (chars && !containsNull) // since AstName contains a char*, it can't contain null
props.push_back(AstDeclaredClassProp{AstName(chars->data), type, false}); bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size);
if (chars && !containsNull)
props.push_back(AstDeclaredClassProp{
AstName(chars->data), Location(nameBegin, nameEnd), type, false, Location(begin.location, lexer.previousLocation())});
else
report(begin.location, "String literal contains malformed escape sequence or \\0");
}
else else
report(begin.location, "String literal contains malformed escape sequence or \\0"); {
std::optional<AstArray<char>> chars = parseCharArray();
expectMatchAndConsume(']', begin);
expectAndConsume(':', "property type annotation");
AstType* type = parseType();
// since AstName contains a char*, it can't contain null
bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size);
if (chars && !containsNull)
props.push_back(AstDeclaredClassProp{AstName(chars->data), Location{}, type, false});
else
report(begin.location, "String literal contains malformed escape sequence or \\0");
}
} }
else if (lexer.current().type == '[') else if (lexer.current().type == '[')
{ {
@ -942,12 +1112,21 @@ AstStat* Parser::parseDeclaration(const Location& start)
indexer = parseTableIndexer(AstTableAccess::ReadWrite, std::nullopt); indexer = parseTableIndexer(AstTableAccess::ReadWrite, std::nullopt);
} }
} }
else if (FFlag::LuauDeclarationExtraPropData)
{
Location propStart = lexer.current().location;
Name propName = parseName("property name");
expectAndConsume(':', "property type annotation");
AstType* propType = parseType();
props.push_back(
AstDeclaredClassProp{propName.name, propName.location, propType, false, Location(propStart, lexer.previousLocation())});
}
else else
{ {
Name propName = parseName("property name"); Name propName = parseName("property name");
expectAndConsume(':', "property type annotation"); expectAndConsume(':', "property type annotation");
AstType* propType = parseType(); AstType* propType = parseType();
props.push_back(AstDeclaredClassProp{propName.name, propType, false}); props.push_back(AstDeclaredClassProp{propName.name, Location{}, propType, false});
} }
} }
@ -961,7 +1140,8 @@ AstStat* Parser::parseDeclaration(const Location& start)
expectAndConsume(':', "global variable declaration"); expectAndConsume(':', "global variable declaration");
AstType* type = parseType(/* in declaration context */ true); AstType* type = parseType(/* in declaration context */ true);
return allocator.alloc<AstStatDeclareGlobal>(Location(start, type->location), globalName->name, type); return allocator.alloc<AstStatDeclareGlobal>(
Location(start, type->location), globalName->name, FFlag::LuauDeclarationExtraPropData ? globalName->location : Location{}, type);
} }
else else
{ {
@ -1036,7 +1216,7 @@ std::pair<AstLocal*, AstArray<AstLocal*>> Parser::prepareFunctionArguments(const
// funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end // funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end
// parlist ::= bindinglist [`,' `...'] | `...' // parlist ::= bindinglist [`,' `...'] | `...'
std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody( std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody(
bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName) bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray<AstAttr*>& attributes)
{ {
Location start = matchFunction.location; Location start = matchFunction.location;
@ -1088,7 +1268,7 @@ std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody(
bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction); bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction);
body->hasEnd = hasEnd; body->hasEnd = hasEnd;
return {allocator.alloc<AstExprFunction>(Location(start, end), generics, genericPacks, self, vars, vararg, varargLocation, body, return {allocator.alloc<AstExprFunction>(Location(start, end), attributes, generics, genericPacks, self, vars, vararg, varargLocation, body,
functionStack.size(), debugname, typelist, varargAnnotation, argLocation), functionStack.size(), debugname, typelist, varargAnnotation, argLocation),
funLocal}; funLocal};
} }
@ -1297,7 +1477,7 @@ std::pair<Location, AstTypeList> Parser::parseReturnType()
return {location, AstTypeList{copy(result), varargAnnotation}}; return {location, AstTypeList{copy(result), varargAnnotation}};
} }
AstType* tail = parseFunctionTypeTail(begin, {}, {}, copy(result), copy(resultNames), varargAnnotation); AstType* tail = parseFunctionTypeTail(begin, {nullptr, 0}, {}, {}, copy(result), copy(resultNames), varargAnnotation);
return {Location{location, tail->location}, AstTypeList{copy(&tail, 1), varargAnnotation}}; return {Location{location, tail->location}, AstTypeList{copy(&tail, 1), varargAnnotation}};
} }
@ -1340,22 +1520,19 @@ AstType* Parser::parseTableType(bool inDeclarationContext)
AstTableAccess access = AstTableAccess::ReadWrite; AstTableAccess access = AstTableAccess::ReadWrite;
std::optional<Location> accessLocation; std::optional<Location> accessLocation;
if (FFlag::LuauReadWritePropertySyntax || FFlag::DebugLuauDeferredConstraintResolution) if (lexer.current().type == Lexeme::Name && lexer.lookahead().type != ':')
{ {
if (lexer.current().type == Lexeme::Name && lexer.lookahead().type != ':') if (AstName(lexer.current().name) == "read")
{ {
if (AstName(lexer.current().name) == "read") accessLocation = lexer.current().location;
{ access = AstTableAccess::Read;
accessLocation = lexer.current().location; lexer.next();
access = AstTableAccess::Read; }
lexer.next(); else if (AstName(lexer.current().name) == "write")
} {
else if (AstName(lexer.current().name) == "write") accessLocation = lexer.current().location;
{ access = AstTableAccess::Write;
accessLocation = lexer.current().location; lexer.next();
access = AstTableAccess::Write;
lexer.next();
}
} }
} }
@ -1439,7 +1616,7 @@ AstType* Parser::parseTableType(bool inDeclarationContext)
// ReturnType ::= Type | `(' TypeList `)' // ReturnType ::= Type | `(' TypeList `)'
// FunctionType ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType // FunctionType ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType
AstTypeOrPack Parser::parseFunctionType(bool allowPack, bool isCheckedFunction) AstTypeOrPack Parser::parseFunctionType(bool allowPack, const AstArray<AstAttr*>& attributes)
{ {
incrementRecursionCounter("type annotation"); incrementRecursionCounter("type annotation");
@ -1487,11 +1664,12 @@ AstTypeOrPack Parser::parseFunctionType(bool allowPack, bool isCheckedFunction)
AstArray<std::optional<AstArgumentName>> paramNames = copy(names); AstArray<std::optional<AstArgumentName>> paramNames = copy(names);
return {parseFunctionTypeTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation, isCheckedFunction), {}}; return {parseFunctionTypeTail(begin, attributes, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}};
} }
AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray<AstGenericType> generics, AstArray<AstGenericTypePack> genericPacks, AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, const AstArray<AstAttr*>& attributes, AstArray<AstGenericType> generics,
AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames, AstTypePack* varargAnnotation, bool isCheckedFunction) AstArray<AstGenericTypePack> genericPacks, AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames,
AstTypePack* varargAnnotation)
{ {
incrementRecursionCounter("type annotation"); incrementRecursionCounter("type annotation");
@ -1516,7 +1694,7 @@ AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray<AstGenericT
AstTypeList paramTypes = AstTypeList{params, varargAnnotation}; AstTypeList paramTypes = AstTypeList{params, varargAnnotation};
return allocator.alloc<AstTypeFunction>( return allocator.alloc<AstTypeFunction>(
Location(begin.location, endLocation), generics, genericPacks, paramTypes, paramNames, returnTypeList, isCheckedFunction); Location(begin.location, endLocation), attributes, generics, genericPacks, paramTypes, paramNames, returnTypeList);
} }
// Type ::= // Type ::=
@ -1528,7 +1706,11 @@ AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray<AstGenericT
AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin)
{ {
TempVector<AstType*> parts(scratchType); TempVector<AstType*> parts(scratchType);
parts.push_back(type);
if (!FFlag::LuauLeadingBarAndAmpersand2 || type != nullptr)
{
parts.push_back(type);
}
incrementRecursionCounter("type annotation"); incrementRecursionCounter("type annotation");
@ -1553,6 +1735,8 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin)
} }
else if (c == '?') else if (c == '?')
{ {
LUAU_ASSERT(parts.size() >= 1);
Location loc = lexer.current().location; Location loc = lexer.current().location;
nextLexeme(); nextLexeme();
@ -1585,7 +1769,7 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin)
} }
if (parts.size() == 1) if (parts.size() == 1)
return type; return FFlag::LuauLeadingBarAndAmpersand2 ? parts[0] : type;
if (isUnion && isIntersection) if (isUnion && isIntersection)
{ {
@ -1628,15 +1812,34 @@ AstTypeOrPack Parser::parseTypeOrPack()
AstType* Parser::parseType(bool inDeclarationContext) AstType* Parser::parseType(bool inDeclarationContext)
{ {
unsigned int oldRecursionCount = recursionCounter; unsigned int oldRecursionCount = recursionCounter;
// recursion counter is incremented in parseSimpleType // recursion counter is incremented in parseSimpleType and/or parseTypeSuffix
Location begin = lexer.current().location; Location begin = lexer.current().location;
AstType* type = parseSimpleType(/* allowPack= */ false, /* in declaration context */ inDeclarationContext).type; if (FFlag::LuauLeadingBarAndAmpersand2)
{
AstType* type = nullptr;
recursionCounter = oldRecursionCount; Lexeme::Type c = lexer.current().type;
if (c != '|' && c != '&')
{
type = parseSimpleType(/* allowPack= */ false, /* in declaration context */ inDeclarationContext).type;
recursionCounter = oldRecursionCount;
}
return parseTypeSuffix(type, begin); AstType* typeWithSuffix = parseTypeSuffix(type, begin);
recursionCounter = oldRecursionCount;
return typeWithSuffix;
}
else
{
AstType* type = parseSimpleType(/* allowPack= */ false, /* in declaration context */ inDeclarationContext).type;
recursionCounter = oldRecursionCount;
return parseTypeSuffix(type, begin);
}
} }
// Type ::= nil | Name[`.' Name] [ `<' Type [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}' // Type ::= nil | Name[`.' Name] [ `<' Type [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}'
@ -1647,7 +1850,21 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext)
Location start = lexer.current().location; Location start = lexer.current().location;
if (lexer.current().type == Lexeme::ReservedNil) AstArray<AstAttr*> attributes{nullptr, 0};
if (lexer.current().type == Lexeme::Attribute)
{
if (!inDeclarationContext || !FFlag::LuauAttributeSyntax)
{
return {reportTypeError(start, {}, "attributes are not allowed in declaration context")};
}
else
{
attributes = Parser::parseAttributes();
return parseFunctionType(allowPack, attributes);
}
}
else if (lexer.current().type == Lexeme::ReservedNil)
{ {
nextLexeme(); nextLexeme();
return {allocator.alloc<AstTypeReference>(start, std::nullopt, nameNil, std::nullopt, start), {}}; return {allocator.alloc<AstTypeReference>(start, std::nullopt, nameNil, std::nullopt, start), {}};
@ -1735,15 +1952,9 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext)
{ {
return {parseTableType(/* inDeclarationContext */ inDeclarationContext), {}}; return {parseTableType(/* inDeclarationContext */ inDeclarationContext), {}};
} }
else if (FFlag::LuauCheckedFunctionSyntax && inDeclarationContext && lexer.current().type == Lexeme::ReservedChecked)
{
LUAU_ASSERT(FFlag::LuauCheckedFunctionSyntax);
nextLexeme();
return parseFunctionType(allowPack, /* isCheckedFunction */ true);
}
else if (lexer.current().type == '(' || lexer.current().type == '<') else if (lexer.current().type == '(' || lexer.current().type == '<')
{ {
return parseFunctionType(allowPack); return parseFunctionType(allowPack, AstArray<AstAttr*>({nullptr, 0}));
} }
else if (lexer.current().type == Lexeme::ReservedFunction) else if (lexer.current().type == Lexeme::ReservedFunction)
{ {
@ -2213,11 +2424,24 @@ static ConstantNumberParseResult parseDouble(double& result, const char* data)
return ConstantNumberParseResult::Ok; return ConstantNumberParseResult::Ok;
} }
// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp // simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | [attributes] FUNCTION body | primaryexp
AstExpr* Parser::parseSimpleExpr() AstExpr* Parser::parseSimpleExpr()
{ {
Location start = lexer.current().location; Location start = lexer.current().location;
AstArray<AstAttr*> attributes{nullptr, 0};
if (FFlag::LuauAttributeSyntax && FFlag::LuauAttributeSyntaxFunExpr && lexer.current().type == Lexeme::Attribute)
{
attributes = parseAttributes();
if (lexer.current().type != Lexeme::ReservedFunction)
{
return reportExprError(
start, {}, "Expected 'function' declaration after attribute, but got %s intead", lexer.current().toString().c_str());
}
}
if (lexer.current().type == Lexeme::ReservedNil) if (lexer.current().type == Lexeme::ReservedNil)
{ {
nextLexeme(); nextLexeme();
@ -2241,7 +2465,7 @@ AstExpr* Parser::parseSimpleExpr()
Lexeme matchFunction = lexer.current(); Lexeme matchFunction = lexer.current();
nextLexeme(); nextLexeme();
return parseFunctionBody(false, matchFunction, AstName(), nullptr).first; return parseFunctionBody(false, matchFunction, AstName(), nullptr, attributes).first;
} }
else if (lexer.current().type == Lexeme::Number) else if (lexer.current().type == Lexeme::Number)
{ {
@ -2671,7 +2895,7 @@ std::optional<AstArray<char>> Parser::parseCharArray()
LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString || LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString ||
lexer.current().type == Lexeme::InterpStringSimple); lexer.current().type == Lexeme::InterpStringSimple);
scratchData.assign(lexer.current().data, lexer.current().length); scratchData.assign(lexer.current().data, lexer.current().getLength());
if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple) if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple)
{ {
@ -2716,7 +2940,7 @@ AstExpr* Parser::parseInterpString()
endLocation = currentLexeme.location; endLocation = currentLexeme.location;
scratchData.assign(currentLexeme.data, currentLexeme.length); scratchData.assign(currentLexeme.data, currentLexeme.getLength());
if (!Lexer::fixupQuotedString(scratchData)) if (!Lexer::fixupQuotedString(scratchData))
{ {
@ -2789,7 +3013,7 @@ AstExpr* Parser::parseNumber()
{ {
Location start = lexer.current().location; Location start = lexer.current().location;
scratchData.assign(lexer.current().data, lexer.current().length); scratchData.assign(lexer.current().data, lexer.current().getLength());
// Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al // Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al
if (scratchData.find('_') != std::string::npos) if (scratchData.find('_') != std::string::npos)
@ -3144,11 +3368,11 @@ void Parser::nextLexeme()
return; return;
// Comments starting with ! are called "hot comments" and contain directives for type checking / linting / compiling // Comments starting with ! are called "hot comments" and contain directives for type checking / linting / compiling
if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!') if (lexeme.type == Lexeme::Comment && lexeme.getLength() && lexeme.data[0] == '!')
{ {
const char* text = lexeme.data; const char* text = lexeme.data;
unsigned int end = lexeme.length; unsigned int end = lexeme.getLength();
while (end > 0 && isSpace(text[end - 1])) while (end > 0 && isSpace(text[end - 1]))
--end; --end;

View file

@ -250,6 +250,10 @@ void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector<Ev
ThreadContext& getThreadContext() ThreadContext& getThreadContext()
{ {
// Check custom provider that which might implement a custom TLS
if (auto provider = threadContextProvider())
return provider();
thread_local ThreadContext context; thread_local ThreadContext context;
return context; return context;
} }

View file

@ -317,6 +317,7 @@ static bool compileFile(const char* name, CompileFormat format, Luau::CodeGen::A
{ {
options.includeAssembly = format != CompileFormat::CodegenIr; options.includeAssembly = format != CompileFormat::CodegenIr;
options.includeIr = format != CompileFormat::CodegenAsm; options.includeIr = format != CompileFormat::CodegenAsm;
options.includeIrTypes = format != CompileFormat::CodegenAsm;
options.includeOutlinedCode = format == CompileFormat::CodegenVerbose; options.includeOutlinedCode = format == CompileFormat::CodegenVerbose;
} }

View file

@ -144,7 +144,10 @@ static int lua_require(lua_State* L)
if (luau_load(ML, resolvedRequire.chunkName.c_str(), bytecode.data(), bytecode.size(), 0) == 0) if (luau_load(ML, resolvedRequire.chunkName.c_str(), bytecode.data(), bytecode.size(), 0) == 0)
{ {
if (codegen) if (codegen)
Luau::CodeGen::compile(ML, -1); {
Luau::CodeGen::CompilationOptions nativeOptions;
Luau::CodeGen::compile(ML, -1, nativeOptions);
}
if (coverageActive()) if (coverageActive())
coverageTrack(ML, -1); coverageTrack(ML, -1);
@ -253,12 +256,16 @@ void setupState(lua_State* L)
void setupArguments(lua_State* L, int argc, char** argv) void setupArguments(lua_State* L, int argc, char** argv)
{ {
lua_checkstack(L, argc);
for (int i = 0; i < argc; ++i) for (int i = 0; i < argc; ++i)
lua_pushstring(L, argv[i]); lua_pushstring(L, argv[i]);
} }
std::string runCode(lua_State* L, const std::string& source) std::string runCode(lua_State* L, const std::string& source)
{ {
lua_checkstack(L, LUA_MINSTACK);
std::string bytecode = Luau::compile(source, copts()); std::string bytecode = Luau::compile(source, copts());
if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0) if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0)
@ -429,6 +436,8 @@ static void completeIndexer(lua_State* L, const std::string& editBuffer, const A
std::string_view lookup = editBuffer; std::string_view lookup = editBuffer;
bool completeOnlyFunctions = false; bool completeOnlyFunctions = false;
lua_checkstack(L, LUA_MINSTACK);
// Push the global variable table to begin the search // Push the global variable table to begin the search
lua_pushvalue(L, LUA_GLOBALSINDEX); lua_pushvalue(L, LUA_GLOBALSINDEX);
@ -602,7 +611,10 @@ static bool runFile(const char* name, lua_State* GL, bool repl)
if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0)
{ {
if (codegen) if (codegen)
Luau::CodeGen::compile(L, -1); {
Luau::CodeGen::CompilationOptions nativeOptions;
Luau::CodeGen::compile(L, -1, nativeOptions);
}
if (coverageActive()) if (coverageActive())
coverageTrack(L, -1); coverageTrack(L, -1);

View file

@ -229,6 +229,7 @@ if(LUAU_BUILD_TESTS)
target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen) target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen)
target_compile_options(Luau.Conformance PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Conformance PRIVATE ${LUAU_OPTIONS})
target_compile_definitions(Luau.Conformance PRIVATE DOCTEST_CONFIG_DOUBLE_STRINGIFY)
target_include_directories(Luau.Conformance PRIVATE extern) target_include_directories(Luau.Conformance PRIVATE extern)
target_link_libraries(Luau.Conformance PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen Luau.VM) target_link_libraries(Luau.Conformance PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen Luau.VM)
if(CMAKE_SYSTEM_NAME MATCHES "Android|iOS") if(CMAKE_SYSTEM_NAME MATCHES "Android|iOS")

View file

@ -13,10 +13,11 @@ namespace CodeGen
{ {
struct IrFunction; struct IrFunction;
struct HostIrHooks;
void loadBytecodeTypeInfo(IrFunction& function); void loadBytecodeTypeInfo(IrFunction& function);
void buildBytecodeBlocks(IrFunction& function, const std::vector<uint8_t>& jumpTargets); void buildBytecodeBlocks(IrFunction& function, const std::vector<uint8_t>& jumpTargets);
void analyzeBytecodeTypes(IrFunction& function); void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks);
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -12,6 +12,12 @@
struct lua_State; struct lua_State;
#if defined(__x86_64__) || defined(_M_X64)
#define CODEGEN_TARGET_X64
#elif defined(__aarch64__) || defined(_M_ARM64)
#define CODEGEN_TARGET_A64
#endif
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
@ -40,8 +46,12 @@ enum class CodeGenCompilationResult
CodeGenAssemblerFinalizationFailure = 7, // Failure during assembler finalization CodeGenAssemblerFinalizationFailure = 7, // Failure during assembler finalization
CodeGenLoweringFailure = 8, // Lowering failed CodeGenLoweringFailure = 8, // Lowering failed
AllocationFailed = 9, // Native codegen failed due to an allocation error AllocationFailed = 9, // Native codegen failed due to an allocation error
Count = 10,
}; };
std::string toString(const CodeGenCompilationResult& result);
struct ProtoCompilationFailure struct ProtoCompilationFailure
{ {
CodeGenCompilationResult result = CodeGenCompilationResult::Success; CodeGenCompilationResult result = CodeGenCompilationResult::Success;
@ -62,6 +72,97 @@ struct CompilationResult
} }
}; };
struct IrBuilder;
struct IrOp;
using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength);
using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos);
using HostVectorNamecallHandler = bool (*)(
IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos);
enum class HostMetamethod
{
Add,
Sub,
Mul,
Div,
Idiv,
Mod,
Pow,
Minus,
Equal,
LessThan,
LessEqual,
Length,
Concat,
};
using HostUserdataOperationBytecodeType = uint8_t (*)(uint8_t type, const char* member, size_t memberLength);
using HostUserdataMetamethodBytecodeType = uint8_t (*)(uint8_t lhsTy, uint8_t rhsTy, HostMetamethod method);
using HostUserdataAccessHandler = bool (*)(
IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos);
using HostUserdataMetamethodHandler = bool (*)(
IrBuilder& builder, uint8_t lhsTy, uint8_t rhsTy, int resultReg, IrOp lhs, IrOp rhs, HostMetamethod method, int pcpos);
using HostUserdataNamecallHandler = bool (*)(
IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos);
struct HostIrHooks
{
// Suggest result type of a vector field access
HostVectorOperationBytecodeType vectorAccessBytecodeType = nullptr;
// Suggest result type of a vector function namecall
HostVectorOperationBytecodeType vectorNamecallBytecodeType = nullptr;
// Handle vector value field access
// 'sourceReg' is guaranteed to be a vector
// Guards should take a VM exit to 'pcpos'
HostVectorAccessHandler vectorAccess = nullptr;
// Handle namecall performed on a vector value
// 'sourceReg' (self argument) is guaranteed to be a vector
// All other arguments can be of any type
// Guards should take a VM exit to 'pcpos'
HostVectorNamecallHandler vectorNamecall = nullptr;
// Suggest result type of a userdata field access
HostUserdataOperationBytecodeType userdataAccessBytecodeType = nullptr;
// Suggest result type of a metamethod call
HostUserdataMetamethodBytecodeType userdataMetamethodBytecodeType = nullptr;
// Suggest result type of a userdata namecall
HostUserdataOperationBytecodeType userdataNamecallBytecodeType = nullptr;
// Handle userdata value field access
// 'sourceReg' is guaranteed to be a userdata, but tag has to be checked
// Write to 'resultReg' might invalidate 'sourceReg'
// Guards should take a VM exit to 'pcpos'
HostUserdataAccessHandler userdataAccess = nullptr;
// Handle metamethod operation on a userdata value
// 'lhs' and 'rhs' operands can be VM registers of constants
// Operand types have to be checked and userdata operand tags have to be checked
// Write to 'resultReg' might invalidate source operands
// Guards should take a VM exit to 'pcpos'
HostUserdataMetamethodHandler userdataMetamethod = nullptr;
// Handle namecall performed on a userdata value
// 'sourceReg' (self argument) is guaranteed to be a userdata, but tag has to be checked
// All other arguments can be of any type
// Guards should take a VM exit to 'pcpos'
HostUserdataNamecallHandler userdataNamecall = nullptr;
};
struct CompilationOptions
{
unsigned int flags = 0;
HostIrHooks hooks;
// null-terminated array of userdata types names that might have custom lowering
const char* const* userdataTypes = nullptr;
};
struct CompilationStats struct CompilationStats
{ {
size_t bytecodeSizeBytes = 0; size_t bytecodeSizeBytes = 0;
@ -101,8 +202,17 @@ using UniqueSharedCodeGenContext = std::unique_ptr<SharedCodeGenContext, SharedC
// SharedCodeGenContext must be destroyed before this function is called. // SharedCodeGenContext must be destroyed before this function is called.
void destroySharedCodeGenContext(const SharedCodeGenContext* codeGenContext) noexcept; void destroySharedCodeGenContext(const SharedCodeGenContext* codeGenContext) noexcept;
void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext); // Initializes native code-gen on the provided Luau VM, using a VM-specific
// code-gen context and either the default allocator parameters or custom
// allocator parameters.
void create(lua_State* L); void create(lua_State* L);
void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext);
void create(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext);
// Initializes native code-gen on the provided Luau VM, using the provided
// SharedCodeGenContext. Note that after this function is called, the
// SharedCodeGenContext must not be destroyed until after the Luau VM L is
// destroyed via lua_close.
void create(lua_State* L, SharedCodeGenContext* codeGenContext); void create(lua_State* L, SharedCodeGenContext* codeGenContext);
// Check if native execution is enabled // Check if native execution is enabled
@ -111,11 +221,20 @@ void create(lua_State* L, SharedCodeGenContext* codeGenContext);
// Enable or disable native execution according to `enabled` argument // Enable or disable native execution according to `enabled` argument
void setNativeExecutionEnabled(lua_State* L, bool enabled); void setNativeExecutionEnabled(lua_State* L, bool enabled);
// Given a name, this function must return the index of the type which matches the type array used all CompilationOptions and AssemblyOptions
// If the type is unknown, 0xff has to be returned
using UserdataRemapperCallback = uint8_t(void* context, const char* name, size_t nameLength);
void setUserdataRemapper(lua_State* L, void* context, UserdataRemapperCallback cb);
using ModuleId = std::array<uint8_t, 16>; using ModuleId = std::array<uint8_t, 16>;
// Builds target function and all inner functions // Builds target function and all inner functions
CompilationResult compile(lua_State* L, int idx, unsigned int flags = 0, CompilationStats* stats = nullptr); CompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats = nullptr);
CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags = 0, CompilationStats* stats = nullptr); CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats = nullptr);
CompilationResult compile(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats = nullptr);
CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats = nullptr);
using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int instpos); using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int instpos);
@ -160,7 +279,7 @@ struct AssemblyOptions
Target target = Host; Target target = Host;
unsigned int flags = 0; CompilationOptions compilationOptions;
bool outputBinary = false; bool outputBinary = false;

View file

@ -16,11 +16,11 @@ namespace Luau
namespace CodeGen namespace CodeGen
{ {
struct AssemblyOptions; struct HostIrHooks;
struct IrBuilder struct IrBuilder
{ {
IrBuilder(); IrBuilder(const HostIrHooks& hostHooks);
void buildFunctionIr(Proto* proto); void buildFunctionIr(Proto* proto);
@ -54,6 +54,7 @@ struct IrBuilder
IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d); IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d);
IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e); IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e);
IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f); IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f);
IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f, IrOp g);
IrOp block(IrBlockKind kind); // Requested kind can be ignored if we are in an outlined sequence IrOp block(IrBlockKind kind); // Requested kind can be ignored if we are in an outlined sequence
IrOp blockAtInst(uint32_t index); IrOp blockAtInst(uint32_t index);
@ -64,13 +65,17 @@ struct IrBuilder
IrOp vmExit(uint32_t pcpos); IrOp vmExit(uint32_t pcpos);
const HostIrHooks& hostHooks;
bool inTerminatedBlock = false; bool inTerminatedBlock = false;
bool interruptRequested = false; bool interruptRequested = false;
bool activeFastcallFallback = false; bool activeFastcallFallback = false;
IrOp fastcallFallbackReturn; IrOp fastcallFallbackReturn;
int fastcallSkipTarget = -1;
// Force builder to skip source commands
int cmdSkipTarget = -1;
IrFunction function; IrFunction function;

View file

@ -31,7 +31,7 @@ enum
// * Rn - VM stack register slot, n in 0..254 // * Rn - VM stack register slot, n in 0..254
// * Kn - VM proto constant slot, n in 0..2^23-1 // * Kn - VM proto constant slot, n in 0..2^23-1
// * UPn - VM function upvalue slot, n in 0..199 // * UPn - VM function upvalue slot, n in 0..199
// * A, B, C, D, E are instruction arguments // * A, B, C, D, E, F, G are instruction arguments
enum class IrCmd : uint8_t enum class IrCmd : uint8_t
{ {
NOP, NOP,
@ -179,6 +179,10 @@ enum class IrCmd : uint8_t
// A: double // A: double
ABS_NUM, ABS_NUM,
// Get the sign of the argument (math.sign)
// A: double
SIGN_NUM,
// Add/Sub/Mul/Div/Idiv two vectors // Add/Sub/Mul/Div/Idiv two vectors
// A, B: TValue // A, B: TValue
ADD_VEC, ADD_VEC,
@ -290,6 +294,11 @@ enum class IrCmd : uint8_t
// C: block // C: block
TRY_CALL_FASTGETTM, TRY_CALL_FASTGETTM,
// Create new tagged userdata
// A: int (size)
// B: int (tag)
NEW_USERDATA,
// Convert integer into a double number // Convert integer into a double number
// A: int // A: int
INT_TO_NUM, INT_TO_NUM,
@ -321,13 +330,12 @@ enum class IrCmd : uint8_t
// This is used to recover after calling a variadic function // This is used to recover after calling a variadic function
ADJUST_STACK_TO_TOP, ADJUST_STACK_TO_TOP,
// Execute fastcall builtin function in-place // Execute fastcall builtin function with 1 argument in-place
// This is used for a few builtins that can have more than 1 result and cannot be represented as a regular instruction
// A: unsigned int (builtin id) // A: unsigned int (builtin id)
// B: Rn (result start) // B: Rn (result start)
// C: Rn (argument start) // C: Rn (first argument)
// D: Rn or Kn or undef (optional second argument) // D: int (result count)
// E: int (argument count)
// F: int (result count)
FASTCALL, FASTCALL,
// Call the fastcall builtin function // Call the fastcall builtin function
@ -335,8 +343,9 @@ enum class IrCmd : uint8_t
// B: Rn (result start) // B: Rn (result start)
// C: Rn (argument start) // C: Rn (argument start)
// D: Rn or Kn or undef (optional second argument) // D: Rn or Kn or undef (optional second argument)
// E: int (argument count or -1 to use all arguments up to stack top) // E: Rn or Kn or undef (optional third argument)
// F: int (result count or -1 to preserve all results and adjust stack top) // F: int (argument count or -1 to use all arguments up to stack top)
// G: int (result count or -1 to preserve all results and adjust stack top)
INVOKE_FASTCALL, INVOKE_FASTCALL,
// Check that fastcall builtin function invocation was successful (negative result count jumps to fallback) // Check that fastcall builtin function invocation was successful (negative result count jumps to fallback)
@ -460,6 +469,13 @@ enum class IrCmd : uint8_t
// When undef is specified instead of a block, execution is aborted on check failure // When undef is specified instead of a block, execution is aborted on check failure
CHECK_BUFFER_LEN, CHECK_BUFFER_LEN,
// Guard against userdata tag mismatch
// A: pointer (userdata)
// B: int (tag)
// C: block/vmexit/undef
// When undef is specified instead of a block, execution is aborted on check failure
CHECK_USERDATA_TAG,
// Special operations // Special operations
// Check interrupt handler // Check interrupt handler
@ -857,6 +873,7 @@ struct IrInst
IrOp d; IrOp d;
IrOp e; IrOp e;
IrOp f; IrOp f;
IrOp g;
uint32_t lastUse = 0; uint32_t lastUse = 0;
uint16_t useCount = 0; uint16_t useCount = 0;
@ -911,6 +928,7 @@ struct IrInstHash
h = mix(h, key.d); h = mix(h, key.d);
h = mix(h, key.e); h = mix(h, key.e);
h = mix(h, key.f); h = mix(h, key.f);
h = mix(h, key.g);
// MurmurHash2 tail // MurmurHash2 tail
h ^= h >> 13; h ^= h >> 13;
@ -925,7 +943,7 @@ struct IrInstEq
{ {
bool operator()(const IrInst& a, const IrInst& b) const bool operator()(const IrInst& a, const IrInst& b) const
{ {
return a.cmd == b.cmd && a.a == b.a && a.b == b.b && a.c == b.c && a.d == b.d && a.e == b.e && a.f == b.f; return a.cmd == b.cmd && a.a == b.a && a.b == b.b && a.c == b.c && a.d == b.d && a.e == b.e && a.f == b.f && a.g == b.g;
} }
}; };

View file

@ -31,9 +31,11 @@ void toString(IrToStringContext& ctx, IrOp op);
void toString(std::string& result, IrConst constant); void toString(std::string& result, IrConst constant);
const char* getBytecodeTypeName(uint8_t type); const char* getBytecodeTypeName_DEPRECATED(uint8_t type);
const char* getBytecodeTypeName(uint8_t type, const char* const* userdataTypes);
void toString(std::string& result, const BytecodeTypes& bcTypes); void toString_DEPRECATED(std::string& result, const BytecodeTypes& bcTypes);
void toString(std::string& result, const BytecodeTypes& bcTypes, const char* const* userdataTypes);
void toStringDetailed( void toStringDetailed(
IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, const IrInst& inst, uint32_t instIdx, IncludeUseInfo includeUseInfo); IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, const IrInst& inst, uint32_t instIdx, IncludeUseInfo includeUseInfo);

View file

@ -11,6 +11,7 @@ namespace CodeGen
{ {
struct IrBuilder; struct IrBuilder;
enum class HostMetamethod;
inline bool isJumpD(LuauOpcode op) inline bool isJumpD(LuauOpcode op)
{ {
@ -63,6 +64,7 @@ inline bool isFastCall(LuauOpcode op)
case LOP_FASTCALL1: case LOP_FASTCALL1:
case LOP_FASTCALL2: case LOP_FASTCALL2:
case LOP_FASTCALL2K: case LOP_FASTCALL2K:
case LOP_FASTCALL3:
return true; return true;
default: default:
@ -129,6 +131,7 @@ inline bool isNonTerminatingJump(IrCmd cmd)
case IrCmd::CHECK_NODE_NO_NEXT: case IrCmd::CHECK_NODE_NO_NEXT:
case IrCmd::CHECK_NODE_VALUE: case IrCmd::CHECK_NODE_VALUE:
case IrCmd::CHECK_BUFFER_LEN: case IrCmd::CHECK_BUFFER_LEN:
case IrCmd::CHECK_USERDATA_TAG:
return true; return true;
default: default:
break; break;
@ -168,6 +171,7 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::ROUND_NUM: case IrCmd::ROUND_NUM:
case IrCmd::SQRT_NUM: case IrCmd::SQRT_NUM:
case IrCmd::ABS_NUM: case IrCmd::ABS_NUM:
case IrCmd::SIGN_NUM:
case IrCmd::ADD_VEC: case IrCmd::ADD_VEC:
case IrCmd::SUB_VEC: case IrCmd::SUB_VEC:
case IrCmd::MUL_VEC: case IrCmd::MUL_VEC:
@ -182,6 +186,7 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::DUP_TABLE: case IrCmd::DUP_TABLE:
case IrCmd::TRY_NUM_TO_INDEX: case IrCmd::TRY_NUM_TO_INDEX:
case IrCmd::TRY_CALL_FASTGETTM: case IrCmd::TRY_CALL_FASTGETTM:
case IrCmd::NEW_USERDATA:
case IrCmd::INT_TO_NUM: case IrCmd::INT_TO_NUM:
case IrCmd::UINT_TO_NUM: case IrCmd::UINT_TO_NUM:
case IrCmd::NUM_TO_INT: case IrCmd::NUM_TO_INT:
@ -241,6 +246,12 @@ IrValueKind getCmdValueKind(IrCmd cmd);
bool isGCO(uint8_t tag); bool isGCO(uint8_t tag);
// Optional bit has to be cleared at call site, otherwise, this will return 'false' for 'userdata?'
bool isUserdataBytecodeType(uint8_t ty);
bool isCustomUserdataBytecodeType(uint8_t ty);
HostMetamethod tmToHostMetamethod(int tm);
// Manually add or remove use of an operand // Manually add or remove use of an operand
void addUse(IrFunction& function, IrOp op); void addUse(IrFunction& function, IrOp op);
void removeUse(IrFunction& function, IrOp op); void removeUse(IrFunction& function, IrOp op);

View file

@ -4,7 +4,7 @@
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/IrData.h" #include "Luau/IrData.h"
LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(LuauCodegenFastcall3)
namespace Luau namespace Luau
{ {
@ -112,12 +112,48 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i
visitor.useRange(vmRegOp(inst.a), function.intOp(inst.b)); visitor.useRange(vmRegOp(inst.a), function.intOp(inst.b));
break; break;
// TODO: FASTCALL is more restrictive than INVOKE_FASTCALL; we should either determine the exact semantics, or rework it
case IrCmd::FASTCALL: case IrCmd::FASTCALL:
case IrCmd::INVOKE_FASTCALL: if (FFlag::LuauCodegenFastcall3)
if (int count = function.intOp(inst.e); count != -1)
{ {
if (count >= 3) visitor.use(inst.c);
if (int nresults = function.intOp(inst.d); nresults != -1)
visitor.defRange(vmRegOp(inst.b), nresults);
}
else
{
if (int count = function.intOp(inst.e); count != -1)
{
if (count >= 3)
{
CODEGEN_ASSERT(inst.d.kind == IrOpKind::VmReg && vmRegOp(inst.d) == vmRegOp(inst.c) + 1);
visitor.useRange(vmRegOp(inst.c), count);
}
else
{
if (count >= 1)
visitor.use(inst.c);
if (count >= 2)
visitor.maybeUse(inst.d); // Argument can also be a VmConst
}
}
else
{
visitor.useVarargs(vmRegOp(inst.c));
}
// Multiple return sequences (count == -1) are defined by ADJUST_STACK_TO_REG
if (int count = function.intOp(inst.f); count != -1)
visitor.defRange(vmRegOp(inst.b), count);
}
break;
case IrCmd::INVOKE_FASTCALL:
if (int count = function.intOp(FFlag::LuauCodegenFastcall3 ? inst.f : inst.e); count != -1)
{
// Only LOP_FASTCALL3 lowering is allowed to have third optional argument
if (count >= 3 && (!FFlag::LuauCodegenFastcall3 || inst.e.kind == IrOpKind::Undef))
{ {
CODEGEN_ASSERT(inst.d.kind == IrOpKind::VmReg && vmRegOp(inst.d) == vmRegOp(inst.c) + 1); CODEGEN_ASSERT(inst.d.kind == IrOpKind::VmReg && vmRegOp(inst.d) == vmRegOp(inst.c) + 1);
@ -130,6 +166,9 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i
if (count >= 2) if (count >= 2)
visitor.maybeUse(inst.d); // Argument can also be a VmConst visitor.maybeUse(inst.d); // Argument can also be a VmConst
if (FFlag::LuauCodegenFastcall3 && count >= 3)
visitor.maybeUse(inst.e); // Argument can also be a VmConst
} }
} }
else else
@ -138,7 +177,7 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i
} }
// Multiple return sequences (count == -1) are defined by ADJUST_STACK_TO_REG // Multiple return sequences (count == -1) are defined by ADJUST_STACK_TO_REG
if (int count = function.intOp(inst.f); count != -1) if (int count = function.intOp(FFlag::LuauCodegenFastcall3 ? inst.g : inst.f); count != -1)
visitor.defRange(vmRegOp(inst.b), count); visitor.defRange(vmRegOp(inst.b), count);
break; break;
case IrCmd::FORGLOOP: case IrCmd::FORGLOOP:
@ -188,15 +227,8 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i
visitor.def(inst.b); visitor.def(inst.b);
break; break;
case IrCmd::FALLBACK_FORGPREP: case IrCmd::FALLBACK_FORGPREP:
if (FFlag::LuauCodegenRemoveDeadStores5) // This instruction doesn't always redefine Rn, Rn+1, Rn+2, so we have to mark it as implicit use
{ visitor.useRange(vmRegOp(inst.b), 3);
// This instruction doesn't always redefine Rn, Rn+1, Rn+2, so we have to mark it as implicit use
visitor.useRange(vmRegOp(inst.b), 3);
}
else
{
visitor.use(inst.b);
}
visitor.defRange(vmRegOp(inst.b), 3); visitor.defRange(vmRegOp(inst.b), 3);
break; break;
@ -214,12 +246,6 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i
visitor.use(inst.a); visitor.use(inst.a);
break; break;
// After optimizations with DebugLuauAbortingChecks enabled, CHECK_TAG Rn, tag, block instructions are generated
case IrCmd::CHECK_TAG:
if (!FFlag::LuauCodegenRemoveDeadStores5)
visitor.maybeUse(inst.a);
break;
default: default:
// All instructions which reference registers have to be handled explicitly // All instructions which reference registers have to be handled explicitly
CODEGEN_ASSERT(inst.a.kind != IrOpKind::VmReg); CODEGEN_ASSERT(inst.a.kind != IrOpKind::VmReg);
@ -228,6 +254,7 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i
CODEGEN_ASSERT(inst.d.kind != IrOpKind::VmReg); CODEGEN_ASSERT(inst.d.kind != IrOpKind::VmReg);
CODEGEN_ASSERT(inst.e.kind != IrOpKind::VmReg); CODEGEN_ASSERT(inst.e.kind != IrOpKind::VmReg);
CODEGEN_ASSERT(inst.f.kind != IrOpKind::VmReg); CODEGEN_ASSERT(inst.f.kind != IrOpKind::VmReg);
CODEGEN_ASSERT(inst.g.kind != IrOpKind::VmReg);
break; break;
} }
} }

View file

@ -16,7 +16,7 @@ namespace CodeGen
{ {
// This value is used in 'finishFunction' to mark the function that spans to the end of the whole code block // This value is used in 'finishFunction' to mark the function that spans to the end of the whole code block
static uint32_t kFullBlockFuncton = ~0u; static uint32_t kFullBlockFunction = ~0u;
class UnwindBuilder class UnwindBuilder
{ {
@ -52,11 +52,10 @@ public:
virtual void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> gpr, virtual void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> gpr,
const std::vector<X64::RegisterX64>& simd) = 0; const std::vector<X64::RegisterX64>& simd) = 0;
virtual size_t getSize() const = 0; virtual size_t getUnwindInfoSize(size_t blockSize) const = 0;
virtual size_t getFunctionCount() const = 0;
// This will place the unwinding data at the target address and might update values of some fields // This will place the unwinding data at the target address and might update values of some fields
virtual void finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const = 0; virtual size_t finalize(char* target, size_t offset, void* funcAddress, size_t blockSize) const = 0;
}; };
} // namespace CodeGen } // namespace CodeGen

View file

@ -33,10 +33,9 @@ public:
void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> gpr, void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> gpr,
const std::vector<X64::RegisterX64>& simd) override; const std::vector<X64::RegisterX64>& simd) override;
size_t getSize() const override; size_t getUnwindInfoSize(size_t blockSize = 0) const override;
size_t getFunctionCount() const override;
void finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const override; size_t finalize(char* target, size_t offset, void* funcAddress, size_t blockSize) const override;
private: private:
size_t beginOffset = 0; size_t beginOffset = 0;

View file

@ -53,10 +53,9 @@ public:
void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> gpr, void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> gpr,
const std::vector<X64::RegisterX64>& simd) override; const std::vector<X64::RegisterX64>& simd) override;
size_t getSize() const override; size_t getUnwindInfoSize(size_t blockSize = 0) const override;
size_t getFunctionCount() const override;
void finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const override; size_t finalize(char* target, size_t offset, void* funcAddress, size_t blockSize) const override;
private: private:
size_t beginOffset = 0; size_t beginOffset = 0;

View file

@ -826,7 +826,7 @@ void AssemblyBuilderX64::vcvtss2sd(OperandX64 dst, OperandX64 src1, OperandX64 s
else else
CODEGEN_ASSERT(src2.memSize == SizeX64::dword); CODEGEN_ASSERT(src2.memSize == SizeX64::dword);
placeAvx("vcvtsd2ss", dst, src1, src2, 0x5a, false, AVX_0F, AVX_F3); placeAvx("vcvtss2sd", dst, src1, src2, 0x5a, false, AVX_0F, AVX_F3);
} }
void AssemblyBuilderX64::vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode) void AssemblyBuilderX64::vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode)

View file

@ -2,35 +2,26 @@
#include "Luau/BytecodeAnalysis.h" #include "Luau/BytecodeAnalysis.h"
#include "Luau/BytecodeUtils.h" #include "Luau/BytecodeUtils.h"
#include "Luau/CodeGen.h"
#include "Luau/IrData.h" #include "Luau/IrData.h"
#include "Luau/IrUtils.h" #include "Luau/IrUtils.h"
#include "lobject.h" #include "lobject.h"
#include "lstate.h"
#include <algorithm> #include <algorithm>
LUAU_FASTFLAG(LuauCodegenDirectUserdataFlow) LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOps, false)
LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo load changes the format used by Codegen, same flag is used LUAU_FASTFLAGVARIABLE(LuauCodegenFastcall3, false)
LUAU_FASTFLAGVARIABLE(LuauCodegenTypeInfo, false) // New analysis is flagged separately
LUAU_FASTFLAG(LuauTypeInfoLookupImprovement)
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
{ {
static bool hasTypedParameters(Proto* proto)
{
CODEGEN_ASSERT(!FFlag::LuauLoadTypeInfo);
return proto->typeinfo && proto->numparams != 0;
}
template<typename T> template<typename T>
static T read(uint8_t* data, size_t& offset) static T read(uint8_t* data, size_t& offset)
{ {
CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo);
T result; T result;
memcpy(&result, data + offset, sizeof(T)); memcpy(&result, data + offset, sizeof(T));
offset += sizeof(T); offset += sizeof(T);
@ -40,8 +31,6 @@ static T read(uint8_t* data, size_t& offset)
static uint32_t readVarInt(uint8_t* data, size_t& offset) static uint32_t readVarInt(uint8_t* data, size_t& offset)
{ {
CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo);
uint32_t result = 0; uint32_t result = 0;
uint32_t shift = 0; uint32_t shift = 0;
@ -59,25 +48,15 @@ static uint32_t readVarInt(uint8_t* data, size_t& offset)
void loadBytecodeTypeInfo(IrFunction& function) void loadBytecodeTypeInfo(IrFunction& function)
{ {
CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo);
Proto* proto = function.proto; Proto* proto = function.proto;
if (FFlag::LuauTypeInfoLookupImprovement) if (!proto)
{ return;
if (!proto)
return;
}
else
{
if (!proto || !proto->typeinfo)
return;
}
BytecodeTypeInfo& typeInfo = function.bcTypeInfo; BytecodeTypeInfo& typeInfo = function.bcTypeInfo;
// If there is no typeinfo, we generate default values for arguments and upvalues // If there is no typeinfo, we generate default values for arguments and upvalues
if (FFlag::LuauTypeInfoLookupImprovement && !proto->typeinfo) if (!proto->typeinfo)
{ {
typeInfo.argumentTypes.resize(proto->numparams, LBC_TYPE_ANY); typeInfo.argumentTypes.resize(proto->numparams, LBC_TYPE_ANY);
typeInfo.upvalueTypes.resize(proto->nups, LBC_TYPE_ANY); typeInfo.upvalueTypes.resize(proto->nups, LBC_TYPE_ANY);
@ -91,8 +70,6 @@ void loadBytecodeTypeInfo(IrFunction& function)
uint32_t upvalCount = readVarInt(data, offset); uint32_t upvalCount = readVarInt(data, offset);
uint32_t localCount = readVarInt(data, offset); uint32_t localCount = readVarInt(data, offset);
CODEGEN_ASSERT(upvalCount == unsigned(proto->nups));
if (typeSize != 0) if (typeSize != 0)
{ {
uint8_t* types = (uint8_t*)data + offset; uint8_t* types = (uint8_t*)data + offset;
@ -110,6 +87,8 @@ void loadBytecodeTypeInfo(IrFunction& function)
if (upvalCount != 0) if (upvalCount != 0)
{ {
CODEGEN_ASSERT(upvalCount == unsigned(proto->nups));
typeInfo.upvalueTypes.resize(upvalCount); typeInfo.upvalueTypes.resize(upvalCount);
uint8_t* types = (uint8_t*)data + offset; uint8_t* types = (uint8_t*)data + offset;
@ -137,8 +116,6 @@ void loadBytecodeTypeInfo(IrFunction& function)
static void prepareRegTypeInfoLookups(BytecodeTypeInfo& typeInfo) static void prepareRegTypeInfoLookups(BytecodeTypeInfo& typeInfo)
{ {
CODEGEN_ASSERT(FFlag::LuauTypeInfoLookupImprovement);
// Sort by register first, then by end PC // Sort by register first, then by end PC
std::sort(typeInfo.regTypes.begin(), typeInfo.regTypes.end(), [](const BytecodeRegTypeInfo& a, const BytecodeRegTypeInfo& b) { std::sort(typeInfo.regTypes.begin(), typeInfo.regTypes.end(), [](const BytecodeRegTypeInfo& a, const BytecodeRegTypeInfo& b) {
if (a.reg != b.reg) if (a.reg != b.reg)
@ -171,47 +148,30 @@ static void prepareRegTypeInfoLookups(BytecodeTypeInfo& typeInfo)
static BytecodeRegTypeInfo* findRegType(BytecodeTypeInfo& info, uint8_t reg, int pc) static BytecodeRegTypeInfo* findRegType(BytecodeTypeInfo& info, uint8_t reg, int pc)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo); auto b = info.regTypes.begin() + info.regTypeOffsets[reg];
auto e = info.regTypes.begin() + info.regTypeOffsets[reg + 1];
if (FFlag::LuauTypeInfoLookupImprovement)
{
auto b = info.regTypes.begin() + info.regTypeOffsets[reg];
auto e = info.regTypes.begin() + info.regTypeOffsets[reg + 1];
// Doen't have info
if (b == e)
return nullptr;
// No info after the last live range
if (pc >= (e - 1)->endpc)
return nullptr;
for (auto it = b; it != e; ++it)
{
CODEGEN_ASSERT(it->reg == reg);
if (pc >= it->startpc && pc < it->endpc)
return &*it;
}
// Doen't have info
if (b == e)
return nullptr; return nullptr;
}
else
{
for (BytecodeRegTypeInfo& el : info.regTypes)
{
if (reg == el.reg && pc >= el.startpc && pc < el.endpc)
return &el;
}
// No info after the last live range
if (pc >= (e - 1)->endpc)
return nullptr; return nullptr;
for (auto it = b; it != e; ++it)
{
CODEGEN_ASSERT(it->reg == reg);
if (pc >= it->startpc && pc < it->endpc)
return &*it;
} }
return nullptr;
} }
static void refineRegType(BytecodeTypeInfo& info, uint8_t reg, int pc, uint8_t ty) static void refineRegType(BytecodeTypeInfo& info, uint8_t reg, int pc, uint8_t ty)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo);
if (ty != LBC_TYPE_ANY) if (ty != LBC_TYPE_ANY)
{ {
if (BytecodeRegTypeInfo* regType = findRegType(info, reg, pc)) if (BytecodeRegTypeInfo* regType = findRegType(info, reg, pc))
@ -220,7 +180,7 @@ static void refineRegType(BytecodeTypeInfo& info, uint8_t reg, int pc, uint8_t t
if (regType->type == LBC_TYPE_ANY) if (regType->type == LBC_TYPE_ANY)
regType->type = ty; regType->type = ty;
} }
else if (FFlag::LuauTypeInfoLookupImprovement && reg < info.argumentTypes.size()) else if (reg < info.argumentTypes.size())
{ {
if (info.argumentTypes[reg] == LBC_TYPE_ANY) if (info.argumentTypes[reg] == LBC_TYPE_ANY)
info.argumentTypes[reg] = ty; info.argumentTypes[reg] = ty;
@ -230,8 +190,6 @@ static void refineRegType(BytecodeTypeInfo& info, uint8_t reg, int pc, uint8_t t
static void refineUpvalueType(BytecodeTypeInfo& info, int up, uint8_t ty) static void refineUpvalueType(BytecodeTypeInfo& info, int up, uint8_t ty)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo);
if (ty != LBC_TYPE_ANY) if (ty != LBC_TYPE_ANY)
{ {
if (size_t(up) < info.upvalueTypes.size()) if (size_t(up) < info.upvalueTypes.size())
@ -558,6 +516,49 @@ static void applyBuiltinCall(int bfid, BytecodeTypes& types)
} }
} }
static HostMetamethod opcodeToHostMetamethod(LuauOpcode op)
{
switch (op)
{
case LOP_ADD:
return HostMetamethod::Add;
case LOP_SUB:
return HostMetamethod::Sub;
case LOP_MUL:
return HostMetamethod::Mul;
case LOP_DIV:
return HostMetamethod::Div;
case LOP_IDIV:
return HostMetamethod::Idiv;
case LOP_MOD:
return HostMetamethod::Mod;
case LOP_POW:
return HostMetamethod::Pow;
case LOP_ADDK:
return HostMetamethod::Add;
case LOP_SUBK:
return HostMetamethod::Sub;
case LOP_MULK:
return HostMetamethod::Mul;
case LOP_DIVK:
return HostMetamethod::Div;
case LOP_IDIVK:
return HostMetamethod::Idiv;
case LOP_MODK:
return HostMetamethod::Mod;
case LOP_POWK:
return HostMetamethod::Pow;
case LOP_SUBRK:
return HostMetamethod::Sub;
case LOP_DIVRK:
return HostMetamethod::Div;
default:
CODEGEN_ASSERT(!"opcode is not assigned to a host metamethod");
}
return HostMetamethod::Add;
}
void buildBytecodeBlocks(IrFunction& function, const std::vector<uint8_t>& jumpTargets) void buildBytecodeBlocks(IrFunction& function, const std::vector<uint8_t>& jumpTargets)
{ {
Proto* proto = function.proto; Proto* proto = function.proto;
@ -607,15 +608,14 @@ void buildBytecodeBlocks(IrFunction& function, const std::vector<uint8_t>& jumpT
} }
} }
void analyzeBytecodeTypes(IrFunction& function) void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
{ {
Proto* proto = function.proto; Proto* proto = function.proto;
CODEGEN_ASSERT(proto); CODEGEN_ASSERT(proto);
BytecodeTypeInfo& bcTypeInfo = function.bcTypeInfo; BytecodeTypeInfo& bcTypeInfo = function.bcTypeInfo;
if (FFlag::LuauTypeInfoLookupImprovement) prepareRegTypeInfoLookups(bcTypeInfo);
prepareRegTypeInfoLookups(bcTypeInfo);
// Setup our current knowledge of type tags based on arguments // Setup our current knowledge of type tags based on arguments
uint8_t regTags[256]; uint8_t regTags[256];
@ -631,48 +631,31 @@ void analyzeBytecodeTypes(IrFunction& function)
// At the block start, reset or knowledge to the starting state // At the block start, reset or knowledge to the starting state
// In the future we might be able to propagate some info between the blocks as well // In the future we might be able to propagate some info between the blocks as well
if (FFlag::LuauLoadTypeInfo) for (size_t i = 0; i < bcTypeInfo.argumentTypes.size(); i++)
{ {
for (size_t i = 0; i < bcTypeInfo.argumentTypes.size(); i++) uint8_t et = bcTypeInfo.argumentTypes[i];
{
uint8_t et = bcTypeInfo.argumentTypes[i];
// TODO: if argument is optional, this might force a VM exit unnecessarily // TODO: if argument is optional, this might force a VM exit unnecessarily
regTags[i] = et & ~LBC_TYPE_OPTIONAL_BIT; regTags[i] = et & ~LBC_TYPE_OPTIONAL_BIT;
}
}
else
{
if (hasTypedParameters(proto))
{
for (int i = 0; i < proto->numparams; ++i)
{
uint8_t et = proto->typeinfo[2 + i];
// TODO: if argument is optional, this might force a VM exit unnecessarily
regTags[i] = et & ~LBC_TYPE_OPTIONAL_BIT;
}
}
} }
for (int i = proto->numparams; i < proto->maxstacksize; ++i) for (int i = proto->numparams; i < proto->maxstacksize; ++i)
regTags[i] = LBC_TYPE_ANY; regTags[i] = LBC_TYPE_ANY;
LuauBytecodeType knownNextCallResult = LBC_TYPE_ANY;
for (int i = block.startpc; i <= block.finishpc;) for (int i = block.startpc; i <= block.finishpc;)
{ {
const Instruction* pc = &proto->code[i]; const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc)); LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
if (FFlag::LuauCodegenTypeInfo) // Assign known register types from local type information
// TODO: this is an expensive walk for each instruction
// TODO: it's best to lookup when register is actually used in the instruction
for (BytecodeRegTypeInfo& el : bcTypeInfo.regTypes)
{ {
// Assign known register types from local type information if (el.type != LBC_TYPE_ANY && i >= el.startpc && i < el.endpc)
// TODO: this is an expensive walk for each instruction regTags[el.reg] = el.type;
// TODO: it's best to lookup when register is actually used in the instruction
for (BytecodeRegTypeInfo& el : bcTypeInfo.regTypes)
{
if (el.type != LBC_TYPE_ANY && i >= el.startpc && i < el.endpc)
regTags[el.reg] = el.type;
}
} }
BytecodeTypes& bcType = function.bcTypes[i]; BytecodeTypes& bcType = function.bcTypes[i];
@ -694,8 +677,7 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = LBC_TYPE_BOOLEAN; regTags[ra] = LBC_TYPE_BOOLEAN;
bcType.result = regTags[ra]; bcType.result = regTags[ra];
if (FFlag::LuauCodegenTypeInfo) refineRegType(bcTypeInfo, ra, i, bcType.result);
refineRegType(bcTypeInfo, ra, i, bcType.result);
break; break;
} }
case LOP_LOADN: case LOP_LOADN:
@ -704,8 +686,7 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
bcType.result = regTags[ra]; bcType.result = regTags[ra];
if (FFlag::LuauCodegenTypeInfo) refineRegType(bcTypeInfo, ra, i, bcType.result);
refineRegType(bcTypeInfo, ra, i, bcType.result);
break; break;
} }
case LOP_LOADK: case LOP_LOADK:
@ -716,8 +697,7 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = bcType.a; regTags[ra] = bcType.a;
bcType.result = regTags[ra]; bcType.result = regTags[ra];
if (FFlag::LuauCodegenTypeInfo) refineRegType(bcTypeInfo, ra, i, bcType.result);
refineRegType(bcTypeInfo, ra, i, bcType.result);
break; break;
} }
case LOP_LOADKX: case LOP_LOADKX:
@ -728,8 +708,7 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = bcType.a; regTags[ra] = bcType.a;
bcType.result = regTags[ra]; bcType.result = regTags[ra];
if (FFlag::LuauCodegenTypeInfo) refineRegType(bcTypeInfo, ra, i, bcType.result);
refineRegType(bcTypeInfo, ra, i, bcType.result);
break; break;
} }
case LOP_MOVE: case LOP_MOVE:
@ -740,8 +719,7 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = regTags[rb]; regTags[ra] = regTags[rb];
bcType.result = regTags[ra]; bcType.result = regTags[ra];
if (FFlag::LuauCodegenTypeInfo) refineRegType(bcTypeInfo, ra, i, bcType.result);
refineRegType(bcTypeInfo, ra, i, bcType.result);
break; break;
} }
case LOP_GETTABLE: case LOP_GETTABLE:
@ -771,10 +749,51 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = LBC_TYPE_ANY; regTags[ra] = LBC_TYPE_ANY;
// Assuming that vector component is being indexed if (FFlag::LuauCodegenUserdataOps)
// TODO: check what key is used {
if (bcType.a == LBC_TYPE_VECTOR) TString* str = gco2ts(function.proto->k[kc].value.gc);
regTags[ra] = LBC_TYPE_NUMBER; const char* field = getstr(str);
if (bcType.a == LBC_TYPE_VECTOR)
{
if (str->len == 1)
{
// Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z"
char ch = field[0] | ' ';
if (ch == 'x' || ch == 'y' || ch == 'z')
regTags[ra] = LBC_TYPE_NUMBER;
}
if (regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType)
regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len);
}
else if (isCustomUserdataBytecodeType(bcType.a))
{
if (regTags[ra] == LBC_TYPE_ANY && hostHooks.userdataAccessBytecodeType)
regTags[ra] = hostHooks.userdataAccessBytecodeType(bcType.a, field, str->len);
}
}
else
{
if (bcType.a == LBC_TYPE_VECTOR)
{
TString* str = gco2ts(function.proto->k[kc].value.gc);
const char* field = getstr(str);
if (str->len == 1)
{
// Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z"
char ch = field[0] | ' ';
if (ch == 'x' || ch == 'y' || ch == 'z')
regTags[ra] = LBC_TYPE_NUMBER;
}
if (regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType)
regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len);
}
}
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -810,6 +829,9 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -839,6 +861,11 @@ void analyzeBytecodeTypes(IrFunction& function)
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
} }
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
{
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
}
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -857,6 +884,9 @@ void analyzeBytecodeTypes(IrFunction& function)
if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER) if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -877,6 +907,9 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -906,6 +939,11 @@ void analyzeBytecodeTypes(IrFunction& function)
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
} }
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
{
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
}
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -924,6 +962,9 @@ void analyzeBytecodeTypes(IrFunction& function)
if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER) if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -943,6 +984,9 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -970,6 +1014,11 @@ void analyzeBytecodeTypes(IrFunction& function)
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
} }
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType &&
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
{
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
}
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -998,6 +1047,8 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR) else if (bcType.a == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && isCustomUserdataBytecodeType(bcType.a))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, LBC_TYPE_ANY, HostMetamethod::Minus);
bcType.result = regTags[ra]; bcType.result = regTags[ra];
break; break;
@ -1036,8 +1087,7 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra + 3] = bcType.c; regTags[ra + 3] = bcType.c;
regTags[ra] = bcType.result; regTags[ra] = bcType.result;
if (FFlag::LuauCodegenTypeInfo) refineRegType(bcTypeInfo, ra, i, bcType.result);
refineRegType(bcTypeInfo, ra, i, bcType.result);
break; break;
} }
case LOP_FASTCALL1: case LOP_FASTCALL1:
@ -1055,8 +1105,7 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[LUAU_INSN_B(*pc)] = bcType.a; regTags[LUAU_INSN_B(*pc)] = bcType.a;
regTags[ra] = bcType.result; regTags[ra] = bcType.result;
if (FFlag::LuauCodegenTypeInfo) refineRegType(bcTypeInfo, ra, i, bcType.result);
refineRegType(bcTypeInfo, ra, i, bcType.result);
break; break;
} }
case LOP_FASTCALL2: case LOP_FASTCALL2:
@ -1074,8 +1123,29 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[int(pc[1])] = bcType.b; regTags[int(pc[1])] = bcType.b;
regTags[ra] = bcType.result; regTags[ra] = bcType.result;
if (FFlag::LuauCodegenTypeInfo) refineRegType(bcTypeInfo, ra, i, bcType.result);
refineRegType(bcTypeInfo, ra, i, bcType.result); break;
}
case LOP_FASTCALL3:
{
CODEGEN_ASSERT(FFlag::LuauCodegenFastcall3);
int bfid = LUAU_INSN_A(*pc);
int skip = LUAU_INSN_C(*pc);
int aux = pc[1];
Instruction call = pc[skip + 1];
CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
int ra = LUAU_INSN_A(call);
applyBuiltinCall(bfid, bcType);
regTags[LUAU_INSN_B(*pc)] = bcType.a;
regTags[aux & 0xff] = bcType.b;
regTags[(aux >> 8) & 0xff] = bcType.c;
regTags[ra] = bcType.result;
refineRegType(bcTypeInfo, ra, i, bcType.result);
break; break;
} }
case LOP_FORNPREP: case LOP_FORNPREP:
@ -1086,12 +1156,9 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra + 1] = LBC_TYPE_NUMBER; regTags[ra + 1] = LBC_TYPE_NUMBER;
regTags[ra + 2] = LBC_TYPE_NUMBER; regTags[ra + 2] = LBC_TYPE_NUMBER;
if (FFlag::LuauCodegenTypeInfo) refineRegType(bcTypeInfo, ra, i, regTags[ra]);
{ refineRegType(bcTypeInfo, ra + 1, i, regTags[ra + 1]);
refineRegType(bcTypeInfo, ra, i, regTags[ra]); refineRegType(bcTypeInfo, ra + 2, i, regTags[ra + 2]);
refineRegType(bcTypeInfo, ra + 1, i, regTags[ra + 1]);
refineRegType(bcTypeInfo, ra + 2, i, regTags[ra + 2]);
}
break; break;
} }
case LOP_FORNLOOP: case LOP_FORNLOOP:
@ -1121,61 +1188,88 @@ void analyzeBytecodeTypes(IrFunction& function)
} }
case LOP_NAMECALL: case LOP_NAMECALL:
{ {
if (FFlag::LuauCodegenDirectUserdataFlow) int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
uint32_t kc = pc[1];
bcType.a = regTags[rb];
bcType.b = getBytecodeConstantTag(proto, kc);
// While namecall might result in a callable table, we assume the function fast path
regTags[ra] = LBC_TYPE_FUNCTION;
// Namecall places source register into target + 1
regTags[ra + 1] = bcType.a;
bcType.result = LBC_TYPE_FUNCTION;
if (FFlag::LuauCodegenUserdataOps)
{ {
int ra = LUAU_INSN_A(*pc); TString* str = gco2ts(function.proto->k[kc].value.gc);
int rb = LUAU_INSN_B(*pc); const char* field = getstr(str);
uint32_t kc = pc[1];
bcType.a = regTags[rb]; if (bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType)
bcType.b = getBytecodeConstantTag(proto, kc); knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len));
else if (isCustomUserdataBytecodeType(bcType.a) && hostHooks.userdataNamecallBytecodeType)
// While namecall might result in a callable table, we assume the function fast path knownNextCallResult = LuauBytecodeType(hostHooks.userdataNamecallBytecodeType(bcType.a, field, str->len));
regTags[ra] = LBC_TYPE_FUNCTION;
// Namecall places source register into target + 1
regTags[ra + 1] = bcType.a;
bcType.result = LBC_TYPE_FUNCTION;
} }
else
{
if (bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType)
{
TString* str = gco2ts(function.proto->k[kc].value.gc);
const char* field = getstr(str);
knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len));
}
}
break;
}
case LOP_CALL:
{
int ra = LUAU_INSN_A(*pc);
if (knownNextCallResult != LBC_TYPE_ANY)
{
bcType.result = knownNextCallResult;
knownNextCallResult = LBC_TYPE_ANY;
regTags[ra] = bcType.result;
}
refineRegType(bcTypeInfo, ra, i, bcType.result);
break; break;
} }
case LOP_GETUPVAL: case LOP_GETUPVAL:
{ {
if (FFlag::LuauCodegenTypeInfo) int ra = LUAU_INSN_A(*pc);
int up = LUAU_INSN_B(*pc);
bcType.a = LBC_TYPE_ANY;
if (size_t(up) < bcTypeInfo.upvalueTypes.size())
{ {
int ra = LUAU_INSN_A(*pc); uint8_t et = bcTypeInfo.upvalueTypes[up];
int up = LUAU_INSN_B(*pc);
bcType.a = LBC_TYPE_ANY; // TODO: if argument is optional, this might force a VM exit unnecessarily
bcType.a = et & ~LBC_TYPE_OPTIONAL_BIT;
if (size_t(up) < bcTypeInfo.upvalueTypes.size())
{
uint8_t et = bcTypeInfo.upvalueTypes[up];
// TODO: if argument is optional, this might force a VM exit unnecessarily
bcType.a = et & ~LBC_TYPE_OPTIONAL_BIT;
}
regTags[ra] = bcType.a;
bcType.result = regTags[ra];
} }
regTags[ra] = bcType.a;
bcType.result = regTags[ra];
break; break;
} }
case LOP_SETUPVAL: case LOP_SETUPVAL:
{ {
if (FFlag::LuauCodegenTypeInfo) int ra = LUAU_INSN_A(*pc);
{ int up = LUAU_INSN_B(*pc);
int ra = LUAU_INSN_A(*pc);
int up = LUAU_INSN_B(*pc);
refineUpvalueType(bcTypeInfo, up, regTags[ra]); refineUpvalueType(bcTypeInfo, up, regTags[ra]);
}
break; break;
} }
case LOP_GETGLOBAL: case LOP_GETGLOBAL:
case LOP_SETGLOBAL: case LOP_SETGLOBAL:
case LOP_CALL:
case LOP_RETURN: case LOP_RETURN:
case LOP_JUMP: case LOP_JUMP:
case LOP_JUMPBACK: case LOP_JUMPBACK:

View file

@ -8,6 +8,8 @@
#include "lobject.h" #include "lobject.h"
#include "lstate.h" #include "lstate.h"
LUAU_FASTFLAG(LuauNativeAttribute)
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
@ -56,7 +58,10 @@ std::vector<FunctionBytecodeSummary> summarizeBytecode(lua_State* L, int idx, un
Proto* root = clvalue(func)->l.p; Proto* root = clvalue(func)->l.p;
std::vector<Proto*> protos; std::vector<Proto*> protos;
gatherFunctions(protos, root, CodeGen_ColdFunctions); if (FFlag::LuauNativeAttribute)
gatherFunctions(protos, root, CodeGen_ColdFunctions, root->flags & LPF_NATIVE_FUNCTION);
else
gatherFunctions_DEPRECATED(protos, root, CodeGen_ColdFunctions);
std::vector<FunctionBytecodeSummary> summaries; std::vector<FunctionBytecodeSummary> summaries;
summaries.reserve(protos.size()); summaries.reserve(protos.size());

View file

@ -7,7 +7,7 @@
#include <string.h> #include <string.h>
#include <stdlib.h> #include <stdlib.h>
#if defined(_WIN32) && defined(_M_X64) #if defined(_WIN32) && defined(CODEGEN_TARGET_X64)
#ifndef WIN32_LEAN_AND_MEAN #ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN
@ -26,7 +26,7 @@ extern "C" void __deregister_frame(const void*) __attribute__((weak));
extern "C" void __unw_add_dynamic_fde() __attribute__((weak)); extern "C" void __unw_add_dynamic_fde() __attribute__((weak));
#endif #endif
#if defined(__APPLE__) && defined(__aarch64__) #if defined(__APPLE__) && defined(CODEGEN_TARGET_A64)
#include <sys/sysctl.h> #include <sys/sysctl.h>
#include <mach-o/loader.h> #include <mach-o/loader.h>
#include <dlfcn.h> #include <dlfcn.h>
@ -48,7 +48,7 @@ namespace Luau
namespace CodeGen namespace CodeGen
{ {
#if defined(__APPLE__) && defined(__aarch64__) #if defined(__APPLE__) && defined(CODEGEN_TARGET_A64)
static int findDynamicUnwindSections(uintptr_t addr, unw_dynamic_unwind_sections_t* info) static int findDynamicUnwindSections(uintptr_t addr, unw_dynamic_unwind_sections_t* info)
{ {
// Define a minimal mach header for JIT'd code. // Define a minimal mach header for JIT'd code.
@ -102,17 +102,17 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz
UnwindBuilder* unwind = (UnwindBuilder*)context; UnwindBuilder* unwind = (UnwindBuilder*)context;
// All unwinding related data is placed together at the start of the block // All unwinding related data is placed together at the start of the block
size_t unwindSize = unwind->getSize(); size_t unwindSize = unwind->getUnwindInfoSize(blockSize);
unwindSize = (unwindSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1); // Match code allocator alignment unwindSize = (unwindSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1); // Match code allocator alignment
CODEGEN_ASSERT(blockSize >= unwindSize); CODEGEN_ASSERT(blockSize >= unwindSize);
char* unwindData = (char*)block; char* unwindData = (char*)block;
unwind->finalize(unwindData, unwindSize, block, blockSize); [[maybe_unused]] size_t functionCount = unwind->finalize(unwindData, unwindSize, block, blockSize);
#if defined(_WIN32) && defined(_M_X64) #if defined(_WIN32) && defined(CODEGEN_TARGET_X64)
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM)
if (!RtlAddFunctionTable((RUNTIME_FUNCTION*)block, uint32_t(unwind->getFunctionCount()), uintptr_t(block))) if (!RtlAddFunctionTable((RUNTIME_FUNCTION*)block, uint32_t(functionCount), uintptr_t(block)))
{ {
CODEGEN_ASSERT(!"Failed to allocate function table"); CODEGEN_ASSERT(!"Failed to allocate function table");
return nullptr; return nullptr;
@ -126,7 +126,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz
visitFdeEntries(unwindData, __register_frame); visitFdeEntries(unwindData, __register_frame);
#endif #endif
#if defined(__APPLE__) && defined(__aarch64__) #if defined(__APPLE__) && defined(CODEGEN_TARGET_A64)
// Starting from macOS 14, we need to register unwind section callback to state that our ABI doesn't require pointer authentication // Starting from macOS 14, we need to register unwind section callback to state that our ABI doesn't require pointer authentication
// This might conflict with other JITs that do the same; unfortunately this is the best we can do for now. // This might conflict with other JITs that do the same; unfortunately this is the best we can do for now.
static unw_add_find_dynamic_unwind_sections_t unw_add_find_dynamic_unwind_sections = static unw_add_find_dynamic_unwind_sections_t unw_add_find_dynamic_unwind_sections =
@ -141,7 +141,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz
void destroyBlockUnwindInfo(void* context, void* unwindData) void destroyBlockUnwindInfo(void* context, void* unwindData)
{ {
#if defined(_WIN32) && defined(_M_X64) #if defined(_WIN32) && defined(CODEGEN_TARGET_X64)
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM)
if (!RtlDeleteFunctionTable((RUNTIME_FUNCTION*)unwindData)) if (!RtlDeleteFunctionTable((RUNTIME_FUNCTION*)unwindData))
@ -161,12 +161,12 @@ void destroyBlockUnwindInfo(void* context, void* unwindData)
bool isUnwindSupported() bool isUnwindSupported()
{ {
#if defined(_WIN32) && defined(_M_X64) #if defined(_WIN32) && defined(CODEGEN_TARGET_X64)
return true; return true;
#elif defined(__ANDROID__) #elif defined(__ANDROID__)
// Current unwind information is not compatible with Android // Current unwind information is not compatible with Android
return false; return false;
#elif defined(__APPLE__) && defined(__aarch64__) #elif defined(__APPLE__) && defined(CODEGEN_TARGET_A64)
char ver[256]; char ver[256];
size_t verLength = sizeof(ver); size_t verLength = sizeof(ver);
// libunwind on macOS 12 and earlier (which maps to osrelease 21) assumes JIT frames use pointer authentication without a way to override that // libunwind on macOS 12 and earlier (which maps to osrelease 21) assumes JIT frames use pointer authentication without a way to override that

View file

@ -27,7 +27,7 @@
#include <memory> #include <memory>
#include <optional> #include <optional>
#if defined(__x86_64__) || defined(_M_X64) #if defined(CODEGEN_TARGET_X64)
#ifdef _MSC_VER #ifdef _MSC_VER
#include <intrin.h> // __cpuid #include <intrin.h> // __cpuid
#else #else
@ -35,7 +35,7 @@
#endif #endif
#endif #endif
#if defined(__aarch64__) #if defined(CODEGEN_TARGET_A64)
#ifdef __APPLE__ #ifdef __APPLE__
#include <sys/sysctl.h> #include <sys/sysctl.h>
#endif #endif
@ -58,186 +58,41 @@ LUAU_FASTINTVARIABLE(CodegenHeuristicsBlockLimit, 32'768) // 32 K
// Current value is based on some member variables being limited to 16 bits // Current value is based on some member variables being limited to 16 bits
LUAU_FASTINTVARIABLE(CodegenHeuristicsBlockInstructionLimit, 65'536) // 64 K LUAU_FASTINTVARIABLE(CodegenHeuristicsBlockInstructionLimit, 65'536) // 64 K
LUAU_FASTFLAG(LuauCodegenContext)
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
{ {
static const Instruction kCodeEntryInsn = LOP_NATIVECALL; std::string toString(const CodeGenCompilationResult& result)
void* gPerfLogContext = nullptr;
PerfLogFn gPerfLogFn = nullptr;
struct OldNativeProto
{ {
Proto* p; switch (result)
void* execdata;
uintptr_t exectarget;
};
// Additional data attached to Proto::execdata
// Guaranteed to be aligned to 16 bytes
struct ExtraExecData
{
size_t execDataSize;
size_t codeSize;
};
static int alignTo(int value, int align)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
CODEGEN_ASSERT(align > 0 && (align & (align - 1)) == 0);
return (value + (align - 1)) & ~(align - 1);
}
// Returns the size of execdata required to store all code offsets and ExtraExecData structure at proper alignment
// Always a multiple of 4 bytes
static int calculateExecDataSize(Proto* proto)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
int size = proto->sizecode * sizeof(uint32_t);
size = alignTo(size, 16);
size += sizeof(ExtraExecData);
return size;
}
// Returns pointer to the ExtraExecData inside the Proto::execdata
// Even though 'execdata' is a field in Proto, we require it to support cases where it's not attached to Proto during construction
ExtraExecData* getExtraExecData(Proto* proto, void* execdata)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
int size = proto->sizecode * sizeof(uint32_t);
size = alignTo(size, 16);
return reinterpret_cast<ExtraExecData*>(reinterpret_cast<char*>(execdata) + size);
}
static OldNativeProto createOldNativeProto(Proto* proto, const IrBuilder& ir)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
int execDataSize = calculateExecDataSize(proto);
CODEGEN_ASSERT(execDataSize % 4 == 0);
uint32_t* execData = new uint32_t[execDataSize / 4];
uint32_t instTarget = ir.function.entryLocation;
for (int i = 0; i < proto->sizecode; i++)
{ {
CODEGEN_ASSERT(ir.function.bcMapping[i].asmLocation >= instTarget); case CodeGenCompilationResult::Success:
return "Success";
execData[i] = ir.function.bcMapping[i].asmLocation - instTarget; case CodeGenCompilationResult::NothingToCompile:
return "NothingToCompile";
case CodeGenCompilationResult::NotNativeModule:
return "NotNativeModule";
case CodeGenCompilationResult::CodeGenNotInitialized:
return "CodeGenNotInitialized";
case CodeGenCompilationResult::CodeGenOverflowInstructionLimit:
return "CodeGenOverflowInstructionLimit";
case CodeGenCompilationResult::CodeGenOverflowBlockLimit:
return "CodeGenOverflowBlockLimit";
case CodeGenCompilationResult::CodeGenOverflowBlockInstructionLimit:
return "CodeGenOverflowBlockInstructionLimit";
case CodeGenCompilationResult::CodeGenAssemblerFinalizationFailure:
return "CodeGenAssemblerFinalizationFailure";
case CodeGenCompilationResult::CodeGenLoweringFailure:
return "CodeGenLoweringFailure";
case CodeGenCompilationResult::AllocationFailed:
return "AllocationFailed";
case CodeGenCompilationResult::Count:
return "Count";
} }
// Set first instruction offset to 0 so that entering this function still executes any generated entry code. CODEGEN_ASSERT(false);
execData[0] = 0; return "";
ExtraExecData* extra = getExtraExecData(proto, execData);
memset(extra, 0, sizeof(ExtraExecData));
extra->execDataSize = execDataSize;
// entry target will be relocated when assembly is finalized
return {proto, execData, instTarget};
}
static void destroyExecData(void* execdata)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
delete[] static_cast<uint32_t*>(execdata);
}
static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
CODEGEN_ASSERT(p->source);
const char* source = getstr(p->source);
source = (source[0] == '=' || source[0] == '@') ? source + 1 : "[string]";
char name[256];
snprintf(name, sizeof(name), "<luau> %s:%d %s", source, p->linedefined, p->debugname ? getstr(p->debugname) : "");
if (gPerfLogFn)
gPerfLogFn(gPerfLogContext, addr, size, name);
}
template<typename AssemblyBuilder>
static std::optional<OldNativeProto> createNativeFunction(
AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, uint32_t& totalIrInstCount, CodeGenCompilationResult& result)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
IrBuilder ir;
ir.buildFunctionIr(proto);
unsigned instCount = unsigned(ir.function.instructions.size());
if (totalIrInstCount + instCount >= unsigned(FInt::CodegenHeuristicsInstructionLimit.value))
{
result = CodeGenCompilationResult::CodeGenOverflowInstructionLimit;
return std::nullopt;
}
totalIrInstCount += instCount;
if (!lowerFunction(ir, build, helpers, proto, {}, /* stats */ nullptr, result))
return std::nullopt;
return createOldNativeProto(proto, ir);
}
static NativeState* getNativeState(lua_State* L)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
return static_cast<NativeState*>(L->global->ecb.context);
}
static void onCloseState(lua_State* L)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
delete getNativeState(L);
L->global->ecb = lua_ExecutionCallbacks();
}
static void onDestroyFunction(lua_State* L, Proto* proto)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
destroyExecData(proto->execdata);
proto->execdata = nullptr;
proto->exectarget = 0;
proto->codeentry = proto->code;
}
static int onEnter(lua_State* L, Proto* proto)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
NativeState* data = getNativeState(L);
CODEGEN_ASSERT(proto->execdata);
CODEGEN_ASSERT(L->ci->savedpc >= proto->code && L->ci->savedpc < proto->code + proto->sizecode);
uintptr_t target = proto->exectarget + static_cast<uint32_t*>(proto->execdata)[L->ci->savedpc - proto->code];
// Returns 1 to finish the function in the VM
return GateFn(data->context.gateEntry)(L, proto, target, &data->context);
}
// used to disable native execution, unconditionally
static int onEnterDisabled(lua_State* L, Proto* proto)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
return 1;
} }
void onDisable(lua_State* L, Proto* proto) void onDisable(lua_State* L, Proto* proto)
@ -279,18 +134,7 @@ void onDisable(lua_State* L, Proto* proto)
}); });
} }
static size_t getMemorySize(lua_State* L, Proto* proto) #if defined(CODEGEN_TARGET_A64)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
ExtraExecData* extra = getExtraExecData(proto, proto->execdata);
// While execDataSize is exactly the size of the allocation we made and hold for 'execdata' field, the code size is approximate
// This is because code+data page is shared and owned by all Proto from a single module and each one can keep the whole region alive
// So individual Proto being freed by GC will not reflect memory use by native code correctly
return extra->execDataSize + extra->codeSize;
}
#if defined(__aarch64__)
unsigned int getCpuFeaturesA64() unsigned int getCpuFeaturesA64()
{ {
unsigned int result = 0; unsigned int result = 0;
@ -326,7 +170,7 @@ bool isSupported()
return false; return false;
#endif #endif
#if defined(__x86_64__) || defined(_M_X64) #if defined(CODEGEN_TARGET_X64)
int cpuinfo[4] = {}; int cpuinfo[4] = {};
#ifdef _MSC_VER #ifdef _MSC_VER
__cpuid(cpuinfo, 1); __cpuid(cpuinfo, 1);
@ -341,273 +185,12 @@ bool isSupported()
return false; return false;
return true; return true;
#elif defined(__aarch64__) #elif defined(CODEGEN_TARGET_A64)
return true; return true;
#else #else
return false; return false;
#endif #endif
} }
static void create_OLD(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
CODEGEN_ASSERT(isSupported());
std::unique_ptr<NativeState> data = std::make_unique<NativeState>(allocationCallback, allocationCallbackContext);
#if defined(_WIN32)
data->unwindBuilder = std::make_unique<UnwindBuilderWin>();
#else
data->unwindBuilder = std::make_unique<UnwindBuilderDwarf2>();
#endif
data->codeAllocator.context = data->unwindBuilder.get();
data->codeAllocator.createBlockUnwindInfo = createBlockUnwindInfo;
data->codeAllocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo;
initFunctions(*data);
#if defined(__x86_64__) || defined(_M_X64)
if (!X64::initHeaderFunctions(*data))
return;
#elif defined(__aarch64__)
if (!A64::initHeaderFunctions(*data))
return;
#endif
if (gPerfLogFn)
gPerfLogFn(gPerfLogContext, uintptr_t(data->context.gateEntry), 4096, "<luau gate>");
lua_ExecutionCallbacks* ecb = &L->global->ecb;
ecb->context = data.release();
ecb->close = onCloseState;
ecb->destroy = onDestroyFunction;
ecb->enter = onEnter;
ecb->disable = onDisable;
ecb->getmemorysize = getMemorySize;
}
void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext)
{
if (FFlag::LuauCodegenContext)
{
create_NEW(L, allocationCallback, allocationCallbackContext);
}
else
{
create_OLD(L, allocationCallback, allocationCallbackContext);
}
}
void create(lua_State* L)
{
if (FFlag::LuauCodegenContext)
{
create_NEW(L);
}
else
{
create(L, nullptr, nullptr);
}
}
void create(lua_State* L, SharedCodeGenContext* codeGenContext)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
create_NEW(L, codeGenContext);
}
[[nodiscard]] bool isNativeExecutionEnabled(lua_State* L)
{
if (FFlag::LuauCodegenContext)
{
return isNativeExecutionEnabled_NEW(L);
}
else
{
return getNativeState(L) ? (L->global->ecb.enter == onEnter) : false;
}
}
void setNativeExecutionEnabled(lua_State* L, bool enabled)
{
if (FFlag::LuauCodegenContext)
{
setNativeExecutionEnabled_NEW(L, enabled);
}
else
{
if (getNativeState(L))
L->global->ecb.enter = enabled ? onEnter : onEnterDisabled;
}
}
static CompilationResult compile_OLD(lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
{
CompilationResult compilationResult;
CODEGEN_ASSERT(lua_isLfunction(L, idx));
const TValue* func = luaA_toobject(L, idx);
Proto* root = clvalue(func)->l.p;
if ((flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0)
{
compilationResult.result = CodeGenCompilationResult::NotNativeModule;
return compilationResult;
}
// If initialization has failed, do not compile any functions
NativeState* data = getNativeState(L);
if (!data)
{
compilationResult.result = CodeGenCompilationResult::CodeGenNotInitialized;
return compilationResult;
}
std::vector<Proto*> protos;
gatherFunctions(protos, root, flags);
// Skip protos that have been compiled during previous invocations of CodeGen::compile
protos.erase(std::remove_if(protos.begin(), protos.end(),
[](Proto* p) {
return p == nullptr || p->execdata != nullptr;
}),
protos.end());
if (protos.empty())
{
compilationResult.result = CodeGenCompilationResult::NothingToCompile;
return compilationResult;
}
if (stats != nullptr)
stats->functionsTotal = uint32_t(protos.size());
#if defined(__aarch64__)
static unsigned int cpuFeatures = getCpuFeaturesA64();
A64::AssemblyBuilderA64 build(/* logText= */ false, cpuFeatures);
#else
X64::AssemblyBuilderX64 build(/* logText= */ false);
#endif
ModuleHelpers helpers;
#if defined(__aarch64__)
A64::assembleHelpers(build, helpers);
#else
X64::assembleHelpers(build, helpers);
#endif
std::vector<OldNativeProto> results;
results.reserve(protos.size());
uint32_t totalIrInstCount = 0;
for (Proto* p : protos)
{
CodeGenCompilationResult protoResult = CodeGenCompilationResult::Success;
if (std::optional<OldNativeProto> np = createNativeFunction(build, helpers, p, totalIrInstCount, protoResult))
results.push_back(*np);
else
compilationResult.protoFailures.push_back({protoResult, p->debugname ? getstr(p->debugname) : "", p->linedefined});
}
// Very large modules might result in overflowing a jump offset; in this case we currently abandon the entire module
if (!build.finalize())
{
for (OldNativeProto result : results)
destroyExecData(result.execdata);
compilationResult.result = CodeGenCompilationResult::CodeGenAssemblerFinalizationFailure;
return compilationResult;
}
// If no functions were assembled, we don't need to allocate/copy executable pages for helpers
if (results.empty())
return compilationResult;
uint8_t* nativeData = nullptr;
size_t sizeNativeData = 0;
uint8_t* codeStart = nullptr;
if (!data->codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast<const uint8_t*>(build.code.data()),
int(build.code.size() * sizeof(build.code[0])), nativeData, sizeNativeData, codeStart))
{
for (OldNativeProto result : results)
destroyExecData(result.execdata);
compilationResult.result = CodeGenCompilationResult::AllocationFailed;
return compilationResult;
}
if (gPerfLogFn && results.size() > 0)
gPerfLogFn(gPerfLogContext, uintptr_t(codeStart), uint32_t(results[0].exectarget), "<luau helpers>");
for (size_t i = 0; i < results.size(); ++i)
{
uint32_t begin = uint32_t(results[i].exectarget);
uint32_t end = i + 1 < results.size() ? uint32_t(results[i + 1].exectarget) : uint32_t(build.code.size() * sizeof(build.code[0]));
CODEGEN_ASSERT(begin < end);
if (gPerfLogFn)
logPerfFunction(results[i].p, uintptr_t(codeStart) + begin, end - begin);
ExtraExecData* extra = getExtraExecData(results[i].p, results[i].execdata);
extra->codeSize = end - begin;
}
for (const OldNativeProto& result : results)
{
// the memory is now managed by VM and will be freed via onDestroyFunction
result.p->execdata = result.execdata;
result.p->exectarget = uintptr_t(codeStart) + result.exectarget;
result.p->codeentry = &kCodeEntryInsn;
}
if (stats != nullptr)
{
for (const OldNativeProto& result : results)
{
stats->bytecodeSizeBytes += result.p->sizecode * sizeof(Instruction);
// Account for the native -> bytecode instruction offsets mapping:
stats->nativeMetadataSizeBytes += result.p->sizecode * sizeof(uint32_t);
}
stats->functionsCompiled += uint32_t(results.size());
stats->nativeCodeSizeBytes += build.code.size();
stats->nativeDataSizeBytes += build.data.size();
}
return compilationResult;
}
CompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
{
if (FFlag::LuauCodegenContext)
{
return compile_NEW(L, idx, flags, stats);
}
else
{
return compile_OLD(L, idx, flags, stats);
}
}
CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return compile_NEW(moduleId, L, idx, flags, stats);
}
void setPerfLog(void* context, PerfLogFn logFn)
{
gPerfLogContext = context;
gPerfLogFn = logFn;
}
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -253,44 +253,11 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde
// Our entry function is special, it spans the whole remaining code area // Our entry function is special, it spans the whole remaining code area
unwind.startFunction(); unwind.startFunction();
unwind.prologueA64(prologueSize, kStackSize, {x29, x30, x19, x20, x21, x22, x23, x24, x25}); unwind.prologueA64(prologueSize, kStackSize, {x29, x30, x19, x20, x21, x22, x23, x24, x25});
unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFuncton); unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFunction);
return locations; return locations;
} }
bool initHeaderFunctions(NativeState& data)
{
AssemblyBuilderA64 build(/* logText= */ false);
UnwindBuilder& unwind = *data.unwindBuilder.get();
unwind.startInfo(UnwindBuilder::A64);
EntryLocations entryLocations = buildEntryFunction(build, unwind);
build.finalize();
unwind.finishInfo();
CODEGEN_ASSERT(build.data.empty());
uint8_t* codeStart = nullptr;
if (!data.codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast<const uint8_t*>(build.code.data()),
int(build.code.size() * sizeof(build.code[0])), data.gateData, data.gateDataSize, codeStart))
{
CODEGEN_ASSERT(!"Failed to create entry function");
return false;
}
// Set the offset at the begining so that functions in new blocks will not overlay the locations
// specified by the unwind information of the entry function
unwind.setBeginOffset(build.getLabelOffset(entryLocations.prologueEnd));
data.context.gateEntry = codeStart + build.getLabelOffset(entryLocations.start);
data.context.gateExit = codeStart + build.getLabelOffset(entryLocations.epilogueStart);
return true;
}
bool initHeaderFunctions(BaseCodeGenContext& codeGenContext) bool initHeaderFunctions(BaseCodeGenContext& codeGenContext)
{ {
AssemblyBuilderA64 build(/* logText= */ false); AssemblyBuilderA64 build(/* logText= */ false);

View file

@ -7,7 +7,6 @@ namespace CodeGen
{ {
class BaseCodeGenContext; class BaseCodeGenContext;
struct NativeState;
struct ModuleHelpers; struct ModuleHelpers;
namespace A64 namespace A64
@ -15,7 +14,6 @@ namespace A64
class AssemblyBuilderA64; class AssemblyBuilderA64;
bool initHeaderFunctions(NativeState& data);
bool initHeaderFunctions(BaseCodeGenContext& codeGenContext); bool initHeaderFunctions(BaseCodeGenContext& codeGenContext);
void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers); void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers);

View file

@ -12,13 +12,50 @@
#include "lapi.h" #include "lapi.h"
LUAU_FASTFLAG(LuauCodegenTypeInfo) LUAU_FASTFLAG(LuauLoadUserdataInfo)
LUAU_FASTFLAG(LuauNativeAttribute)
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
{ {
static const LocVar* tryFindLocal(const Proto* proto, int reg, int pcpos)
{
for (int i = 0; i < proto->sizelocvars; i++)
{
const LocVar& local = proto->locvars[i];
if (reg == local.reg && pcpos >= local.startpc && pcpos < local.endpc)
return &local;
}
return nullptr;
}
const char* tryFindLocalName(const Proto* proto, int reg, int pcpos)
{
const LocVar* var = tryFindLocal(proto, reg, pcpos);
if (var && var->varname)
return getstr(var->varname);
return nullptr;
}
const char* tryFindUpvalueName(const Proto* proto, int upval)
{
if (proto->upvalues)
{
CODEGEN_ASSERT(upval < proto->sizeupvalues);
if (proto->upvalues[upval])
return getstr(proto->upvalues[upval]);
}
return nullptr;
}
template<typename AssemblyBuilder> template<typename AssemblyBuilder>
static void logFunctionHeader(AssemblyBuilder& build, Proto* proto) static void logFunctionHeader(AssemblyBuilder& build, Proto* proto)
{ {
@ -29,10 +66,8 @@ static void logFunctionHeader(AssemblyBuilder& build, Proto* proto)
for (int i = 0; i < proto->numparams; i++) for (int i = 0; i < proto->numparams; i++)
{ {
LocVar* var = proto->locvars ? &proto->locvars[proto->sizelocvars - proto->numparams + i] : nullptr; if (const char* name = tryFindLocalName(proto, i, 0))
build.logAppend("%s%s", i == 0 ? "" : ", ", name);
if (var && var->varname)
build.logAppend("%s%s", i == 0 ? "" : ", ", getstr(var->varname));
else else
build.logAppend("%s$arg%d", i == 0 ? "" : ", ", i); build.logAppend("%s$arg%d", i == 0 ? "" : ", ", i);
} }
@ -49,9 +84,9 @@ static void logFunctionHeader(AssemblyBuilder& build, Proto* proto)
} }
template<typename AssemblyBuilder> template<typename AssemblyBuilder>
static void logFunctionTypes(AssemblyBuilder& build, const IrFunction& function) static void logFunctionTypes_DEPRECATED(AssemblyBuilder& build, const IrFunction& function)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo); CODEGEN_ASSERT(!FFlag::LuauLoadUserdataInfo);
const BytecodeTypeInfo& typeInfo = function.bcTypeInfo; const BytecodeTypeInfo& typeInfo = function.bcTypeInfo;
@ -60,7 +95,12 @@ static void logFunctionTypes(AssemblyBuilder& build, const IrFunction& function)
uint8_t ty = typeInfo.argumentTypes[i]; uint8_t ty = typeInfo.argumentTypes[i];
if (ty != LBC_TYPE_ANY) if (ty != LBC_TYPE_ANY)
build.logAppend("; R%d: %s [argument]\n", int(i), getBytecodeTypeName(ty)); {
if (const char* name = tryFindLocalName(function.proto, int(i), 0))
build.logAppend("; R%d: %s [argument '%s']\n", int(i), getBytecodeTypeName_DEPRECATED(ty), name);
else
build.logAppend("; R%d: %s [argument]\n", int(i), getBytecodeTypeName_DEPRECATED(ty));
}
} }
for (size_t i = 0; i < typeInfo.upvalueTypes.size(); i++) for (size_t i = 0; i < typeInfo.upvalueTypes.size(); i++)
@ -68,12 +108,73 @@ static void logFunctionTypes(AssemblyBuilder& build, const IrFunction& function)
uint8_t ty = typeInfo.upvalueTypes[i]; uint8_t ty = typeInfo.upvalueTypes[i];
if (ty != LBC_TYPE_ANY) if (ty != LBC_TYPE_ANY)
build.logAppend("; U%d: %s\n", int(i), getBytecodeTypeName(ty)); {
if (const char* name = tryFindUpvalueName(function.proto, int(i)))
build.logAppend("; U%d: %s ['%s']\n", int(i), getBytecodeTypeName_DEPRECATED(ty), name);
else
build.logAppend("; U%d: %s\n", int(i), getBytecodeTypeName_DEPRECATED(ty));
}
} }
for (const BytecodeRegTypeInfo& el : typeInfo.regTypes) for (const BytecodeRegTypeInfo& el : typeInfo.regTypes)
{ {
build.logAppend("; R%d: %s from %d to %d\n", el.reg, getBytecodeTypeName(el.type), el.startpc, el.endpc); // Using last active position as the PC because 'startpc' for type info is before local is initialized
if (const char* name = tryFindLocalName(function.proto, el.reg, el.endpc - 1))
build.logAppend("; R%d: %s from %d to %d [local '%s']\n", el.reg, getBytecodeTypeName_DEPRECATED(el.type), el.startpc, el.endpc, name);
else
build.logAppend("; R%d: %s from %d to %d\n", el.reg, getBytecodeTypeName_DEPRECATED(el.type), el.startpc, el.endpc);
}
}
template<typename AssemblyBuilder>
static void logFunctionTypes(AssemblyBuilder& build, const IrFunction& function, const char* const* userdataTypes)
{
CODEGEN_ASSERT(FFlag::LuauLoadUserdataInfo);
const BytecodeTypeInfo& typeInfo = function.bcTypeInfo;
for (size_t i = 0; i < typeInfo.argumentTypes.size(); i++)
{
uint8_t ty = typeInfo.argumentTypes[i];
const char* type = getBytecodeTypeName(ty, userdataTypes);
const char* optional = (ty & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : "";
if (ty != LBC_TYPE_ANY)
{
if (const char* name = tryFindLocalName(function.proto, int(i), 0))
build.logAppend("; R%d: %s%s [argument '%s']\n", int(i), type, optional, name);
else
build.logAppend("; R%d: %s%s [argument]\n", int(i), type, optional);
}
}
for (size_t i = 0; i < typeInfo.upvalueTypes.size(); i++)
{
uint8_t ty = typeInfo.upvalueTypes[i];
const char* type = getBytecodeTypeName(ty, userdataTypes);
const char* optional = (ty & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : "";
if (ty != LBC_TYPE_ANY)
{
if (const char* name = tryFindUpvalueName(function.proto, int(i)))
build.logAppend("; U%d: %s%s ['%s']\n", int(i), type, optional, name);
else
build.logAppend("; U%d: %s%s\n", int(i), type, optional);
}
}
for (const BytecodeRegTypeInfo& el : typeInfo.regTypes)
{
const char* type = getBytecodeTypeName(el.type, userdataTypes);
const char* optional = (el.type & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : "";
// Using last active position as the PC because 'startpc' for type info is before local is initialized
if (const char* name = tryFindLocalName(function.proto, el.reg, el.endpc - 1))
build.logAppend("; R%d: %s%s from %d to %d [local '%s']\n", el.reg, type, optional, el.startpc, el.endpc, name);
else
build.logAppend("; R%d: %s%s from %d to %d\n", el.reg, type, optional, el.startpc, el.endpc);
} }
} }
@ -93,11 +194,14 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A
{ {
Proto* root = clvalue(func)->l.p; Proto* root = clvalue(func)->l.p;
if ((options.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) if ((options.compilationOptions.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0)
return std::string(); return std::string();
std::vector<Proto*> protos; std::vector<Proto*> protos;
gatherFunctions(protos, root, options.flags); if (FFlag::LuauNativeAttribute)
gatherFunctions(protos, root, options.compilationOptions.flags, root->flags & LPF_NATIVE_FUNCTION);
else
gatherFunctions_DEPRECATED(protos, root, options.compilationOptions.flags);
protos.erase(std::remove_if(protos.begin(), protos.end(), protos.erase(std::remove_if(protos.begin(), protos.end(),
[](Proto* p) { [](Proto* p) {
@ -125,7 +229,7 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A
for (Proto* p : protos) for (Proto* p : protos)
{ {
IrBuilder ir; IrBuilder ir(options.compilationOptions.hooks);
ir.buildFunctionIr(p); ir.buildFunctionIr(p);
unsigned asmSize = build.getCodeSize(); unsigned asmSize = build.getCodeSize();
unsigned asmCount = build.getInstructionCount(); unsigned asmCount = build.getInstructionCount();
@ -133,8 +237,13 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A
if (options.includeAssembly || options.includeIr) if (options.includeAssembly || options.includeIr)
logFunctionHeader(build, p); logFunctionHeader(build, p);
if (FFlag::LuauCodegenTypeInfo && options.includeIrTypes) if (options.includeIrTypes)
logFunctionTypes(build, ir.function); {
if (FFlag::LuauLoadUserdataInfo)
logFunctionTypes(build, ir.function, options.compilationOptions.userdataTypes);
else
logFunctionTypes_DEPRECATED(build, ir.function);
}
CodeGenCompilationResult result = CodeGenCompilationResult::Success; CodeGenCompilationResult result = CodeGenCompilationResult::Success;
@ -189,7 +298,7 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A
return build.text; return build.text;
} }
#if defined(__aarch64__) #if defined(CODEGEN_TARGET_A64)
unsigned int getCpuFeaturesA64(); unsigned int getCpuFeaturesA64();
#endif #endif
@ -202,7 +311,7 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options, Lowering
{ {
case AssemblyOptions::Host: case AssemblyOptions::Host:
{ {
#if defined(__aarch64__) #if defined(CODEGEN_TARGET_A64)
static unsigned int cpuFeatures = getCpuFeaturesA64(); static unsigned int cpuFeatures = getCpuFeaturesA64();
A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, cpuFeatures); A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, cpuFeatures);
#else #else

View file

@ -12,12 +12,9 @@
#include "lapi.h" #include "lapi.h"
LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024)
LUAU_FASTFLAGVARIABLE(LuauCodegenContext, false) LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024)
LUAU_FASTFLAGVARIABLE(LuauCodegenCheckNullContext, false) LUAU_FASTFLAG(LuauNativeAttribute)
LUAU_FASTINT(LuauCodeGenBlockSize)
LUAU_FASTINT(LuauCodeGenMaxTotalSize)
namespace Luau namespace Luau
{ {
@ -27,14 +24,19 @@ namespace CodeGen
static const Instruction kCodeEntryInsn = LOP_NATIVECALL; static const Instruction kCodeEntryInsn = LOP_NATIVECALL;
// From CodeGen.cpp // From CodeGen.cpp
extern void* gPerfLogContext; static void* gPerfLogContext = nullptr;
extern PerfLogFn gPerfLogFn; static PerfLogFn gPerfLogFn = nullptr;
unsigned int getCpuFeaturesA64(); unsigned int getCpuFeaturesA64();
void setPerfLog(void* context, PerfLogFn logFn)
{
gPerfLogContext = context;
gPerfLogFn = logFn;
}
static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size) static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
CODEGEN_ASSERT(p->source); CODEGEN_ASSERT(p->source);
const char* source = getstr(p->source); const char* source = getstr(p->source);
@ -50,8 +52,6 @@ static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size)
static void logPerfFunctions( static void logPerfFunctions(
const std::vector<Proto*>& moduleProtos, const uint8_t* nativeModuleBaseAddress, const std::vector<NativeProtoExecDataPtr>& nativeProtos) const std::vector<Proto*>& moduleProtos, const uint8_t* nativeModuleBaseAddress, const std::vector<NativeProtoExecDataPtr>& nativeProtos)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
if (gPerfLogFn == nullptr) if (gPerfLogFn == nullptr)
return; return;
@ -83,8 +83,6 @@ static void logPerfFunctions(
template<bool Release, typename NativeProtosVector> template<bool Release, typename NativeProtosVector>
[[nodiscard]] static uint32_t bindNativeProtos(const std::vector<Proto*>& moduleProtos, NativeProtosVector& nativeProtos) [[nodiscard]] static uint32_t bindNativeProtos(const std::vector<Proto*>& moduleProtos, NativeProtosVector& nativeProtos)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
uint32_t protosBound = 0; uint32_t protosBound = 0;
auto protoIt = moduleProtos.begin(); auto protoIt = moduleProtos.begin();
@ -125,7 +123,6 @@ template<bool Release, typename NativeProtosVector>
BaseCodeGenContext::BaseCodeGenContext(size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) BaseCodeGenContext::BaseCodeGenContext(size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext)
: codeAllocator{blockSize, maxTotalSize, allocationCallback, allocationCallbackContext} : codeAllocator{blockSize, maxTotalSize, allocationCallback, allocationCallbackContext}
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
CODEGEN_ASSERT(isSupported()); CODEGEN_ASSERT(isSupported());
#if defined(_WIN32) #if defined(_WIN32)
@ -143,12 +140,10 @@ BaseCodeGenContext::BaseCodeGenContext(size_t blockSize, size_t maxTotalSize, Al
[[nodiscard]] bool BaseCodeGenContext::initHeaderFunctions() [[nodiscard]] bool BaseCodeGenContext::initHeaderFunctions()
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext); #if defined(CODEGEN_TARGET_X64)
#if defined(__x86_64__) || defined(_M_X64)
if (!X64::initHeaderFunctions(*this)) if (!X64::initHeaderFunctions(*this))
return false; return false;
#elif defined(__aarch64__) #elif defined(CODEGEN_TARGET_A64)
if (!A64::initHeaderFunctions(*this)) if (!A64::initHeaderFunctions(*this))
return false; return false;
#endif #endif
@ -164,13 +159,10 @@ StandaloneCodeGenContext::StandaloneCodeGenContext(
size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext)
: BaseCodeGenContext{blockSize, maxTotalSize, allocationCallback, allocationCallbackContext} : BaseCodeGenContext{blockSize, maxTotalSize, allocationCallback, allocationCallbackContext}
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
} }
[[nodiscard]] std::optional<ModuleBindResult> StandaloneCodeGenContext::tryBindExistingModule(const ModuleId&, const std::vector<Proto*>&) [[nodiscard]] std::optional<ModuleBindResult> StandaloneCodeGenContext::tryBindExistingModule(const ModuleId&, const std::vector<Proto*>&)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
// The StandaloneCodeGenContext does not support sharing of native code // The StandaloneCodeGenContext does not support sharing of native code
return {}; return {};
} }
@ -178,8 +170,6 @@ StandaloneCodeGenContext::StandaloneCodeGenContext(
[[nodiscard]] ModuleBindResult StandaloneCodeGenContext::bindModule(const std::optional<ModuleId>&, const std::vector<Proto*>& moduleProtos, [[nodiscard]] ModuleBindResult StandaloneCodeGenContext::bindModule(const std::optional<ModuleId>&, const std::vector<Proto*>& moduleProtos,
std::vector<NativeProtoExecDataPtr> nativeProtos, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize) std::vector<NativeProtoExecDataPtr> nativeProtos, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
uint8_t* nativeData = nullptr; uint8_t* nativeData = nullptr;
size_t sizeNativeData = 0; size_t sizeNativeData = 0;
uint8_t* codeStart = nullptr; uint8_t* codeStart = nullptr;
@ -205,8 +195,6 @@ StandaloneCodeGenContext::StandaloneCodeGenContext(
void StandaloneCodeGenContext::onCloseState() noexcept void StandaloneCodeGenContext::onCloseState() noexcept
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
// The StandaloneCodeGenContext is owned by the one VM that owns it, so when // The StandaloneCodeGenContext is owned by the one VM that owns it, so when
// that VM is destroyed, we destroy *this as well: // that VM is destroyed, we destroy *this as well:
delete this; delete this;
@ -214,8 +202,6 @@ void StandaloneCodeGenContext::onCloseState() noexcept
void StandaloneCodeGenContext::onDestroyFunction(void* execdata) noexcept void StandaloneCodeGenContext::onDestroyFunction(void* execdata) noexcept
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
destroyNativeProtoExecData(static_cast<uint32_t*>(execdata)); destroyNativeProtoExecData(static_cast<uint32_t*>(execdata));
} }
@ -225,14 +211,11 @@ SharedCodeGenContext::SharedCodeGenContext(
: BaseCodeGenContext{blockSize, maxTotalSize, allocationCallback, allocationCallbackContext} : BaseCodeGenContext{blockSize, maxTotalSize, allocationCallback, allocationCallbackContext}
, sharedAllocator{&codeAllocator} , sharedAllocator{&codeAllocator}
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
} }
[[nodiscard]] std::optional<ModuleBindResult> SharedCodeGenContext::tryBindExistingModule( [[nodiscard]] std::optional<ModuleBindResult> SharedCodeGenContext::tryBindExistingModule(
const ModuleId& moduleId, const std::vector<Proto*>& moduleProtos) const ModuleId& moduleId, const std::vector<Proto*>& moduleProtos)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
NativeModuleRef nativeModule = sharedAllocator.tryGetNativeModule(moduleId); NativeModuleRef nativeModule = sharedAllocator.tryGetNativeModule(moduleId);
if (nativeModule.empty()) if (nativeModule.empty())
{ {
@ -249,8 +232,6 @@ SharedCodeGenContext::SharedCodeGenContext(
[[nodiscard]] ModuleBindResult SharedCodeGenContext::bindModule(const std::optional<ModuleId>& moduleId, const std::vector<Proto*>& moduleProtos, [[nodiscard]] ModuleBindResult SharedCodeGenContext::bindModule(const std::optional<ModuleId>& moduleId, const std::vector<Proto*>& moduleProtos,
std::vector<NativeProtoExecDataPtr> nativeProtos, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize) std::vector<NativeProtoExecDataPtr> nativeProtos, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
const std::pair<NativeModuleRef, bool> insertionResult = [&]() -> std::pair<NativeModuleRef, bool> { const std::pair<NativeModuleRef, bool> insertionResult = [&]() -> std::pair<NativeModuleRef, bool> {
if (moduleId.has_value()) if (moduleId.has_value())
{ {
@ -279,8 +260,6 @@ SharedCodeGenContext::SharedCodeGenContext(
void SharedCodeGenContext::onCloseState() noexcept void SharedCodeGenContext::onCloseState() noexcept
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
// The lifetime of the SharedCodeGenContext is managed separately from the // The lifetime of the SharedCodeGenContext is managed separately from the
// VMs that use it. When a VM is destroyed, we don't need to do anything // VMs that use it. When a VM is destroyed, we don't need to do anything
// here. // here.
@ -288,23 +267,17 @@ void SharedCodeGenContext::onCloseState() noexcept
void SharedCodeGenContext::onDestroyFunction(void* execdata) noexcept void SharedCodeGenContext::onDestroyFunction(void* execdata) noexcept
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
getNativeProtoExecDataHeader(static_cast<const uint32_t*>(execdata)).nativeModule->release(); getNativeProtoExecDataHeader(static_cast<const uint32_t*>(execdata)).nativeModule->release();
} }
[[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext() [[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext()
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return createSharedCodeGenContext(size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), nullptr, nullptr); return createSharedCodeGenContext(size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), nullptr, nullptr);
} }
[[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext(AllocationCallback* allocationCallback, void* allocationCallbackContext) [[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext(AllocationCallback* allocationCallback, void* allocationCallbackContext)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return createSharedCodeGenContext( return createSharedCodeGenContext(
size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext); size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext);
} }
@ -312,8 +285,6 @@ void SharedCodeGenContext::onDestroyFunction(void* execdata) noexcept
[[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext( [[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext(
size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
UniqueSharedCodeGenContext codeGenContext{new SharedCodeGenContext{blockSize, maxTotalSize, nullptr, nullptr}}; UniqueSharedCodeGenContext codeGenContext{new SharedCodeGenContext{blockSize, maxTotalSize, nullptr, nullptr}};
if (!codeGenContext->initHeaderFunctions()) if (!codeGenContext->initHeaderFunctions())
@ -324,38 +295,28 @@ void SharedCodeGenContext::onDestroyFunction(void* execdata) noexcept
void destroySharedCodeGenContext(const SharedCodeGenContext* codeGenContext) noexcept void destroySharedCodeGenContext(const SharedCodeGenContext* codeGenContext) noexcept
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
delete codeGenContext; delete codeGenContext;
} }
void SharedCodeGenContextDeleter::operator()(const SharedCodeGenContext* codeGenContext) const noexcept void SharedCodeGenContextDeleter::operator()(const SharedCodeGenContext* codeGenContext) const noexcept
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
destroySharedCodeGenContext(codeGenContext); destroySharedCodeGenContext(codeGenContext);
} }
[[nodiscard]] static BaseCodeGenContext* getCodeGenContext(lua_State* L) noexcept [[nodiscard]] static BaseCodeGenContext* getCodeGenContext(lua_State* L) noexcept
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return static_cast<BaseCodeGenContext*>(L->global->ecb.context); return static_cast<BaseCodeGenContext*>(L->global->ecb.context);
} }
static void onCloseState(lua_State* L) noexcept static void onCloseState(lua_State* L) noexcept
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
getCodeGenContext(L)->onCloseState(); getCodeGenContext(L)->onCloseState();
L->global->ecb = lua_ExecutionCallbacks{}; L->global->ecb = lua_ExecutionCallbacks{};
} }
static void onDestroyFunction(lua_State* L, Proto* proto) noexcept static void onDestroyFunction(lua_State* L, Proto* proto) noexcept
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
getCodeGenContext(L)->onDestroyFunction(proto->execdata); getCodeGenContext(L)->onDestroyFunction(proto->execdata);
proto->execdata = nullptr; proto->execdata = nullptr;
proto->exectarget = 0; proto->exectarget = 0;
@ -364,8 +325,6 @@ static void onDestroyFunction(lua_State* L, Proto* proto) noexcept
static int onEnter(lua_State* L, Proto* proto) static int onEnter(lua_State* L, Proto* proto)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
BaseCodeGenContext* codeGenContext = getCodeGenContext(L); BaseCodeGenContext* codeGenContext = getCodeGenContext(L);
CODEGEN_ASSERT(proto->execdata); CODEGEN_ASSERT(proto->execdata);
@ -379,8 +338,6 @@ static int onEnter(lua_State* L, Proto* proto)
static int onEnterDisabled(lua_State* L, Proto* proto) static int onEnterDisabled(lua_State* L, Proto* proto)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return 1; return 1;
} }
@ -389,8 +346,6 @@ void onDisable(lua_State* L, Proto* proto);
static size_t getMemorySize(lua_State* L, Proto* proto) static size_t getMemorySize(lua_State* L, Proto* proto)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
const NativeProtoExecDataHeader& execDataHeader = getNativeProtoExecDataHeader(static_cast<const uint32_t*>(proto->execdata)); const NativeProtoExecDataHeader& execDataHeader = getNativeProtoExecDataHeader(static_cast<const uint32_t*>(proto->execdata));
const size_t execDataSize = sizeof(NativeProtoExecDataHeader) + execDataHeader.bytecodeInstructionCount * sizeof(Instruction); const size_t execDataSize = sizeof(NativeProtoExecDataHeader) + execDataHeader.bytecodeInstructionCount * sizeof(Instruction);
@ -403,8 +358,7 @@ static size_t getMemorySize(lua_State* L, Proto* proto)
static void initializeExecutionCallbacks(lua_State* L, BaseCodeGenContext* codeGenContext) noexcept static void initializeExecutionCallbacks(lua_State* L, BaseCodeGenContext* codeGenContext) noexcept
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext); CODEGEN_ASSERT(codeGenContext != nullptr);
CODEGEN_ASSERT(!FFlag::LuauCodegenCheckNullContext || codeGenContext != nullptr);
lua_ExecutionCallbacks* ecb = &L->global->ecb; lua_ExecutionCallbacks* ecb = &L->global->ecb;
@ -416,24 +370,18 @@ static void initializeExecutionCallbacks(lua_State* L, BaseCodeGenContext* codeG
ecb->getmemorysize = getMemorySize; ecb->getmemorysize = getMemorySize;
} }
void create_NEW(lua_State* L) void create(lua_State* L)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext); return create(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), nullptr, nullptr);
return create_NEW(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), nullptr, nullptr);
} }
void create_NEW(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext) void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext); return create(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext);
return create_NEW(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext);
} }
void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) void create(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
std::unique_ptr<StandaloneCodeGenContext> codeGenContext = std::unique_ptr<StandaloneCodeGenContext> codeGenContext =
std::make_unique<StandaloneCodeGenContext>(blockSize, maxTotalSize, allocationCallback, allocationCallbackContext); std::make_unique<StandaloneCodeGenContext>(blockSize, maxTotalSize, allocationCallback, allocationCallbackContext);
@ -443,17 +391,13 @@ void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationC
initializeExecutionCallbacks(L, codeGenContext.release()); initializeExecutionCallbacks(L, codeGenContext.release());
} }
void create_NEW(lua_State* L, SharedCodeGenContext* codeGenContext) void create(lua_State* L, SharedCodeGenContext* codeGenContext)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
initializeExecutionCallbacks(L, codeGenContext); initializeExecutionCallbacks(L, codeGenContext);
} }
[[nodiscard]] static NativeProtoExecDataPtr createNativeProtoExecData(Proto* proto, const IrBuilder& ir) [[nodiscard]] static NativeProtoExecDataPtr createNativeProtoExecData(Proto* proto, const IrBuilder& ir)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
NativeProtoExecDataPtr nativeExecData = createNativeProtoExecData(proto->sizecode); NativeProtoExecDataPtr nativeExecData = createNativeProtoExecData(proto->sizecode);
uint32_t instTarget = ir.function.entryLocation; uint32_t instTarget = ir.function.entryLocation;
@ -478,12 +422,10 @@ void create_NEW(lua_State* L, SharedCodeGenContext* codeGenContext)
} }
template<typename AssemblyBuilder> template<typename AssemblyBuilder>
[[nodiscard]] static NativeProtoExecDataPtr createNativeFunction( [[nodiscard]] static NativeProtoExecDataPtr createNativeFunction(AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto,
AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, uint32_t& totalIrInstCount, CodeGenCompilationResult& result) uint32_t& totalIrInstCount, const HostIrHooks& hooks, CodeGenCompilationResult& result)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext); IrBuilder ir(hooks);
IrBuilder ir;
ir.buildFunctionIr(proto); ir.buildFunctionIr(proto);
unsigned instCount = unsigned(ir.function.instructions.size()); unsigned instCount = unsigned(ir.function.instructions.size());
@ -505,15 +447,14 @@ template<typename AssemblyBuilder>
} }
[[nodiscard]] static CompilationResult compileInternal( [[nodiscard]] static CompilationResult compileInternal(
const std::optional<ModuleId>& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats) const std::optional<ModuleId>& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
CODEGEN_ASSERT(lua_isLfunction(L, idx)); CODEGEN_ASSERT(lua_isLfunction(L, idx));
const TValue* func = luaA_toobject(L, idx); const TValue* func = luaA_toobject(L, idx);
Proto* root = clvalue(func)->l.p; Proto* root = clvalue(func)->l.p;
if ((flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) if ((options.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0 && (root->flags & LPF_NATIVE_FUNCTION) == 0)
return CompilationResult{CodeGenCompilationResult::NotNativeModule}; return CompilationResult{CodeGenCompilationResult::NotNativeModule};
BaseCodeGenContext* codeGenContext = getCodeGenContext(L); BaseCodeGenContext* codeGenContext = getCodeGenContext(L);
@ -521,7 +462,10 @@ template<typename AssemblyBuilder>
return CompilationResult{CodeGenCompilationResult::CodeGenNotInitialized}; return CompilationResult{CodeGenCompilationResult::CodeGenNotInitialized};
std::vector<Proto*> protos; std::vector<Proto*> protos;
gatherFunctions(protos, root, flags); if (FFlag::LuauNativeAttribute)
gatherFunctions(protos, root, options.flags, root->flags & LPF_NATIVE_FUNCTION);
else
gatherFunctions_DEPRECATED(protos, root, options.flags);
// Skip protos that have been compiled during previous invocations of CodeGen::compile // Skip protos that have been compiled during previous invocations of CodeGen::compile
protos.erase(std::remove_if(protos.begin(), protos.end(), protos.erase(std::remove_if(protos.begin(), protos.end(),
@ -547,7 +491,7 @@ template<typename AssemblyBuilder>
} }
} }
#if defined(__aarch64__) #if defined(CODEGEN_TARGET_A64)
static unsigned int cpuFeatures = getCpuFeaturesA64(); static unsigned int cpuFeatures = getCpuFeaturesA64();
A64::AssemblyBuilderA64 build(/* logText= */ false, cpuFeatures); A64::AssemblyBuilderA64 build(/* logText= */ false, cpuFeatures);
#else #else
@ -555,7 +499,7 @@ template<typename AssemblyBuilder>
#endif #endif
ModuleHelpers helpers; ModuleHelpers helpers;
#if defined(__aarch64__) #if defined(CODEGEN_TARGET_A64)
A64::assembleHelpers(build, helpers); A64::assembleHelpers(build, helpers);
#else #else
X64::assembleHelpers(build, helpers); X64::assembleHelpers(build, helpers);
@ -572,7 +516,7 @@ template<typename AssemblyBuilder>
{ {
CodeGenCompilationResult protoResult = CodeGenCompilationResult::Success; CodeGenCompilationResult protoResult = CodeGenCompilationResult::Success;
NativeProtoExecDataPtr nativeExecData = createNativeFunction(build, helpers, protos[i], totalIrInstCount, protoResult); NativeProtoExecDataPtr nativeExecData = createNativeFunction(build, helpers, protos[i], totalIrInstCount, options.hooks, protoResult);
if (nativeExecData != nullptr) if (nativeExecData != nullptr)
{ {
nativeProtos.push_back(std::move(nativeExecData)); nativeProtos.push_back(std::move(nativeExecData));
@ -639,34 +583,60 @@ template<typename AssemblyBuilder>
return compilationResult; return compilationResult;
} }
CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats) CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext); return compileInternal(moduleId, L, idx, options, stats);
return compileInternal(moduleId, L, idx, flags, stats);
} }
CompilationResult compile_NEW(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) CompilationResult compile(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext); return compileInternal({}, L, idx, options, stats);
return compileInternal({}, L, idx, flags, stats);
} }
[[nodiscard]] bool isNativeExecutionEnabled_NEW(lua_State* L) CompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext); return compileInternal({}, L, idx, CompilationOptions{flags}, stats);
}
CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
{
return compileInternal(moduleId, L, idx, CompilationOptions{flags}, stats);
}
[[nodiscard]] bool isNativeExecutionEnabled(lua_State* L)
{
return getCodeGenContext(L) != nullptr && L->global->ecb.enter == onEnter; return getCodeGenContext(L) != nullptr && L->global->ecb.enter == onEnter;
} }
void setNativeExecutionEnabled_NEW(lua_State* L, bool enabled) void setNativeExecutionEnabled(lua_State* L, bool enabled)
{ {
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
if (getCodeGenContext(L) != nullptr) if (getCodeGenContext(L) != nullptr)
L->global->ecb.enter = enabled ? onEnter : onEnterDisabled; L->global->ecb.enter = enabled ? onEnter : onEnterDisabled;
} }
static uint8_t userdataRemapperWrap(lua_State* L, const char* str, size_t len)
{
if (BaseCodeGenContext* codegenCtx = getCodeGenContext(L))
{
uint8_t index = codegenCtx->userdataRemapper(codegenCtx->userdataRemappingContext, str, len);
if (index < (LBC_TYPE_TAGGED_USERDATA_END - LBC_TYPE_TAGGED_USERDATA_BASE))
return LBC_TYPE_TAGGED_USERDATA_BASE + index;
}
return LBC_TYPE_USERDATA;
}
void setUserdataRemapper(lua_State* L, void* context, UserdataRemapperCallback cb)
{
if (BaseCodeGenContext* codegenCtx = getCodeGenContext(L))
{
codegenCtx->userdataRemappingContext = context;
codegenCtx->userdataRemapper = cb;
L->global->ecb.gettypemapping = cb ? userdataRemapperWrap : nullptr;
}
}
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -50,6 +50,9 @@ public:
uint8_t* gateData = nullptr; uint8_t* gateData = nullptr;
size_t gateDataSize = 0; size_t gateDataSize = 0;
void* userdataRemappingContext = nullptr;
UserdataRemapperCallback* userdataRemapper = nullptr;
NativeContext context; NativeContext context;
}; };
@ -88,33 +91,5 @@ private:
SharedCodeAllocator sharedAllocator; SharedCodeAllocator sharedAllocator;
}; };
// The following will become the public interface, and can be moved into
// CodeGen.h after the shared allocator work is complete. When the old
// implementation is removed, the _NEW suffix can be dropped from these
// functions.
// Initializes native code-gen on the provided Luau VM, using a VM-specific
// code-gen context and either the default allocator parameters or custom
// allocator parameters.
void create_NEW(lua_State* L);
void create_NEW(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext);
void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext);
// Initializes native code-gen on the provided Luau VM, using the provided
// SharedCodeGenContext. Note that after this function is called, the
// SharedCodeGenContext must not be destroyed until after the Luau VM L is
// destroyed via lua_close.
void create_NEW(lua_State* L, SharedCodeGenContext* codeGenContext);
CompilationResult compile_NEW(lua_State* L, int idx, unsigned int flags, CompilationStats* stats);
CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats);
// Returns true if native execution is currently enabled for this VM
[[nodiscard]] bool isNativeExecutionEnabled_NEW(lua_State* L);
// Enables or disables native excution for this VM
void setNativeExecutionEnabled_NEW(lua_State* L, bool enabled);
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -27,14 +27,15 @@ LUAU_FASTFLAG(DebugCodegenSkipNumbering)
LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTINT(CodegenHeuristicsInstructionLimit)
LUAU_FASTINT(CodegenHeuristicsBlockLimit) LUAU_FASTINT(CodegenHeuristicsBlockLimit)
LUAU_FASTINT(CodegenHeuristicsBlockInstructionLimit) LUAU_FASTINT(CodegenHeuristicsBlockInstructionLimit)
LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(LuauLoadUserdataInfo)
LUAU_FASTFLAG(LuauNativeAttribute)
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
{ {
inline void gatherFunctions(std::vector<Proto*>& results, Proto* proto, unsigned int flags) inline void gatherFunctions_DEPRECATED(std::vector<Proto*>& results, Proto* proto, unsigned int flags)
{ {
if (results.size() <= size_t(proto->bytecodeid)) if (results.size() <= size_t(proto->bytecodeid))
results.resize(proto->bytecodeid + 1); results.resize(proto->bytecodeid + 1);
@ -49,7 +50,36 @@ inline void gatherFunctions(std::vector<Proto*>& results, Proto* proto, unsigned
// Recursively traverse child protos even if we aren't compiling this one // Recursively traverse child protos even if we aren't compiling this one
for (int i = 0; i < proto->sizep; i++) for (int i = 0; i < proto->sizep; i++)
gatherFunctions(results, proto->p[i], flags); gatherFunctions_DEPRECATED(results, proto->p[i], flags);
}
inline void gatherFunctionsHelper(
std::vector<Proto*>& results, Proto* proto, const unsigned int flags, const bool hasNativeFunctions, const bool root)
{
if (results.size() <= size_t(proto->bytecodeid))
results.resize(proto->bytecodeid + 1);
// Skip protos that we've already compiled in this run: this happens because at -O2, inlined functions get their protos reused
if (results[proto->bytecodeid])
return;
// if native module, compile cold functions if requested
// if not native module, compile function if it has native attribute and is not root
bool shouldGather = hasNativeFunctions ? (!root && (proto->flags & LPF_NATIVE_FUNCTION) != 0)
: ((proto->flags & LPF_NATIVE_COLD) == 0 || (flags & CodeGen_ColdFunctions) != 0);
if (shouldGather)
results[proto->bytecodeid] = proto;
// Recursively traverse child protos even if we aren't compiling this one
for (int i = 0; i < proto->sizep; i++)
gatherFunctionsHelper(results, proto->p[i], flags, hasNativeFunctions, false);
}
inline void gatherFunctions(std::vector<Proto*>& results, Proto* root, const unsigned int flags, const bool hasNativeFunctions = false)
{
LUAU_ASSERT(FFlag::LuauNativeAttribute);
gatherFunctionsHelper(results, root, flags, hasNativeFunctions, true);
} }
inline unsigned getInstructionCount(const std::vector<IrInst>& instructions, IrCmd cmd) inline unsigned getInstructionCount(const std::vector<IrInst>& instructions, IrCmd cmd)
@ -149,7 +179,11 @@ inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction&
if (bcTypes.result != LBC_TYPE_ANY || bcTypes.a != LBC_TYPE_ANY || bcTypes.b != LBC_TYPE_ANY || bcTypes.c != LBC_TYPE_ANY) if (bcTypes.result != LBC_TYPE_ANY || bcTypes.a != LBC_TYPE_ANY || bcTypes.b != LBC_TYPE_ANY || bcTypes.c != LBC_TYPE_ANY)
{ {
toString(ctx.result, bcTypes); if (FFlag::LuauLoadUserdataInfo)
toString(ctx.result, bcTypes, options.compilationOptions.userdataTypes);
else
toString_DEPRECATED(ctx.result, bcTypes);
build.logAppend("\n"); build.logAppend("\n");
} }
} }
@ -312,8 +346,7 @@ inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers&
} }
} }
if (FFlag::LuauCodegenRemoveDeadStores5) markDeadStoresInBlockChains(ir);
markDeadStoresInBlockChains(ir);
} }
std::vector<uint32_t> sortedBlocks = getSortedBlockOrder(ir.function); std::vector<uint32_t> sortedBlocks = getSortedBlockOrder(ir.function);

View file

@ -14,6 +14,7 @@
#include "lstate.h" #include "lstate.h"
#include "lstring.h" #include "lstring.h"
#include "ltable.h" #include "ltable.h"
#include "ludata.h"
#include <string.h> #include <string.h>
@ -219,6 +220,20 @@ void callEpilogC(lua_State* L, int nresults, int n)
L->top = (nresults == LUA_MULTRET) ? res : cip->top; L->top = (nresults == LUA_MULTRET) ? res : cip->top;
} }
Udata* newUserdata(lua_State* L, size_t s, int tag)
{
Udata* u = luaU_newudata(L, s, tag);
if (Table* h = L->global->udatamt[tag])
{
u->metatable = h;
luaC_objbarrier(L, u, h);
}
return u;
}
// Extracted as-is from lvmexecute.cpp with the exception of control flow (reentry) and removed interrupts/savedpc // Extracted as-is from lvmexecute.cpp with the exception of control flow (reentry) and removed interrupts/savedpc
Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults) Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults)
{ {

View file

@ -17,6 +17,8 @@ void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc);
Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults); Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults);
void callEpilogC(lua_State* L, int nresults, int n); void callEpilogC(lua_State* L, int nresults, int n);
Udata* newUserdata(lua_State* L, size_t s, int tag);
#define CALL_FALLBACK_YIELD 1 #define CALL_FALLBACK_YIELD 1
Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults); Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults);

View file

@ -181,44 +181,11 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde
build.ret(); build.ret();
// Our entry function is special, it spans the whole remaining code area // Our entry function is special, it spans the whole remaining code area
unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFuncton); unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFunction);
return locations; return locations;
} }
bool initHeaderFunctions(NativeState& data)
{
AssemblyBuilderX64 build(/* logText= */ false);
UnwindBuilder& unwind = *data.unwindBuilder.get();
unwind.startInfo(UnwindBuilder::X64);
EntryLocations entryLocations = buildEntryFunction(build, unwind);
build.finalize();
unwind.finishInfo();
CODEGEN_ASSERT(build.data.empty());
uint8_t* codeStart = nullptr;
if (!data.codeAllocator.allocate(
build.data.data(), int(build.data.size()), build.code.data(), int(build.code.size()), data.gateData, data.gateDataSize, codeStart))
{
CODEGEN_ASSERT(!"Failed to create entry function");
return false;
}
// Set the offset at the begining so that functions in new blocks will not overlay the locations
// specified by the unwind information of the entry function
unwind.setBeginOffset(build.getLabelOffset(entryLocations.prologueEnd));
data.context.gateEntry = codeStart + build.getLabelOffset(entryLocations.start);
data.context.gateExit = codeStart + build.getLabelOffset(entryLocations.epilogueStart);
return true;
}
bool initHeaderFunctions(BaseCodeGenContext& codeGenContext) bool initHeaderFunctions(BaseCodeGenContext& codeGenContext)
{ {
AssemblyBuilderX64 build(/* logText= */ false); AssemblyBuilderX64 build(/* logText= */ false);

View file

@ -7,7 +7,6 @@ namespace CodeGen
{ {
class BaseCodeGenContext; class BaseCodeGenContext;
struct NativeState;
struct ModuleHelpers; struct ModuleHelpers;
namespace X64 namespace X64
@ -15,7 +14,6 @@ namespace X64
class AssemblyBuilderX64; class AssemblyBuilderX64;
bool initHeaderFunctions(NativeState& data);
bool initHeaderFunctions(BaseCodeGenContext& codeGenContext); bool initHeaderFunctions(BaseCodeGenContext& codeGenContext);
void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers); void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers);

View file

@ -12,7 +12,7 @@
#include "lstate.h" #include "lstate.h"
LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(LuauCodegenMathSign)
namespace Luau namespace Luau
{ {
@ -29,17 +29,13 @@ static void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build,
callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_frexp)]); callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_frexp)]);
build.vmovsd(luauRegValue(ra), xmm0); build.vmovsd(luauRegValue(ra), xmm0);
build.mov(luauRegTag(ra), LUA_TNUMBER);
if (FFlag::LuauCodegenRemoveDeadStores5)
build.mov(luauRegTag(ra), LUA_TNUMBER);
if (nresults > 1) if (nresults > 1)
{ {
build.vcvtsi2sd(xmm0, xmm0, dword[sTemporarySlot + 0]); build.vcvtsi2sd(xmm0, xmm0, dword[sTemporarySlot + 0]);
build.vmovsd(luauRegValue(ra + 1), xmm0); build.vmovsd(luauRegValue(ra + 1), xmm0);
build.mov(luauRegTag(ra + 1), LUA_TNUMBER);
if (FFlag::LuauCodegenRemoveDeadStores5)
build.mov(luauRegTag(ra + 1), LUA_TNUMBER);
} }
} }
@ -52,21 +48,19 @@ static void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build,
build.vmovsd(xmm1, qword[sTemporarySlot + 0]); build.vmovsd(xmm1, qword[sTemporarySlot + 0]);
build.vmovsd(luauRegValue(ra), xmm1); build.vmovsd(luauRegValue(ra), xmm1);
build.mov(luauRegTag(ra), LUA_TNUMBER);
if (FFlag::LuauCodegenRemoveDeadStores5)
build.mov(luauRegTag(ra), LUA_TNUMBER);
if (nresults > 1) if (nresults > 1)
{ {
build.vmovsd(luauRegValue(ra + 1), xmm0); build.vmovsd(luauRegValue(ra + 1), xmm0);
build.mov(luauRegTag(ra + 1), LUA_TNUMBER);
if (FFlag::LuauCodegenRemoveDeadStores5)
build.mov(luauRegTag(ra + 1), LUA_TNUMBER);
} }
} }
static void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg) static void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg)
{ {
CODEGEN_ASSERT(!FFlag::LuauCodegenMathSign);
ScopedRegX64 tmp0{regs, SizeX64::xmmword}; ScopedRegX64 tmp0{regs, SizeX64::xmmword};
ScopedRegX64 tmp1{regs, SizeX64::xmmword}; ScopedRegX64 tmp1{regs, SizeX64::xmmword};
ScopedRegX64 tmp2{regs, SizeX64::xmmword}; ScopedRegX64 tmp2{regs, SizeX64::xmmword};
@ -90,23 +84,22 @@ static void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build,
build.vblendvpd(tmp0.reg, tmp2.reg, build.f64x2(1, 1), tmp0.reg); build.vblendvpd(tmp0.reg, tmp2.reg, build.f64x2(1, 1), tmp0.reg);
build.vmovsd(luauRegValue(ra), tmp0.reg); build.vmovsd(luauRegValue(ra), tmp0.reg);
build.mov(luauRegTag(ra), LUA_TNUMBER);
if (FFlag::LuauCodegenRemoveDeadStores5)
build.mov(luauRegTag(ra), LUA_TNUMBER);
} }
void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, OperandX64 arg2, int nparams, int nresults) void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, int nresults)
{ {
switch (bfid) switch (bfid)
{ {
case LBF_MATH_FREXP: case LBF_MATH_FREXP:
CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); CODEGEN_ASSERT(nresults == 1 || nresults == 2);
return emitBuiltinMathFrexp(regs, build, ra, arg, nresults); return emitBuiltinMathFrexp(regs, build, ra, arg, nresults);
case LBF_MATH_MODF: case LBF_MATH_MODF:
CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); CODEGEN_ASSERT(nresults == 1 || nresults == 2);
return emitBuiltinMathModf(regs, build, ra, arg, nresults); return emitBuiltinMathModf(regs, build, ra, arg, nresults);
case LBF_MATH_SIGN: case LBF_MATH_SIGN:
CODEGEN_ASSERT(nparams == 1 && nresults == 1); CODEGEN_ASSERT(!FFlag::LuauCodegenMathSign);
CODEGEN_ASSERT(nresults == 1);
return emitBuiltinMathSign(regs, build, ra, arg); return emitBuiltinMathSign(regs, build, ra, arg);
default: default:
CODEGEN_ASSERT(!"Missing x64 lowering"); CODEGEN_ASSERT(!"Missing x64 lowering");

View file

@ -16,7 +16,7 @@ class AssemblyBuilderX64;
struct OperandX64; struct OperandX64;
struct IrRegAllocX64; struct IrRegAllocX64;
void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, OperandX64 arg2, int nparams, int nresults); void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, int nresults);
} // namespace X64 } // namespace X64
} // namespace CodeGen } // namespace CodeGen

View file

@ -22,8 +22,6 @@ namespace Luau
namespace CodeGen namespace CodeGen
{ {
struct NativeState;
namespace A64 namespace A64
{ {

View file

@ -155,8 +155,37 @@ void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, Ope
callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); callWrap.addArgument(SizeX64::qword, luauRegAddress(ra));
callWrap.addArgument(SizeX64::qword, b); callWrap.addArgument(SizeX64::qword, b);
callWrap.addArgument(SizeX64::qword, c); callWrap.addArgument(SizeX64::qword, c);
callWrap.addArgument(SizeX64::dword, tm);
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]); switch (tm)
{
case TM_ADD:
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithadd)]);
break;
case TM_SUB:
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithsub)]);
break;
case TM_MUL:
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithmul)]);
break;
case TM_DIV:
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithdiv)]);
break;
case TM_IDIV:
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithidiv)]);
break;
case TM_MOD:
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithmod)]);
break;
case TM_POW:
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithpow)]);
break;
case TM_UNM:
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithunm)]);
break;
default:
CODEGEN_ASSERT(!"Invalid doarith helper operation tag");
break;
}
emitUpdateBase(build); emitUpdateBase(build);
} }

View file

@ -26,7 +26,6 @@ namespace CodeGen
{ {
enum class IrCondition : uint8_t; enum class IrCondition : uint8_t;
struct NativeState;
struct IrOp; struct IrOp;
namespace X64 namespace X64

View file

@ -13,6 +13,8 @@
#include <stddef.h> #include <stddef.h>
LUAU_FASTFLAGVARIABLE(LuauCodegenInstG, false)
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
@ -52,6 +54,9 @@ void updateUseCounts(IrFunction& function)
checkOp(inst.d); checkOp(inst.d);
checkOp(inst.e); checkOp(inst.e);
checkOp(inst.f); checkOp(inst.f);
if (FFlag::LuauCodegenInstG)
checkOp(inst.g);
} }
} }
@ -95,6 +100,9 @@ void updateLastUseLocations(IrFunction& function, const std::vector<uint32_t>& s
checkOp(inst.d); checkOp(inst.d);
checkOp(inst.e); checkOp(inst.e);
checkOp(inst.f); checkOp(inst.f);
if (FFlag::LuauCodegenInstG)
checkOp(inst.g);
} }
} }
} }
@ -128,6 +136,12 @@ uint32_t getNextInstUse(IrFunction& function, uint32_t targetInstIdx, uint32_t s
if (inst.f.kind == IrOpKind::Inst && inst.f.index == targetInstIdx) if (inst.f.kind == IrOpKind::Inst && inst.f.index == targetInstIdx)
return i; return i;
if (FFlag::LuauCodegenInstG)
{
if (inst.g.kind == IrOpKind::Inst && inst.g.index == targetInstIdx)
return i;
}
} }
// There must be a next use since there is the last use location // There must be a next use since there is the last use location
@ -165,6 +179,9 @@ std::pair<uint32_t, uint32_t> getLiveInOutValueCount(IrFunction& function, IrBlo
checkOp(inst.d); checkOp(inst.d);
checkOp(inst.e); checkOp(inst.e);
checkOp(inst.f); checkOp(inst.f);
if (FFlag::LuauCodegenInstG)
checkOp(inst.g);
} }
return std::make_pair(liveIns, liveOuts); return std::make_pair(liveIns, liveOuts);
@ -488,6 +505,9 @@ static void computeCfgBlockEdges(IrFunction& function)
checkOp(inst.d); checkOp(inst.d);
checkOp(inst.e); checkOp(inst.e);
checkOp(inst.f); checkOp(inst.f);
if (FFlag::LuauCodegenInstG)
checkOp(inst.g);
} }
} }

View file

@ -13,8 +13,9 @@
#include <string.h> #include <string.h>
LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo load changes the format used by Codegen, same flag is used LUAU_FASTFLAG(LuauLoadUserdataInfo)
LUAU_FASTFLAG(LuauTypeInfoLookupImprovement) LUAU_FASTFLAG(LuauCodegenInstG)
LUAU_FASTFLAG(LuauCodegenFastcall3)
namespace Luau namespace Luau
{ {
@ -23,120 +24,25 @@ namespace CodeGen
constexpr unsigned kNoAssociatedBlockIndex = ~0u; constexpr unsigned kNoAssociatedBlockIndex = ~0u;
IrBuilder::IrBuilder() IrBuilder::IrBuilder(const HostIrHooks& hostHooks)
: constantMap({IrConstKind::Tag, ~0ull}) : hostHooks(hostHooks)
, constantMap({IrConstKind::Tag, ~0ull})
{ {
} }
static bool hasTypedParameters_DEPRECATED(Proto* proto)
{
CODEGEN_ASSERT(!FFlag::LuauLoadTypeInfo);
return proto->typeinfo && proto->numparams != 0;
}
static void buildArgumentTypeChecks_DEPRECATED(IrBuilder& build, Proto* proto)
{
CODEGEN_ASSERT(!FFlag::LuauLoadTypeInfo);
CODEGEN_ASSERT(hasTypedParameters_DEPRECATED(proto));
for (int i = 0; i < proto->numparams; ++i)
{
uint8_t et = proto->typeinfo[2 + i];
uint8_t tag = et & ~LBC_TYPE_OPTIONAL_BIT;
uint8_t optional = et & LBC_TYPE_OPTIONAL_BIT;
if (tag == LBC_TYPE_ANY)
continue;
IrOp load = build.inst(IrCmd::LOAD_TAG, build.vmReg(i));
IrOp nextCheck;
if (optional)
{
nextCheck = build.block(IrBlockKind::Internal);
IrOp fallbackCheck = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_EQ_TAG, load, build.constTag(LUA_TNIL), nextCheck, fallbackCheck);
build.beginBlock(fallbackCheck);
}
switch (tag)
{
case LBC_TYPE_NIL:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNIL), build.vmExit(kVmExitEntryGuardPc));
break;
case LBC_TYPE_BOOLEAN:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBOOLEAN), build.vmExit(kVmExitEntryGuardPc));
break;
case LBC_TYPE_NUMBER:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNUMBER), build.vmExit(kVmExitEntryGuardPc));
break;
case LBC_TYPE_STRING:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TSTRING), build.vmExit(kVmExitEntryGuardPc));
break;
case LBC_TYPE_TABLE:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTABLE), build.vmExit(kVmExitEntryGuardPc));
break;
case LBC_TYPE_FUNCTION:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TFUNCTION), build.vmExit(kVmExitEntryGuardPc));
break;
case LBC_TYPE_THREAD:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTHREAD), build.vmExit(kVmExitEntryGuardPc));
break;
case LBC_TYPE_USERDATA:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TUSERDATA), build.vmExit(kVmExitEntryGuardPc));
break;
case LBC_TYPE_VECTOR:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TVECTOR), build.vmExit(kVmExitEntryGuardPc));
break;
case LBC_TYPE_BUFFER:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBUFFER), build.vmExit(kVmExitEntryGuardPc));
break;
}
if (optional)
{
build.inst(IrCmd::JUMP, nextCheck);
build.beginBlock(nextCheck);
}
}
// If the last argument is optional, we can skip creating a new internal block since one will already have been created.
if (!(proto->typeinfo[2 + proto->numparams - 1] & LBC_TYPE_OPTIONAL_BIT))
{
IrOp next = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP, next);
build.beginBlock(next);
}
}
static bool hasTypedParameters(const BytecodeTypeInfo& typeInfo) static bool hasTypedParameters(const BytecodeTypeInfo& typeInfo)
{ {
CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo); for (auto el : typeInfo.argumentTypes)
if (FFlag::LuauTypeInfoLookupImprovement)
{ {
for (auto el : typeInfo.argumentTypes) if (el != LBC_TYPE_ANY)
{ return true;
if (el != LBC_TYPE_ANY) }
return true;
}
return false; return false;
}
else
{
return !typeInfo.argumentTypes.empty();
}
} }
static void buildArgumentTypeChecks(IrBuilder& build) static void buildArgumentTypeChecks(IrBuilder& build)
{ {
CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo);
const BytecodeTypeInfo& typeInfo = build.function.bcTypeInfo; const BytecodeTypeInfo& typeInfo = build.function.bcTypeInfo;
CODEGEN_ASSERT(hasTypedParameters(typeInfo)); CODEGEN_ASSERT(hasTypedParameters(typeInfo));
@ -195,6 +101,19 @@ static void buildArgumentTypeChecks(IrBuilder& build)
case LBC_TYPE_BUFFER: case LBC_TYPE_BUFFER:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBUFFER), build.vmExit(kVmExitEntryGuardPc)); build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBUFFER), build.vmExit(kVmExitEntryGuardPc));
break; break;
default:
if (FFlag::LuauLoadUserdataInfo)
{
if (tag >= LBC_TYPE_TAGGED_USERDATA_BASE && tag < LBC_TYPE_TAGGED_USERDATA_END)
{
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TUSERDATA), build.vmExit(kVmExitEntryGuardPc));
}
else
{
CODEGEN_ASSERT(!"unknown argument type tag");
}
}
break;
} }
if (optional) if (optional)
@ -219,18 +138,17 @@ void IrBuilder::buildFunctionIr(Proto* proto)
function.proto = proto; function.proto = proto;
function.variadic = proto->is_vararg != 0; function.variadic = proto->is_vararg != 0;
if (FFlag::LuauLoadTypeInfo) loadBytecodeTypeInfo(function);
loadBytecodeTypeInfo(function);
// Reserve entry block // Reserve entry block
bool generateTypeChecks = FFlag::LuauLoadTypeInfo ? hasTypedParameters(function.bcTypeInfo) : hasTypedParameters_DEPRECATED(proto); bool generateTypeChecks = hasTypedParameters(function.bcTypeInfo);
IrOp entry = generateTypeChecks ? block(IrBlockKind::Internal) : IrOp{}; IrOp entry = generateTypeChecks ? block(IrBlockKind::Internal) : IrOp{};
// Rebuild original control flow blocks // Rebuild original control flow blocks
rebuildBytecodeBasicBlocks(proto); rebuildBytecodeBasicBlocks(proto);
// Infer register tags in bytecode // Infer register tags in bytecode
analyzeBytecodeTypes(function); analyzeBytecodeTypes(function, hostHooks);
function.bcMapping.resize(proto->sizecode, {~0u, ~0u}); function.bcMapping.resize(proto->sizecode, {~0u, ~0u});
@ -238,10 +156,7 @@ void IrBuilder::buildFunctionIr(Proto* proto)
{ {
beginBlock(entry); beginBlock(entry);
if (FFlag::LuauLoadTypeInfo) buildArgumentTypeChecks(*this);
buildArgumentTypeChecks(*this);
else
buildArgumentTypeChecks_DEPRECATED(*this, proto);
inst(IrCmd::JUMP, blockAtInst(0)); inst(IrCmd::JUMP, blockAtInst(0));
} }
@ -283,10 +198,10 @@ void IrBuilder::buildFunctionIr(Proto* proto)
translateInst(op, pc, i); translateInst(op, pc, i);
if (fastcallSkipTarget != -1) if (cmdSkipTarget != -1)
{ {
nexti = fastcallSkipTarget; nexti = cmdSkipTarget;
fastcallSkipTarget = -1; cmdSkipTarget = -1;
} }
} }
@ -535,16 +450,21 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
translateInstCloseUpvals(*this, pc); translateInstCloseUpvals(*this, pc);
break; break;
case LOP_FASTCALL: case LOP_FASTCALL:
handleFastcallFallback(translateFastCallN(*this, pc, i, false, 0, {}), pc, i); handleFastcallFallback(translateFastCallN(*this, pc, i, false, 0, {}, {}), pc, i);
break; break;
case LOP_FASTCALL1: case LOP_FASTCALL1:
handleFastcallFallback(translateFastCallN(*this, pc, i, true, 1, undef()), pc, i); handleFastcallFallback(translateFastCallN(*this, pc, i, true, 1, undef(), undef()), pc, i);
break; break;
case LOP_FASTCALL2: case LOP_FASTCALL2:
handleFastcallFallback(translateFastCallN(*this, pc, i, true, 2, vmReg(pc[1])), pc, i); handleFastcallFallback(translateFastCallN(*this, pc, i, true, 2, vmReg(pc[1]), undef()), pc, i);
break; break;
case LOP_FASTCALL2K: case LOP_FASTCALL2K:
handleFastcallFallback(translateFastCallN(*this, pc, i, true, 2, vmConst(pc[1])), pc, i); handleFastcallFallback(translateFastCallN(*this, pc, i, true, 2, vmConst(pc[1]), undef()), pc, i);
break;
case LOP_FASTCALL3:
CODEGEN_ASSERT(FFlag::LuauCodegenFastcall3);
handleFastcallFallback(translateFastCallN(*this, pc, i, true, 3, vmReg(pc[1] & 0xff), vmReg((pc[1] >> 8) & 0xff)), pc, i);
break; break;
case LOP_FORNPREP: case LOP_FORNPREP:
translateInstForNPrep(*this, pc, i); translateInstForNPrep(*this, pc, i);
@ -613,7 +533,8 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
translateInstCapture(*this, pc, i); translateInstCapture(*this, pc, i);
break; break;
case LOP_NAMECALL: case LOP_NAMECALL:
translateInstNamecall(*this, pc, i); if (translateInstNamecall(*this, pc, i))
cmdSkipTarget = i + 3;
break; break;
case LOP_PREPVARARGS: case LOP_PREPVARARGS:
inst(IrCmd::FALLBACK_PREPVARARGS, constUint(i), constInt(LUAU_INSN_A(*pc))); inst(IrCmd::FALLBACK_PREPVARARGS, constUint(i), constInt(LUAU_INSN_A(*pc)));
@ -654,7 +575,7 @@ void IrBuilder::handleFastcallFallback(IrOp fallbackOrUndef, const Instruction*
} }
else else
{ {
fastcallSkipTarget = i + skip + 2; cmdSkipTarget = i + skip + 2;
} }
} }
@ -725,6 +646,9 @@ void IrBuilder::clone(const IrBlock& source, bool removeCurrentTerminator)
redirect(clone.e); redirect(clone.e);
redirect(clone.f); redirect(clone.f);
if (FFlag::LuauCodegenInstG)
redirect(clone.g);
addUse(function, clone.a); addUse(function, clone.a);
addUse(function, clone.b); addUse(function, clone.b);
addUse(function, clone.c); addUse(function, clone.c);
@ -732,11 +656,17 @@ void IrBuilder::clone(const IrBlock& source, bool removeCurrentTerminator)
addUse(function, clone.e); addUse(function, clone.e);
addUse(function, clone.f); addUse(function, clone.f);
if (FFlag::LuauCodegenInstG)
addUse(function, clone.g);
// Instructions that referenced the original will have to be adjusted to use the clone // Instructions that referenced the original will have to be adjusted to use the clone
instRedir[index] = uint32_t(function.instructions.size()); instRedir[index] = uint32_t(function.instructions.size());
// Reconstruct the fresh clone // Reconstruct the fresh clone
inst(clone.cmd, clone.a, clone.b, clone.c, clone.d, clone.e, clone.f); if (FFlag::LuauCodegenInstG)
inst(clone.cmd, clone.a, clone.b, clone.c, clone.d, clone.e, clone.f, clone.g);
else
inst(clone.cmd, clone.a, clone.b, clone.c, clone.d, clone.e, clone.f);
} }
} }
@ -834,8 +764,33 @@ IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e)
IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f) IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f)
{ {
if (FFlag::LuauCodegenInstG)
{
return inst(cmd, a, b, c, d, e, f, {});
}
else
{
uint32_t index = uint32_t(function.instructions.size());
function.instructions.push_back({cmd, a, b, c, d, e, f});
CODEGEN_ASSERT(!inTerminatedBlock);
if (isBlockTerminator(cmd))
{
function.blocks[activeBlockIdx].finish = index;
inTerminatedBlock = true;
}
return {IrOpKind::Inst, index};
}
}
IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f, IrOp g)
{
CODEGEN_ASSERT(FFlag::LuauCodegenInstG);
uint32_t index = uint32_t(function.instructions.size()); uint32_t index = uint32_t(function.instructions.size());
function.instructions.push_back({cmd, a, b, c, d, e, f}); function.instructions.push_back({cmd, a, b, c, d, e, f, g});
CODEGEN_ASSERT(!inTerminatedBlock); CODEGEN_ASSERT(!inTerminatedBlock);

View file

@ -7,6 +7,9 @@
#include <stdarg.h> #include <stdarg.h>
LUAU_FASTFLAG(LuauLoadUserdataInfo)
LUAU_FASTFLAG(LuauCodegenInstG)
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
@ -151,6 +154,8 @@ const char* getCmdName(IrCmd cmd)
return "SQRT_NUM"; return "SQRT_NUM";
case IrCmd::ABS_NUM: case IrCmd::ABS_NUM:
return "ABS_NUM"; return "ABS_NUM";
case IrCmd::SIGN_NUM:
return "SIGN_NUM";
case IrCmd::ADD_VEC: case IrCmd::ADD_VEC:
return "ADD_VEC"; return "ADD_VEC";
case IrCmd::SUB_VEC: case IrCmd::SUB_VEC:
@ -197,6 +202,8 @@ const char* getCmdName(IrCmd cmd)
return "TRY_NUM_TO_INDEX"; return "TRY_NUM_TO_INDEX";
case IrCmd::TRY_CALL_FASTGETTM: case IrCmd::TRY_CALL_FASTGETTM:
return "TRY_CALL_FASTGETTM"; return "TRY_CALL_FASTGETTM";
case IrCmd::NEW_USERDATA:
return "NEW_USERDATA";
case IrCmd::INT_TO_NUM: case IrCmd::INT_TO_NUM:
return "INT_TO_NUM"; return "INT_TO_NUM";
case IrCmd::UINT_TO_NUM: case IrCmd::UINT_TO_NUM:
@ -255,6 +262,8 @@ const char* getCmdName(IrCmd cmd)
return "CHECK_NODE_VALUE"; return "CHECK_NODE_VALUE";
case IrCmd::CHECK_BUFFER_LEN: case IrCmd::CHECK_BUFFER_LEN:
return "CHECK_BUFFER_LEN"; return "CHECK_BUFFER_LEN";
case IrCmd::CHECK_USERDATA_TAG:
return "CHECK_USERDATA_TAG";
case IrCmd::INTERRUPT: case IrCmd::INTERRUPT:
return "INTERRUPT"; return "INTERRUPT";
case IrCmd::CHECK_GC: case IrCmd::CHECK_GC:
@ -411,6 +420,9 @@ void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index)
checkOp(inst.d, ", "); checkOp(inst.d, ", ");
checkOp(inst.e, ", "); checkOp(inst.e, ", ");
checkOp(inst.f, ", "); checkOp(inst.f, ", ");
if (FFlag::LuauCodegenInstG)
checkOp(inst.g, ", ");
} }
void toString(IrToStringContext& ctx, const IrBlock& block, uint32_t index) void toString(IrToStringContext& ctx, const IrBlock& block, uint32_t index)
@ -480,8 +492,10 @@ void toString(std::string& result, IrConst constant)
} }
} }
const char* getBytecodeTypeName(uint8_t type) const char* getBytecodeTypeName_DEPRECATED(uint8_t type)
{ {
CODEGEN_ASSERT(!FFlag::LuauLoadUserdataInfo);
switch (type & ~LBC_TYPE_OPTIONAL_BIT) switch (type & ~LBC_TYPE_OPTIONAL_BIT)
{ {
case LBC_TYPE_NIL: case LBC_TYPE_NIL:
@ -512,13 +526,78 @@ const char* getBytecodeTypeName(uint8_t type)
return nullptr; return nullptr;
} }
void toString(std::string& result, const BytecodeTypes& bcTypes) const char* getBytecodeTypeName(uint8_t type, const char* const* userdataTypes)
{ {
CODEGEN_ASSERT(FFlag::LuauLoadUserdataInfo);
// Optional bit should be handled externally
type = type & ~LBC_TYPE_OPTIONAL_BIT;
if (type >= LBC_TYPE_TAGGED_USERDATA_BASE && type < LBC_TYPE_TAGGED_USERDATA_END)
{
if (userdataTypes)
return userdataTypes[type - LBC_TYPE_TAGGED_USERDATA_BASE];
return "userdata";
}
switch (type)
{
case LBC_TYPE_NIL:
return "nil";
case LBC_TYPE_BOOLEAN:
return "boolean";
case LBC_TYPE_NUMBER:
return "number";
case LBC_TYPE_STRING:
return "string";
case LBC_TYPE_TABLE:
return "table";
case LBC_TYPE_FUNCTION:
return "function";
case LBC_TYPE_THREAD:
return "thread";
case LBC_TYPE_USERDATA:
return "userdata";
case LBC_TYPE_VECTOR:
return "vector";
case LBC_TYPE_BUFFER:
return "buffer";
case LBC_TYPE_ANY:
return "any";
}
CODEGEN_ASSERT(!"Unhandled type in getBytecodeTypeName");
return nullptr;
}
void toString_DEPRECATED(std::string& result, const BytecodeTypes& bcTypes)
{
CODEGEN_ASSERT(!FFlag::LuauLoadUserdataInfo);
if (bcTypes.c != LBC_TYPE_ANY) if (bcTypes.c != LBC_TYPE_ANY)
append(result, "%s <- %s, %s, %s", getBytecodeTypeName(bcTypes.result), getBytecodeTypeName(bcTypes.a), getBytecodeTypeName(bcTypes.b), append(result, "%s <- %s, %s, %s", getBytecodeTypeName_DEPRECATED(bcTypes.result), getBytecodeTypeName_DEPRECATED(bcTypes.a),
getBytecodeTypeName(bcTypes.c)); getBytecodeTypeName_DEPRECATED(bcTypes.b), getBytecodeTypeName_DEPRECATED(bcTypes.c));
else else
append(result, "%s <- %s, %s", getBytecodeTypeName(bcTypes.result), getBytecodeTypeName(bcTypes.a), getBytecodeTypeName(bcTypes.b)); append(result, "%s <- %s, %s", getBytecodeTypeName_DEPRECATED(bcTypes.result), getBytecodeTypeName_DEPRECATED(bcTypes.a),
getBytecodeTypeName_DEPRECATED(bcTypes.b));
}
void toString(std::string& result, const BytecodeTypes& bcTypes, const char* const* userdataTypes)
{
CODEGEN_ASSERT(FFlag::LuauLoadUserdataInfo);
append(result, "%s%s", getBytecodeTypeName(bcTypes.result, userdataTypes), (bcTypes.result & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : "");
append(result, " <- ");
append(result, "%s%s", getBytecodeTypeName(bcTypes.a, userdataTypes), (bcTypes.a & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : "");
append(result, ", ");
append(result, "%s%s", getBytecodeTypeName(bcTypes.b, userdataTypes), (bcTypes.b & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : "");
if (bcTypes.c != LBC_TYPE_ANY)
{
append(result, ", ");
append(result, "%s%s", getBytecodeTypeName(bcTypes.c, userdataTypes), (bcTypes.c & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : "");
}
} }
static void appendBlockSet(IrToStringContext& ctx, BlockIteratorWrapper blocks) static void appendBlockSet(IrToStringContext& ctx, BlockIteratorWrapper blocks)
@ -583,6 +662,8 @@ static RegisterSet getJumpTargetExtraLiveIn(IrToStringContext& ctx, const IrBloc
op = inst.e; op = inst.e;
else if (inst.f.kind == IrOpKind::Block) else if (inst.f.kind == IrOpKind::Block)
op = inst.f; op = inst.f;
else if (FFlag::LuauCodegenInstG && inst.g.kind == IrOpKind::Block)
op = inst.g;
if (op.kind == IrOpKind::Block && op.index < ctx.cfg.in.size()) if (op.kind == IrOpKind::Block && op.index < ctx.cfg.in.size())
{ {
@ -867,6 +948,9 @@ std::string toDot(const IrFunction& function, bool includeInst)
checkOp(inst.d); checkOp(inst.d);
checkOp(inst.e); checkOp(inst.e);
checkOp(inst.f); checkOp(inst.f);
if (FFlag::LuauCodegenInstG)
checkOp(inst.g);
} }
} }

View file

@ -11,7 +11,11 @@
#include "lstate.h" #include "lstate.h"
#include "lgc.h" #include "lgc.h"
LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(LuauCodegenUserdataOps)
LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataAlloc, false)
LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOpsFixA64, false)
LUAU_FASTFLAG(LuauCodegenFastcall3)
LUAU_FASTFLAG(LuauCodegenMathSign)
namespace Luau namespace Luau
{ {
@ -193,78 +197,51 @@ static void emitInvokeLibm1P(AssemblyBuilderA64& build, size_t func, int arg)
build.blr(x1); build.blr(x1);
} }
static bool emitBuiltin( static bool emitBuiltin(AssemblyBuilderA64& build, IrFunction& function, IrRegAllocA64& regs, int bfid, int res, int arg, int nresults)
AssemblyBuilderA64& build, IrFunction& function, IrRegAllocA64& regs, int bfid, int res, int arg, IrOp args, int nparams, int nresults)
{ {
switch (bfid) switch (bfid)
{ {
case LBF_MATH_FREXP: case LBF_MATH_FREXP:
{ {
if (FFlag::LuauCodegenRemoveDeadStores5) CODEGEN_ASSERT(nresults == 1 || nresults == 2);
{ emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg);
CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n)));
emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg);
build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n)));
RegisterA64 temp = regs.allocTemp(KindA64::w); RegisterA64 temp = regs.allocTemp(KindA64::w);
build.mov(temp, LUA_TNUMBER); build.mov(temp, LUA_TNUMBER);
build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt))); build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt)));
if (nresults == 2) if (nresults == 2)
{
build.ldr(w0, sTemporary);
build.scvtf(d1, w0);
build.str(d1, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n)));
build.str(temp, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, tt)));
}
}
else
{ {
CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); build.ldr(w0, sTemporary);
emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg); build.scvtf(d1, w0);
build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); build.str(d1, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n)));
if (nresults == 2) build.str(temp, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, tt)));
{
build.ldr(w0, sTemporary);
build.scvtf(d1, w0);
build.str(d1, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n)));
}
} }
return true; return true;
} }
case LBF_MATH_MODF: case LBF_MATH_MODF:
{ {
if (FFlag::LuauCodegenRemoveDeadStores5) CODEGEN_ASSERT(nresults == 1 || nresults == 2);
{ emitInvokeLibm1P(build, offsetof(NativeContext, libm_modf), arg);
CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); build.ldr(d1, sTemporary);
emitInvokeLibm1P(build, offsetof(NativeContext, libm_modf), arg); build.str(d1, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n)));
build.ldr(d1, sTemporary);
build.str(d1, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n)));
RegisterA64 temp = regs.allocTemp(KindA64::w); RegisterA64 temp = regs.allocTemp(KindA64::w);
build.mov(temp, LUA_TNUMBER); build.mov(temp, LUA_TNUMBER);
build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt))); build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt)));
if (nresults == 2) if (nresults == 2)
{
build.str(d0, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n)));
build.str(temp, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, tt)));
}
}
else
{ {
CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); build.str(d0, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n)));
emitInvokeLibm1P(build, offsetof(NativeContext, libm_modf), arg); build.str(temp, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, tt)));
build.ldr(d1, sTemporary);
build.str(d1, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n)));
if (nresults == 2)
build.str(d0, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n)));
} }
return true; return true;
} }
case LBF_MATH_SIGN: case LBF_MATH_SIGN:
{ {
CODEGEN_ASSERT(nparams == 1 && nresults == 1); CODEGEN_ASSERT(!FFlag::LuauCodegenMathSign);
CODEGEN_ASSERT(nresults == 1);
build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n)));
build.fcmpz(d0); build.fcmpz(d0);
build.fmov(d0, 0.0); build.fmov(d0, 0.0);
@ -274,12 +251,10 @@ static bool emitBuiltin(
build.fcsel(d0, d1, d0, getConditionFP(IrCondition::Less)); build.fcsel(d0, d1, d0, getConditionFP(IrCondition::Less));
build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n)));
if (FFlag::LuauCodegenRemoveDeadStores5) RegisterA64 temp = regs.allocTemp(KindA64::w);
{ build.mov(temp, LUA_TNUMBER);
RegisterA64 temp = regs.allocTemp(KindA64::w); build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt)));
build.mov(temp, LUA_TNUMBER);
build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt)));
}
return true; return true;
} }
@ -723,6 +698,24 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.fabs(inst.regA64, temp); build.fabs(inst.regA64, temp);
break; break;
} }
case IrCmd::SIGN_NUM:
{
CODEGEN_ASSERT(FFlag::LuauCodegenMathSign);
inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a});
RegisterA64 temp = tempDouble(inst.a);
RegisterA64 temp0 = regs.allocTemp(KindA64::d);
RegisterA64 temp1 = regs.allocTemp(KindA64::d);
build.fcmpz(temp);
build.fmov(temp0, 0.0);
build.fmov(temp1, 1.0);
build.fcsel(inst.regA64, temp1, temp0, getConditionFP(IrCondition::Greater));
build.fmov(temp1, -1.0);
build.fcsel(inst.regA64, temp1, inst.regA64, getConditionFP(IrCondition::Less));
break;
}
case IrCmd::ADD_VEC: case IrCmd::ADD_VEC:
{ {
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b}); inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b});
@ -1082,6 +1075,19 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
inst.regA64 = regs.takeReg(x0, index); inst.regA64 = regs.takeReg(x0, index);
break; break;
} }
case IrCmd::NEW_USERDATA:
{
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc);
regs.spill(build, index);
build.mov(x0, rState);
build.mov(x1, intOp(inst.a));
build.mov(x2, intOp(inst.b));
build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, newUserdata)));
build.blr(x3);
inst.regA64 = regs.takeReg(x0, index);
break;
}
case IrCmd::INT_TO_NUM: case IrCmd::INT_TO_NUM:
{ {
inst.regA64 = regs.allocReg(KindA64::d, index); inst.regA64 = regs.allocReg(KindA64::d, index);
@ -1188,34 +1194,88 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
} }
case IrCmd::FASTCALL: case IrCmd::FASTCALL:
regs.spill(build, index); regs.spill(build, index);
error |= !emitBuiltin(build, function, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), inst.d, intOp(inst.e), intOp(inst.f));
if (FFlag::LuauCodegenFastcall3)
error |= !emitBuiltin(build, function, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d));
else
error |= !emitBuiltin(build, function, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.f));
break; break;
case IrCmd::INVOKE_FASTCALL: case IrCmd::INVOKE_FASTCALL:
{ {
regs.spill(build, index); if (FFlag::LuauCodegenFastcall3)
build.mov(x0, rState);
build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue)));
build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue)));
build.mov(w3, intOp(inst.f)); // nresults
if (inst.d.kind == IrOpKind::VmReg)
build.add(x4, rBase, uint16_t(vmRegOp(inst.d) * sizeof(TValue)));
else if (inst.d.kind == IrOpKind::VmConst)
emitAddOffset(build, x4, rConstants, vmConstOp(inst.d) * sizeof(TValue));
else
CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef);
// nparams
if (intOp(inst.e) == LUA_MULTRET)
{ {
// L->top - (ra + 1) // We might need a temporary and we have to preserve it over the spill
build.ldr(x5, mem(rState, offsetof(lua_State, top))); RegisterA64 temp = regs.allocTemp(KindA64::q);
build.sub(x5, x5, rBase); regs.spill(build, index, {temp});
build.sub(x5, x5, uint16_t((vmRegOp(inst.b) + 1) * sizeof(TValue)));
build.lsr(x5, x5, kTValueSizeLog2); build.mov(x0, rState);
build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue)));
build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue)));
build.mov(w3, intOp(inst.g)); // nresults
// 'E' argument can only be produced by LOP_FASTCALL3 lowering
if (inst.e.kind != IrOpKind::Undef)
{
CODEGEN_ASSERT(intOp(inst.f) == 3);
build.ldr(x4, mem(rState, offsetof(lua_State, top)));
build.ldr(temp, mem(rBase, vmRegOp(inst.d) * sizeof(TValue)));
build.str(temp, mem(x4, 0));
build.ldr(temp, mem(rBase, vmRegOp(inst.e) * sizeof(TValue)));
build.str(temp, mem(x4, sizeof(TValue)));
}
else
{
if (inst.d.kind == IrOpKind::VmReg)
build.add(x4, rBase, uint16_t(vmRegOp(inst.d) * sizeof(TValue)));
else if (inst.d.kind == IrOpKind::VmConst)
emitAddOffset(build, x4, rConstants, vmConstOp(inst.d) * sizeof(TValue));
else
CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef);
}
// nparams
if (intOp(inst.f) == LUA_MULTRET)
{
// L->top - (ra + 1)
build.ldr(x5, mem(rState, offsetof(lua_State, top)));
build.sub(x5, x5, rBase);
build.sub(x5, x5, uint16_t((vmRegOp(inst.b) + 1) * sizeof(TValue)));
build.lsr(x5, x5, kTValueSizeLog2);
}
else
build.mov(w5, intOp(inst.f));
} }
else else
build.mov(w5, intOp(inst.e)); {
regs.spill(build, index);
build.mov(x0, rState);
build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue)));
build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue)));
build.mov(w3, intOp(inst.f)); // nresults
if (inst.d.kind == IrOpKind::VmReg)
build.add(x4, rBase, uint16_t(vmRegOp(inst.d) * sizeof(TValue)));
else if (inst.d.kind == IrOpKind::VmConst)
emitAddOffset(build, x4, rConstants, vmConstOp(inst.d) * sizeof(TValue));
else
CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef);
// nparams
if (intOp(inst.e) == LUA_MULTRET)
{
// L->top - (ra + 1)
build.ldr(x5, mem(rState, offsetof(lua_State, top)));
build.sub(x5, x5, rBase);
build.sub(x5, x5, uint16_t((vmRegOp(inst.b) + 1) * sizeof(TValue)));
build.lsr(x5, x5, kTValueSizeLog2);
}
else
build.mov(w5, intOp(inst.e));
}
build.ldr(x6, mem(rNativeContext, offsetof(NativeContext, luauF_table) + uintOp(inst.a) * sizeof(luau_FastFunction))); build.ldr(x6, mem(rNativeContext, offsetof(NativeContext, luauF_table) + uintOp(inst.a) * sizeof(luau_FastFunction)));
build.blr(x6); build.blr(x6);
@ -1242,9 +1302,38 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
else else
build.add(x3, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); build.add(x3, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue)));
build.mov(w4, TMS(intOp(inst.d))); switch (TMS(intOp(inst.d)))
build.ldr(x5, mem(rNativeContext, offsetof(NativeContext, luaV_doarith))); {
build.blr(x5); case TM_ADD:
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithadd)));
break;
case TM_SUB:
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithsub)));
break;
case TM_MUL:
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithmul)));
break;
case TM_DIV:
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithdiv)));
break;
case TM_IDIV:
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithidiv)));
break;
case TM_MOD:
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithmod)));
break;
case TM_POW:
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithpow)));
break;
case TM_UNM:
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithunm)));
break;
default:
CODEGEN_ASSERT(!"Invalid doarith helper operation tag");
break;
}
build.blr(x4);
emitUpdateBase(build); emitUpdateBase(build);
break; break;
@ -1388,35 +1477,14 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
Label fresh; // used when guard aborts execution or jumps to a VM exit Label fresh; // used when guard aborts execution or jumps to a VM exit
Label& fail = getTargetLabel(inst.c, fresh); Label& fail = getTargetLabel(inst.c, fresh);
if (FFlag::LuauCodegenRemoveDeadStores5) if (tagOp(inst.b) == 0)
{ {
if (tagOp(inst.b) == 0) build.cbnz(regOp(inst.a), fail);
{
build.cbnz(regOp(inst.a), fail);
}
else
{
build.cmp(regOp(inst.a), tagOp(inst.b));
build.b(ConditionA64::NotEqual, fail);
}
} }
else else
{ {
// To support DebugLuauAbortingChecks, CHECK_TAG with VmReg has to be handled build.cmp(regOp(inst.a), tagOp(inst.b));
RegisterA64 tag = inst.a.kind == IrOpKind::VmReg ? regs.allocTemp(KindA64::w) : regOp(inst.a); build.b(ConditionA64::NotEqual, fail);
if (inst.a.kind == IrOpKind::VmReg)
build.ldr(tag, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, tt)));
if (tagOp(inst.b) == 0)
{
build.cbnz(tag, fail);
}
else
{
build.cmp(tag, tagOp(inst.b));
build.b(ConditionA64::NotEqual, fail);
}
} }
finalizeTargetLabel(inst.c, fresh); finalizeTargetLabel(inst.c, fresh);
@ -1638,6 +1706,24 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
finalizeTargetLabel(inst.d, fresh); finalizeTargetLabel(inst.d, fresh);
break; break;
} }
case IrCmd::CHECK_USERDATA_TAG:
{
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps);
Label fresh; // used when guard aborts execution or jumps to a VM exit
Label& fail = getTargetLabel(inst.c, fresh);
RegisterA64 temp = regs.allocTemp(KindA64::w);
build.ldrb(temp, mem(regOp(inst.a), offsetof(Udata, tag)));
if (FFlag::LuauCodegenUserdataOpsFixA64)
build.cmp(temp, intOp(inst.b));
else
build.cmp(temp, tagOp(inst.b));
build.b(ConditionA64::NotEqual, fail);
finalizeTargetLabel(inst.c, fresh);
break;
}
case IrCmd::INTERRUPT: case IrCmd::INTERRUPT:
{ {
regs.spill(build, index); regs.spill(build, index);
@ -2269,7 +2355,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READI8: case IrCmd::BUFFER_READI8:
{ {
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldrsb(inst.regA64, addr); build.ldrsb(inst.regA64, addr);
break; break;
@ -2278,7 +2364,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READU8: case IrCmd::BUFFER_READU8:
{ {
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldrb(inst.regA64, addr); build.ldrb(inst.regA64, addr);
break; break;
@ -2287,7 +2373,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_WRITEI8: case IrCmd::BUFFER_WRITEI8:
{ {
RegisterA64 temp = tempInt(inst.c); RegisterA64 temp = tempInt(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.strb(temp, addr); build.strb(temp, addr);
break; break;
@ -2296,7 +2382,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READI16: case IrCmd::BUFFER_READI16:
{ {
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldrsh(inst.regA64, addr); build.ldrsh(inst.regA64, addr);
break; break;
@ -2305,7 +2391,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READU16: case IrCmd::BUFFER_READU16:
{ {
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldrh(inst.regA64, addr); build.ldrh(inst.regA64, addr);
break; break;
@ -2314,7 +2400,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_WRITEI16: case IrCmd::BUFFER_WRITEI16:
{ {
RegisterA64 temp = tempInt(inst.c); RegisterA64 temp = tempInt(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.strh(temp, addr); build.strh(temp, addr);
break; break;
@ -2323,7 +2409,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READI32: case IrCmd::BUFFER_READI32:
{ {
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b});
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldr(inst.regA64, addr); build.ldr(inst.regA64, addr);
break; break;
@ -2332,7 +2418,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_WRITEI32: case IrCmd::BUFFER_WRITEI32:
{ {
RegisterA64 temp = tempInt(inst.c); RegisterA64 temp = tempInt(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.str(temp, addr); build.str(temp, addr);
break; break;
@ -2342,7 +2428,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{ {
inst.regA64 = regs.allocReg(KindA64::d, index); inst.regA64 = regs.allocReg(KindA64::d, index);
RegisterA64 temp = castReg(KindA64::s, inst.regA64); // safe to alias a fresh register RegisterA64 temp = castReg(KindA64::s, inst.regA64); // safe to alias a fresh register
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldr(temp, addr); build.ldr(temp, addr);
build.fcvt(inst.regA64, temp); build.fcvt(inst.regA64, temp);
@ -2353,7 +2439,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{ {
RegisterA64 temp1 = tempDouble(inst.c); RegisterA64 temp1 = tempDouble(inst.c);
RegisterA64 temp2 = regs.allocTemp(KindA64::s); RegisterA64 temp2 = regs.allocTemp(KindA64::s);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.fcvt(temp2, temp1); build.fcvt(temp2, temp1);
build.str(temp2, addr); build.str(temp2, addr);
@ -2363,7 +2449,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READF64: case IrCmd::BUFFER_READF64:
{ {
inst.regA64 = regs.allocReg(KindA64::d, index); inst.regA64 = regs.allocReg(KindA64::d, index);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c));
build.ldr(inst.regA64, addr); build.ldr(inst.regA64, addr);
break; break;
@ -2372,7 +2458,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_WRITEF64: case IrCmd::BUFFER_WRITEF64:
{ {
RegisterA64 temp = tempDouble(inst.c); RegisterA64 temp = tempDouble(inst.c);
AddressA64 addr = tempAddrBuffer(inst.a, inst.b); AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d));
build.str(temp, addr); build.str(temp, addr);
break; break;
@ -2600,32 +2686,68 @@ AddressA64 IrLoweringA64::tempAddr(IrOp op, int offset)
} }
} }
AddressA64 IrLoweringA64::tempAddrBuffer(IrOp bufferOp, IrOp indexOp) AddressA64 IrLoweringA64::tempAddrBuffer(IrOp bufferOp, IrOp indexOp, uint8_t tag)
{ {
if (indexOp.kind == IrOpKind::Inst) if (FFlag::LuauCodegenUserdataOps)
{ {
RegisterA64 temp = regs.allocTemp(KindA64::x); CODEGEN_ASSERT(tag == LUA_TUSERDATA || tag == LUA_TBUFFER);
build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw int dataOffset = tag == LUA_TBUFFER ? offsetof(Buffer, data) : offsetof(Udata, data);
return mem(temp, offsetof(Buffer, data));
}
else if (indexOp.kind == IrOpKind::Constant)
{
// Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled encoding
if (unsigned(intOp(indexOp)) + offsetof(Buffer, data) <= 255)
return mem(regOp(bufferOp), int(intOp(indexOp) + offsetof(Buffer, data)));
// indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset if (indexOp.kind == IrOpKind::Inst)
if (intOp(indexOp) < 0) {
return mem(regOp(bufferOp), offsetof(Buffer, data)); RegisterA64 temp = regs.allocTemp(KindA64::x);
build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw
return mem(temp, dataOffset);
}
else if (indexOp.kind == IrOpKind::Constant)
{
// Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled
// encoding
if (unsigned(intOp(indexOp)) + dataOffset <= 255)
return mem(regOp(bufferOp), int(intOp(indexOp) + dataOffset));
RegisterA64 temp = regs.allocTemp(KindA64::x); // indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset
emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp))); if (intOp(indexOp) < 0)
return mem(temp, offsetof(Buffer, data)); return mem(regOp(bufferOp), dataOffset);
RegisterA64 temp = regs.allocTemp(KindA64::x);
emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp)));
return mem(temp, dataOffset);
}
else
{
CODEGEN_ASSERT(!"Unsupported instruction form");
return noreg;
}
} }
else else
{ {
CODEGEN_ASSERT(!"Unsupported instruction form"); if (indexOp.kind == IrOpKind::Inst)
return noreg; {
RegisterA64 temp = regs.allocTemp(KindA64::x);
build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw
return mem(temp, offsetof(Buffer, data));
}
else if (indexOp.kind == IrOpKind::Constant)
{
// Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled
// encoding
if (unsigned(intOp(indexOp)) + offsetof(Buffer, data) <= 255)
return mem(regOp(bufferOp), int(intOp(indexOp) + offsetof(Buffer, data)));
// indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset
if (intOp(indexOp) < 0)
return mem(regOp(bufferOp), offsetof(Buffer, data));
RegisterA64 temp = regs.allocTemp(KindA64::x);
emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp)));
return mem(temp, offsetof(Buffer, data));
}
else
{
CODEGEN_ASSERT(!"Unsupported instruction form");
return noreg;
}
} }
} }

View file

@ -44,7 +44,7 @@ struct IrLoweringA64
RegisterA64 tempInt(IrOp op); RegisterA64 tempInt(IrOp op);
RegisterA64 tempUint(IrOp op); RegisterA64 tempUint(IrOp op);
AddressA64 tempAddr(IrOp op, int offset); AddressA64 tempAddr(IrOp op, int offset);
AddressA64 tempAddrBuffer(IrOp bufferOp, IrOp indexOp); AddressA64 tempAddrBuffer(IrOp bufferOp, IrOp indexOp, uint8_t tag);
// May emit restore instructions // May emit restore instructions
RegisterA64 regOp(IrOp op); RegisterA64 regOp(IrOp op);

View file

@ -15,6 +15,11 @@
#include "lstate.h" #include "lstate.h"
#include "lgc.h" #include "lgc.h"
LUAU_FASTFLAG(LuauCodegenUserdataOps)
LUAU_FASTFLAG(LuauCodegenUserdataAlloc)
LUAU_FASTFLAG(LuauCodegenFastcall3)
LUAU_FASTFLAG(LuauCodegenMathSign)
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
@ -586,6 +591,33 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.vandpd(inst.regX64, inst.regX64, build.i64(~(1LL << 63))); build.vandpd(inst.regX64, inst.regX64, build.i64(~(1LL << 63)));
break; break;
case IrCmd::SIGN_NUM:
{
CODEGEN_ASSERT(FFlag::LuauCodegenMathSign);
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a});
ScopedRegX64 tmp0{regs, SizeX64::xmmword};
ScopedRegX64 tmp1{regs, SizeX64::xmmword};
ScopedRegX64 tmp2{regs, SizeX64::xmmword};
build.vxorpd(tmp0.reg, tmp0.reg, tmp0.reg);
// Set tmp1 to -1 if arg < 0, else 0
build.vcmpltsd(tmp1.reg, regOp(inst.a), tmp0.reg);
build.vmovsd(tmp2.reg, build.f64(-1));
build.vandpd(tmp1.reg, tmp1.reg, tmp2.reg);
// Set mask bit to 1 if 0 < arg, else 0
build.vcmpltsd(inst.regX64, tmp0.reg, regOp(inst.a));
// Result = (mask-bit == 1) ? 1.0 : tmp1
// If arg < 0 then tmp1 is -1 and mask-bit is 0, result is -1
// If arg == 0 then tmp1 is 0 and mask-bit is 0, result is 0
// If arg > 0 then tmp1 is 0 and mask-bit is 1, result is 1
build.vblendvpd(inst.regX64, tmp1.reg, build.f64x2(1, 1), inst.regX64);
break;
}
case IrCmd::ADD_VEC: case IrCmd::ADD_VEC:
{ {
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b});
@ -905,6 +937,18 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
inst.regX64 = regs.takeReg(rax, index); inst.regX64 = regs.takeReg(rax, index);
break; break;
} }
case IrCmd::NEW_USERDATA:
{
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc);
IrCallWrapperX64 callWrap(regs, build, index);
callWrap.addArgument(SizeX64::qword, rState);
callWrap.addArgument(SizeX64::qword, intOp(inst.a));
callWrap.addArgument(SizeX64::dword, intOp(inst.b));
callWrap.call(qword[rNativeContext + offsetof(NativeContext, newUserdata)]);
inst.regX64 = regs.takeReg(rax, index);
break;
}
case IrCmd::INT_TO_NUM: case IrCmd::INT_TO_NUM:
inst.regX64 = regs.allocReg(SizeX64::xmmword, index); inst.regX64 = regs.allocReg(SizeX64::xmmword, index);
@ -993,9 +1037,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::FASTCALL: case IrCmd::FASTCALL:
{ {
OperandX64 arg2 = inst.d.kind != IrOpKind::Undef ? memRegDoubleOp(inst.d) : OperandX64{0}; if (FFlag::LuauCodegenFastcall3)
emitBuiltin(regs, build, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d));
emitBuiltin(regs, build, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), arg2, intOp(inst.e), intOp(inst.f)); else
emitBuiltin(regs, build, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.f));
break; break;
} }
case IrCmd::INVOKE_FASTCALL: case IrCmd::INVOKE_FASTCALL:
@ -1003,25 +1048,49 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
unsigned bfid = uintOp(inst.a); unsigned bfid = uintOp(inst.a);
OperandX64 args = 0; OperandX64 args = 0;
ScopedRegX64 argsAlt{regs};
if (inst.d.kind == IrOpKind::VmReg) // 'E' argument can only be produced by LOP_FASTCALL3
args = luauRegAddress(vmRegOp(inst.d)); if (FFlag::LuauCodegenFastcall3 && inst.e.kind != IrOpKind::Undef)
else if (inst.d.kind == IrOpKind::VmConst) {
args = luauConstantAddress(vmConstOp(inst.d)); CODEGEN_ASSERT(intOp(inst.f) == 3);
ScopedRegX64 tmp{regs, SizeX64::xmmword};
argsAlt.alloc(SizeX64::qword);
build.mov(argsAlt.reg, qword[rState + offsetof(lua_State, top)]);
build.vmovups(tmp.reg, luauReg(vmRegOp(inst.d)));
build.vmovups(xmmword[argsAlt.reg], tmp.reg);
build.vmovups(tmp.reg, luauReg(vmRegOp(inst.e)));
build.vmovups(xmmword[argsAlt.reg + sizeof(TValue)], tmp.reg);
}
else else
CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef); {
if (inst.d.kind == IrOpKind::VmReg)
args = luauRegAddress(vmRegOp(inst.d));
else if (inst.d.kind == IrOpKind::VmConst)
args = luauConstantAddress(vmConstOp(inst.d));
else
CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef);
}
int ra = vmRegOp(inst.b); int ra = vmRegOp(inst.b);
int arg = vmRegOp(inst.c); int arg = vmRegOp(inst.c);
int nparams = intOp(inst.e); int nparams = intOp(FFlag::LuauCodegenFastcall3 ? inst.f : inst.e);
int nresults = intOp(inst.f); int nresults = intOp(FFlag::LuauCodegenFastcall3 ? inst.g : inst.f);
IrCallWrapperX64 callWrap(regs, build, index); IrCallWrapperX64 callWrap(regs, build, index);
callWrap.addArgument(SizeX64::qword, rState); callWrap.addArgument(SizeX64::qword, rState);
callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); callWrap.addArgument(SizeX64::qword, luauRegAddress(ra));
callWrap.addArgument(SizeX64::qword, luauRegAddress(arg)); callWrap.addArgument(SizeX64::qword, luauRegAddress(arg));
callWrap.addArgument(SizeX64::dword, nresults); callWrap.addArgument(SizeX64::dword, nresults);
callWrap.addArgument(SizeX64::qword, args);
if (FFlag::LuauCodegenFastcall3 && inst.e.kind != IrOpKind::Undef)
callWrap.addArgument(SizeX64::qword, argsAlt);
else
callWrap.addArgument(SizeX64::qword, args);
if (nparams == LUA_MULTRET) if (nparams == LUA_MULTRET)
{ {
@ -1350,6 +1419,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
} }
break; break;
} }
case IrCmd::CHECK_USERDATA_TAG:
{
CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps);
build.cmp(byte[regOp(inst.a) + offsetof(Udata, tag)], intOp(inst.b));
jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.c, next);
break;
}
case IrCmd::INTERRUPT: case IrCmd::INTERRUPT:
{ {
unsigned pcpos = uintOp(inst.a); unsigned pcpos = uintOp(inst.a);
@ -1895,71 +1972,71 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
case IrCmd::BUFFER_READI8: case IrCmd::BUFFER_READI8:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movsx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b)]); build.movsx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break; break;
case IrCmd::BUFFER_READU8: case IrCmd::BUFFER_READU8:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movzx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b)]); build.movzx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break; break;
case IrCmd::BUFFER_WRITEI8: case IrCmd::BUFFER_WRITEI8:
{ {
OperandX64 value = inst.c.kind == IrOpKind::Inst ? byteReg(regOp(inst.c)) : OperandX64(int8_t(intOp(inst.c))); OperandX64 value = inst.c.kind == IrOpKind::Inst ? byteReg(regOp(inst.c)) : OperandX64(int8_t(intOp(inst.c)));
build.mov(byte[bufferAddrOp(inst.a, inst.b)], value); build.mov(byte[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value);
break; break;
} }
case IrCmd::BUFFER_READI16: case IrCmd::BUFFER_READI16:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movsx(inst.regX64, word[bufferAddrOp(inst.a, inst.b)]); build.movsx(inst.regX64, word[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break; break;
case IrCmd::BUFFER_READU16: case IrCmd::BUFFER_READU16:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.movzx(inst.regX64, word[bufferAddrOp(inst.a, inst.b)]); build.movzx(inst.regX64, word[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break; break;
case IrCmd::BUFFER_WRITEI16: case IrCmd::BUFFER_WRITEI16:
{ {
OperandX64 value = inst.c.kind == IrOpKind::Inst ? wordReg(regOp(inst.c)) : OperandX64(int16_t(intOp(inst.c))); OperandX64 value = inst.c.kind == IrOpKind::Inst ? wordReg(regOp(inst.c)) : OperandX64(int16_t(intOp(inst.c)));
build.mov(word[bufferAddrOp(inst.a, inst.b)], value); build.mov(word[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value);
break; break;
} }
case IrCmd::BUFFER_READI32: case IrCmd::BUFFER_READI32:
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
build.mov(inst.regX64, dword[bufferAddrOp(inst.a, inst.b)]); build.mov(inst.regX64, dword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break; break;
case IrCmd::BUFFER_WRITEI32: case IrCmd::BUFFER_WRITEI32:
{ {
OperandX64 value = inst.c.kind == IrOpKind::Inst ? regOp(inst.c) : OperandX64(intOp(inst.c)); OperandX64 value = inst.c.kind == IrOpKind::Inst ? regOp(inst.c) : OperandX64(intOp(inst.c));
build.mov(dword[bufferAddrOp(inst.a, inst.b)], value); build.mov(dword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value);
break; break;
} }
case IrCmd::BUFFER_READF32: case IrCmd::BUFFER_READF32:
inst.regX64 = regs.allocReg(SizeX64::xmmword, index); inst.regX64 = regs.allocReg(SizeX64::xmmword, index);
build.vcvtss2sd(inst.regX64, inst.regX64, dword[bufferAddrOp(inst.a, inst.b)]); build.vcvtss2sd(inst.regX64, inst.regX64, dword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break; break;
case IrCmd::BUFFER_WRITEF32: case IrCmd::BUFFER_WRITEF32:
storeDoubleAsFloat(dword[bufferAddrOp(inst.a, inst.b)], inst.c); storeDoubleAsFloat(dword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], inst.c);
break; break;
case IrCmd::BUFFER_READF64: case IrCmd::BUFFER_READF64:
inst.regX64 = regs.allocReg(SizeX64::xmmword, index); inst.regX64 = regs.allocReg(SizeX64::xmmword, index);
build.vmovsd(inst.regX64, qword[bufferAddrOp(inst.a, inst.b)]); build.vmovsd(inst.regX64, qword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]);
break; break;
case IrCmd::BUFFER_WRITEF64: case IrCmd::BUFFER_WRITEF64:
@ -1967,11 +2044,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{ {
ScopedRegX64 tmp{regs, SizeX64::xmmword}; ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, build.f64(doubleOp(inst.c))); build.vmovsd(tmp.reg, build.f64(doubleOp(inst.c)));
build.vmovsd(qword[bufferAddrOp(inst.a, inst.b)], tmp.reg); build.vmovsd(qword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], tmp.reg);
} }
else if (inst.c.kind == IrOpKind::Inst) else if (inst.c.kind == IrOpKind::Inst)
{ {
build.vmovsd(qword[bufferAddrOp(inst.a, inst.b)], regOp(inst.c)); build.vmovsd(qword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], regOp(inst.c));
} }
else else
{ {
@ -2190,12 +2267,25 @@ RegisterX64 IrLoweringX64::regOp(IrOp op)
return inst.regX64; return inst.regX64;
} }
OperandX64 IrLoweringX64::bufferAddrOp(IrOp bufferOp, IrOp indexOp) OperandX64 IrLoweringX64::bufferAddrOp(IrOp bufferOp, IrOp indexOp, uint8_t tag)
{ {
if (indexOp.kind == IrOpKind::Inst) if (FFlag::LuauCodegenUserdataOps)
return regOp(bufferOp) + qwordReg(regOp(indexOp)) + offsetof(Buffer, data); {
else if (indexOp.kind == IrOpKind::Constant) CODEGEN_ASSERT(tag == LUA_TUSERDATA || tag == LUA_TBUFFER);
return regOp(bufferOp) + intOp(indexOp) + offsetof(Buffer, data); int dataOffset = tag == LUA_TBUFFER ? offsetof(Buffer, data) : offsetof(Udata, data);
if (indexOp.kind == IrOpKind::Inst)
return regOp(bufferOp) + qwordReg(regOp(indexOp)) + dataOffset;
else if (indexOp.kind == IrOpKind::Constant)
return regOp(bufferOp) + intOp(indexOp) + dataOffset;
}
else
{
if (indexOp.kind == IrOpKind::Inst)
return regOp(bufferOp) + qwordReg(regOp(indexOp)) + offsetof(Buffer, data);
else if (indexOp.kind == IrOpKind::Constant)
return regOp(bufferOp) + intOp(indexOp) + offsetof(Buffer, data);
}
CODEGEN_ASSERT(!"Unsupported instruction form"); CODEGEN_ASSERT(!"Unsupported instruction form");
return noreg; return noreg;

View file

@ -50,7 +50,7 @@ struct IrLoweringX64
OperandX64 memRegUintOp(IrOp op); OperandX64 memRegUintOp(IrOp op);
OperandX64 memRegTagOp(IrOp op); OperandX64 memRegTagOp(IrOp op);
RegisterX64 regOp(IrOp op); RegisterX64 regOp(IrOp op);
OperandX64 bufferAddrOp(IrOp bufferOp, IrOp indexOp); OperandX64 bufferAddrOp(IrOp bufferOp, IrOp indexOp, uint8_t tag);
RegisterX64 vecOp(IrOp op, ScopedRegX64& tmp); RegisterX64 vecOp(IrOp op, ScopedRegX64& tmp);
IrConst constOp(IrOp op) const; IrConst constOp(IrOp op) const;

Some files were not shown because too many files have changed in this diff Show more