diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 7a11fbe1..8e1bf983 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -77,10 +77,12 @@ jobs: valgrind --tool=callgrind ./luau-compile --null -O1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O1 | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --null -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2 | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --codegennull -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-codegen | tee -a compile-output.txt + valgrind --tool=callgrind ./luau-compile --codegennull -O2 -t1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-t1-codegen | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --null -O0 bench/other/regex.lua 2>&1 | filter regex-O0 | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --null -O1 bench/other/regex.lua 2>&1 | filter regex-O1 | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --null -O2 bench/other/regex.lua 2>&1 | filter regex-O2 | tee -a compile-output.txt valgrind --tool=callgrind ./luau-compile --codegennull -O2 bench/other/regex.lua 2>&1 | filter regex-O2-codegen | tee -a compile-output.txt + valgrind --tool=callgrind ./luau-compile --codegennull -O2 -t1 bench/other/regex.lua 2>&1 | filter regex-O2-t1-codegen | tee -a compile-output.txt - name: Checkout benchmark results uses: actions/checkout@v3 diff --git a/.gitignore b/.gitignore index 528ab204..8de6d91d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ /build/ /build[.-]*/ +/cmake/ +/cmake[.-]*/ /coverage/ /.vs/ /.vscode/ diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index d52ae6e0..3f3ad641 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -57,7 +57,7 @@ struct GeneralizationConstraint struct IterableConstraint { TypePackId iterator; - TypePackId variables; + std::vector variables; const AstNode* nextAstFragment; DenseHashMap* astForInNextTypes; @@ -179,23 +179,6 @@ struct HasPropConstraint bool suppressSimplification = false; }; -// result ~ setProp subjectType ["prop", "prop2", ...] propType -// -// If the subject is a table or table-like thing that already has the named -// property chain, we unify propType with that existing property type. -// -// If the subject is a free table, we augment it in place. -// -// If the subject is an unsealed table, result is an augmented table that -// includes that new prop. -struct SetPropConstraint -{ - TypeId resultType; - TypeId subjectType; - std::vector path; - TypeId propType; -}; - // resultType ~ hasIndexer subjectType indexType // // If the subject type is a table or table-like thing that supports indexing, @@ -209,46 +192,48 @@ struct HasIndexerConstraint TypeId indexType; }; -// result ~ setIndexer subjectType indexType propType +// assignProp lhsType propName rhsType // -// If the subject is a table or table-like thing that already has an indexer, -// unify its indexType and propType with those from this constraint. -// -// If the table is a free or unsealed table, we augment it with a new indexer. -struct SetIndexerConstraint +// Assign a value of type rhsType into the named property of lhsType. + +struct AssignPropConstraint { - TypeId subjectType; + TypeId lhsType; + std::string propName; + TypeId rhsType; + + /// The canonical write type of the property. It is _solely_ used to + /// populate astTypes during constraint resolution. Nothing should ever + /// block on it. + TypeId propType; + + // When we generate constraints, we increment the remaining prop count on + // the table if we are able. This flag informs the solver as to whether or + // not it should in turn decrement the prop count when this constraint is + // dispatched. + bool decrementPropCount = false; +}; + +struct AssignIndexConstraint +{ + TypeId lhsType; TypeId indexType; + TypeId rhsType; + + /// The canonical write type of the property. It is _solely_ used to + /// populate astTypes during constraint resolution. Nothing should ever + /// block on it. TypeId propType; }; -// resultType ~ unpack sourceTypePack +// resultTypes ~ unpack sourceTypePack // // Similar to PackSubtypeConstraint, but with one important difference: If the // sourcePack is blocked, this constraint blocks. struct UnpackConstraint { - TypePackId resultPack; + std::vector resultPack; TypePackId sourcePack; - - // UnpackConstraint is sometimes used to resolve the types of assignments. - // When this is the case, any LocalTypes in resultPack can have their - // domains extended by the corresponding type from sourcePack. - bool resultIsLValue = false; -}; - -// resultType ~ unpack sourceType -// -// The same as UnpackConstraint, but specialized for a pair of types as opposed to packs. -struct Unpack1Constraint -{ - TypeId resultType; - TypeId sourceType; - - // UnpackConstraint is sometimes used to resolve the types of assignments. - // When this is the case, any LocalTypes in resultPack can have their - // domains extended by the corresponding type from sourcePack. - bool resultIsLValue = false; }; // ty ~ reduce ty @@ -268,8 +253,8 @@ struct ReducePackConstraint }; using ConstraintV = Variant; + TypeAliasExpansionConstraint, FunctionCallConstraint, FunctionCheckConstraint, PrimitiveTypeConstraint, HasPropConstraint, HasIndexerConstraint, + AssignPropConstraint, AssignIndexConstraint, UnpackConstraint, ReduceConstraint, ReducePackConstraint, EqualityConstraint>; struct Constraint { @@ -284,11 +269,13 @@ struct Constraint std::vector> dependencies; - DenseHashSet getFreeTypes() const; + DenseHashSet getMaybeMutatedFreeTypes() const; }; using ConstraintPtr = std::unique_ptr; +bool isReferenceCountedType(const TypeId typ); + inline Constraint& asMutable(const Constraint& c) { return const_cast(c); diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index 6cb4b6d6..28cfb5aa 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -118,6 +118,8 @@ struct ConstraintGenerator std::function prepareModuleScope; std::vector requireCycles; + DenseHashMap localTypes{nullptr}; + DcrLogger* logger; ConstraintGenerator(ModulePtr module, NotNull normalizer, NotNull moduleResolver, NotNull builtinTypes, @@ -254,18 +256,11 @@ private: Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); std::tuple checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); - struct LValueBounds - { - std::optional annotationTy; - std::optional assignedTy; - }; - - LValueBounds checkLValue(const ScopePtr& scope, AstExpr* expr); - LValueBounds checkLValue(const ScopePtr& scope, AstExprLocal* local); - LValueBounds checkLValue(const ScopePtr& scope, AstExprGlobal* global); - LValueBounds checkLValue(const ScopePtr& scope, AstExprIndexName* indexName); - LValueBounds checkLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr); - LValueBounds updateProperty(const ScopePtr& scope, AstExpr* expr); + void visitLValue(const ScopePtr& scope, AstExpr* expr, TypeId rhsType); + void visitLValue(const ScopePtr& scope, AstExprLocal* local, TypeId rhsType); + void visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType); + void visitLValue(const ScopePtr& scope, AstExprIndexName* indexName, TypeId rhsType); + void visitLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr, TypeId rhsType); struct FunctionSignature { @@ -361,6 +356,8 @@ private: */ void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program); + bool recordPropertyAssignment(TypeId ty); + // Record the fact that a particular local has a particular type in at least // one of its states. void recordInferredBinding(AstLocal* local, TypeId ty); @@ -373,7 +370,8 @@ private: */ std::vector> getExpectedCallTypesForFunctionOverloads(const TypeId fnType); - TypeId createFamilyInstance(TypeFamilyInstanceType instance, const ScopePtr& scope, Location location); + TypeId createTypeFamilyInstance( + const TypeFamily& family, std::vector typeArguments, std::vector packArguments, const ScopePtr& scope, Location location); }; /** Borrow a vector of pointers from a vector of owning pointers to constraints. diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index bb1fe2d8..925be04e 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -94,6 +94,10 @@ struct ConstraintSolver // Irreducible/uninhabited type families or type pack families. DenseHashSet uninhabitedTypeFamilies{{}}; + // The set of types that will definitely be unchanged by generalization. + DenseHashSet generalizedTypes_{nullptr}; + const NotNull> generalizedTypes{&generalizedTypes_}; + // Recorded errors that take place within the solver. ErrorVec errors; @@ -103,6 +107,8 @@ struct ConstraintSolver DcrLogger* logger; TypeCheckLimits limits; + DenseHashMap typeFamiliesToFinalize{nullptr}; + explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger, TypeCheckLimits limits); @@ -116,8 +122,35 @@ struct ConstraintSolver **/ void run(); + + /** + * Attempts to perform one final reduction on type families after every constraint has been completed + * + **/ + void finalizeTypeFamilies(); + bool isDone(); +private: + /** + * Bind a type variable to another type. + * + * A constraint is required and will validate that blockedTy is owned by this + * constraint. This prevents one constraint from interfering with another's + * blocked types. + * + * Bind will also unblock the type variable for you. + */ + void bind(NotNull constraint, TypeId ty, TypeId boundTo); + void bind(NotNull constraint, TypePackId tp, TypePackId boundTo); + + template + void emplace(NotNull constraint, TypeId ty, Args&&... args); + + template + void emplace(NotNull constraint, TypePackId tp, Args&&... args); + +public: /** Attempt to dispatch a constraint. Returns true if it was successful. If * tryDispatch() returns false, the constraint remains in the unsolved set * and will be retried later. @@ -134,20 +167,15 @@ struct ConstraintSolver bool tryDispatch(const FunctionCheckConstraint& c, NotNull constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); - bool tryDispatch(const SetPropConstraint& c, NotNull constraint); + bool tryDispatchHasIndexer( int& recursionDepth, NotNull constraint, TypeId subjectType, TypeId indexType, TypeId resultType, Set& seen); bool tryDispatch(const HasIndexerConstraint& c, NotNull constraint); - std::pair> tryDispatchSetIndexer( - NotNull constraint, TypeId subjectType, TypeId indexType, TypeId propType, bool expandFreeTypeBounds); - bool tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force); - - bool tryDispatchUnpack1(NotNull constraint, TypeId resultType, TypeId sourceType, bool resultIsLValue); + bool tryDispatch(const AssignPropConstraint& c, NotNull constraint); + bool tryDispatch(const AssignIndexConstraint& c, NotNull constraint); bool tryDispatch(const UnpackConstraint& c, NotNull constraint); - bool tryDispatch(const Unpack1Constraint& c, NotNull constraint); - bool tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force); bool tryDispatch(const ReducePackConstraint& c, NotNull constraint, bool force); bool tryDispatch(const EqualityConstraint& c, NotNull constraint, bool force); @@ -157,14 +185,28 @@ struct ConstraintSolver bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force); // for a, ... in next_function, t, ... do - bool tryDispatchIterableFunction( - TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force); + bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull constraint, bool force); std::pair, std::optional> lookupTableProp(NotNull constraint, TypeId subjectType, const std::string& propName, ValueContext context, bool inConditional = false, bool suppressSimplification = false); std::pair, std::optional> lookupTableProp(NotNull constraint, TypeId subjectType, const std::string& propName, ValueContext context, bool inConditional, bool suppressSimplification, DenseHashSet& seen); + /** + * Generate constraints to unpack the types of srcTypes and assign each + * value to the corresponding BlockedType in destTypes. + * + * This function also overwrites the owners of each BlockedType. This is + * okay because this function is only used to decompose IterableConstraint + * into an UnpackConstraint. + * + * @param destTypes A vector of types comprised of BlockedTypes. + * @param srcTypes A TypePack that represents rvalues to be assigned. + * @returns The underlying UnpackConstraint. There's a bit of code in + * iteration that needs to pass blocks on to this constraint. + */ + NotNull unpackAndAssign(const std::vector destTypes, TypePackId srcTypes, NotNull constraint); + void block(NotNull target, NotNull constraint); /** * Block a constraint on the resolution of a Type. @@ -242,6 +284,24 @@ struct ConstraintSolver void reportError(TypeErrorData&& data, const Location& location); void reportError(TypeError e); + /** + * Shifts the count of references from `source` to `target`. This should be paired + * with any instance of binding a free type in order to maintain accurate refcounts. + * If `target` is not a free type, this is a noop. + * @param source the free type which is being bound + * @param target the type which the free type is being bound to + */ + void shiftReferences(TypeId source, TypeId target); + + /** + * Generalizes the given free type if the reference counting allows it. + * @param the scope to generalize in + * @param type the free type we want to generalize + * @returns a non-free type that generalizes the argument, or `std::nullopt` if one + * does not exist + */ + std::optional generalizeFreeType(NotNull scope, TypeId type, bool avoidSealingTables = false); + /** * Checks the existing set of constraints to see if there exist any that contain * the provided free type, indicating that it is not yet ready to be replaced by @@ -266,22 +326,6 @@ struct ConstraintSolver template bool unify(NotNull constraint, TID subTy, TID superTy); -private: - /** - * Bind a BlockedType to another type while taking care not to bind it to - * itself in the case that resultTy == blockedTy. This can happen if we - * have a tautological constraint. When it does, we must instead bind - * blockedTy to a fresh type belonging to an appropriate scope. - * - * To determine which scope is appropriate, we also accept rootTy, which is - * to be the type that contains blockedTy. - * - * A constraint is required and will validate that blockedTy is owned by this - * constraint. This prevents one constraint from interfering with another's - * blocked types. - */ - void bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId rootTy, NotNull constraint); - /** * Marks a constraint as being blocked on a type or type pack. The constraint * solver will not attempt to dispatch blocked constraints until their diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index fc9bc54f..27a67f40 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -191,7 +191,7 @@ struct Frontend void queueModuleCheck(const std::vector& names); void queueModuleCheck(const ModuleName& name); std::vector checkQueuedModules(std::optional optionOverride = {}, - std::function task)> executeTask = {}, std::function progress = {}); + std::function task)> executeTask = {}, std::function progress = {}); std::optional getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false); diff --git a/Analysis/include/Luau/Generalization.h b/Analysis/include/Luau/Generalization.h new file mode 100644 index 00000000..04ac2df1 --- /dev/null +++ b/Analysis/include/Luau/Generalization.h @@ -0,0 +1,13 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Scope.h" +#include "Luau/NotNull.h" +#include "Luau/TypeFwd.h" + +namespace Luau +{ + +std::optional generalize(NotNull arena, NotNull builtinTypes, NotNull scope, + NotNull> bakedTypes, TypeId ty, /* avoid sealing tables*/ bool avoidSealingTables = false); +} diff --git a/Analysis/include/Luau/Instantiation.h b/Analysis/include/Luau/Instantiation.h index 2122f0fa..58ba88ab 100644 --- a/Analysis/include/Luau/Instantiation.h +++ b/Analysis/include/Luau/Instantiation.h @@ -27,12 +27,16 @@ struct ReplaceGenerics : Substitution { } + void resetState(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope, + const std::vector& generics, const std::vector& genericPacks); + NotNull builtinTypes; TypeLevel level; Scope* scope; std::vector generics; std::vector genericPacks; + bool ignoreChildren(TypeId ty) override; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; @@ -48,13 +52,19 @@ struct Instantiation : Substitution , builtinTypes(builtinTypes) , level(level) , scope(scope) + , reusableReplaceGenerics(log, arena, builtinTypes, level, scope, {}, {}) { } + void resetState(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope); + NotNull builtinTypes; TypeLevel level; Scope* scope; + + ReplaceGenerics reusableReplaceGenerics; + bool ignoreChildren(TypeId ty) override; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 197c7f9c..152d8c65 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -102,6 +102,12 @@ struct Module DenseHashMap astResolvedTypes{nullptr}; DenseHashMap astResolvedTypePacks{nullptr}; + // The computed result type of a compound assignment. (eg foo += 1) + // + // Type checking uses this to check that the result of such an operation is + // actually compatible with the left-side operand. + DenseHashMap astCompoundAssignResultTypes{nullptr}; + DenseHashMap>> upperBoundContributors{nullptr}; // Map AST nodes to the scope they create. Cannot be NotNull because diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 35e0c7a1..b21e470c 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -307,6 +307,9 @@ struct NormalizedType /// Returns true if the type is a subtype of string(it could be a singleton). Behaves like Type::isString() bool isSubtypeOfString() const; + /// Returns true if the type is a subtype of boolean(it could be a singleton). Behaves like Type::isBoolean() + bool isSubtypeOfBooleans() const; + /// Returns true if this type should result in error suppressing behavior. bool shouldSuppressErrors() const; @@ -360,7 +363,6 @@ public: Normalizer& operator=(Normalizer&) = delete; // If this returns null, the typechecker should emit a "too complex" error - const NormalizedType* DEPRECATED_normalize(TypeId ty); std::shared_ptr normalize(TypeId ty); void clearNormal(NormalizedType& norm); @@ -395,6 +397,7 @@ public: TypeId negate(TypeId there); void subtractPrimitive(NormalizedType& here, TypeId ty); void subtractSingleton(NormalizedType& here, TypeId ty); + NormalizationResult intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect); // ------- Normalizing intersections TypeId intersectionOfTops(TypeId here, TypeId there); @@ -403,8 +406,8 @@ public: void intersectClassesWithClass(NormalizedClassType& heres, TypeId there); void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there); std::optional intersectionOfTypePacks(TypePackId here, TypePackId there); - std::optional intersectionOfTables(TypeId here, TypeId there); - void intersectTablesWithTable(TypeIds& heres, TypeId there); + std::optional intersectionOfTables(TypeId here, TypeId there, Set& seenSet); + void intersectTablesWithTable(TypeIds& heres, TypeId there, Set& seenSetTypes); void intersectTables(TypeIds& heres, const TypeIds& theres); std::optional intersectionOfFunctions(TypeId here, TypeId there); void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there); @@ -412,7 +415,7 @@ public: NormalizationResult intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set& seenSetTypes); NormalizationResult intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); NormalizationResult intersectNormalWithTy(NormalizedType& here, TypeId there, Set& seenSetTypes); - NormalizationResult normalizeIntersections(const std::vector& intersections, NormalizedType& outType); + NormalizationResult normalizeIntersections(const std::vector& intersections, NormalizedType& outType, Set& seenSet); // Check for inhabitance NormalizationResult isInhabited(TypeId ty); @@ -422,6 +425,7 @@ public: // Check for intersections being inhabited NormalizationResult isIntersectionInhabited(TypeId left, TypeId right); + NormalizationResult isIntersectionInhabited(TypeId left, TypeId right, Set& seenSet); // -------- Convert back from a normalized type to a type TypeId typeFromNormal(const NormalizedType& norm); diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 5f1630d5..0e6eff56 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -102,4 +102,12 @@ bool subsumesStrict(Scope* left, Scope* right); // outermost-possible scope. bool subsumes(Scope* left, Scope* right); +inline Scope* max(Scope* left, Scope* right) +{ + if (subsumes(left, right)) + return right; + else + return left; +} + } // namespace Luau diff --git a/Analysis/include/Luau/Set.h b/Analysis/include/Luau/Set.h index 2fea2e6a..274375cf 100644 --- a/Analysis/include/Luau/Set.h +++ b/Analysis/include/Luau/Set.h @@ -4,7 +4,6 @@ #include "Luau/Common.h" #include "Luau/DenseHash.h" -LUAU_FASTFLAG(LuauFixSetIter) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) namespace Luau @@ -143,11 +142,8 @@ public: : impl(impl_) , end(end_) { - if (FFlag::LuauFixSetIter || FFlag::DebugLuauDeferredConstraintResolution) - { - while (impl != end && impl->second == false) - ++impl; - } + while (impl != end && impl->second == false) + ++impl; } const T& operator*() const diff --git a/Analysis/include/Luau/Simplify.h b/Analysis/include/Luau/Simplify.h index 10f27d4e..5b363e96 100644 --- a/Analysis/include/Luau/Simplify.h +++ b/Analysis/include/Luau/Simplify.h @@ -5,6 +5,7 @@ #include "Luau/DenseHash.h" #include "Luau/NotNull.h" #include "Luau/TypeFwd.h" +#include namespace Luau { @@ -19,6 +20,8 @@ struct SimplifyResult }; SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, TypeId ty, TypeId discriminant); +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, std::set parts); + SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId ty, TypeId discriminant); enum class Relation diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 16e36e09..28ebc93d 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -134,7 +134,8 @@ struct Tarjan TarjanResult visitRoot(TypeId ty); TarjanResult visitRoot(TypePackId ty); - void clearTarjan(); + // Used to reuse the object for a new operation + void clearTarjan(const TxnLog* log); // Get/set the dirty bit for an index (grows the vector if needed) bool getDirty(int index); @@ -212,6 +213,8 @@ public: std::optional substitute(TypeId ty); std::optional substitute(TypePackId tp); + void resetState(const TxnLog* log, TypeArena* arena); + TypeId replace(TypeId ty); TypePackId replace(TypePackId tp); diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 6e88aecb..b543e414 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -86,24 +86,6 @@ struct FreeType TypeId upperBound = nullptr; }; -/** A type that tracks the domain of a local variable. - * - * We consider each local's domain to be the union of all types assigned to it. - * We accomplish this with LocalType. Each time we dispatch an assignment to a - * local, we accumulate this union and decrement blockCount. - * - * When blockCount reaches 0, we can consider the LocalType to be "fully baked" - * and replace it with the union we've built. - */ -struct LocalType -{ - TypeId domain; - int blockCount = 0; - - // Used for debugging - std::string name; -}; - struct GenericType { // By default, generics are global, with a synthetic name @@ -148,6 +130,7 @@ struct BlockedType Constraint* getOwner() const; void setOwner(Constraint* newOwner); + void replaceOwner(Constraint* newOwner); private: // The constraint that is intended to unblock this type. Other constraints @@ -471,6 +454,11 @@ struct TableType // Methods of this table that have an untyped self will use the same shared self type. std::optional selfTy; + + // We track the number of as-yet-unadded properties to unsealed tables. + // Some constraints will use this information to decide whether or not they + // are able to dispatch. + size_t remainingProps = 0; }; // Represents a metatable attached to a table type. Somewhat analogous to a bound type. @@ -672,9 +660,9 @@ struct NegationType using ErrorType = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = + Unifiable::Variant; struct Type final { diff --git a/Analysis/include/Luau/TypeFamily.h b/Analysis/include/Luau/TypeFamily.h index fa418e17..9d2182df 100644 --- a/Analysis/include/Luau/TypeFamily.h +++ b/Analysis/include/Luau/TypeFamily.h @@ -6,7 +6,6 @@ #include "Luau/NotNull.h" #include "Luau/TypeCheckLimits.h" #include "Luau/TypeFwd.h" -#include "Luau/Variant.h" #include #include @@ -19,22 +18,6 @@ struct TypeArena; struct TxnLog; class Normalizer; -struct TypeFamilyQueue -{ - NotNull> queuedTys; - NotNull> queuedTps; - - void add(TypeId instanceTy); - void add(TypePackId instanceTp); - - template - void add(const std::vector& ts) - { - for (const T& t : ts) - enqueue(t); - } -}; - struct TypeFamilyContext { NotNull arena; @@ -99,8 +82,8 @@ struct TypeFamilyReductionResult }; template -using ReducerFunction = std::function( - T, NotNull, const std::vector&, const std::vector&, NotNull)>; +using ReducerFunction = + std::function(T, const std::vector&, const std::vector&, NotNull)>; /// Represents a type function that may be applied to map a series of types and /// type packs to a single output type. @@ -196,11 +179,12 @@ struct BuiltinTypeFamilies TypeFamily keyofFamily; TypeFamily rawkeyofFamily; + TypeFamily indexFamily; + TypeFamily rawgetFamily; + void addToScope(NotNull arena, NotNull scope) const; }; - - -const BuiltinTypeFamilies kBuiltinTypeFamilies{}; +const BuiltinTypeFamilies& builtinTypeFunctions(); } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 26a67c7a..340c1e72 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -4,6 +4,7 @@ #include "Luau/Anyification.h" #include "Luau/ControlFlow.h" #include "Luau/Error.h" +#include "Luau/Instantiation.h" #include "Luau/Module.h" #include "Luau/Predicate.h" #include "Luau/Substitution.h" @@ -362,6 +363,8 @@ public: UnifierSharedState unifierState; Normalizer normalizer; + Instantiation reusableInstantiation; + std::vector requireCycles; // Type inference limits diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 81c8a5ca..c8ee99e9 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -55,6 +55,9 @@ struct InConditionalContext using ScopePtr = std::shared_ptr; +std::optional findTableProperty( + NotNull builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location); + std::optional findMetatableEntry( NotNull builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location); std::optional findTablePropertyRespectingMeta( diff --git a/Analysis/include/Luau/Unifier2.h b/Analysis/include/Luau/Unifier2.h index a7d64312..bbf3a63a 100644 --- a/Analysis/include/Luau/Unifier2.h +++ b/Analysis/include/Luau/Unifier2.h @@ -69,7 +69,6 @@ struct Unifier2 */ bool unify(TypeId subTy, TypeId superTy); bool unifyFreeWithType(TypeId subTy, TypeId superTy); - bool unify(const LocalType* subTy, TypeId superFn); bool unify(TypeId subTy, const FunctionType* superFn); bool unify(const UnionType* subUnion, TypeId superTy); bool unify(TypeId subTy, const UnionType* superUnion); @@ -78,6 +77,11 @@ struct Unifier2 bool unify(TableType* subTable, const TableType* superTable); bool unify(const MetatableType* subMetatable, const MetatableType* superMetatable); + bool unify(const AnyType* subAny, const FunctionType* superFn); + bool unify(const FunctionType* subFn, const AnyType* superAny); + bool unify(const AnyType* subAny, const TableType* superTable); + bool unify(const TableType* subTable, const AnyType* superAny); + // TODO think about this one carefully. We don't do unions or intersections of type packs bool unify(TypePackId subTp, TypePackId superTp); diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index 40dccbd2..8c0f5ed9 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -100,10 +100,6 @@ struct GenericTypeVisitor { return visit(ty); } - virtual bool visit(TypeId ty, const LocalType& ftv) - { - return visit(ty); - } virtual bool visit(TypeId ty, const GenericType& gtv) { return visit(ty); @@ -248,11 +244,6 @@ struct GenericTypeVisitor else visit(ty, *ftv); } - else if (auto lt = get(ty)) - { - if (visit(ty, *lt)) - traverse(lt->domain); - } else if (auto gtv = get(ty)) visit(ty, *gtv); else if (auto etv = get(ty)) @@ -357,16 +348,38 @@ struct GenericTypeVisitor { if (visit(ty, *utv)) { + bool unionChanged = false; for (TypeId optTy : utv->options) + { traverse(optTy); + if (!get(follow(ty))) + { + unionChanged = true; + break; + } + } + + if (unionChanged) + traverse(ty); } } else if (auto itv = get(ty)) { if (visit(ty, *itv)) { + bool intersectionChanged = false; for (TypeId partTy : itv->parts) + { traverse(partTy); + if (!get(follow(ty))) + { + intersectionChanged = true; + break; + } + } + + if (intersectionChanged) + traverse(ty); } } else if (auto ltv = get(ty)) diff --git a/Analysis/src/AstJsonEncoder.cpp b/Analysis/src/AstJsonEncoder.cpp index 470d69b3..3507a68f 100644 --- a/Analysis/src/AstJsonEncoder.cpp +++ b/Analysis/src/AstJsonEncoder.cpp @@ -8,6 +8,8 @@ #include +LUAU_FASTFLAG(LuauDeclarationExtraPropData) + namespace Luau { @@ -735,8 +737,21 @@ struct AstJsonEncoder : public AstVisitor void write(class AstStatDeclareFunction* node) { writeNode(node, "AstStatDeclareFunction", [&]() { + // TODO: attributes PROP(name); + + if (FFlag::LuauDeclarationExtraPropData) + PROP(nameLocation); + PROP(params); + + if (FFlag::LuauDeclarationExtraPropData) + { + PROP(paramNames); + PROP(vararg); + PROP(varargLocation); + } + PROP(retTypes); PROP(generics); PROP(genericPacks); @@ -747,6 +762,10 @@ struct AstJsonEncoder : public AstVisitor { writeNode(node, "AstStatDeclareGlobal", [&]() { PROP(name); + + if (FFlag::LuauDeclarationExtraPropData) + PROP(nameLocation); + PROP(type); }); } @@ -756,8 +775,16 @@ struct AstJsonEncoder : public AstVisitor writeRaw("{"); bool c = pushComma(); write("name", prop.name); + + if (FFlag::LuauDeclarationExtraPropData) + write("nameLocation", prop.nameLocation); + writeType("AstDeclaredClassProp"); write("luauType", prop.ty); + + if (FFlag::LuauDeclarationExtraPropData) + write("location", prop.location); + popComma(c); writeRaw("}"); } diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index cebb226a..928e5dfb 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -12,6 +12,7 @@ #include LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAGVARIABLE(LuauFixBindingForGlobalPos, false); namespace Luau { @@ -332,6 +333,11 @@ std::optional findExpectedTypeAtPosition(const Module& module, const Sou static std::optional findBindingLocalStatement(const SourceModule& source, const Binding& binding) { + // Bindings coming from global sources (e.g., definition files) have a zero position. + // They cannot be defined from a local statement + if (FFlag::LuauFixBindingForGlobalPos && binding.location == Location{{0, 0}, {0, 0}}) + return std::nullopt; + std::vector nodes = findAstAncestryOfPosition(source, binding.location.begin); auto iter = std::find_if(nodes.rbegin(), nodes.rend(), [](AstNode* node) { return node->is(); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index d6f0ab83..0dab640f 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -1830,12 +1830,21 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName if (!sourceModule) return {}; - ModulePtr module = frontend.moduleResolverForAutocomplete.getModule(moduleName); + ModulePtr module; + if (FFlag::DebugLuauDeferredConstraintResolution) + module = frontend.moduleResolver.getModule(moduleName); + else + module = frontend.moduleResolverForAutocomplete.getModule(moduleName); + if (!module) return {}; NotNull builtinTypes = frontend.builtinTypes; - Scope* globalScope = frontend.globalsForAutocomplete.globalScope.get(); + Scope* globalScope; + if (FFlag::DebugLuauDeferredConstraintResolution) + globalScope = frontend.globals.globalScope.get(); + else + globalScope = frontend.globalsForAutocomplete.globalScope.get(); TypeArena typeArena; return autocomplete(*sourceModule, module, builtinTypes, &typeArena, globalScope, position, callback); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index f9ce87e0..582d5a7d 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -24,7 +24,6 @@ */ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAGVARIABLE(LuauMakeStringMethodsChecked, false); namespace Luau { @@ -217,7 +216,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC NotNull builtinTypes = globals.builtinTypes; if (FFlag::DebugLuauDeferredConstraintResolution) - kBuiltinTypeFamilies.addToScope(NotNull{&arena}, NotNull{globals.globalScope.get()}); + builtinTypeFunctions().addToScope(NotNull{&arena}, NotNull{globals.globalScope.get()}); LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile( globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false, typeCheckForAutocomplete); @@ -257,21 +256,44 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC TypeId tableMetaMT = arena.addType(MetatableType{tabTy, genericMT}); + // getmetatable : ({ @metatable MT, {+ +} }) -> MT addGlobalBinding(globals, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); - // clang-format off - // setmetatable(T, MT) -> { @metatable MT, T } - addGlobalBinding(globals, "setmetatable", - arena.addType( - FunctionType{ - {genericMT}, - {}, - arena.addTypePack(TypePack{{tabTy, genericMT}}), - arena.addTypePack(TypePack{{tableMetaMT}}) - } - ), "@luau" - ); - // clang-format on + if (FFlag::DebugLuauDeferredConstraintResolution) + { + TypeId genericT = arena.addType(GenericType{"T"}); + TypeId tMetaMT = arena.addType(MetatableType{genericT, genericMT}); + + // clang-format off + // setmetatable(T, MT) -> { @metatable MT, T } + addGlobalBinding(globals, "setmetatable", + arena.addType( + FunctionType{ + {genericT, genericMT}, + {}, + arena.addTypePack(TypePack{{genericT, genericMT}}), + arena.addTypePack(TypePack{{tMetaMT}}) + } + ), "@luau" + ); + // clang-format on + } + else + { + // clang-format off + // setmetatable(T, MT) -> { @metatable MT, T } + addGlobalBinding(globals, "setmetatable", + arena.addType( + FunctionType{ + {genericMT}, + {}, + arena.addTypePack(TypePack{{tabTy, genericMT}}), + arena.addTypePack(TypePack{{tableMetaMT}}) + } + ), "@luau" + ); + // clang-format on + } for (const auto& pair : globals.globalScope->bindings) { @@ -291,7 +313,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC // declare function assert(value: T, errorMessage: string?): intersect TypeId genericT = arena.addType(GenericType{"T"}); TypeId refinedTy = arena.addType(TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.intersectFamily}, {genericT, arena.addType(NegationType{builtinTypes->falsyType})}, {}}); + NotNull{&builtinTypeFunctions().intersectFamily}, {genericT, arena.addType(NegationType{builtinTypes->falsyType})}, {}}); TypeId assertTy = arena.addType(FunctionType{ {genericT}, {}, arena.addTypePack(TypePack{{genericT, builtinTypes->optionalStringType}}), arena.addTypePack(TypePack{{refinedTy}})}); @@ -773,153 +795,87 @@ TypeId makeStringMetatable(NotNull builtinTypes) const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}}); - if (FFlag::LuauMakeStringMethodsChecked) - { - FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack}; - formatFTV.magicFunction = &magicFunctionFormat; - formatFTV.isCheckedFunction = true; - const TypeId formatFn = arena->addType(formatFTV); - attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); + FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack}; + formatFTV.magicFunction = &magicFunctionFormat; + formatFTV.isCheckedFunction = true; + const TypeId formatFn = arena->addType(formatFTV); + attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); - const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true); + const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true); - const TypeId replArgType = arena->addType( - UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), - makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ false)}}); - const TypeId gsubFunc = - makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false); - const TypeId gmatchFunc = makeFunction( - *arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true); - attachMagicFunction(gmatchFunc, magicFunctionGmatch); - attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); + const TypeId replArgType = + arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), + makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ false)}}); + const TypeId gsubFunc = + makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false); + const TypeId gmatchFunc = + makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true); + attachMagicFunction(gmatchFunc, magicFunctionGmatch); + attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); - FunctionType matchFuncTy{ - arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}; - matchFuncTy.isCheckedFunction = true; - const TypeId matchFunc = arena->addType(matchFuncTy); - attachMagicFunction(matchFunc, magicFunctionMatch); - attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); + FunctionType matchFuncTy{ + arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}; + matchFuncTy.isCheckedFunction = true; + const TypeId matchFunc = arena->addType(matchFuncTy); + attachMagicFunction(matchFunc, magicFunctionMatch); + attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); - FunctionType findFuncTy{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), - arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}; - findFuncTy.isCheckedFunction = true; - const TypeId findFunc = arena->addType(findFuncTy); - attachMagicFunction(findFunc, magicFunctionFind); - attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); + FunctionType findFuncTy{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), + arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}; + findFuncTy.isCheckedFunction = true; + const TypeId findFunc = arena->addType(findFuncTy); + attachMagicFunction(findFunc, magicFunctionFind); + attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); - // string.byte : string -> number? -> number? -> ...number - FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList}; - stringDotByte.isCheckedFunction = true; + // string.byte : string -> number? -> number? -> ...number + FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList}; + stringDotByte.isCheckedFunction = true; - // string.char : .... number -> string - FunctionType stringDotChar{numberVariadicList, arena->addTypePack({stringType})}; - stringDotChar.isCheckedFunction = true; + // string.char : .... number -> string + FunctionType stringDotChar{numberVariadicList, arena->addTypePack({stringType})}; + stringDotChar.isCheckedFunction = true; - // string.unpack : string -> string -> number? -> ...any - FunctionType stringDotUnpack{ - arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), - variadicTailPack, - }; - stringDotUnpack.isCheckedFunction = true; + // string.unpack : string -> string -> number? -> ...any + FunctionType stringDotUnpack{ + arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), + variadicTailPack, + }; + stringDotUnpack.isCheckedFunction = true; - TableType::Props stringLib = { - {"byte", {arena->addType(stringDotByte)}}, - {"char", {arena->addType(stringDotChar)}}, - {"find", {findFunc}}, - {"format", {formatFn}}, // FIXME - {"gmatch", {gmatchFunc}}, - {"gsub", {gsubFunc}}, - {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}}, - {"lower", {stringToStringType}}, - {"match", {matchFunc}}, - {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType}, /* checked */ true)}}, - {"reverse", {stringToStringType}}, - {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType}, /* checked */ true)}}, - {"upper", {stringToStringType}}, - {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, - {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})}, - /* checked */ true)}}, - {"pack", {arena->addType(FunctionType{ - arena->addTypePack(TypePack{{stringType}, variadicTailPack}), - oneStringPack, - })}}, - {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}}, - {"unpack", {arena->addType(stringDotUnpack)}}, - }; - assignPropDocumentationSymbols(stringLib, "@luau/global/string"); + TableType::Props stringLib = { + {"byte", {arena->addType(stringDotByte)}}, + {"char", {arena->addType(stringDotChar)}}, + {"find", {findFunc}}, + {"format", {formatFn}}, // FIXME + {"gmatch", {gmatchFunc}}, + {"gsub", {gsubFunc}}, + {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}}, + {"lower", {stringToStringType}}, + {"match", {matchFunc}}, + {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType}, /* checked */ true)}}, + {"reverse", {stringToStringType}}, + {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType}, /* checked */ true)}}, + {"upper", {stringToStringType}}, + {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, + {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})}, + /* checked */ true)}}, + {"pack", {arena->addType(FunctionType{ + arena->addTypePack(TypePack{{stringType}, variadicTailPack}), + oneStringPack, + })}}, + {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}}, + {"unpack", {arena->addType(stringDotUnpack)}}, + }; - TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); + assignPropDocumentationSymbols(stringLib, "@luau/global/string"); - if (TableType* ttv = getMutable(tableType)) - ttv->name = "typeof(string)"; + TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); - return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); - } - else - { - FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack}; - formatFTV.magicFunction = &magicFunctionFormat; - const TypeId formatFn = arena->addType(formatFTV); - attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); + if (TableType* ttv = getMutable(tableType)) + ttv->name = "typeof(string)"; - const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}); - - const TypeId replArgType = arena->addType( - UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), - makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}}); - const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}); - const TypeId gmatchFunc = - makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}); - attachMagicFunction(gmatchFunc, magicFunctionGmatch); - attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); - - const TypeId matchFunc = arena->addType(FunctionType{ - arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}); - attachMagicFunction(matchFunc, magicFunctionMatch); - attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); - - const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), - arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}); - attachMagicFunction(findFunc, magicFunctionFind); - attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); - - TableType::Props stringLib = { - {"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, - {"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}}, - {"find", {findFunc}}, - {"format", {formatFn}}, // FIXME - {"gmatch", {gmatchFunc}}, - {"gsub", {gsubFunc}}, - {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, - {"lower", {stringToStringType}}, - {"match", {matchFunc}}, - {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}}, - {"reverse", {stringToStringType}}, - {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, - {"upper", {stringToStringType}}, - {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, - {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}}, - {"pack", {arena->addType(FunctionType{ - arena->addTypePack(TypePack{{stringType}, variadicTailPack}), - oneStringPack, - })}}, - {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, - {"unpack", {arena->addType(FunctionType{ - arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), - variadicTailPack, - })}}, - }; - - assignPropDocumentationSymbols(stringLib, "@luau/global/string"); - - TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); - - if (TableType* ttv = getMutable(tableType)) - ttv->name = "typeof(string)"; - - return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); - } + return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } static std::optional> magicFunctionSelect( diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index a96e5866..371ace2e 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -271,11 +271,6 @@ private: t->upperBound = shallowClone(t->upperBound); } - void cloneChildren(LocalType* t) - { - t->domain = shallowClone(t->domain); - } - void cloneChildren(GenericType* t) { // TOOD: clone upper bounds. diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index 4f35b58f..a62879fa 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -13,12 +13,12 @@ Constraint::Constraint(NotNull scope, const Location& location, Constrain { } -struct FreeTypeCollector : TypeOnceVisitor +struct ReferenceCountInitializer : TypeOnceVisitor { DenseHashSet* result; - FreeTypeCollector(DenseHashSet* result) + ReferenceCountInitializer(DenseHashSet* result) : result(result) { } @@ -29,6 +29,18 @@ struct FreeTypeCollector : TypeOnceVisitor return false; } + bool visit(TypeId ty, const BlockedType&) override + { + result->insert(ty); + return false; + } + + bool visit(TypeId ty, const PendingExpansionType&) override + { + result->insert(ty); + return false; + } + bool visit(TypeId ty, const ClassType&) override { // ClassTypes never contain free types. @@ -36,26 +48,89 @@ struct FreeTypeCollector : TypeOnceVisitor } }; -DenseHashSet Constraint::getFreeTypes() const +bool isReferenceCountedType(const TypeId typ) { - DenseHashSet types{{}}; - FreeTypeCollector ftc{&types}; + // n.b. this should match whatever `ReferenceCountInitializer` includes. + return get(typ) || get(typ) || get(typ); +} - if (auto sc = get(*this)) +DenseHashSet Constraint::getMaybeMutatedFreeTypes() const +{ + // For the purpose of this function and reference counting in general, we are only considering + // mutations that affect the _bounds_ of the free type, and not something that may bind the free + // type itself to a new type. As such, `ReduceConstraint` and `GeneralizationConstraint` have no + // contribution to the output set here. + + DenseHashSet types{{}}; + ReferenceCountInitializer rci{&types}; + + if (auto ec = get(*this)) { - ftc.traverse(sc->subType); - ftc.traverse(sc->superType); + rci.traverse(ec->resultType); + // `EqualityConstraints` should not mutate `assignmentType`. + } + else if (auto sc = get(*this)) + { + rci.traverse(sc->subType); + rci.traverse(sc->superType); } else if (auto psc = get(*this)) { - ftc.traverse(psc->subPack); - ftc.traverse(psc->superPack); + rci.traverse(psc->subPack); + rci.traverse(psc->superPack); + } + else if (auto itc = get(*this)) + { + for (TypeId ty : itc->variables) + rci.traverse(ty); + // `IterableConstraints` should not mutate `iterator`. + } + else if (auto nc = get(*this)) + { + rci.traverse(nc->namedType); + } + else if (auto taec = get(*this)) + { + rci.traverse(taec->target); + } + else if (auto fchc = get(*this)) + { + rci.traverse(fchc->argsPack); } else if (auto ptc = get(*this)) { - // we need to take into account primitive type constraints to prevent type families from reducing on - // primitive whose types we have not yet selected to be singleton or not. - ftc.traverse(ptc->freeType); + rci.traverse(ptc->freeType); + } + else if (auto hpc = get(*this)) + { + rci.traverse(hpc->resultType); + // `HasPropConstraints` should not mutate `subjectType`. + } + else if (auto hic = get(*this)) + { + rci.traverse(hic->resultType); + // `HasIndexerConstraint` should not mutate `subjectType` or `indexType`. + } + else if (auto apc = get(*this)) + { + rci.traverse(apc->lhsType); + rci.traverse(apc->rhsType); + } + else if (auto aic = get(*this)) + { + rci.traverse(aic->lhsType); + rci.traverse(aic->indexType); + rci.traverse(aic->rhsType); + } + else if (auto uc = get(*this)) + { + for (TypeId ty : uc->resultPack) + rci.traverse(ty); + // `UnpackConstraint` should not mutate `sourcePack`. + } + else if (auto rpc = get(*this)) + { + rci.traverse(rpc->tp); } return types; diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 25bb98d3..7ec531d3 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -28,6 +28,8 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); +LUAU_FASTFLAG(LuauAttributeSyntax); +LUAU_FASTFLAG(LuauDeclarationExtraPropData); namespace Luau { @@ -246,6 +248,21 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) if (logger) logger->captureGenerationModule(module); + + for (const auto& [ty, domain] : localTypes) + { + // FIXME: This isn't the most efficient thing. + TypeId domainTy = builtinTypes->neverType; + for (TypeId d : domain) + { + if (d == ty) + continue; + domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; + } + + LUAU_ASSERT(get(ty)); + asMutable(ty)->ty.emplace(domainTy); + } } TypeId ConstraintGenerator::freshType(const ScopePtr& scope) @@ -311,6 +328,7 @@ std::optional ConstraintGenerator::lookup(const ScopePtr& scope, Locatio if (!ty) { ty = arena->addType(BlockedType{}); + localTypes.try_insert(*ty, {}); rootScope->lvalueTypes[operand] = *ty; } @@ -414,7 +432,7 @@ void ConstraintGenerator::computeRefinement(const ScopePtr& scope, Location loca discriminantTy = arena->addType(NegationType{discriminantTy}); if (eq) - discriminantTy = arena->addTypeFamily(kBuiltinTypeFamilies.singletonFamily, {discriminantTy}); + discriminantTy = createTypeFamilyInstance(builtinTypeFunctions().singletonFamily, {discriminantTy}, {}, scope, location); for (const RefinementKey* key = proposition->key; key; key = key->parent) { @@ -526,13 +544,7 @@ void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location locat { if (mustDeferIntersection(ty) || mustDeferIntersection(dt)) { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.refineFamily}, - {ty, dt}, - {}, - }, - scope, location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().refineFamily, {ty, dt}, {}, scope, location); ty = resultType; } @@ -709,7 +721,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat { const Location location = local->location; - TypeId assignee = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, local->name.value}); + TypeId assignee = arena->addType(BlockedType{}); + localTypes.try_insert(assignee, {}); assignees.push_back(assignee); @@ -745,12 +758,48 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat if (hasAnnotation) { + for (size_t i = 0; i < statLocal->vars.size; ++i) + { + LUAU_ASSERT(get(assignees[i])); + TypeIds* localDomain = localTypes.find(assignees[i]); + LUAU_ASSERT(localDomain); + localDomain->insert(annotatedTypes[i]); + } + TypePackId annotatedPack = arena->addTypePack(std::move(annotatedTypes)); - addConstraint(scope, statLocal->location, UnpackConstraint{arena->addTypePack(std::move(assignees)), annotatedPack, /*resultIsLValue*/ true}); addConstraint(scope, statLocal->location, PackSubtypeConstraint{rvaluePack, annotatedPack}); } else - addConstraint(scope, statLocal->location, UnpackConstraint{arena->addTypePack(std::move(assignees)), rvaluePack, /*resultIsLValue*/ true}); + { + std::vector valueTypes; + valueTypes.reserve(statLocal->vars.size); + + auto [head, tail] = flatten(rvaluePack); + + if (head.size() >= statLocal->vars.size) + { + for (size_t i = 0; i < statLocal->vars.size; ++i) + valueTypes.push_back(head[i]); + } + else + { + for (size_t i = 0; i < statLocal->vars.size; ++i) + valueTypes.push_back(arena->addType(BlockedType{})); + + auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{valueTypes, rvaluePack}); + + for (TypeId t : valueTypes) + getMutable(t)->setOwner(uc); + } + + for (size_t i = 0; i < statLocal->vars.size; ++i) + { + LUAU_ASSERT(get(assignees[i])); + TypeIds* localDomain = localTypes.find(assignees[i]); + LUAU_ASSERT(localDomain); + localDomain->insert(valueTypes[i]); + } + } if (statLocal->vars.size == 1 && statLocal->values.size == 1 && firstValueType && scope.get() == rootScope && !hasAnnotation) { @@ -843,7 +892,6 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFor* for_) ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forIn) { ScopePtr loopScope = childScope(forIn, scope); - TypePackId iterator = checkPack(scope, forIn->values).tp; std::vector variableTypes; @@ -851,27 +899,43 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forI for (AstLocal* var : forIn->vars) { - TypeId assignee = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, var->name.value}); + TypeId assignee = arena->addType(BlockedType{}); variableTypes.push_back(assignee); + TypeId loopVar = arena->addType(BlockedType{}); + localTypes[loopVar].insert(assignee); + if (var->annotation) { TypeId annotationTy = resolveType(loopScope, var->annotation, /*inTypeArguments*/ false); loopScope->bindings[var] = Binding{annotationTy, var->location}; - addConstraint(scope, var->location, SubtypeConstraint{assignee, annotationTy}); + addConstraint(scope, var->location, SubtypeConstraint{loopVar, annotationTy}); } else - loopScope->bindings[var] = Binding{assignee, var->location}; + loopScope->bindings[var] = Binding{loopVar, var->location}; DefId def = dfg->getDef(var); - loopScope->lvalueTypes[def] = assignee; + loopScope->lvalueTypes[def] = loopVar; } - TypePackId variablePack = arena->addTypePack(std::move(variableTypes)); - addConstraint( - loopScope, getLocation(forIn->values), IterableConstraint{iterator, variablePack, forIn->values.data[0], &module->astForInNextTypes}); + auto iterable = addConstraint( + loopScope, getLocation(forIn->values), IterableConstraint{iterator, variableTypes, forIn->values.data[0], &module->astForInNextTypes}); + for (TypeId var : variableTypes) + { + auto bt = getMutable(var); + LUAU_ASSERT(bt); + bt->setOwner(iterable); + } + + Checkpoint start = checkpoint(this); visit(loopScope, forIn->body); + Checkpoint end = checkpoint(this); + + // This iter constraint must dispatch first. + forEachConstraint(start, end, this, [&iterable](const ConstraintPtr& runLater) { + runLater->dependencies.push_back(iterable); + }); return ControlFlow::None; } @@ -963,67 +1027,63 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. // With or without self - TypeId generalizedType = arena->addType(BlockedType{}); Checkpoint start = checkpoint(this); FunctionSignature sig = checkFunctionSignature(scope, function->func, /* expectedType */ std::nullopt, function->name->location); bool sigFullyDefined = !hasFreeType(sig.signature); + checkFunctionBody(sig.bodyScope, function->func); + Checkpoint end = checkpoint(this); + + TypeId generalizedType = arena->addType(BlockedType{}); if (sigFullyDefined) emplaceType(asMutable(generalizedType), sig.signature); + else + { + const ScopePtr& constraintScope = sig.signatureScope ? sig.signatureScope : sig.bodyScope; - DenseHashSet excludeList{nullptr}; + NotNull c = addConstraint(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature}); + getMutable(generalizedType)->setOwner(c); + + Constraint* previous = nullptr; + forEachConstraint(start, end, this, [&c, &previous](const ConstraintPtr& constraint) { + c->dependencies.push_back(NotNull{constraint.get()}); + + if (auto psc = get(*constraint); psc && psc->returns) + { + if (previous) + constraint->dependencies.push_back(NotNull{previous}); + + previous = constraint.get(); + } + }); + } DefId def = dfg->getDef(function->name); std::optional existingFunctionTy = follow(lookup(scope, function->name->location, def)); - if (get(existingFunctionTy) && sigFullyDefined) - emplaceType(asMutable(*existingFunctionTy), sig.signature); - if (AstExprLocal* localName = function->name->as()) { - if (existingFunctionTy) - { - addConstraint(scope, function->name->location, SubtypeConstraint{generalizedType, *existingFunctionTy}); - - Symbol sym{localName->local}; - scope->bindings[sym].typeId = generalizedType; - } - else - scope->bindings[localName->local] = Binding{generalizedType, localName->location}; + visitLValue(scope, localName, generalizedType); scope->bindings[localName->local] = Binding{sig.signature, localName->location}; scope->lvalueTypes[def] = sig.signature; - scope->rvalueRefinements[def] = sig.signature; } else if (AstExprGlobal* globalName = function->name->as()) { if (!existingFunctionTy) ice->ice("prepopulateGlobalScope did not populate a global name", globalName->location); - if (!sigFullyDefined) - generalizedType = *existingFunctionTy; + // Sketchy: We're specifically looking for BlockedTypes that were + // initially created by ConstraintGenerator::prepopulateGlobalScope. + if (auto bt = get(*existingFunctionTy); bt && nullptr == bt->getOwner()) + emplaceType(asMutable(*existingFunctionTy), generalizedType); scope->bindings[globalName->name] = Binding{sig.signature, globalName->location}; scope->lvalueTypes[def] = sig.signature; - scope->rvalueRefinements[def] = sig.signature; } else if (AstExprIndexName* indexName = function->name->as()) { - Checkpoint check1 = checkpoint(this); - auto [_, lvalueType] = checkLValue(scope, indexName); - Checkpoint check2 = checkpoint(this); - - forEachConstraint(check1, check2, this, [&excludeList](const ConstraintPtr& c) { - excludeList.insert(c.get()); - }); - - // TODO figure out how to populate the location field of the table Property. - - if (lvalueType && *lvalueType != generalizedType) - { - LUAU_ASSERT(get(lvalueType)); - emplaceType(asMutable(*lvalueType), generalizedType); - } + visitLValue(scope, indexName, generalizedType); } else if (AstExprError* err = function->name->as()) { @@ -1035,48 +1095,6 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f scope->rvalueRefinements[def] = generalizedType; - checkFunctionBody(sig.bodyScope, function->func); - Checkpoint end = checkpoint(this); - - if (!sigFullyDefined) - { - NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; - std::unique_ptr c = - std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature}); - - Constraint* previous = nullptr; - forEachConstraint(start, end, this, [&c, &excludeList, &previous](const ConstraintPtr& constraint) { - if (!excludeList.contains(constraint.get())) - c->dependencies.push_back(NotNull{constraint.get()}); - - if (auto psc = get(*constraint); psc && psc->returns) - { - if (previous) - constraint->dependencies.push_back(NotNull{previous}); - - previous = constraint.get(); - } - }); - - - // We need to check if the blocked type has no owner here because - // if a function is defined twice anywhere in the program like: - // `function f() end` and then later like `function f() end` - // Then there will be exactly one definition in the scope for it because it's a global - // (this is the same as writing f = function() end) - // Therefore, when we visit() the multiple different expression of this global variable - // They will all be aliased to the same blocked type, which means we can create multiple constraints - // for the same blocked type. - if (auto blocked = getMutable(generalizedType); blocked && !blocked->getOwner()) - blocked->setOwner(addConstraint(scope, std::move(c))); - } - - if (BlockedType* bt = getMutable(follow(existingFunctionTy)); bt && !bt->getOwner()) - { - auto uc = addConstraint(scope, function->name->location, Unpack1Constraint{*existingFunctionTy, generalizedType}); - bt->setOwner(uc); - } - return ControlFlow::None; } @@ -1130,38 +1148,37 @@ static void bindFreeType(TypeId a, TypeId b) ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatAssign* assign) { - std::vector upperBounds; - upperBounds.reserve(assign->vars.size); + TypePackId resultPack = checkPack(scope, assign->values).tp; - std::vector typeStates; - typeStates.reserve(assign->vars.size); + std::vector valueTypes; + valueTypes.reserve(assign->vars.size); - Checkpoint lvalueBeginCheckpoint = checkpoint(this); - - for (AstExpr* lvalue : assign->vars) + auto [head, tail] = flatten(resultPack); + if (head.size() >= assign->vars.size) { - auto [upperBound, typeState] = checkLValue(scope, lvalue); - upperBounds.push_back(upperBound.value_or(builtinTypes->unknownType)); - typeStates.push_back(typeState.value_or(builtinTypes->unknownType)); + // If the resultPack is definitely long enough for each variable, we can + // skip the UnpackConstraint and use the result types directly. + + for (size_t i = 0; i < assign->vars.size; ++i) + valueTypes.push_back(head[i]); + } + else + { + // We're not sure how many types are produced by the right-side + // expressions. We'll use an UnpackConstraint to defer this until + // later. + for (size_t i = 0; i < assign->vars.size; ++i) + valueTypes.push_back(arena->addType(BlockedType{})); + + auto uc = addConstraint(scope, assign->location, UnpackConstraint{valueTypes, resultPack}); + + for (TypeId t : valueTypes) + getMutable(t)->setOwner(uc); } - Checkpoint lvalueEndCheckpoint = checkpoint(this); - - TypePackId resultPack = checkPack(scope, assign->values).tp; - auto uc = addConstraint(scope, assign->location, UnpackConstraint{arena->addTypePack(typeStates), resultPack, /*resultIsLValue*/ true}); - forEachConstraint(lvalueBeginCheckpoint, lvalueEndCheckpoint, this, [uc](const ConstraintPtr& constraint) { - uc->dependencies.push_back(NotNull{constraint.get()}); - }); - - auto psc = addConstraint(scope, assign->location, PackSubtypeConstraint{resultPack, arena->addTypePack(std::move(upperBounds))}); - psc->dependencies.push_back(uc); - - for (TypeId assignee : typeStates) + for (size_t i = 0; i < assign->vars.size; ++i) { - auto blocked = getMutable(assignee); - - if (blocked && !blocked->getOwner()) - blocked->setOwner(uc); + visitLValue(scope, assign->vars.data[i], valueTypes[i]); } return ControlFlow::None; @@ -1171,25 +1188,13 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatCompoundAss { AstExprBinary binop = AstExprBinary{assign->location, assign->op, assign->var, assign->value}; TypeId resultTy = check(scope, &binop).ty; + module->astCompoundAssignResultTypes[assign] = resultTy; - auto [upperBound, typeState] = checkLValue(scope, assign->var); + TypeId lhsType = check(scope, assign->var).ty; + visitLValue(scope, assign->var, lhsType); - Constraint* sc = nullptr; - if (upperBound) - sc = addConstraint(scope, assign->location, SubtypeConstraint{resultTy, *upperBound}); - - if (typeState) - { - NotNull uc = addConstraint(scope, assign->location, Unpack1Constraint{*typeState, resultTy, /*resultIsLValue=*/true}); - if (auto blocked = getMutable(*typeState); blocked && !blocked->getOwner()) - blocked->setOwner(uc); - - if (sc) - uc->dependencies.push_back(NotNull{sc}); - } - - DefId def = dfg->getDef(assign->var); - scope->lvalueTypes[def] = resultTy; + follow(lhsType); + follow(resultTy); return ControlFlow::None; } @@ -1385,19 +1390,34 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas ftv->argTypes = addTypePack({classTy}, ftv->argTypes); ftv->hasSelf = true; + + if (FFlag::LuauDeclarationExtraPropData) + { + FunctionDefinition defn; + + defn.definitionModuleName = module->name; + defn.definitionLocation = prop.location; + // No data is preserved for varargLocation + defn.originalNameLocation = prop.nameLocation; + + ftv->definition = defn; + } } } - if (ctv->props.count(propName) == 0) + TableType::Props& props = assignToMetatable ? metatable->props : ctv->props; + + if (props.count(propName) == 0) { - if (assignToMetatable) - metatable->props[propName] = {propTy}; + if (FFlag::LuauDeclarationExtraPropData) + props[propName] = {propTy, /*deprecated*/ false, /*deprecatedSuggestion*/ "", prop.location}; else - ctv->props[propName] = {propTy}; + props[propName] = {propTy}; } - else + else if (FFlag::LuauDeclarationExtraPropData) { - TypeId currentTy = assignToMetatable ? metatable->props[propName].type() : ctv->props[propName].type(); + Luau::Property& prop = props[propName]; + TypeId currentTy = prop.type(); // We special-case this logic to keep the intersection flat; otherwise we // would create a ton of nested intersection types. @@ -1407,19 +1427,40 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas options.push_back(propTy); TypeId newItv = arena->addType(IntersectionType{std::move(options)}); - if (assignToMetatable) - metatable->props[propName] = {newItv}; - else - ctv->props[propName] = {newItv}; + prop.readTy = newItv; + prop.writeTy = newItv; } else if (get(currentTy)) { TypeId intersection = arena->addType(IntersectionType{{currentTy, propTy}}); - if (assignToMetatable) - metatable->props[propName] = {intersection}; - else - ctv->props[propName] = {intersection}; + prop.readTy = intersection; + prop.writeTy = intersection; + } + else + { + reportError(declaredClass->location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + } + } + else + { + TypeId currentTy = props[propName].type(); + + // We special-case this logic to keep the intersection flat; otherwise we + // would create a ton of nested intersection types. + if (const IntersectionType* itv = get(currentTy)) + { + std::vector options = itv->parts; + options.push_back(propTy); + TypeId newItv = arena->addType(IntersectionType{std::move(options)}); + + props[propName] = {newItv}; + } + else if (get(currentTy)) + { + TypeId intersection = arena->addType(IntersectionType{{currentTy, propTy}}); + + props[propName] = {intersection}; } else { @@ -1456,9 +1497,20 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareFunc TypePackId paramPack = resolveTypePack(funScope, global->params, /* inTypeArguments */ false); TypePackId retPack = resolveTypePack(funScope, global->retTypes, /* inTypeArguments */ false); - TypeId fnType = arena->addType(FunctionType{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack}); + + FunctionDefinition defn; + + if (FFlag::LuauDeclarationExtraPropData) + { + defn.definitionModuleName = module->name; + defn.definitionLocation = global->location; + defn.varargLocation = global->vararg ? std::make_optional(global->varargLocation) : std::nullopt; + defn.originalNameLocation = global->nameLocation; + } + + TypeId fnType = arena->addType(FunctionType{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack, defn}); FunctionType* ftv = getMutable(fnType); - ftv->isCheckedFunction = global->checkedFunction; + ftv->isCheckedFunction = FFlag::LuauAttributeSyntax ? global->isCheckedFunction() : false; ftv->argNames.reserve(global->paramNames.size); for (const auto& el : global->paramNames) @@ -1664,9 +1716,8 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* mt = arena->addType(BlockedType{}); unpackedTypes.emplace_back(mt); - TypePackId mtPack = arena->addTypePack(std::move(unpackedTypes)); - auto c = addConstraint(scope, call->location, UnpackConstraint{mtPack, *argTail}); + auto c = addConstraint(scope, call->location, UnpackConstraint{unpackedTypes, *argTail}); getMutable(mt)->setOwner(c); if (auto b = getMutable(target); b && b->getOwner() == nullptr) b->setOwner(c); @@ -1900,16 +1951,44 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprGlobal* globa return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)}; } else - { - reportError(global->location, UnknownSymbol{global->name.value, UnknownSymbol::Binding}); return Inference{builtinTypes->errorRecoveryType()}; - } } -Inference ConstraintGenerator::checkIndexName(const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation) +Inference ConstraintGenerator::checkIndexName( + const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation) { TypeId obj = check(scope, indexee).ty; - TypeId result = arena->addType(BlockedType{}); + TypeId result = nullptr; + + // We optimize away the HasProp constraint in simple cases so that we can + // reason about updates to unsealed tables more accurately. + + const TableType* tt = getTableType(obj); + + // This is a little bit iffy but I *believe* it is okay because, if the + // local's domain is going to be extended at all, it will be someplace after + // the current lexical position within the script. + if (!tt) + { + if (TypeIds* localDomain = localTypes.find(obj); localDomain && 1 == localDomain->size()) + tt = getTableType(*localDomain->begin()); + } + + if (tt) + { + auto it = tt->props.find(index); + if (it != tt->props.end() && it->second.readTy.has_value()) + result = *it->second.readTy; + } + + if (!result) + { + result = arena->addType(BlockedType{}); + + auto c = addConstraint( + scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)}); + getMutable(result)->setOwner(c); + } if (key) { @@ -1919,10 +1998,6 @@ Inference ConstraintGenerator::checkIndexName(const ScopePtr& scope, const Refin scope->rvalueRefinements[key->def] = result; } - auto c = - addConstraint(scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)}); - getMutable(result)->setOwner(c); - if (key) return Inference{result, refinementArena.proposition(key, builtinTypes->truthyType)}; else @@ -2012,35 +2087,17 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprUnary* unary) { case AstExprUnary::Op::Not: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.notFamily}, - {operandType}, - {}, - }, - scope, unary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().notFamily, {operandType}, {}, scope, unary->location); return Inference{resultType, refinementArena.negation(refinement)}; } case AstExprUnary::Op::Len: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.lenFamily}, - {operandType}, - {}, - }, - scope, unary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().lenFamily, {operandType}, {}, scope, unary->location); return Inference{resultType, refinementArena.negation(refinement)}; } case AstExprUnary::Op::Minus: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.unmFamily}, - {operandType}, - {}, - }, - scope, unary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().unmFamily, {operandType}, {}, scope, unary->location); return Inference{resultType, refinementArena.negation(refinement)}; } default: // msvc can't prove that this is exhaustive. @@ -2056,168 +2113,97 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprBinary* binar { case AstExprBinary::Op::Add: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.addFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().addFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Sub: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.subFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().subFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Mul: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.mulFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().mulFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Div: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.divFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().divFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::FloorDiv: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.idivFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().idivFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Pow: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.powFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().powFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Mod: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.modFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().modFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Concat: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.concatFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().concatFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::And: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.andFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().andFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Or: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.orFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().orFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareLt: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.ltFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().ltFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareGe: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.ltFamily}, - {rightType, leftType}, // lua decided that `__ge(a, b)` is instead just `__lt(b, a)` - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().ltFamily, + {rightType, leftType}, // lua decided that `__ge(a, b)` is instead just `__lt(b, a)` + {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareLe: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.leFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().leFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareGt: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.leFamily}, - {rightType, leftType}, // lua decided that `__gt(a, b)` is instead just `__le(b, a)` - {}, - }, - scope, binary->location); + TypeId resultType = createTypeFamilyInstance( +builtinTypeFunctions().leFamily, + {rightType, leftType}, // lua decided that `__gt(a, b)` is instead just `__le(b, a)` + {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareEq: case AstExprBinary::Op::CompareNe: { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.eqFamily}, - {leftType, rightType}, - {}, - }, - scope, binary->location); + DefId leftDef = dfg->getDef(binary->left); + DefId rightDef = dfg->getDef(binary->right); + bool leftSubscripted = containsSubscriptedDefinition(leftDef); + bool rightSubscripted = containsSubscriptedDefinition(rightDef); + + if (leftSubscripted && rightSubscripted) + { + // we cannot add nil in this case because then we will blindly accept comparisons that we should not. + } + else if (leftSubscripted) + leftType = makeUnion(scope, binary->location, leftType, builtinTypes->nilType); + else if (rightSubscripted) + rightType = makeUnion(scope, binary->location, rightType, builtinTypes->nilType); + + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().eqFamily, {leftType, rightType}, {}, scope, binary->location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Op__Count: @@ -2371,26 +2357,25 @@ std::tuple ConstraintGenerator::checkBinary( } } -ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePtr& scope, AstExpr* expr) +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExpr* expr, TypeId rhsType) { - if (auto local = expr->as()) - return checkLValue(scope, local); - else if (auto global = expr->as()) - return checkLValue(scope, global); - else if (auto indexName = expr->as()) - return checkLValue(scope, indexName); - else if (auto indexExpr = expr->as()) - return checkLValue(scope, indexExpr); - else if (auto error = expr->as()) + if (auto e = expr->as()) + visitLValue(scope, e, rhsType); + else if (auto e = expr->as()) + visitLValue(scope, e, rhsType); + else if (auto e = expr->as()) + visitLValue(scope, e, rhsType); + else if (auto e = expr->as()) + visitLValue(scope, e, rhsType); + else if (auto e = expr->as()) { - check(scope, error); - return {builtinTypes->errorRecoveryType(), builtinTypes->errorRecoveryType()}; + // Nothing? } else - ice->ice("checkLValue is inexhaustive"); + ice->ice("Unexpected lvalue expression", expr->location); } -ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePtr& scope, AstExprLocal* local) +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local, TypeId rhsType) { std::optional annotatedTy = scope->lookup(local->local); LUAU_ASSERT(annotatedTy); @@ -2400,18 +2385,14 @@ ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePt if (ty) { - if (auto lt = getMutable(*ty)) - ++lt->blockCount; - else if (auto ut = getMutable(*ty)) - { - for (TypeId optTy : ut->options) - if (auto lt = getMutable(optTy)) - ++lt->blockCount; - } + TypeIds* localDomain = localTypes.find(*ty); + if (localDomain) + localDomain->insert(rhsType); } else { - ty = arena->addType(LocalType{builtinTypes->neverType, /* blockCount */ 1, local->local->name.value}); + ty = arena->addType(BlockedType{}); + localTypes[*ty].insert(rhsType); if (annotatedTy) { @@ -2431,181 +2412,63 @@ ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePt scope->lvalueTypes[defId] = *ty; } - // TODO: Need to clip this, but this requires more code to be reworked first before we can clip this. - std::optional assignedTy = arena->addType(BlockedType{}); - - auto unpackC = addConstraint(scope, local->location, Unpack1Constraint{*ty, *assignedTy, /*resultIsLValue*/ true}); - - if (auto blocked = get(*ty)) - { - if (blocked->getOwner()) - unpackC->dependencies.push_back(NotNull{blocked->getOwner()}); - else if (auto blocked = getMutable(*ty)) - blocked->setOwner(unpackC); - } - recordInferredBinding(local->local, *ty); - return {annotatedTy, assignedTy}; + if (annotatedTy) + addConstraint(scope, local->location, SubtypeConstraint{rhsType, *annotatedTy}); + + if (TypeIds* localDomain = localTypes.find(*ty)) + localDomain->insert(rhsType); } -ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePtr& scope, AstExprGlobal* global) +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType) { std::optional annotatedTy = scope->lookup(Symbol{global->name}); if (annotatedTy) - return {annotatedTy, arena->addType(BlockedType{})}; - else - return {annotatedTy, std::nullopt}; + { + DefId def = dfg->getDef(global); + rootScope->lvalueTypes[def] = rhsType; + + addConstraint(scope, global->location, SubtypeConstraint{rhsType, *annotatedTy}); + } } -ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePtr& scope, AstExprIndexName* indexName) +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexName* expr, TypeId rhsType) { - return updateProperty(scope, indexName); + TypeId lhsTy = check(scope, expr->expr).ty; + TypeId propTy = arena->addType(BlockedType{}); + module->astTypes[expr] = propTy; + + bool incremented = recordPropertyAssignment(lhsTy); + + auto apc = addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, propTy, incremented}); + getMutable(propTy)->setOwner(apc); } -ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr) +void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* expr, TypeId rhsType) { - return updateProperty(scope, indexExpr); -} - -/** - * This function is mostly about identifying properties that are being inserted into unsealed tables. - * - * If expr has the form name.a.b.c - */ -ConstraintGenerator::LValueBounds ConstraintGenerator::updateProperty(const ScopePtr& scope, AstExpr* expr) -{ - // There are a bunch of cases where we realize that this is not the kind of - // assignment that potentially changes the shape of a table. When we - // encounter them, we call this to fall back and do the "usual thing." - auto fallback = [&]() -> LValueBounds { - TypeId resTy = check(scope, expr).ty; - return {resTy, std::nullopt}; - }; - - LUAU_ASSERT(expr->is() || expr->is()); - - if (auto indexExpr = expr->as(); indexExpr && !indexExpr->index->is()) + if (auto constantString = expr->index->as()) { - // An indexer is only interesting in an lvalue-ey way if it is at the - // tail of an expression. - // - // If the indexer is not at the tail, then we are not interested in - // augmenting the lhs data structure with a new indexer. Constraint - // generation can treat it as an ordinary lvalue. - // - // eg - // - // a.b.c[1] = 44 -- lvalue - // a.b[4].c = 2 -- rvalue + TypeId lhsTy = check(scope, expr->expr).ty; + TypeId propTy = arena->addType(BlockedType{}); + module->astTypes[expr] = propTy; + module->astTypes[expr->index] = builtinTypes->stringType; // FIXME? Singleton strings exist. + std::string propName{constantString->value.data, constantString->value.size}; - TypeId subjectType = check(scope, indexExpr->expr).ty; - TypeId indexType = check(scope, indexExpr->index).ty; - TypeId assignedTy = arena->addType(BlockedType{}); - auto sic = addConstraint(scope, expr->location, SetIndexerConstraint{subjectType, indexType, assignedTy}); - getMutable(assignedTy)->setOwner(sic); + bool incremented = recordPropertyAssignment(lhsTy); - module->astTypes[expr] = assignedTy; + auto apc = addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, propTy, incremented}); + getMutable(propTy)->setOwner(apc); - return {assignedTy, assignedTy}; + return; } - Symbol sym; - const Def* def = nullptr; - std::vector segments; - std::vector exprs; - - AstExpr* e = expr; - while (e) - { - if (auto global = e->as()) - { - sym = global->name; - def = dfg->getDef(global); - break; - } - else if (auto local = e->as()) - { - sym = local->local; - def = dfg->getDef(local); - break; - } - else if (auto indexName = e->as()) - { - segments.push_back(indexName->index.value); - exprs.push_back(e); - e = indexName->expr; - } - else if (auto indexExpr = e->as()) - { - if (auto strIndex = indexExpr->index->as()) - { - // We need to populate astTypes for the index value. - check(scope, indexExpr->index); - - segments.push_back(std::string(strIndex->value.data, strIndex->value.size)); - exprs.push_back(e); - e = indexExpr->expr; - } - else - { - return fallback(); - } - } - else - { - return fallback(); - } - } - - LUAU_ASSERT(!segments.empty()); - - std::reverse(begin(segments), end(segments)); - std::reverse(begin(exprs), end(exprs)); - - LUAU_ASSERT(def); - std::optional> lookupResult = scope->lookupEx(NotNull{def}); - if (!lookupResult) - return fallback(); - - const auto [subjectType, subjectScope] = *lookupResult; - - std::vector segmentStrings(begin(segments), end(segments)); - - TypeId updatedType = arena->addType(BlockedType{}); - TypeId assignedTy = arena->addType(BlockedType{}); - auto setC = addConstraint(scope, expr->location, SetPropConstraint{updatedType, subjectType, std::move(segmentStrings), assignedTy}); - getMutable(updatedType)->setOwner(setC); - - TypeId prevSegmentTy = updatedType; - for (size_t i = 0; i < segments.size(); ++i) - { - TypeId segmentTy = arena->addType(BlockedType{}); - module->astTypes[exprs[i]] = segmentTy; - ValueContext ctx = i == segments.size() - 1 ? ValueContext::LValue : ValueContext::RValue; - auto hasC = addConstraint(scope, expr->location, HasPropConstraint{segmentTy, prevSegmentTy, segments[i], ctx, inConditional(typeContext)}); - getMutable(segmentTy)->setOwner(hasC); - setC->dependencies.push_back(hasC); - prevSegmentTy = segmentTy; - } - - module->astTypes[expr] = prevSegmentTy; - module->astTypes[e] = updatedType; - - if (!subjectType->persistent) - { - subjectScope->bindings[sym].typeId = updatedType; - - // This can fail if the user is erroneously trying to augment a builtin - // table like os or string. - if (auto key = dfg->getRefinementKey(e)) - { - subjectScope->lvalueTypes[key->def] = updatedType; - subjectScope->rvalueRefinements[key->def] = updatedType; - } - } - - return {assignedTy, assignedTy}; + TypeId lhsTy = check(scope, expr->expr).ty; + TypeId indexTy = check(scope, expr->index).ty; + TypeId propTy = arena->addType(BlockedType{}); + module->astTypes[expr] = propTy; + auto aic = addConstraint(scope, expr->location, AssignIndexConstraint{lhsTy, indexTy, rhsType, propTy}); + getMutable(propTy)->setOwner(aic); } Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) @@ -2645,7 +2508,8 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, if (AstExprConstantString* key = item.key->as()) { - ttv->props[key->value.begin()] = {itemTy}; + std::string propName{key->value.data, key->value.size}; + ttv->props[propName] = {itemTy}; } else { @@ -3061,7 +2925,7 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool // TODO: FunctionType needs a pointer to the scope so that we know // how to quantify/instantiate it. FunctionType ftv{TypeLevel{}, scope.get(), {}, {}, argTypes, returnTypes}; - ftv.isCheckedFunction = fn->checkedFunction; + ftv.isCheckedFunction = FFlag::LuauAttributeSyntax ? fn->isCheckedFunction() : false; // This replicates the behavior of the appropriate FunctionType // constructors. @@ -3263,8 +3127,7 @@ Inference ConstraintGenerator::flattenPack(const ScopePtr& scope, Location locat return Inference{*f, refinement}; TypeId typeResult = arena->addType(BlockedType{}); - TypePackId resultPack = arena->addTypePack({typeResult}, arena->freshTypePack(scope.get())); - auto c = addConstraint(scope, location, UnpackConstraint{resultPack, tp}); + auto c = addConstraint(scope, location, UnpackConstraint{{typeResult}, tp}); getMutable(typeResult)->setOwner(c); return Inference{typeResult, refinement}; @@ -3288,26 +3151,19 @@ void ConstraintGenerator::reportCodeTooComplex(Location location) TypeId ConstraintGenerator::makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs) { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.unionFamily}, - {lhs, rhs}, - {}, - }, - scope, location); + if (get(follow(lhs))) + return rhs; + if (get(follow(rhs))) + return lhs; + + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().unionFamily, {lhs, rhs}, {}, scope, location); return resultType; } TypeId ConstraintGenerator::makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs) { - TypeId resultType = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.intersectFamily}, - {lhs, rhs}, - {}, - }, - scope, location); + TypeId resultType = createTypeFamilyInstance(builtinTypeFunctions().intersectFamily, {lhs, rhs}, {}, scope, location); return resultType; } @@ -3368,6 +3224,46 @@ void ConstraintGenerator::prepopulateGlobalScope(const ScopePtr& globalScope, As program->visit(&gp); } +bool ConstraintGenerator::recordPropertyAssignment(TypeId ty) +{ + DenseHashSet seen{nullptr}; + VecDeque queue; + + queue.push_back(ty); + + bool incremented = false; + + while (!queue.empty()) + { + const TypeId t = follow(queue.front()); + queue.pop_front(); + + if (seen.find(t)) + continue; + seen.insert(t); + + if (auto tt = getMutable(t); tt && tt->state == TableState::Unsealed) + { + tt->remainingProps += 1; + incremented = true; + } + else if (auto mt = get(t)) + queue.push_back(mt->table); + else if (TypeIds* localDomain = localTypes.find(t)) + { + for (TypeId domainTy : *localDomain) + queue.push_back(domainTy); + } + else if (auto ut = get(t)) + { + for (TypeId part : ut) + queue.push_back(part); + } + } + + return incremented; +} + void ConstraintGenerator::recordInferredBinding(AstLocal* local, TypeId ty) { if (InferredBinding* ib = inferredBindings.find(local)) @@ -3385,13 +3281,7 @@ void ConstraintGenerator::fillInInferredBindings(const ScopePtr& globalScope, As scope->bindings[symbol] = Binding{tys.front(), location}; else { - TypeId ty = createFamilyInstance( - TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.unionFamily}, - std::move(tys), - {}, - }, - globalScope, location); + TypeId ty = createTypeFamilyInstance(builtinTypeFunctions().unionFamily, std::move(tys), {}, globalScope, location); scope->bindings[symbol] = Binding{ty, location}; } @@ -3461,9 +3351,10 @@ std::vector> ConstraintGenerator::getExpectedCallTypesForF return expectedTypes; } -TypeId ConstraintGenerator::createFamilyInstance(TypeFamilyInstanceType instance, const ScopePtr& scope, Location location) +TypeId ConstraintGenerator::createTypeFamilyInstance( + const TypeFamily& family, std::vector typeArguments, std::vector packArguments, const ScopePtr& scope, Location location) { - TypeId result = arena->addType(std::move(instance)); + TypeId result = arena->addTypeFamily(family, typeArguments, packArguments); addConstraint(scope, location, ReduceConstraint{result}); return result; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 6a9dd031..8756ec44 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -1,10 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/ConstraintSolver.h" #include "Luau/Anyification.h" #include "Luau/ApplyTypeFunction.h" #include "Luau/Common.h" -#include "Luau/ConstraintSolver.h" #include "Luau/DcrLogger.h" +#include "Luau/Generalization.h" #include "Luau/Instantiation.h" #include "Luau/Instantiation2.h" #include "Luau/Location.h" @@ -21,12 +22,13 @@ #include "Luau/TypeFwd.h" #include "Luau/TypeUtils.h" #include "Luau/Unifier2.h" -#include "Luau/VecDeque.h" #include "Luau/VisitType.h" + #include #include LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverIncludeDependencies, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings, false); LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500); @@ -65,7 +67,11 @@ size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const [[maybe_unused]] static bool canMutate(TypeId ty, NotNull constraint) { if (auto blocked = get(ty)) - return blocked->getOwner() == constraint; + { + Constraint* owner = blocked->getOwner(); + LUAU_ASSERT(owner); + return owner == constraint; + } return true; } @@ -74,7 +80,11 @@ size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const [[maybe_unused]] static bool canMutate(TypePackId tp, NotNull constraint) { if (auto blocked = get(tp)) - return blocked->owner == nullptr || blocked->owner == constraint; + { + Constraint* owner = blocked->owner; + LUAU_ASSERT(owner); + return owner == constraint; + } return true; } @@ -204,6 +214,12 @@ static std::pair, std::vector> saturateArguments saturatedPackArguments.push_back(builtinTypes->errorRecoveryTypePack()); } + for (TypeId& arg : saturatedTypeArguments) + arg = follow(arg); + + for (TypePackId& pack : saturatedPackArguments) + pack = follow(pack); + // At this point, these two conditions should be true. If they aren't we // will run into access violations. LUAU_ASSERT(saturatedTypeArguments.size() == fn.typeParams.size()); @@ -251,6 +267,15 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) auto it = cs->blockedConstraints.find(c); int blockCount = it == cs->blockedConstraints.end() ? 0 : int(it->second); printf("\t%d\t%s\n", blockCount, toString(*c, opts).c_str()); + + if (FFlag::DebugLuauLogSolverIncludeDependencies) + { + for (NotNull dep : c->dependencies) + { + if (std::find(cs->unsolvedConstraints.begin(), cs->unsolvedConstraints.end(), dep) != cs->unsolvedConstraints.end()) + printf("\t\t|\t%s\n", toString(*dep, opts).c_str()); + } + } } } @@ -305,7 +330,7 @@ ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNullgetFreeTypes()) + for (auto ty : c->getMaybeMutatedFreeTypes()) { // increment the reference count for `ty` auto [refCount, _] = unresolvedConstraints.try_insert(ty, 0); @@ -394,13 +419,19 @@ void ConstraintSolver::run() unsolvedConstraints.erase(unsolvedConstraints.begin() + i); // decrement the referenced free types for this constraint if we dispatched successfully! - for (auto ty : c->getFreeTypes()) + for (auto ty : c->getMaybeMutatedFreeTypes()) { - // this is a little weird, but because we're only counting free types in subtyping constraints, - // some constraints (like unpack) might actually produce _more_ references to a free type. size_t& refCount = unresolvedConstraints[ty]; if (refCount > 0) refCount -= 1; + + // We have two constraints that are designed to wait for the + // refCount on a free type to be equal to 1: the + // PrimitiveTypeConstraint and ReduceConstraint. We + // therefore wake any constraint waiting for a free type's + // refcount to be 1 or 0. + if (refCount <= 1) + unblock(ty, Location{}); } if (logger) @@ -455,6 +486,12 @@ void ConstraintSolver::run() progress |= runSolverPass(true); } while (progress); + // After we have run all the constraints, type families should be generalized + // At this point, we can try to perform one final simplification to suss out + // whether type families are truly uninhabited or if they can reduce + + finalizeTypeFamilies(); + if (FFlag::DebugLuauLogSolver || FFlag::DebugLuauLogBindings) dumpBindings(rootScope, opts); @@ -464,6 +501,25 @@ void ConstraintSolver::run() } } +void ConstraintSolver::finalizeTypeFamilies() +{ + // At this point, we've generalized. Let's try to finish reducing as much as we can, we'll leave warning to the typechecker + for (auto [t, constraint] : typeFamiliesToFinalize) + { + TypeId ty = follow(t); + if (get(ty)) + { + FamilyGraphReductionResult result = + reduceFamilies(t, constraint->location, TypeFamilyContext{NotNull{this}, constraint->scope, NotNull{constraint}}, true); + + for (TypeId r : result.reducedTypes) + unblock(r, constraint->location); + for (TypePackId r : result.reducedPacks) + unblock(r, constraint->location); + } + } +} + bool ConstraintSolver::isDone() { return unsolvedConstraints.empty(); @@ -480,6 +536,56 @@ struct TypeAndLocation } // namespace +void ConstraintSolver::bind(NotNull constraint, TypeId ty, TypeId boundTo) +{ + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(canMutate(ty, constraint)); + + boundTo = follow(boundTo); + if (get(ty) && ty == boundTo) + return emplace(constraint, ty, constraint->scope, builtinTypes->neverType, builtinTypes->unknownType); + + shiftReferences(ty, boundTo); + emplaceType(asMutable(ty), boundTo); + unblock(ty, constraint->location); +} + +void ConstraintSolver::bind(NotNull constraint, TypePackId tp, TypePackId boundTo) +{ + LUAU_ASSERT(get(tp) || get(tp)); + LUAU_ASSERT(canMutate(tp, constraint)); + + boundTo = follow(boundTo); + LUAU_ASSERT(tp != boundTo); + + emplaceTypePack(asMutable(tp), boundTo); + unblock(tp, constraint->location); +} + +template +void ConstraintSolver::emplace(NotNull constraint, TypeId ty, Args&&... args) +{ + static_assert(!std::is_same_v, "cannot use `emplace`! use `bind`"); + + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(canMutate(ty, constraint)); + + emplaceType(asMutable(ty), std::forward(args)...); + unblock(ty, constraint->location); +} + +template +void ConstraintSolver::emplace(NotNull constraint, TypePackId tp, Args&&... args) +{ + static_assert(!std::is_same_v, "cannot use `emplace`! use `bind`"); + + LUAU_ASSERT(get(tp) || get(tp)); + LUAU_ASSERT(canMutate(tp, constraint)); + + emplaceTypePack(asMutable(tp), std::forward(args)...); + unblock(tp, constraint->location); +} + bool ConstraintSolver::tryDispatch(NotNull constraint, bool force) { if (!force && isBlocked(constraint)) @@ -507,15 +613,13 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*fcc, constraint); else if (auto hpc = get(*constraint)) success = tryDispatch(*hpc, constraint); - else if (auto spc = get(*constraint)) - success = tryDispatch(*spc, constraint); else if (auto spc = get(*constraint)) success = tryDispatch(*spc, constraint); - else if (auto spc = get(*constraint)) - success = tryDispatch(*spc, constraint, force); - else if (auto uc = get(*constraint)) + else if (auto uc = get(*constraint)) success = tryDispatch(*uc, constraint); - else if (auto uc = get(*constraint)) + else if (auto uc = get(*constraint)) + success = tryDispatch(*uc, constraint); + else if (auto uc = get(*constraint)) success = tryDispatch(*uc, constraint); else if (auto rc = get(*constraint)) success = tryDispatch(*rc, constraint, force); @@ -526,9 +630,6 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo else LUAU_ASSERT(false); - if (success) - unblock(constraint); - return success; } @@ -567,9 +668,7 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull generalized; - Unifier2 u2{NotNull{arena}, builtinTypes, constraint->scope, NotNull{&iceReporter}}; - - std::optional generalizedTy = u2.generalize(c.sourceType); + std::optional generalizedTy = generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, c.sourceType); if (generalizedTy) generalized = QuantifierResult{*generalizedTy}; // FIXME insertedGenerics and insertedGenericPacks else @@ -578,7 +677,7 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull(generalizedType)) - bindBlockedType(generalizedType, generalized->result, c.sourceType, constraint); + bind(constraint, generalizedType, generalized->result); else unify(constraint, generalizedType, generalized->result); @@ -591,17 +690,11 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNulllocation); - emplaceType(asMutable(c.generalizedType), builtinTypes->errorType); + bind(constraint, c.generalizedType, builtinTypes->errorRecoveryType()); } - unblock(c.generalizedType, constraint->location); - unblock(c.sourceType, constraint->location); - for (TypeId ty : c.interiorTypes) - { - u2.generalize(ty); - unblock(ty, constraint->location); - } + generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty, /* avoidSealingTables */ false); return true; } @@ -665,14 +758,44 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullanyTypePack, c.variables); + for (TypeId ty : c.variables) + unify(constraint, builtinTypes->errorRecoveryType(), ty); return true; } TypeId nextTy = follow(iterator.head[0]); if (get(nextTy)) - return block_(nextTy); + { + TypeId keyTy = freshType(arena, builtinTypes, constraint->scope); + TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); + TypeId tableTy = + arena->addType(TableType{TableType::Props{}, TableIndexer{keyTy, valueTy}, TypeLevel{}, constraint->scope, TableState::Free}); + + unify(constraint, nextTy, tableTy); + + auto it = begin(c.variables); + auto endIt = end(c.variables); + + if (it != endIt) + { + bind(constraint, *it, keyTy); + ++it; + } + if (it != endIt) + { + bind(constraint, *it, valueTy); + ++it; + } + + while (it != endIt) + { + bind(constraint, *it, builtinTypes->nilType); + ++it; + } + + return true; + } if (get(nextTy)) { @@ -680,11 +803,7 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull= 2) tableTy = iterator.head[1]; - TypeId firstIndexTy = builtinTypes->nilType; - if (iterator.head.size() >= 3) - firstIndexTy = iterator.head[2]; - - return tryDispatchIterableFunction(nextTy, tableTy, firstIndexTy, c, constraint, force); + return tryDispatchIterableFunction(nextTy, tableTy, c, constraint, force); } else @@ -720,8 +839,6 @@ bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull(follow(c.target)); if (!petv) { - unblock(c.target, constraint->location); + unblock(c.target, constraint->location); // TODO: do we need this? any re-entrancy? return true; } auto bindResult = [this, &c, constraint](TypeId result) { LUAU_ASSERT(get(c.target)); - emplaceType(asMutable(c.target), result); - unblock(c.target, constraint->location); + shiftReferences(c.target, result); + bind(constraint, c.target, result); }; std::optional tf = (petv->prefix) ? constraint->scope->lookupImportedType(petv->prefix->value, petv->name.value) @@ -898,7 +1015,14 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul // Type function application will happily give us the exact same type if // there are e.g. generic saturatedTypeArguments that go unused. const TableType* tfTable = getTableType(tf->type); - bool needsClone = follow(tf->type) == target || (tfTable != nullptr && tfTable == getTableType(target)); + + //clang-format off + bool needsClone = follow(tf->type) == target || (tfTable != nullptr && tfTable == getTableType(target)) || + std::any_of(typeArguments.begin(), typeArguments.end(), [&](const auto& other) { + return other == target; + }); + //clang-format on + // Only tables have the properties we're trying to set. TableType* ttv = getMutableTableType(target); @@ -950,19 +1074,23 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(fn)) + { + emplaceTypePack(asMutable(c.result), builtinTypes->anyTypePack); + unblock(c.result, constraint->location); + return true; + } + // if we're calling an error type, the result is an error type, and that's that. if (get(fn)) { - emplaceTypePack(asMutable(c.result), builtinTypes->errorTypePack); - unblock(c.result, constraint->location); - + bind(constraint, c.result, builtinTypes->errorRecoveryTypePack()); return true; } if (get(fn)) { - emplaceTypePack(asMutable(c.result), builtinTypes->neverTypePack); - unblock(c.result, constraint->location); + bind(constraint, c.result, builtinTypes->neverTypePack); return true; } @@ -1019,7 +1147,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulladdTypePack(TypePack{std::move(argsHead), argsTail}); fn = follow(*callMm); - emplaceTypePack(asMutable(c.result), constraint->scope); + emplace(constraint, c.result, constraint->scope); } else { @@ -1036,14 +1164,21 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(asMutable(c.result), constraint->scope); + emplace(constraint, c.result, constraint->scope); } for (std::optional ty : c.discriminantTypes) { - if (!ty || !isBlocked(*ty)) + if (!ty) continue; + // If the discriminant type has been transmuted, we need to unblock them. + if (!isBlocked(*ty)) + { + unblock(*ty, constraint->location); + continue; + } + // We use `any` here because the discriminant type may be pointed at by both branches, // where the discriminant type is not negated, and the other where it is negated, i.e. // `unknown ~ unknown` and `~unknown ~ never`, so `T & unknown ~ T` and `T & ~unknown ~ never` @@ -1051,7 +1186,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullanyType}; + emplaceType(asMutable(follow(*ty)), builtinTypes->anyType); } OverloadResolver resolver{ @@ -1061,7 +1196,6 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulladdType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result}); Unifier2 u2{NotNull{arena}, builtinTypes, constraint->scope, NotNull{&iceReporter}}; @@ -1091,12 +1225,12 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulllocation); - InstantiationQueuer queuer{constraint->scope, constraint->location, this}; queuer.traverse(overloadToUse); queuer.traverse(inferredTy); + unblock(c.result, constraint->location); + return true; } @@ -1190,7 +1324,8 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNullargs.data[j]->annotation && get(follow(lambdaArgTys[j]))) { - emplaceType(asMutable(lambdaArgTys[j]), expectedLambdaArgTys[j]); + shiftReferences(lambdaArgTys[j], expectedLambdaArgTys[j]); + bind(constraint, lambdaArgTys[j], expectedLambdaArgTys[j]); } } } @@ -1242,7 +1377,8 @@ bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNulllowerBound; - emplaceType(asMutable(c.freeType), bindTo); + shiftReferences(c.freeType, bindTo); + bind(constraint, c.freeType, bindTo); return true; } @@ -1258,6 +1394,14 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull(subjectType) || get(subjectType)) return block(subjectType, constraint); + if (const TableType* subjectTable = getTableType(subjectType)) + { + if (subjectTable->state == TableState::Unsealed && subjectTable->remainingProps > 0 && subjectTable->props.count(c.prop) == 0) + { + return block(subjectType, constraint); + } + } + auto [blocked, result] = lookupTableProp(constraint, subjectType, c.prop, c.context, c.inConditional, c.suppressSimplification); if (!blocked.empty()) { @@ -1267,160 +1411,7 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNullanyType), c.subjectType, constraint); - unblock(resultType, constraint->location); - return true; -} - -static bool isUnsealedTable(TypeId ty) -{ - ty = follow(ty); - const TableType* ttv = get(ty); - return ttv && ttv->state == TableState::Unsealed; -} - -/** - * Given a path into a set of nested unsealed tables `ty`, insert a new property `replaceTy` as the leaf-most property. - * - * Fails and does nothing if every table along the way is not unsealed. - * - * Mutates the innermost table type in-place. - */ -static void updateTheTableType( - NotNull builtinTypes, NotNull arena, TypeId ty, const std::vector& path, TypeId replaceTy) -{ - if (path.empty()) - return; - - // First walk the path and ensure that it's unsealed tables all the way - // to the end. - { - TypeId t = ty; - for (size_t i = 0; i < path.size() - 1; ++i) - { - if (!isUnsealedTable(t)) - return; - - const TableType* tbl = get(t); - auto it = tbl->props.find(path[i]); - if (it == tbl->props.end()) - return; - - t = follow(it->second.type()); - } - - // The last path segment should not be a property of the table at all. - // We are not changing property types. We are only admitting this one - // new property to be appended. - if (!isUnsealedTable(t)) - return; - const TableType* tbl = get(t); - if (0 != tbl->props.count(path.back())) - return; - } - - TypeId t = ty; - ErrorVec dummy; - - for (size_t i = 0; i < path.size() - 1; ++i) - { - t = follow(t); - auto propTy = findTablePropertyRespectingMeta(builtinTypes, dummy, t, path[i], ValueContext::LValue, Location{}); - dummy.clear(); - - if (!propTy) - return; - - t = *propTy; - } - - const std::string& lastSegment = path.back(); - - t = follow(t); - TableType* tt = getMutable(t); - if (auto mt = get(t)) - tt = getMutable(mt->table); - - if (!tt) - return; - - tt->props[lastSegment].setType(replaceTy); -} - -bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull constraint) -{ - TypeId subjectType = follow(c.subjectType); - const TypeId propType = follow(c.propType); - - if (isBlocked(subjectType)) - return block(subjectType, constraint); - - std::optional existingPropType = subjectType; - - LUAU_ASSERT(!c.path.empty()); - if (c.path.empty()) - return false; - - for (size_t i = 0; i < c.path.size(); ++i) - { - const std::string& segment = c.path[i]; - if (!existingPropType) - break; - - ValueContext ctx = i == c.path.size() - 1 ? ValueContext::LValue : ValueContext::RValue; - - auto [blocked, result] = lookupTableProp(constraint, *existingPropType, segment, ctx); - if (!blocked.empty()) - { - for (TypeId blocked : blocked) - block(blocked, constraint); - return false; - } - - existingPropType = result; - } - - auto bind = [&](TypeId a, TypeId b) { - bindBlockedType(a, b, subjectType, constraint); - }; - - if (existingPropType) - { - unify(constraint, propType, *existingPropType); - unify(constraint, *existingPropType, propType); - bind(c.resultType, c.subjectType); - unblock(c.resultType, constraint->location); - return true; - } - - const TypeId originalSubjectType = subjectType; - - if (auto mt = get(subjectType)) - subjectType = follow(mt->table); - - if (get(subjectType)) - return false; - else if (auto ttv = getMutable(subjectType)) - { - if (ttv->state == TableState::Free) - { - LUAU_ASSERT(!subjectType->persistent); - - ttv->props[c.path[0]] = Property{propType}; - bind(c.resultType, subjectType); - unblock(c.resultType, constraint->location); - return true; - } - else if (ttv->state == TableState::Unsealed) - { - LUAU_ASSERT(!subjectType->persistent); - - updateTheTableType(builtinTypes, NotNull{arena}, subjectType, c.path, propType); - } - } - - bind(c.resultType, originalSubjectType); - unblock(c.resultType, constraint->location); + bind(constraint, resultType, result.value_or(builtinTypes->anyType)); return true; } @@ -1441,8 +1432,17 @@ bool ConstraintSolver::tryDispatchHasIndexer( if (auto ft = get(subjectType)) { + if (auto tbl = get(follow(ft->upperBound)); tbl && tbl->indexer) + { + unify(constraint, indexType, tbl->indexer->indexType); + bind(constraint, resultType, tbl->indexer->indexResultType); + return true; + } + else if (auto mt = get(follow(ft->upperBound))) + return tryDispatchHasIndexer(recursionDepth, constraint, mt->table, indexType, resultType, seen); + FreeType freeResult{ft->scope, builtinTypes->neverType, builtinTypes->unknownType}; - emplaceType(asMutable(resultType), freeResult); + emplace(constraint, resultType, freeResult); TypeId upperBound = arena->addType(TableType{/* props */ {}, TableIndexer{indexType, resultType}, TypeLevel{}, TableState::Unsealed}); @@ -1455,16 +1455,16 @@ bool ConstraintSolver::tryDispatchHasIndexer( if (auto indexer = tt->indexer) { unify(constraint, indexType, indexer->indexType); - - bindBlockedType(resultType, indexer->indexResultType, subjectType, constraint); + bind(constraint, resultType, indexer->indexResultType); return true; } - else if (tt->state == TableState::Unsealed) + + if (tt->state == TableState::Unsealed) { // FIXME this is greedy. FreeType freeResult{tt->scope, builtinTypes->neverType, builtinTypes->unknownType}; - emplaceType(asMutable(resultType), freeResult); + emplace(constraint, resultType, freeResult); tt->indexer = TableIndexer{indexType, resultType}; return true; @@ -1477,12 +1477,12 @@ bool ConstraintSolver::tryDispatchHasIndexer( if (auto indexer = ct->indexer) { unify(constraint, indexType, indexer->indexType); - bindBlockedType(resultType, indexer->indexResultType, subjectType, constraint); + bind(constraint, resultType, indexer->indexResultType); return true; } else if (isString(indexType)) { - bindBlockedType(resultType, builtinTypes->unknownType, subjectType, constraint); + bind(constraint, resultType, builtinTypes->unknownType); return true; } } @@ -1517,11 +1517,11 @@ bool ConstraintSolver::tryDispatchHasIndexer( } if (0 == results.size()) - bindBlockedType(resultType, builtinTypes->errorType, subjectType, constraint); + bind(constraint, resultType, builtinTypes->errorType); else if (1 == results.size()) - bindBlockedType(resultType, *results.begin(), subjectType, constraint); + bind(constraint, resultType, *results.begin()); else - emplaceType(asMutable(resultType), std::vector(results.begin(), results.end())); + emplace(constraint, resultType, std::vector(results.begin(), results.end())); return true; } @@ -1549,16 +1549,20 @@ bool ConstraintSolver::tryDispatchHasIndexer( } if (0 == results.size()) - emplaceType(asMutable(resultType), builtinTypes->errorType); + bind(constraint, resultType, builtinTypes->errorType); else if (1 == results.size()) - emplaceType(asMutable(resultType), *results.begin()); + { + TypeId firstResult = *results.begin(); + shiftReferences(resultType, firstResult); + bind(constraint, resultType, firstResult); + } else - emplaceType(asMutable(resultType), std::vector(results.begin(), results.end())); + emplace(constraint, resultType, std::vector(results.begin(), results.end())); return true; } - bindBlockedType(resultType, builtinTypes->errorType, subjectType, constraint); + bind(constraint, resultType, builtinTypes->errorType); return true; } @@ -1609,167 +1613,283 @@ bool ConstraintSolver::tryDispatch(const HasIndexerConstraint& c, NotNull> ConstraintSolver::tryDispatchSetIndexer( - NotNull constraint, TypeId subjectType, TypeId indexType, TypeId propType, bool expandFreeTypeBounds) +bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull constraint) { - if (isBlocked(subjectType)) - return {block(subjectType, constraint), std::nullopt}; + TypeId lhsType = follow(c.lhsType); + const std::string& propName = c.propName; + const TypeId rhsType = follow(c.rhsType); - if (auto tt = getMutable(subjectType)) + if (isBlocked(lhsType)) + return block(lhsType, constraint); + + // 1. lhsType is a class that already has the prop + // 2. lhsType is a table that already has the prop (or a union or + // intersection that has the prop in aggregate) + // 3. lhsType has a metatable that already has the prop + // 4. lhsType is an unsealed table that does not have the prop, but has a + // string indexer + // 5. lhsType is an unsealed table that does not have the prop or a string + // indexer + + // Important: In every codepath through this function, the type `c.propType` + // must be bound to something, even if it's just the errorType. + + if (auto lhsClass = get(lhsType)) { - if (tt->indexer) - { - if (isBlocked(tt->indexer->indexType)) - return {block(tt->indexer->indexType, constraint), std::nullopt}; - else if (isBlocked(tt->indexer->indexResultType)) - return {block(tt->indexer->indexResultType, constraint), std::nullopt}; + const Property* prop = lookupClassProp(lhsClass, propName); + if (!prop || !prop->writeTy.has_value()) + return true; - unify(constraint, indexType, tt->indexer->indexType); - return {true, tt->indexer->indexResultType}; - } - else if (tt->state == TableState::Free || tt->state == TableState::Unsealed) + bind(constraint, c.propType, *prop->writeTy); + unify(constraint, rhsType, *prop->writeTy); + return true; + } + + if (auto lhsFree = getMutable(lhsType)) + { + if (get(lhsFree->upperBound) || get(lhsFree->upperBound)) + lhsType = lhsFree->upperBound; + else { - TypeId resultTy = freshType(arena, builtinTypes, constraint->scope.get()); - tt->indexer = TableIndexer{indexType, resultTy}; - return {true, resultTy}; + TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, constraint->scope}); + TableType* upperTable = getMutable(newUpperBound); + LUAU_ASSERT(upperTable); + + upperTable->props[c.propName] = rhsType; + + // Food for thought: Could we block if simplification encounters a blocked type? + lhsFree->upperBound = simplifyIntersection(builtinTypes, arena, lhsFree->upperBound, newUpperBound).result; + + bind(constraint, c.propType, rhsType); + return true; } } - else if (auto ft = getMutable(subjectType); ft && expandFreeTypeBounds) + + // Handle the case that lhsType is a table that already has the property or + // a matching indexer. This also handles unions and intersections. + const auto [blocked, maybeTy] = lookupTableProp(constraint, lhsType, propName, ValueContext::LValue); + if (!blocked.empty()) { - // Setting an indexer on some fresh type means we use that fresh type in a negative position. - // Therefore, we only care about the upper bound. - // - // We'll extend the upper bound if we could dispatch, but could not find a table type to update the indexer. - auto [dispatched, resultTy] = tryDispatchSetIndexer(constraint, ft->upperBound, indexType, propType, /*expandFreeTypeBounds=*/false); - if (dispatched && !resultTy) + for (TypeId t : blocked) + block(t, constraint); + return false; + } + + if (maybeTy) + { + const TypeId propTy = *maybeTy; + bind(constraint, c.propType, propTy); + unify(constraint, rhsType, propTy); + return true; + } + + if (auto lhsMeta = get(lhsType)) + lhsType = follow(lhsMeta->table); + + // Handle the case where the lhs type is a table that does not have the + // named property. It could be a table with a string indexer, or an unsealed + // or free table that can grow. + if (auto lhsTable = getMutable(lhsType)) + { + if (auto it = lhsTable->props.find(propName); it != lhsTable->props.end()) { - // Despite that we haven't found a table type, adding a table type causes us to have one that we can /now/ find. - resultTy = freshType(arena, builtinTypes, constraint->scope.get()); + Property& prop = it->second; - TypeId tableTy = arena->addType(TableType{TableState::Sealed, TypeLevel{}, constraint->scope.get()}); - TableType* tt2 = getMutable(tableTy); - tt2->indexer = TableIndexer{indexType, *resultTy}; - - ft->upperBound = - simplifyIntersection(builtinTypes, arena, ft->upperBound, tableTy).result; // TODO: intersect type family or a constraint. + if (prop.writeTy.has_value()) + { + bind(constraint, c.propType, *prop.writeTy); + unify(constraint, rhsType, *prop.writeTy); + return true; + } + else + { + LUAU_ASSERT(prop.isReadOnly()); + if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) + { + prop.writeTy = prop.readTy; + bind(constraint, c.propType, *prop.writeTy); + unify(constraint, rhsType, *prop.writeTy); + return true; + } + else + { + bind(constraint, c.propType, builtinTypes->errorType); + return true; + } + } } - return {dispatched, resultTy}; - } - else if (auto it = get(subjectType)) - { - bool dispatched = true; - std::vector results; - - for (TypeId part : it) + if (lhsTable->indexer && maybeString(lhsTable->indexer->indexType)) { - auto [dispatched2, found] = tryDispatchSetIndexer(constraint, part, indexType, propType, expandFreeTypeBounds); - dispatched &= dispatched2; - results.push_back(found.value_or(builtinTypes->errorRecoveryType())); - - if (!dispatched) - return {dispatched, std::nullopt}; + bind(constraint, c.propType, rhsType); + unify(constraint, rhsType, lhsTable->indexer->indexResultType); + return true; } - TypeId resultTy = arena->addType(TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.unionFamily}, - std::move(results), - {}, - }); + if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) + { + bind(constraint, c.propType, rhsType); + lhsTable->props[propName] = Property::rw(rhsType); - pushConstraint(constraint->scope, constraint->location, ReduceConstraint{resultTy}); + if (lhsTable->state == TableState::Unsealed && c.decrementPropCount) + { + LUAU_ASSERT(lhsTable->remainingProps > 0); + lhsTable->remainingProps -= 1; + } - return {dispatched, resultTy}; + return true; + } } - else if (is(subjectType)) - return {true, subjectType}; - return {true, std::nullopt}; + bind(constraint, c.propType, builtinTypes->errorType); + + return true; } -bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force) +bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNull constraint) { - TypeId subjectType = follow(c.subjectType); - if (isBlocked(subjectType)) - return block(subjectType, constraint); + const TypeId lhsType = follow(c.lhsType); + const TypeId indexType = follow(c.indexType); + const TypeId rhsType = follow(c.rhsType); - auto [dispatched, resultTy] = tryDispatchSetIndexer(constraint, subjectType, c.indexType, c.propType, /*expandFreeTypeBounds=*/true); - if (dispatched) - { - bindBlockedType(c.propType, resultTy.value_or(builtinTypes->errorRecoveryType()), subjectType, constraint); - unblock(c.propType, constraint->location); - } + if (isBlocked(lhsType)) + return block(lhsType, constraint); - return dispatched; -} + // 0. lhsType could be an intersection or union. + // 1. lhsType is a class with an indexer + // 2. lhsType is a table with an indexer, or it has a metatable that has an indexer + // 3. lhsType is a free or unsealed table and can grow an indexer -bool ConstraintSolver::tryDispatchUnpack1(NotNull constraint, TypeId resultTy, TypeId srcTy, bool resultIsLValue) -{ - resultTy = follow(resultTy); - LUAU_ASSERT(canMutate(resultTy, constraint)); + // Important: In every codepath through this function, the type `c.propType` + // must be bound to something, even if it's just the errorType. - auto tryExpand = [&](TypeId ty) { - LocalType* lt = getMutable(ty); - if (!lt || !resultIsLValue) - return; + auto tableStuff = [&](TableType* lhsTable) -> std::optional { + if (lhsTable->indexer) + { + unify(constraint, indexType, lhsTable->indexer->indexType); + unify(constraint, rhsType, lhsTable->indexer->indexResultType); + bind(constraint, c.propType, lhsTable->indexer->indexResultType); + return true; + } - lt->domain = simplifyUnion(builtinTypes, arena, lt->domain, srcTy).result; - LUAU_ASSERT(lt->blockCount > 0); - --lt->blockCount; + if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) + { + lhsTable->indexer = TableIndexer{indexType, rhsType}; + bind(constraint, c.propType, rhsType); + return true; + } - if (0 == lt->blockCount) - emplaceType(asMutable(ty), lt->domain); + return {}; }; - if (auto ut = get(resultTy)) - std::for_each(begin(ut), end(ut), tryExpand); - else if (get(resultTy)) - tryExpand(resultTy); - else if (get(resultTy)) + if (auto lhsFree = getMutable(lhsType)) { - if (follow(srcTy) == resultTy) + if (auto lhsTable = getMutable(lhsFree->upperBound)) { - // It is sometimes the case that we find that a blocked type - // is only blocked on itself. This doesn't actually - // constitute any meaningful constraint, so we replace it - // with a free type. - TypeId f = freshType(arena, builtinTypes, constraint->scope); - emplaceType(asMutable(resultTy), f); + if (auto res = tableStuff(lhsTable)) + return *res; } - else - bindBlockedType(resultTy, srcTy, srcTy, constraint); - } - else - { - LUAU_ASSERT(resultIsLValue); - unify(constraint, srcTy, resultTy); + + TypeId newUpperBound = + arena->addType(TableType{/*props*/ {}, TableIndexer{indexType, rhsType}, TypeLevel{}, constraint->scope, TableState::Free}); + const TableType* newTable = get(newUpperBound); + LUAU_ASSERT(newTable); + + unify(constraint, lhsType, newUpperBound); + + LUAU_ASSERT(newTable->indexer); + bind(constraint, c.propType, newTable->indexer->indexResultType); + return true; } - unblock(resultTy, constraint->location); + if (auto lhsTable = getMutable(lhsType)) + { + std::optional res = tableStuff(lhsTable); + if (res.has_value()) + return *res; + } + + if (auto lhsClass = get(lhsType)) + { + while (true) + { + if (lhsClass->indexer) + { + unify(constraint, indexType, lhsClass->indexer->indexType); + unify(constraint, rhsType, lhsClass->indexer->indexResultType); + bind(constraint, c.propType, lhsClass->indexer->indexResultType); + return true; + } + + if (lhsClass->parent) + lhsClass = get(lhsClass->parent); + else + break; + } + return true; + } + + if (auto lhsIntersection = getMutable(lhsType)) + { + std::set parts; + + for (TypeId t : lhsIntersection) + { + if (auto tbl = getMutable(follow(t))) + { + if (tbl->indexer) + { + unify(constraint, indexType, tbl->indexer->indexType); + parts.insert(tbl->indexer->indexResultType); + } + + if (tbl->state == TableState::Unsealed || tbl->state == TableState::Free) + { + tbl->indexer = TableIndexer{indexType, rhsType}; + parts.insert(rhsType); + } + } + else if (auto cls = get(follow(t))) + { + while (true) + { + if (cls->indexer) + { + unify(constraint, indexType, cls->indexer->indexType); + parts.insert(cls->indexer->indexResultType); + break; + } + + if (cls->parent) + cls = get(cls->parent); + else + break; + } + } + } + + TypeId res = simplifyIntersection(builtinTypes, arena, std::move(parts)).result; + + unify(constraint, rhsType, res); + } + + // Other types do not support index assignment. + bind(constraint, c.propType, builtinTypes->errorType); + return true; } bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull constraint) { TypePackId sourcePack = follow(c.sourcePack); - TypePackId resultPack = follow(c.resultPack); if (isBlocked(sourcePack)) return block(sourcePack, constraint); - if (isBlocked(resultPack)) - { - LUAU_ASSERT(canMutate(resultPack, constraint)); - LUAU_ASSERT(resultPack != sourcePack); - emplaceTypePack(asMutable(resultPack), sourcePack); - unblock(resultPack, constraint->location); - return true; - } + TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, c.resultPack.size()); - TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, size(resultPack)); - - auto resultIter = begin(resultPack); - auto resultEnd = end(resultPack); + auto resultIter = begin(c.resultPack); + auto resultEnd = end(c.resultPack); size_t i = 0; while (resultIter != resultEnd) @@ -1779,7 +1899,29 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull(resultTy)); + LUAU_ASSERT(canMutate(resultTy, constraint)); + + if (get(resultTy)) + { + if (follow(srcTy) == resultTy) + { + // It is sometimes the case that we find that a blocked type + // is only blocked on itself. This doesn't actually + // constitute any meaningful constraint, so we replace it + // with a free type. + TypeId f = freshType(arena, builtinTypes, constraint->scope); + shiftReferences(resultTy, f); + emplaceType(asMutable(resultTy), f); + } + else + bind(constraint, resultTy, srcTy); + } + else + unify(constraint, srcTy, resultTy); + + unblock(resultTy, constraint->location); ++resultIter; ++i; @@ -1793,19 +1935,9 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull(resultTy); c.resultIsLValue && lt) + if (get(resultTy) || get(resultTy)) { - lt->domain = simplifyUnion(builtinTypes, arena, lt->domain, builtinTypes->nilType).result; - LUAU_ASSERT(0 <= lt->blockCount); - --lt->blockCount; - - if (0 == lt->blockCount) - emplaceType(asMutable(resultTy), lt->domain); - } - else if (get(resultTy) || get(resultTy)) - { - emplaceType(asMutable(resultTy), builtinTypes->nilType); - unblock(resultTy, constraint->location); + bind(constraint, resultTy, builtinTypes->nilType); } ++resultIter; @@ -1814,11 +1946,6 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull constraint) -{ - return tryDispatchUnpack1(constraint, c.resultType, c.sourceType, c.resultIsLValue); -} - bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force) { TypeId ty = follow(c.ty); @@ -1833,6 +1960,11 @@ bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNull(ty)) + typeFamiliesToFinalize[ty] = constraint; + if (force || reductionFinished) { // if we're completely dispatching this constraint, we want to record any uninhabited type families to unblock. @@ -1904,28 +2036,37 @@ bool ConstraintSolver::tryDispatch(const EqualityConstraint& c, NotNull constraint, bool force) { - auto block_ = [&](auto&& t) { - if (force) - { - // TODO: I believe it is the case that, if we are asked to force - // this constraint, then we can do nothing but fail. I'd like to - // find a code sample that gets here. - LUAU_ASSERT(false); - } - else - block(t, constraint); - return false; - }; - - // We may have to block here if we don't know what the iteratee type is, - // if it's a free table, if we don't know it has a metatable, and so on. iteratorTy = follow(iteratorTy); + if (get(iteratorTy)) - return block_(iteratorTy); + { + TypeId keyTy = freshType(arena, builtinTypes, constraint->scope); + TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); + TypeId tableTy = arena->addType(TableType{TableState::Sealed, {}, constraint->scope}); + getMutable(tableTy)->indexer = TableIndexer{keyTy, valueTy}; + + pushConstraint(constraint->scope, constraint->location, SubtypeConstraint{iteratorTy, tableTy}); + + auto it = begin(c.variables); + auto endIt = end(c.variables); + if (it != endIt) + { + bind(constraint, *it, keyTy); + ++it; + } + if (it != endIt) + bind(constraint, *it, valueTy); + + return true; + } auto unpack = [&](TypeId ty) { - TypePackId variadic = arena->addTypePack(VariadicTypePack{ty}); - pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, variadic, /* resultIsLValue */ true}); + for (TypeId varTy : c.variables) + { + LUAU_ASSERT(get(varTy)); + LUAU_ASSERT(varTy != ty); + bind(constraint, varTy, ty); + } }; if (get(iteratorTy)) @@ -1963,24 +2104,17 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl if (iteratorTable->indexer) { - TypePackId expectedVariablePack = arena->addTypePack({iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType}); - unify(constraint, c.variables, expectedVariablePack); + std::vector expectedVariables{iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType}; + while (c.variables.size() >= expectedVariables.size()) + expectedVariables.push_back(builtinTypes->errorRecoveryType()); - auto [variableTys, variablesTail] = flatten(c.variables); - - // the local types for the indexer _should_ be all set after unification - for (TypeId ty : variableTys) + for (size_t i = 0; i < c.variables.size(); ++i) { - if (auto lt = getMutable(ty)) - { - LUAU_ASSERT(lt->blockCount > 0); - --lt->blockCount; + LUAU_ASSERT(c.variables[i] != expectedVariables[i]); - LUAU_ASSERT(0 <= lt->blockCount); + unify(constraint, c.variables[i], expectedVariables[i]); - if (0 == lt->blockCount) - emplaceType(asMutable(ty), lt->domain); - } + bind(constraint, c.variables[i], expectedVariables[i]); } } else @@ -2014,10 +2148,11 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl if (std::optional instantiatedNextFn = instantiate(builtinTypes, arena, NotNull{&limits}, constraint->scope, nextFn)) { const FunctionType* nextFn = get(*instantiatedNextFn); - LUAU_ASSERT(nextFn); - const TypePackId nextRetPack = nextFn->retTypes; - pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, nextRetPack, /* resultIsLValue=*/true}); + // If nextFn is nullptr, then the iterator function has an improper signature. + if (nextFn) + unpackAndAssign(c.variables, nextFn->retTypes, constraint); + return true; } else @@ -2043,26 +2178,16 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl else if (auto primitiveTy = get(iteratorTy); primitiveTy && primitiveTy->type == PrimitiveType::Type::Table) unpack(builtinTypes->unknownType); else + { unpack(builtinTypes->errorType); + } return true; } bool ConstraintSolver::tryDispatchIterableFunction( - TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force) + TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull constraint, bool force) { - // We need to know whether or not this type is nil or not. - // If we don't know, block and reschedule ourselves. - firstIndexTy = follow(firstIndexTy); - if (get(firstIndexTy)) - { - if (force) - LUAU_ASSERT(false); - else - block(firstIndexTy, constraint); - return false; - } - const FunctionType* nextFn = get(nextTy); // If this does not hold, we should've never called `tryDispatchIterableFunction` in the first place. LUAU_ASSERT(nextFn); @@ -2089,12 +2214,29 @@ bool ConstraintSolver::tryDispatchIterableFunction( modifiedNextRetHead.push_back(*it); TypePackId modifiedNextRetPack = arena->addTypePack(std::move(modifiedNextRetHead), it.tail()); - auto psc = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, modifiedNextRetPack, /* resultIsLValue */ true}); - inheritBlocks(constraint, psc); + + auto unpackConstraint = unpackAndAssign(c.variables, modifiedNextRetPack, constraint); + + inheritBlocks(constraint, unpackConstraint); return true; } +NotNull ConstraintSolver::unpackAndAssign( + const std::vector destTypes, TypePackId srcTypes, NotNull constraint) +{ + auto c = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{destTypes, srcTypes}); + + for (TypeId t : destTypes) + { + BlockedType* bt = getMutable(t); + LUAU_ASSERT(bt); + bt->replaceOwner(c); + } + + return c; +} + std::pair, std::optional> ConstraintSolver::lookupTableProp(NotNull constraint, TypeId subjectType, const std::string& propName, ValueContext context, bool inConditional, bool suppressSimplification) { @@ -2368,37 +2510,9 @@ bool ConstraintSolver::unify(NotNull constraint, TID subTy, TI return false; } - unblock(subTy, constraint->location); - unblock(superTy, constraint->location); - return true; } -void ConstraintSolver::bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId rootTy, NotNull constraint) -{ - resultTy = follow(resultTy); - - LUAU_ASSERT(get(blockedTy) && canMutate(blockedTy, constraint)); - - if (blockedTy == resultTy) - { - rootTy = follow(rootTy); - Scope* freeScope = nullptr; - if (auto ft = get(rootTy)) - freeScope = ft->scope; - else if (auto tt = get(rootTy); tt && tt->state == TableState::Free) - freeScope = tt->scope; - else - iceReporter.ice("bindBlockedType couldn't find an appropriate scope for a fresh type!", constraint->location); - - LUAU_ASSERT(freeScope); - - emplaceType(asMutable(blockedTy), arena->freshType(freeScope)); - } - else - emplaceType(asMutable(blockedTy), resultTy); -} - bool ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) { // If a set is not present for the target, construct a new DenseHashSet for it, @@ -2608,9 +2722,6 @@ bool ConstraintSolver::isBlocked(TypeId ty) { ty = follow(ty); - if (auto lt = get(ty)) - return lt->blockCount > 0; - if (auto tfit = get(ty)) return uninhabitedTypeFamilies.contains(ty) == false; @@ -2698,6 +2809,44 @@ void ConstraintSolver::reportError(TypeError e) errors.back().moduleName = currentModuleName; } +void ConstraintSolver::shiftReferences(TypeId source, TypeId target) +{ + target = follow(target); + + // if the target isn't a reference counted type, there's nothing to do. + // this stops us from keeping unnecessary counts for e.g. primitive types. + if (!isReferenceCountedType(target)) + return; + + auto sourceRefs = unresolvedConstraints.find(source); + if (!sourceRefs) + return; + + // we read out the count before proceeding to avoid hash invalidation issues. + size_t count = *sourceRefs; + + auto [targetRefs, _] = unresolvedConstraints.try_insert(target, 0); + targetRefs += count; +} + +std::optional ConstraintSolver::generalizeFreeType(NotNull scope, TypeId type, bool avoidSealingTables) +{ + TypeId t = follow(type); + if (get(t)) + { + auto refCount = unresolvedConstraints.find(t); + if (refCount && *refCount > 0) + return {}; + + // if no reference count is present, then that means the only constraints referring to + // this free type need only for it to be generalized. in principle, this means we could + // have actually never generated the free type in the first place, but we couldn't know + // that until all constraint generation is complete. + } + + return generalize(NotNull{arena}, builtinTypes, scope, generalizedTypes, type, avoidSealingTables); +} + bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty) { if (auto refCount = unresolvedConstraints.find(ty)) diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp index 33b41698..0a0a64d3 100644 --- a/Analysis/src/DataFlowGraph.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -763,7 +763,8 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c) for (AstExpr* arg : c->args) visitExpr(scope, arg); - return {defArena->freshCell(), nullptr}; + // calls should be treated as subscripted. + return {defArena->freshCell(/* subscripted */ true), nullptr}; } DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i) diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 4fe7c4b7..91d8006a 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -2,7 +2,7 @@ #include "Luau/BuiltinDefinitions.h" LUAU_FASTFLAGVARIABLE(LuauCheckedEmbeddedDefinitions2, false); -LUAU_FASTFLAG(LuauCheckedFunctionSyntax); +LUAU_FASTFLAG(LuauAttributeSyntax); namespace Luau { @@ -320,9 +320,9 @@ declare os: { clock: () -> number, } -declare function @checked require(target: any): any +@checked declare function require(target: any): any -declare function @checked getfenv(target: any): { [string]: any } +@checked declare function getfenv(target: any): { [string]: any } declare _G: any declare _VERSION: string @@ -364,7 +364,7 @@ declare function select(i: string | number, ...: A...): ...any -- (nil, string). declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) -declare function @checked newproxy(mt: boolean?): any +@checked declare function newproxy(mt: boolean?): any declare coroutine: { create: (f: (A...) -> R...) -> thread, @@ -452,7 +452,7 @@ std::string getBuiltinDefinitionSource() std::string result = kBuiltinDefinitionLuaSrc; // Annotates each non generic function as checked - if (FFlag::LuauCheckedEmbeddedDefinitions2 && FFlag::LuauCheckedFunctionSyntax) + if (FFlag::LuauCheckedEmbeddedDefinitions2 && FFlag::LuauAttributeSyntax) result = kBuiltinDefinitionLuaSrcChecked; return result; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 98b15b77..5a9e42a7 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,11 +7,14 @@ #include "Luau/NotNull.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" +#include "Luau/Type.h" +#include "Luau/TypeFamily.h" #include #include #include #include +#include LUAU_FASTINTVARIABLE(LuauIndentTypeMismatchMaxTypeLength, 10) @@ -61,6 +64,17 @@ static std::string wrongNumberOfArgsString( namespace Luau { +// this list of binary operator type families is used for better stringification of type families errors +static const std::unordered_map kBinaryOps{{"add", "+"}, {"sub", "-"}, {"mul", "*"}, {"div", "/"}, {"idiv", "//"}, + {"pow", "^"}, {"mod", "%"}, {"concat", ".."}, {"and", "and"}, {"or", "or"}, {"lt", "< or >="}, {"le", "<= or >"}, {"eq", "== or ~="}}; + +// this list of unary operator type families is used for better stringification of type families errors +static const std::unordered_map kUnaryOps{{"unm", "-"}, {"len", "#"}, {"not", "not"}}; + +// this list of type families will receive a special error indicating that the user should file a bug on the GitHub repository +// putting a type family in this list indicates that it is expected to _always_ reduce +static const std::unordered_set kUnreachableTypeFamilies{"refine", "singleton", "union", "intersect"}; + struct ErrorConverter { FileResolver* fileResolver = nullptr; @@ -565,6 +579,108 @@ struct ErrorConverter std::string operator()(const UninhabitedTypeFamily& e) const { + auto tfit = get(e.ty); + LUAU_ASSERT(tfit); // Luau analysis has actually done something wrong if this type is not a type family. + if (!tfit) + return "Unexpected type " + Luau::toString(e.ty) + " flagged as an uninhabited type family."; + + // unary operators + if (auto unaryString = kUnaryOps.find(tfit->family->name); unaryString != kUnaryOps.end()) + { + std::string result = "Operator '" + std::string(unaryString->second) + "' could not be applied to "; + + if (tfit->typeArguments.size() == 1 && tfit->packArguments.empty()) + { + result += "operand of type " + Luau::toString(tfit->typeArguments[0]); + + if (tfit->family->name != "not") + result += "; there is no corresponding overload for __" + tfit->family->name; + } + else + { + // if it's not the expected case, we ought to add a specialization later, but this is a sane default. + result += "operands of types "; + + bool isFirst = true; + for (auto arg : tfit->typeArguments) + { + if (!isFirst) + result += ", "; + + result += Luau::toString(arg); + isFirst = false; + } + + for (auto packArg : tfit->packArguments) + result += ", " + Luau::toString(packArg); + } + + return result; + } + + // binary operators + if (auto binaryString = kBinaryOps.find(tfit->family->name); binaryString != kBinaryOps.end()) + { + std::string result = "Operator '" + std::string(binaryString->second) + "' could not be applied to operands of types "; + + if (tfit->typeArguments.size() == 2 && tfit->packArguments.empty()) + { + // this is the expected case. + result += Luau::toString(tfit->typeArguments[0]) + " and " + Luau::toString(tfit->typeArguments[1]); + } + else + { + // if it's not the expected case, we ought to add a specialization later, but this is a sane default. + + bool isFirst = true; + for (auto arg : tfit->typeArguments) + { + if (!isFirst) + result += ", "; + + result += Luau::toString(arg); + isFirst = false; + } + + for (auto packArg : tfit->packArguments) + result += ", " + Luau::toString(packArg); + } + + result += "; there is no corresponding overload for __" + tfit->family->name; + + return result; + } + + // miscellaneous + + if ("keyof" == tfit->family->name || "rawkeyof" == tfit->family->name) + { + if (tfit->typeArguments.size() == 1 && tfit->packArguments.empty()) + return "Type '" + toString(tfit->typeArguments[0]) + "' does not have keys, so '" + Luau::toString(e.ty) + "' is invalid"; + else + return "Type family instance " + Luau::toString(e.ty) + " is ill-formed, and thus invalid"; + } + + if ("index" == tfit->family->name || "rawget" == tfit->family->name) + { + if (tfit->typeArguments.size() != 2) + return "Type family instance " + Luau::toString(e.ty) + " is ill-formed, and thus invalid"; + + if (auto errType = get(tfit->typeArguments[1])) // Second argument to (index | rawget)<_,_> is not a type + return "Second argument to " + tfit->family->name + "<" + Luau::toString(tfit->typeArguments[0]) + ", _> is not a valid index type"; + else // Property `indexer` does not exist on type `indexee` + return "Property '" + Luau::toString(tfit->typeArguments[1]) + "' does not exist on type '" + Luau::toString(tfit->typeArguments[0]) + + "'"; + } + + if (kUnreachableTypeFamilies.count(tfit->family->name)) + { + return "Type family instance " + Luau::toString(e.ty) + " is uninhabited\n" + + "This is likely to be a bug, please report it at https://github.com/luau-lang/luau/issues"; + } + + // Everything should be specialized above to report a more descriptive error that hopefully does not mention "type families" explicitly. + // If we produce this message, it's an indication that we've missed a specialization and it should be fixed! return "Type family instance " + Luau::toString(e.ty) + " is uninhabited"; } @@ -1205,7 +1321,7 @@ void copyError(T& e, TypeArena& destArena, CloneState& cloneState) else if constexpr (std::is_same_v) { e.recommendedReturn = clone(e.recommendedReturn); - for (auto [_, t] : e.recommendedArgs) + for (auto& [_, t] : e.recommendedArgs) t = clone(t); } else if constexpr (std::is_same_v) diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 55cff7f6..4339960d 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -34,6 +34,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) +LUAU_FASTFLAGVARIABLE(LuauCancelFromProgress, false) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJsonFile, false) @@ -440,6 +441,8 @@ CheckResult Frontend::check(const ModuleName& name, std::optional result = getCheckResult(name, true, frontendOptions.forAutocomplete)) return std::move(*result); @@ -492,9 +495,11 @@ void Frontend::queueModuleCheck(const ModuleName& name) } std::vector Frontend::checkQueuedModules(std::optional optionOverride, - std::function task)> executeTask, std::function progress) + std::function task)> executeTask, std::function progress) { FrontendOptions frontendOptions = optionOverride.value_or(options); + if (FFlag::DebugLuauDeferredConstraintResolution) + frontendOptions.forAutocomplete = false; // By taking data into locals, we make sure queue is cleared at the end, even if an ICE or a different exception is thrown std::vector currModuleQueue; @@ -673,7 +678,17 @@ std::vector Frontend::checkQueuedModules(std::optional Frontend::checkQueuedModules(std::optional Frontend::getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete) { + if (FFlag::DebugLuauDeferredConstraintResolution) + forAutocomplete = false; + auto it = sourceNodes.find(name); if (it == sourceNodes.end() || it->second->hasDirtyModule(forAutocomplete)) @@ -1003,11 +1021,10 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) module->astForInNextTypes.clear(); module->astResolvedTypes.clear(); module->astResolvedTypePacks.clear(); + module->astCompoundAssignResultTypes.clear(); module->astScopes.clear(); module->upperBoundContributors.clear(); - - if (!FFlag::DebugLuauDeferredConstraintResolution) - module->scopes.clear(); + module->scopes.clear(); } if (mode != Mode::NoCheck) @@ -1196,12 +1213,6 @@ struct InternalTypeFinder : TypeOnceVisitor return false; } - bool visit(TypeId, const LocalType&) override - { - LUAU_ASSERT(false); - return false; - } - bool visit(TypePackId, const BlockedTypePack&) override { LUAU_ASSERT(false); @@ -1297,6 +1308,30 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vectortype = sourceModule.type; result->upperBoundContributors = std::move(cs.upperBoundContributors); + if (result->timeout || result->cancelled) + { + // If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending + // types + ScopePtr moduleScope = result->getModuleScope(); + moduleScope->returnType = builtinTypes->errorRecoveryTypePack(); + + for (auto& [name, ty] : result->declaredGlobals) + ty = builtinTypes->errorRecoveryType(); + + for (auto& [name, tf] : result->exportedTypeBindings) + tf.type = builtinTypes->errorRecoveryType(); + } + else + { + if (mode == Mode::Nonstrict) + Luau::checkNonStrict(builtinTypes, iceHandler, NotNull{&unifierState}, NotNull{&dfg}, NotNull{&limits}, sourceModule, result.get()); + else + Luau::check(builtinTypes, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get()); + } + + unfreeze(result->interfaceTypes); + result->clonePublicInterface(builtinTypes, *iceHandler); + if (FFlag::DebugLuauForbidInternalTypes) { InternalTypeFinder finder; @@ -1325,30 +1360,6 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vectortimeout || result->cancelled) - { - // If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending - // types - ScopePtr moduleScope = result->getModuleScope(); - moduleScope->returnType = builtinTypes->errorRecoveryTypePack(); - - for (auto& [name, ty] : result->declaredGlobals) - ty = builtinTypes->errorRecoveryType(); - - for (auto& [name, tf] : result->exportedTypeBindings) - tf.type = builtinTypes->errorRecoveryType(); - } - else - { - if (mode == Mode::Nonstrict) - Luau::checkNonStrict(builtinTypes, iceHandler, NotNull{&unifierState}, NotNull{&dfg}, NotNull{&limits}, sourceModule, result.get()); - else - Luau::check(builtinTypes, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get()); - } - - unfreeze(result->interfaceTypes); - result->clonePublicInterface(builtinTypes, *iceHandler); - // It would be nice if we could freeze the arenas before doing type // checking, but we'll have to do some work to get there. // diff --git a/Analysis/src/Generalization.cpp b/Analysis/src/Generalization.cpp new file mode 100644 index 00000000..5020ea58 --- /dev/null +++ b/Analysis/src/Generalization.cpp @@ -0,0 +1,910 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Generalization.h" + +#include "Luau/Scope.h" +#include "Luau/Type.h" +#include "Luau/ToString.h" +#include "Luau/TypeArena.h" +#include "Luau/TypePack.h" +#include "Luau/VisitType.h" + +namespace Luau +{ + +struct MutatingGeneralizer : TypeOnceVisitor +{ + NotNull builtinTypes; + + NotNull scope; + NotNull> cachedTypes; + DenseHashMap positiveTypes; + DenseHashMap negativeTypes; + std::vector generics; + std::vector genericPacks; + + bool isWithinFunction = false; + bool avoidSealingTables = false; + + MutatingGeneralizer(NotNull builtinTypes, NotNull scope, NotNull> cachedTypes, + DenseHashMap positiveTypes, DenseHashMap negativeTypes, bool avoidSealingTables) + : TypeOnceVisitor(/* skipBoundTypes */ true) + , builtinTypes(builtinTypes) + , scope(scope) + , cachedTypes(cachedTypes) + , positiveTypes(std::move(positiveTypes)) + , negativeTypes(std::move(negativeTypes)) + , avoidSealingTables(avoidSealingTables) + { + } + + static void replace(DenseHashSet& seen, TypeId haystack, TypeId needle, TypeId replacement) + { + haystack = follow(haystack); + + if (seen.find(haystack)) + return; + seen.insert(haystack); + + if (UnionType* ut = getMutable(haystack)) + { + for (auto iter = ut->options.begin(); iter != ut->options.end();) + { + // FIXME: I bet this function has reentrancy problems + TypeId option = follow(*iter); + + if (option == needle && get(replacement)) + { + iter = ut->options.erase(iter); + continue; + } + + if (option == needle) + { + *iter = replacement; + iter++; + continue; + } + + // advance the iterator, nothing after this can use it. + iter++; + + if (seen.find(option)) + continue; + seen.insert(option); + + if (get(option)) + replace(seen, option, needle, haystack); + else if (get(option)) + replace(seen, option, needle, haystack); + } + + if (ut->options.size() == 1) + { + TypeId onlyType = ut->options[0]; + LUAU_ASSERT(onlyType != haystack); + emplaceType(asMutable(haystack), onlyType); + } + + return; + } + + if (IntersectionType* it = getMutable(needle)) + { + for (auto iter = it->parts.begin(); iter != it->parts.end();) + { + // FIXME: I bet this function has reentrancy problems + TypeId part = follow(*iter); + + if (part == needle && get(replacement)) + { + iter = it->parts.erase(iter); + continue; + } + + if (part == needle) + { + *iter = replacement; + iter++; + continue; + } + + // advance the iterator, nothing after this can use it. + iter++; + + if (seen.find(part)) + continue; + seen.insert(part); + + if (get(part)) + replace(seen, part, needle, haystack); + else if (get(part)) + replace(seen, part, needle, haystack); + } + + if (it->parts.size() == 1) + { + TypeId onlyType = it->parts[0]; + LUAU_ASSERT(onlyType != needle); + emplaceType(asMutable(needle), onlyType); + } + + return; + } + } + + bool visit(TypeId ty, const FunctionType& ft) override + { + if (cachedTypes->contains(ty)) + return false; + + const bool oldValue = isWithinFunction; + + isWithinFunction = true; + + traverse(ft.argTypes); + traverse(ft.retTypes); + + isWithinFunction = oldValue; + + return false; + } + + bool visit(TypeId ty, const FreeType&) override + { + LUAU_ASSERT(!cachedTypes->contains(ty)); + + const FreeType* ft = get(ty); + LUAU_ASSERT(ft); + + traverse(ft->lowerBound); + traverse(ft->upperBound); + + // It is possible for the above traverse() calls to cause ty to be + // transmuted. We must reacquire ft if this happens. + ty = follow(ty); + ft = get(ty); + if (!ft) + return false; + + const size_t positiveCount = getCount(positiveTypes, ty); + const size_t negativeCount = getCount(negativeTypes, ty); + + if (!positiveCount && !negativeCount) + return false; + + const bool hasLowerBound = !get(follow(ft->lowerBound)); + const bool hasUpperBound = !get(follow(ft->upperBound)); + + DenseHashSet seen{nullptr}; + seen.insert(ty); + + if (!hasLowerBound && !hasUpperBound) + { + if (!isWithinFunction || (positiveCount + negativeCount == 1)) + emplaceType(asMutable(ty), builtinTypes->unknownType); + else + { + emplaceType(asMutable(ty), scope); + generics.push_back(ty); + } + } + + // It is possible that this free type has other free types in its upper + // or lower bounds. If this is the case, we must replace those + // references with never (for the lower bound) or unknown (for the upper + // bound). + // + // If we do not do this, we get tautological bounds like a <: a <: unknown. + else if (positiveCount && !hasUpperBound) + { + TypeId lb = follow(ft->lowerBound); + if (FreeType* lowerFree = getMutable(lb); lowerFree && lowerFree->upperBound == ty) + lowerFree->upperBound = builtinTypes->unknownType; + else + { + DenseHashSet replaceSeen{nullptr}; + replace(replaceSeen, lb, ty, builtinTypes->unknownType); + } + + if (lb != ty) + emplaceType(asMutable(ty), lb); + else if (!isWithinFunction || (positiveCount + negativeCount == 1)) + emplaceType(asMutable(ty), builtinTypes->unknownType); + else + { + // if the lower bound is the type in question, we don't actually have a lower bound. + emplaceType(asMutable(ty), scope); + generics.push_back(ty); + } + } + else + { + TypeId ub = follow(ft->upperBound); + if (FreeType* upperFree = getMutable(ub); upperFree && upperFree->lowerBound == ty) + upperFree->lowerBound = builtinTypes->neverType; + else + { + DenseHashSet replaceSeen{nullptr}; + replace(replaceSeen, ub, ty, builtinTypes->neverType); + } + + if (ub != ty) + emplaceType(asMutable(ty), ub); + else if (!isWithinFunction || (positiveCount + negativeCount == 1)) + emplaceType(asMutable(ty), builtinTypes->unknownType); + else + { + // if the upper bound is the type in question, we don't actually have an upper bound. + emplaceType(asMutable(ty), scope); + generics.push_back(ty); + } + } + + return false; + } + + size_t getCount(const DenseHashMap& map, const void* ty) + { + if (const size_t* count = map.find(ty)) + return *count; + else + return 0; + } + + bool visit(TypeId ty, const TableType&) override + { + if (cachedTypes->contains(ty)) + return false; + + const size_t positiveCount = getCount(positiveTypes, ty); + const size_t negativeCount = getCount(negativeTypes, ty); + + // FIXME: Free tables should probably just be replaced by upper bounds on free types. + // + // eg never <: 'a <: {x: number} & {z: boolean} + + if (!positiveCount && !negativeCount) + return true; + + TableType* tt = getMutable(ty); + LUAU_ASSERT(tt); + + if (!avoidSealingTables) + tt->state = TableState::Sealed; + + return true; + } + + bool visit(TypePackId tp, const FreeTypePack& ftp) override + { + if (!subsumes(scope, ftp.scope)) + return true; + + tp = follow(tp); + + const size_t positiveCount = getCount(positiveTypes, tp); + const size_t negativeCount = getCount(negativeTypes, tp); + + if (1 == positiveCount + negativeCount) + emplaceTypePack(asMutable(tp), builtinTypes->unknownTypePack); + else + { + emplaceTypePack(asMutable(tp), scope); + genericPacks.push_back(tp); + } + + return true; + } +}; + +struct FreeTypeSearcher : TypeVisitor +{ + NotNull scope; + NotNull> cachedTypes; + + explicit FreeTypeSearcher(NotNull scope, NotNull> cachedTypes) + : TypeVisitor(/*skipBoundTypes*/ true) + , scope(scope) + , cachedTypes(cachedTypes) + { + } + + enum Polarity + { + Positive, + Negative, + Both, + }; + + Polarity polarity = Positive; + + void flip() + { + switch (polarity) + { + case Positive: + polarity = Negative; + break; + case Negative: + polarity = Positive; + break; + case Both: + break; + } + } + + DenseHashSet seenPositive{nullptr}; + DenseHashSet seenNegative{nullptr}; + + bool seenWithPolarity(const void* ty) + { + switch (polarity) + { + case Positive: + { + if (seenPositive.contains(ty)) + return true; + + seenPositive.insert(ty); + return false; + } + case Negative: + { + if (seenNegative.contains(ty)) + return true; + + seenNegative.insert(ty); + return false; + } + case Both: + { + if (seenPositive.contains(ty) && seenNegative.contains(ty)) + return true; + + seenPositive.insert(ty); + seenNegative.insert(ty); + return false; + } + } + + return false; + } + + // The keys in these maps are either TypeIds or TypePackIds. It's safe to + // mix them because we only use these pointers as unique keys. We never + // indirect them. + DenseHashMap negativeTypes{0}; + DenseHashMap positiveTypes{0}; + + bool visit(TypeId ty) override + { + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) + return false; + + LUAU_ASSERT(ty); + return true; + } + + bool visit(TypeId ty, const FreeType& ft) override + { + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) + return false; + + if (!subsumes(scope, ft.scope)) + return true; + + switch (polarity) + { + case Positive: + positiveTypes[ty]++; + break; + case Negative: + negativeTypes[ty]++; + break; + case Both: + positiveTypes[ty]++; + negativeTypes[ty]++; + break; + } + + return true; + } + + bool visit(TypeId ty, const TableType& tt) override + { + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) + return false; + + if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope)) + { + switch (polarity) + { + case Positive: + positiveTypes[ty]++; + break; + case Negative: + negativeTypes[ty]++; + break; + case Both: + positiveTypes[ty]++; + negativeTypes[ty]++; + break; + } + } + + for (const auto& [_name, prop] : tt.props) + { + if (prop.isReadOnly()) + traverse(*prop.readTy); + else + { + LUAU_ASSERT(prop.isShared()); + + Polarity p = polarity; + polarity = Both; + traverse(prop.type()); + polarity = p; + } + } + + if (tt.indexer) + { + traverse(tt.indexer->indexType); + traverse(tt.indexer->indexResultType); + } + + return false; + } + + bool visit(TypeId ty, const FunctionType& ft) override + { + if (cachedTypes->contains(ty) || seenWithPolarity(ty)) + return false; + + flip(); + traverse(ft.argTypes); + flip(); + + traverse(ft.retTypes); + + return false; + } + + bool visit(TypeId, const ClassType&) override + { + return false; + } + + bool visit(TypePackId tp, const FreeTypePack& ftp) override + { + if (seenWithPolarity(tp)) + return false; + + if (!subsumes(scope, ftp.scope)) + return true; + + switch (polarity) + { + case Positive: + positiveTypes[tp]++; + break; + case Negative: + negativeTypes[tp]++; + break; + case Both: + positiveTypes[tp]++; + negativeTypes[tp]++; + break; + } + + return true; + } +}; + +// We keep a running set of types that will not change under generalization and +// only have outgoing references to types that are the same. We use this to +// short circuit generalization. It improves performance quite a lot. +// +// We do this by tracing through the type and searching for types that are +// uncacheable. If a type has a reference to an uncacheable type, it is itself +// uncacheable. +// +// If a type has no outbound references to uncacheable types, we add it to the +// cache. +struct TypeCacher : TypeOnceVisitor +{ + NotNull> cachedTypes; + + DenseHashSet uncacheable{nullptr}; + DenseHashSet uncacheablePacks{nullptr}; + + explicit TypeCacher(NotNull> cachedTypes) + : TypeOnceVisitor(/* skipBoundTypes */ true) + , cachedTypes(cachedTypes) + { + } + + void cache(TypeId ty) + { + cachedTypes->insert(ty); + } + + bool isCached(TypeId ty) const + { + return cachedTypes->contains(ty); + } + + void markUncacheable(TypeId ty) + { + uncacheable.insert(ty); + } + + void markUncacheable(TypePackId tp) + { + uncacheablePacks.insert(tp); + } + + bool isUncacheable(TypeId ty) const + { + return uncacheable.contains(ty); + } + + bool isUncacheable(TypePackId tp) const + { + return uncacheablePacks.contains(tp); + } + + bool visit(TypeId ty) override + { + if (isUncacheable(ty) || isCached(ty)) + return false; + return true; + } + + bool visit(TypeId ty, const FreeType& ft) override + { + // Free types are never cacheable. + LUAU_ASSERT(!isCached(ty)); + + if (!isUncacheable(ty)) + { + traverse(ft.lowerBound); + traverse(ft.upperBound); + + markUncacheable(ty); + } + + return false; + } + + bool visit(TypeId ty, const GenericType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const PrimitiveType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const SingletonType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const BlockedType&) override + { + markUncacheable(ty); + return false; + } + + bool visit(TypeId ty, const PendingExpansionType&) override + { + markUncacheable(ty); + return false; + } + + bool visit(TypeId ty, const FunctionType& ft) override + { + if (isCached(ty) || isUncacheable(ty)) + return false; + + traverse(ft.argTypes); + traverse(ft.retTypes); + for (TypeId gen : ft.generics) + traverse(gen); + + bool uncacheable = false; + + if (isUncacheable(ft.argTypes)) + uncacheable = true; + + else if (isUncacheable(ft.retTypes)) + uncacheable = true; + + for (TypeId argTy : ft.argTypes) + { + if (isUncacheable(argTy)) + { + uncacheable = true; + break; + } + } + + for (TypeId retTy : ft.retTypes) + { + if (isUncacheable(retTy)) + { + uncacheable = true; + break; + } + } + + for (TypeId g : ft.generics) + { + if (isUncacheable(g)) + { + uncacheable = true; + break; + } + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const TableType& tt) override + { + if (isCached(ty) || isUncacheable(ty)) + return false; + + if (tt.boundTo) + { + traverse(*tt.boundTo); + if (isUncacheable(*tt.boundTo)) + { + markUncacheable(ty); + return false; + } + } + + bool uncacheable = false; + + // This logic runs immediately after generalization, so any remaining + // unsealed tables are assuredly not cacheable. They may yet have + // properties added to them. + if (tt.state == TableState::Free || tt.state == TableState::Unsealed) + uncacheable = true; + + for (const auto& [_name, prop] : tt.props) + { + if (prop.readTy) + { + traverse(*prop.readTy); + + if (isUncacheable(*prop.readTy)) + uncacheable = true; + } + if (prop.writeTy && prop.writeTy != prop.readTy) + { + traverse(*prop.writeTy); + + if (isUncacheable(*prop.writeTy)) + uncacheable = true; + } + } + + if (tt.indexer) + { + traverse(tt.indexer->indexType); + if (isUncacheable(tt.indexer->indexType)) + uncacheable = true; + + traverse(tt.indexer->indexResultType); + if (isUncacheable(tt.indexer->indexResultType)) + uncacheable = true; + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const ClassType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const AnyType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const UnionType& ut) override + { + if (isUncacheable(ty) || isCached(ty)) + return false; + + bool uncacheable = false; + + for (TypeId partTy : ut.options) + { + traverse(partTy); + + uncacheable |= isUncacheable(partTy); + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const IntersectionType& it) override + { + if (isUncacheable(ty) || isCached(ty)) + return false; + + bool uncacheable = false; + + for (TypeId partTy : it.parts) + { + traverse(partTy); + + uncacheable |= isUncacheable(partTy); + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypeId ty, const UnknownType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const NeverType&) override + { + cache(ty); + return false; + } + + bool visit(TypeId ty, const NegationType& nt) override + { + if (!isCached(ty) && !isUncacheable(ty)) + { + traverse(nt.ty); + + if (isUncacheable(nt.ty)) + markUncacheable(ty); + else + cache(ty); + } + + return false; + } + + bool visit(TypeId ty, const TypeFamilyInstanceType& tfit) override + { + if (isCached(ty) || isUncacheable(ty)) + return false; + + bool uncacheable = false; + + for (TypeId argTy : tfit.typeArguments) + { + traverse(argTy); + + if (isUncacheable(argTy)) + uncacheable = true; + } + + for (TypePackId argPack : tfit.packArguments) + { + traverse(argPack); + + if (isUncacheable(argPack)) + uncacheable = true; + } + + if (uncacheable) + markUncacheable(ty); + else + cache(ty); + + return false; + } + + bool visit(TypePackId tp, const FreeTypePack&) override + { + markUncacheable(tp); + return false; + } + + bool visit(TypePackId tp, const VariadicTypePack& vtp) override + { + if (isUncacheable(tp)) + return false; + + traverse(vtp.ty); + + if (isUncacheable(vtp.ty)) + markUncacheable(tp); + + return false; + } + + bool visit(TypePackId tp, const BlockedTypePack&) override + { + markUncacheable(tp); + return false; + } + + bool visit(TypePackId tp, const TypeFamilyInstanceTypePack&) override + { + markUncacheable(tp); + return false; + } +}; + +std::optional generalize(NotNull arena, NotNull builtinTypes, NotNull scope, + NotNull> cachedTypes, TypeId ty, bool avoidSealingTables) +{ + ty = follow(ty); + + if (ty->owningArena != arena || ty->persistent) + return ty; + + if (const FunctionType* ft = get(ty); ft && (!ft->generics.empty() || !ft->genericPacks.empty())) + return ty; + + FreeTypeSearcher fts{scope, cachedTypes}; + fts.traverse(ty); + + MutatingGeneralizer gen{builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables}; + + gen.traverse(ty); + + /* MutatingGeneralizer mutates types in place, so it is possible that ty has + * been transmuted to a BoundType. We must follow it again and verify that + * we are allowed to mutate it before we attach generics to it. + */ + ty = follow(ty); + + if (ty->owningArena != arena || ty->persistent) + return ty; + + TypeCacher cacher{cachedTypes}; + cacher.traverse(ty); + + FunctionType* ftv = getMutable(ty); + if (ftv) + { + ftv->generics = std::move(gen.generics); + ftv->genericPacks = std::move(gen.genericPacks); + } + + return ty; +} + +} // namespace Luau diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 525319c6..811aa048 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -11,10 +11,23 @@ #include LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauReusableSubstitutions) namespace Luau { +void Instantiation::resetState(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope) +{ + LUAU_ASSERT(FFlag::LuauReusableSubstitutions); + + Substitution::resetState(log, arena); + + this->builtinTypes = builtinTypes; + + this->level = level; + this->scope = scope; +} + bool Instantiation::isDirty(TypeId ty) { if (const FunctionType* ftv = log->getMutable(ty)) @@ -58,13 +71,26 @@ TypeId Instantiation::clean(TypeId ty) clone.argNames = ftv->argNames; TypeId result = addType(std::move(clone)); - // Annoyingly, we have to do this even if there are no generics, - // to replace any generic tables. - ReplaceGenerics replaceGenerics{log, arena, builtinTypes, level, scope, ftv->generics, ftv->genericPacks}; + if (FFlag::LuauReusableSubstitutions) + { + // Annoyingly, we have to do this even if there are no generics, + // to replace any generic tables. + reusableReplaceGenerics.resetState(log, arena, builtinTypes, level, scope, ftv->generics, ftv->genericPacks); - // TODO: What to do if this returns nullopt? - // We don't have access to the error-reporting machinery - result = replaceGenerics.substitute(result).value_or(result); + // TODO: What to do if this returns nullopt? + // We don't have access to the error-reporting machinery + result = reusableReplaceGenerics.substitute(result).value_or(result); + } + else + { + // Annoyingly, we have to do this even if there are no generics, + // to replace any generic tables. + ReplaceGenerics replaceGenerics{log, arena, builtinTypes, level, scope, ftv->generics, ftv->genericPacks}; + + // TODO: What to do if this returns nullopt? + // We don't have access to the error-reporting machinery + result = replaceGenerics.substitute(result).value_or(result); + } asMutable(result)->documentationSymbol = ty->documentationSymbol; return result; @@ -76,6 +102,22 @@ TypePackId Instantiation::clean(TypePackId tp) return tp; } +void ReplaceGenerics::resetState(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope, + const std::vector& generics, const std::vector& genericPacks) +{ + LUAU_ASSERT(FFlag::LuauReusableSubstitutions); + + Substitution::resetState(log, arena); + + this->builtinTypes = builtinTypes; + + this->level = level; + this->scope = scope; + + this->generics = generics; + this->genericPacks = genericPacks; +} + bool ReplaceGenerics::ignoreChildren(TypeId ty) { if (const FunctionType* ftv = log->getMutable(ty)) diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index d79361c0..e9d4ca53 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -16,6 +16,11 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauAttributeSyntax) +LUAU_FASTFLAG(LuauAttribute) +LUAU_FASTFLAG(LuauNativeAttribute) +LUAU_FASTFLAGVARIABLE(LintRedundantNativeAttribute, false) + namespace Luau { @@ -2922,6 +2927,64 @@ static void lintComments(LintContext& context, const std::vector& ho } } +static bool hasNativeCommentDirective(const std::vector& hotcomments) +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + LUAU_ASSERT(FFlag::LuauNativeAttribute); + LUAU_ASSERT(FFlag::LintRedundantNativeAttribute); + + for (const HotComment& hc : hotcomments) + { + if (hc.content.empty() || hc.content[0] == ' ' || hc.content[0] == '\t') + continue; + + if (hc.header) + { + size_t space = hc.content.find_first_of(" \t"); + std::string_view first = std::string_view(hc.content).substr(0, space); + + if (first == "native") + return true; + } + } + + return false; +} + +struct LintRedundantNativeAttribute : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + LUAU_ASSERT(FFlag::LuauNativeAttribute); + LUAU_ASSERT(FFlag::LintRedundantNativeAttribute); + + LintRedundantNativeAttribute pass; + pass.context = &context; + context.root->visit(&pass); + } + +private: + LintContext* context; + + bool visit(AstExprFunction* node) override + { + node->body->visit(this); + + for (const auto attribute : node->attributes) + { + if (attribute->type == AstAttr::Type::Native) + { + emitWarning(*context, LintWarning::Code_RedundantNativeAttribute, attribute->location, + "native attribute on a function is redundant in a native module; consider removing it"); + } + } + + return false; + } +}; + std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, const std::vector& hotcomments, const LintOptions& options) { @@ -3008,6 +3071,13 @@ std::vector lint(AstStat* root, const AstNameTable& names, const Sc if (context.warningEnabled(LintWarning::Code_ComparisonPrecedence)) LintComparisonPrecedence::process(context); + if (FFlag::LuauAttributeSyntax && FFlag::LuauNativeAttribute && FFlag::LintRedundantNativeAttribute && + context.warningEnabled(LintWarning::Code_RedundantNativeAttribute)) + { + if (hasNativeCommentDirective(hotcomments)) + LintRedundantNativeAttribute::process(context); + } + std::sort(context.result.begin(), context.result.end(), WarningComparator()); return context.result; diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 848c8684..16fe9546 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -17,23 +17,23 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) LUAU_FASTFLAGVARIABLE(LuauNormalizeAwayUninhabitableTables, false) -LUAU_FASTFLAGVARIABLE(LuauFixNormalizeCaching, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeNotUnknownIntersection, false); -LUAU_FASTFLAGVARIABLE(LuauFixCyclicUnionsOfIntersections, false); +LUAU_FASTFLAGVARIABLE(LuauFixReduceStackPressure, false); +LUAU_FASTFLAGVARIABLE(LuauFixCyclicTablesBlowingStack, false); // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -static bool fixNormalizeCaching() +static bool fixReduceStackPressure() { - return FFlag::LuauFixNormalizeCaching || FFlag::DebugLuauDeferredConstraintResolution; + return FFlag::LuauFixReduceStackPressure || FFlag::DebugLuauDeferredConstraintResolution; } -static bool fixCyclicUnionsOfIntersections() +static bool fixCyclicTablesBlowingStack() { - return FFlag::LuauFixCyclicUnionsOfIntersections || FFlag::DebugLuauDeferredConstraintResolution; + return FFlag::LuauFixCyclicTablesBlowingStack || FFlag::DebugLuauDeferredConstraintResolution; } namespace Luau @@ -45,6 +45,14 @@ static bool normalizeAwayUninhabitableTables() return FFlag::LuauNormalizeAwayUninhabitableTables || FFlag::DebugLuauDeferredConstraintResolution; } +static bool shouldEarlyExit(NormalizationResult res) +{ + // if res is hit limits, return control flow + if (res == NormalizationResult::HitLimits || res == NormalizationResult::False) + return true; + return false; +} + TypeIds::TypeIds(std::initializer_list tys) { for (TypeId ty : tys) @@ -339,6 +347,12 @@ bool NormalizedType::isSubtypeOfString() const !hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars(); } +bool NormalizedType::isSubtypeOfBooleans() const +{ + return hasBooleans() && !hasTops() && !hasClasses() && !hasErrors() && !hasNils() && !hasNumbers() && !hasStrings() && !hasThreads() && + !hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars(); +} + bool NormalizedType::shouldSuppressErrors() const { return hasErrors() || get(tops); @@ -547,22 +561,21 @@ NormalizationResult Normalizer::isInhabited(TypeId ty, Set& seen) return isInhabited(mtv->metatable, seen); } - if (fixNormalizeCaching()) - { - std::shared_ptr norm = normalize(ty); - return isInhabited(norm.get(), seen); - } - else - { - const NormalizedType* norm = DEPRECATED_normalize(ty); - return isInhabited(norm, seen); - } + std::shared_ptr norm = normalize(ty); + return isInhabited(norm.get(), seen); } NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right) +{ + Set seen{nullptr}; + return isIntersectionInhabited(left, right, seen); +} + +NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right, Set& seenSet) { left = follow(left); right = follow(right); + // We're asking if intersection is inahbited between left and right but we've already seen them .... if (cacheInhabitance) { @@ -570,12 +583,8 @@ NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId righ return *result ? NormalizationResult::True : NormalizationResult::False; } - Set seen{nullptr}; - seen.insert(left); - seen.insert(right); - NormalizedType norm{builtinTypes}; - NormalizationResult res = normalizeIntersections({left, right}, norm); + NormalizationResult res = normalizeIntersections({left, right}, norm, seenSet); if (res != NormalizationResult::True) { if (cacheInhabitance && res == NormalizationResult::False) @@ -584,7 +593,7 @@ NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId righ return res; } - NormalizationResult result = isInhabited(&norm, seen); + NormalizationResult result = isInhabited(&norm, seenSet); if (cacheInhabitance && result == NormalizationResult::True) cachedIsInhabitedIntersection[{left, right}] = true; @@ -856,31 +865,6 @@ Normalizer::Normalizer(TypeArena* arena, NotNull builtinTypes, Not { } -const NormalizedType* Normalizer::DEPRECATED_normalize(TypeId ty) -{ - if (!arena) - sharedState->iceHandler->ice("Normalizing types outside a module"); - - auto found = cachedNormals.find(ty); - if (found != cachedNormals.end()) - return found->second.get(); - - NormalizedType norm{builtinTypes}; - Set seenSetTypes{nullptr}; - NormalizationResult res = unionNormalWithTy(norm, ty, seenSetTypes); - if (res != NormalizationResult::True) - return nullptr; - if (norm.isUnknown()) - { - clearNormal(norm); - norm.tops = builtinTypes->unknownType; - } - std::shared_ptr shared = std::make_shared(std::move(norm)); - const NormalizedType* result = shared.get(); - cachedNormals[ty] = std::move(shared); - return result; -} - static bool isCacheable(TypeId ty, Set& seen); static bool isCacheable(TypePackId tp, Set& seen) @@ -935,9 +919,6 @@ static bool isCacheable(TypeId ty, Set& seen) static bool isCacheable(TypeId ty) { - if (!fixNormalizeCaching()) - return true; - Set seen{nullptr}; return isCacheable(ty, seen); } @@ -971,7 +952,7 @@ std::shared_ptr Normalizer::normalize(TypeId ty) return shared; } -NormalizationResult Normalizer::normalizeIntersections(const std::vector& intersections, NormalizedType& outType) +NormalizationResult Normalizer::normalizeIntersections(const std::vector& intersections, NormalizedType& outType, Set& seenSet) { if (!arena) sharedState->iceHandler->ice("Normalizing types outside a module"); @@ -981,7 +962,7 @@ NormalizationResult Normalizer::normalizeIntersections(const std::vector Set seenSetTypes{nullptr}; for (auto ty : intersections) { - NormalizationResult res = intersectNormalWithTy(norm, ty, seenSetTypes); + NormalizationResult res = intersectNormalWithTy(norm, ty, seenSet); if (res != NormalizationResult::True) return res; } @@ -1729,6 +1710,20 @@ bool Normalizer::withinResourceLimits() return true; } +NormalizationResult Normalizer::intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect) +{ + + std::optional negated; + + std::shared_ptr normal = normalize(toNegate); + negated = negateNormal(*normal); + + if (!negated) + return NormalizationResult::False; + intersectNormals(intersect, *negated); + return NormalizationResult::True; +} + // See above for an explaination of `ignoreSmallerTyvars`. NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, Set& seenSetTypes, int ignoreSmallerTyvars) { @@ -1775,12 +1770,9 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t } else if (const IntersectionType* itv = get(there)) { - if (fixCyclicUnionsOfIntersections()) - { - if (seenSetTypes.count(there)) - return NormalizationResult::True; - seenSetTypes.insert(there); - } + if (seenSetTypes.count(there)) + return NormalizationResult::True; + seenSetTypes.insert(there); NormalizedType norm{builtinTypes}; norm.tops = builtinTypes->anyType; @@ -1789,14 +1781,12 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t NormalizationResult res = intersectNormalWithTy(norm, *it, seenSetTypes); if (res != NormalizationResult::True) { - if (fixCyclicUnionsOfIntersections()) - seenSetTypes.erase(there); + seenSetTypes.erase(there); return res; } } - if (fixCyclicUnionsOfIntersections()) - seenSetTypes.erase(there); + seenSetTypes.erase(there); return unionNormals(here, norm); } @@ -1814,12 +1804,6 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t if (!isCacheable(there)) here.isCacheable = false; } - else if (auto lt = get(there)) - { - // FIXME? This is somewhat questionable. - // Maybe we should assert because this should never happen? - unionNormalWithTy(here, lt->domain, seenSetTypes, ignoreSmallerTyvars); - } else if (get(there)) unionFunctionsWithFunction(here.functions, there); else if (get(there) || get(there)) @@ -1876,16 +1860,8 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t { std::optional tn; - if (fixNormalizeCaching()) - { - std::shared_ptr thereNormal = normalize(ntv->ty); - tn = negateNormal(*thereNormal); - } - else - { - const NormalizedType* thereNormal = DEPRECATED_normalize(ntv->ty); - tn = negateNormal(*thereNormal); - } + std::shared_ptr thereNormal = normalize(ntv->ty); + tn = negateNormal(*thereNormal); if (!tn) return NormalizationResult::False; @@ -2484,7 +2460,7 @@ std::optional Normalizer::intersectionOfTypePacks(TypePackId here, T return arena->addTypePack({}); } -std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there) +std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there, Set& seenSet) { if (here == there) return here; @@ -2541,8 +2517,9 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there state = tttv->state; TypeLevel level = max(httv->level, tttv->level); - TableType result{state, level}; + Scope* scope = max(httv->scope, tttv->scope); + std::unique_ptr result = nullptr; bool hereSubThere = true; bool thereSubHere = true; @@ -2563,8 +2540,43 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there if (tprop.readTy.has_value()) { // if the intersection of the read types of a property is uninhabited, the whole table is `never`. - if (normalizeAwayUninhabitableTables() && NormalizationResult::False == isIntersectionInhabited(*hprop.readTy, *tprop.readTy)) - return {builtinTypes->neverType}; + if (fixReduceStackPressure()) + { + // We've seen these table prop elements before and we're about to ask if their intersection + // is inhabited + if (fixCyclicTablesBlowingStack()) + { + if (seenSet.contains(*hprop.readTy) && seenSet.contains(*tprop.readTy)) + { + seenSet.erase(*hprop.readTy); + seenSet.erase(*tprop.readTy); + return {builtinTypes->neverType}; + } + else + { + seenSet.insert(*hprop.readTy); + seenSet.insert(*tprop.readTy); + } + } + + NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy, seenSet); + + // Cleanup + if (fixCyclicTablesBlowingStack()) + { + seenSet.erase(*hprop.readTy); + seenSet.erase(*tprop.readTy); + } + + if (normalizeAwayUninhabitableTables() && NormalizationResult::True != res) + return {builtinTypes->neverType}; + } + else + { + if (normalizeAwayUninhabitableTables() && + NormalizationResult::False == isIntersectionInhabited(*hprop.readTy, *tprop.readTy)) + return {builtinTypes->neverType}; + } TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result; prop.readTy = ty; @@ -2614,14 +2626,21 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there // TODO: string indexers if (prop.readTy || prop.writeTy) - result.props[name] = prop; + { + if (!result.get()) + result = std::make_unique(TableType{state, level, scope}); + result->props[name] = prop; + } } for (const auto& [name, tprop] : tttv->props) { if (httv->props.count(name) == 0) { - result.props[name] = tprop; + if (!result.get()) + result = std::make_unique(TableType{state, level, scope}); + + result->props[name] = tprop; hereSubThere = false; } } @@ -2631,18 +2650,24 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there // TODO: What should intersection of indexes be? TypeId index = unionType(httv->indexer->indexType, tttv->indexer->indexType); TypeId indexResult = intersectionType(httv->indexer->indexResultType, tttv->indexer->indexResultType); - result.indexer = {index, indexResult}; + if (!result.get()) + result = std::make_unique(TableType{state, level, scope}); + result->indexer = {index, indexResult}; hereSubThere &= (httv->indexer->indexType == index) && (httv->indexer->indexResultType == indexResult); thereSubHere &= (tttv->indexer->indexType == index) && (tttv->indexer->indexResultType == indexResult); } else if (httv->indexer) { - result.indexer = httv->indexer; + if (!result.get()) + result = std::make_unique(TableType{state, level, scope}); + result->indexer = httv->indexer; thereSubHere = false; } else if (tttv->indexer) { - result.indexer = tttv->indexer; + if (!result.get()) + result = std::make_unique(TableType{state, level, scope}); + result->indexer = tttv->indexer; hereSubThere = false; } @@ -2652,12 +2677,17 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there else if (thereSubHere) table = ttable; else - table = arena->addType(std::move(result)); + { + if (result.get()) + table = arena->addType(std::move(*result)); + else + table = arena->addType(TableType{state, level, scope}); + } if (tmtable && hmtable) { // NOTE: this assumes metatables are ivariant - if (std::optional mtable = intersectionOfTables(hmtable, tmtable)) + if (std::optional mtable = intersectionOfTables(hmtable, tmtable, seenSet)) { if (table == htable && *mtable == hmtable) return here; @@ -2687,12 +2717,12 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there return table; } -void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there) +void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there, Set& seenSetTypes) { TypeIds tmp; for (TypeId here : heres) { - if (std::optional inter = intersectionOfTables(here, there)) + if (std::optional inter = intersectionOfTables(here, there, seenSetTypes)) tmp.insert(*inter); } heres.retain(tmp); @@ -2706,7 +2736,8 @@ void Normalizer::intersectTables(TypeIds& heres, const TypeIds& theres) { for (TypeId there : theres) { - if (std::optional inter = intersectionOfTables(here, there)) + Set seenSetTypes{nullptr}; + if (std::optional inter = intersectionOfTables(here, there, seenSetTypes)) tmp.insert(*inter); } } @@ -3047,7 +3078,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type return NormalizationResult::True; } else if (get(there) || get(there) || get(there) || get(there) || - get(there) || get(there)) + get(there)) { NormalizedType thereNorm{builtinTypes}; NormalizedType topNorm{builtinTypes}; @@ -3056,10 +3087,6 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type here.isCacheable = false; return intersectNormals(here, thereNorm); } - else if (auto lt = get(there)) - { - return intersectNormalWithTy(here, lt->domain, seenSetTypes); - } NormalizedTyvars tyvars = std::move(here.tyvars); @@ -3074,7 +3101,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type { TypeIds tables = std::move(here.tables); clearNormal(here); - intersectTablesWithTable(tables, there); + intersectTablesWithTable(tables, there, seenSetTypes); here.tables = std::move(tables); } else if (get(there)) @@ -3148,60 +3175,17 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type subtractSingleton(here, follow(ntv->ty)); else if (get(t)) { - if (fixNormalizeCaching()) - { - std::shared_ptr normal = normalize(t); - std::optional negated = negateNormal(*normal); - if (!negated) - return NormalizationResult::False; - intersectNormals(here, *negated); - } - else - { - const NormalizedType* normal = DEPRECATED_normalize(t); - std::optional negated = negateNormal(*normal); - if (!negated) - return NormalizationResult::False; - intersectNormals(here, *negated); - } + NormalizationResult res = intersectNormalWithNegationTy(t, here); + if (shouldEarlyExit(res)) + return res; } else if (const UnionType* itv = get(t)) { - if (fixNormalizeCaching()) + for (TypeId part : itv->options) { - for (TypeId part : itv->options) - { - std::shared_ptr normalPart = normalize(part); - std::optional negated = negateNormal(*normalPart); - if (!negated) - return NormalizationResult::False; - intersectNormals(here, *negated); - } - } - else - { - if (fixNormalizeCaching()) - { - for (TypeId part : itv->options) - { - std::shared_ptr normalPart = normalize(part); - std::optional negated = negateNormal(*normalPart); - if (!negated) - return NormalizationResult::False; - intersectNormals(here, *negated); - } - } - else - { - for (TypeId part : itv->options) - { - const NormalizedType* normalPart = DEPRECATED_normalize(part); - std::optional negated = negateNormal(*normalPart); - if (!negated) - return NormalizationResult::False; - intersectNormals(here, *negated); - } - } + NormalizationResult res = intersectNormalWithNegationTy(part, here); + if (shouldEarlyExit(res)) + return res; } } else if (get(t)) diff --git a/Analysis/src/Set.cpp b/Analysis/src/Set.cpp deleted file mode 100644 index 1819e28a..00000000 --- a/Analysis/src/Set.cpp +++ /dev/null @@ -1,5 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details - -#include "Luau/Common.h" - -LUAU_FASTFLAGVARIABLE(LuauFixSetIter, false) diff --git a/Analysis/src/Simplify.cpp b/Analysis/src/Simplify.cpp index d29546a2..dae7b2d2 100644 --- a/Analysis/src/Simplify.cpp +++ b/Analysis/src/Simplify.cpp @@ -1255,6 +1255,10 @@ TypeId TypeSimplifier::union_(TypeId left, TypeId right) case Relation::Coincident: case Relation::Superset: return left; + case Relation::Subset: + newParts.insert(right); + changed = true; + break; default: newParts.insert(part); newParts.insert(right); @@ -1364,6 +1368,17 @@ SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull< return SimplifyResult{res, std::move(s.blockedTypes)}; } +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, std::set parts) +{ + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + + TypeSimplifier s{builtinTypes, arena}; + + TypeId res = s.intersectFromParts(std::move(parts)); + + return SimplifyResult{res, std::move(s.blockedTypes)}; +} + SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right) { LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index a79c75bf..5d8ed045 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -11,6 +11,7 @@ LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTINTVARIABLE(LuauTarjanPreallocationSize, 256); +LUAU_FASTFLAG(LuauReusableSubstitutions) namespace Luau { @@ -24,8 +25,6 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a // We decline to copy them. if constexpr (std::is_same_v) return ty; - else if constexpr (std::is_same_v) - return ty; else if constexpr (std::is_same_v) { // This should never happen, but visit() cannot see it. @@ -148,6 +147,8 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a } Tarjan::Tarjan() + : typeToIndex(nullptr, FFlag::LuauReusableSubstitutions ? FInt::LuauTarjanPreallocationSize : 0) + , packToIndex(nullptr, FFlag::LuauReusableSubstitutions ? FInt::LuauTarjanPreallocationSize : 0) { nodes.reserve(FInt::LuauTarjanPreallocationSize); stack.reserve(FInt::LuauTarjanPreallocationSize); @@ -448,14 +449,31 @@ TarjanResult Tarjan::visitRoot(TypePackId tp) return loop(); } -void Tarjan::clearTarjan() +void Tarjan::clearTarjan(const TxnLog* log) { - typeToIndex.clear(); - packToIndex.clear(); + if (FFlag::LuauReusableSubstitutions) + { + typeToIndex.clear(~0u); + packToIndex.clear(~0u); + } + else + { + typeToIndex.clear(); + packToIndex.clear(); + } + nodes.clear(); stack.clear(); + if (FFlag::LuauReusableSubstitutions) + { + childCount = 0; + // childLimit setting stays the same + + this->log = log; + } + edgesTy.clear(); edgesTp.clear(); worklist.clear(); @@ -530,7 +548,6 @@ Substitution::Substitution(const TxnLog* log_, TypeArena* arena) { log = log_; LUAU_ASSERT(log); - LUAU_ASSERT(arena); } void Substitution::dontTraverseInto(TypeId ty) @@ -548,7 +565,7 @@ std::optional Substitution::substitute(TypeId ty) ty = log->follow(ty); // clear algorithm state for reentrancy - clearTarjan(); + clearTarjan(log); auto result = findDirty(ty); if (result != TarjanResult::Ok) @@ -581,7 +598,7 @@ std::optional Substitution::substitute(TypePackId tp) tp = log->follow(tp); // clear algorithm state for reentrancy - clearTarjan(); + clearTarjan(log); auto result = findDirty(tp); if (result != TarjanResult::Ok) @@ -609,6 +626,23 @@ std::optional Substitution::substitute(TypePackId tp) return newTp; } +void Substitution::resetState(const TxnLog* log, TypeArena* arena) +{ + LUAU_ASSERT(FFlag::LuauReusableSubstitutions); + + clearTarjan(log); + + this->arena = arena; + + newTypes.clear(); + newPacks.clear(); + replacedTypes.clear(); + replacedTypePacks.clear(); + + noTraverseTypes.clear(); + noTraverseTypePacks.clear(); +} + TypeId Substitution::clone(TypeId ty) { return shallowClone(ty, *arena, log, /* alwaysClone */ true); diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index f2d51b31..040c3fc6 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -1438,6 +1438,7 @@ SubtypingResult Subtyping::isCovariantWith( result.andAlso(isCovariantWith(env, subNorm->strings, superNorm->strings)); result.andAlso(isCovariantWith(env, subNorm->strings, superNorm->tables)); result.andAlso(isCovariantWith(env, subNorm->threads, superNorm->threads)); + result.andAlso(isCovariantWith(env, subNorm->buffers, superNorm->buffers)); result.andAlso(isCovariantWith(env, subNorm->tables, superNorm->tables)); result.andAlso(isCovariantWith(env, subNorm->functions, superNorm->functions)); // isCovariantWith(subNorm->tyvars, superNorm->tyvars); diff --git a/Analysis/src/TableLiteralInference.cpp b/Analysis/src/TableLiteralInference.cpp index 414544b6..3514ff65 100644 --- a/Analysis/src/TableLiteralInference.cpp +++ b/Analysis/src/TableLiteralInference.cpp @@ -337,7 +337,9 @@ TypeId matchLiteralType(NotNull> astTypes, TypeId matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, expectedTableTy->indexer->indexResultType, *propTy, item.value, toBlock); - tableTy->indexer->indexResultType = matchedType; + // if the index result type is the prop type, we can replace it with the matched type here. + if (tableTy->indexer->indexResultType == *propTy) + tableTy->indexer->indexResultType = matchedType; } } else if (item.kind == AstExprTable::Item::General) diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 9093b38a..17b595b1 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -262,14 +262,6 @@ void StateDot::visitChildren(TypeId ty, int index) visitChild(t.upperBound, index, "[upperBound]"); } } - else if constexpr (std::is_same_v) - { - formatAppend(result, "LocalType"); - finishNodeLabel(ty); - finishNode(); - - visitChild(t.domain, 1, "[domain]"); - } else if constexpr (std::is_same_v) { formatAppend(result, "AnyType %d", index); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index cb6b2f4a..dca041a2 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -20,7 +20,6 @@ #include LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAGVARIABLE(LuauToStringiteTypesSingleLine, false) /* * Enables increasing levels of verbosity for Luau type names when stringifying. @@ -101,16 +100,6 @@ struct FindCyclicTypes final : TypeVisitor return false; } - bool visit(TypeId ty, const LocalType& lt) override - { - if (!visited.insert(ty)) - return false; - - traverse(lt.domain); - - return false; - } - bool visit(TypeId ty, const TableType& ttv) override { if (!visited.insert(ty)) @@ -526,21 +515,6 @@ struct TypeStringifier } } - void operator()(TypeId ty, const LocalType& lt) - { - state.emit("l-"); - state.emit(lt.name); - if (FInt::DebugLuauVerboseTypeNames >= 1) - { - state.emit("["); - state.emit(lt.blockCount); - state.emit("]"); - } - state.emit("=["); - stringify(lt.domain); - state.emit("]"); - } - void operator()(TypeId, const BoundType& btv) { stringify(btv.boundTo); @@ -1725,6 +1699,18 @@ std::string generateName(size_t i) return n; } +std::string toStringVector(const std::vector& types, ToStringOptions& opts) +{ + std::string s; + for (TypeId ty : types) + { + if (!s.empty()) + s += ", "; + s += toString(ty, opts); + } + return s; +} + std::string toString(const Constraint& constraint, ToStringOptions& opts) { auto go = [&opts](auto&& c) -> std::string { @@ -1755,7 +1741,7 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) else if constexpr (std::is_same_v) { std::string iteratorStr = tos(c.iterator); - std::string variableStr = tos(c.variables); + std::string variableStr = toStringVector(c.variables, opts); return variableStr + " ~ iterate " + iteratorStr; } @@ -1788,23 +1774,16 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) { return tos(c.resultType) + " ~ hasProp " + tos(c.subjectType) + ", \"" + c.prop + "\" ctx=" + std::to_string(int(c.context)); } - else if constexpr (std::is_same_v) - { - const std::string pathStr = c.path.size() == 1 ? "\"" + c.path[0] + "\"" : "[\"" + join(c.path, "\", \"") + "\"]"; - return tos(c.resultType) + " ~ setProp " + tos(c.subjectType) + ", " + pathStr + " " + tos(c.propType); - } else if constexpr (std::is_same_v) { return tos(c.resultType) + " ~ hasIndexer " + tos(c.subjectType) + " " + tos(c.indexType); } - else if constexpr (std::is_same_v) - { - return "setIndexer " + tos(c.subjectType) + " [ " + tos(c.indexType) + " ] " + tos(c.propType); - } + else if constexpr (std::is_same_v) + return "assignProp " + tos(c.lhsType) + " " + c.propName + " " + tos(c.rhsType); + else if constexpr (std::is_same_v) + return "assignIndex " + tos(c.lhsType) + " " + tos(c.indexType) + " " + tos(c.rhsType); else if constexpr (std::is_same_v) - return tos(c.resultPack) + " ~ ...unpack " + tos(c.sourcePack); - else if constexpr (std::is_same_v) - return tos(c.resultType) + " ~ unpack " + tos(c.sourceType); + return toStringVector(c.resultPack, opts) + " ~ ...unpack " + tos(c.sourcePack); else if constexpr (std::is_same_v) return "reduce " + tos(c.ty); else if constexpr (std::is_same_v) diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 85b8849f..d78bf157 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -1182,11 +1182,11 @@ std::string toString(AstNode* node) Printer printer(writer); printer.writeTypes = true; - if (auto statNode = dynamic_cast(node)) + if (auto statNode = node->asStat()) printer.visualize(*statNode); - else if (auto exprNode = dynamic_cast(node)) + else if (auto exprNode = node->asExpr()) printer.visualize(*exprNode); - else if (auto typeNode = dynamic_cast(node)) + else if (auto typeNode = node->asType()) printer.visualizeTypeAnnotation(*typeNode); return writer.str(); diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 2bd858a8..0d65b787 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -561,6 +561,11 @@ void BlockedType::setOwner(Constraint* newOwner) owner = newOwner; } +void BlockedType::replaceOwner(Constraint* newOwner) +{ + owner = newOwner; +} + PendingExpansionType::PendingExpansionType( std::optional prefix, AstName name, std::vector typeArguments, std::vector packArguments) : prefix(prefix) diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index f1fe83ee..c0294fc9 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -338,10 +338,6 @@ public: { return allocator->alloc(Location(), std::nullopt, AstName("free"), std::nullopt, Location()); } - AstType* operator()(const LocalType& lt) - { - return Luau::visit(*this, lt.domain->ty); - } AstType* operator()(const UnionType& uv) { AstArray unionTypes; diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index a888564e..c53a5d30 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -446,7 +446,6 @@ struct TypeChecker2 .errors; if (!isErrorSuppressing(location, instance)) reportErrors(std::move(errors)); - return instance; } @@ -1108,10 +1107,13 @@ struct TypeChecker2 void visit(AstStatCompoundAssign* stat) { AstExprBinary fake{stat->location, stat->op, stat->var, stat->value}; - TypeId resultTy = visit(&fake, stat); + visit(&fake, stat); + + TypeId* resultTy = module->astCompoundAssignResultTypes.find(stat); + LUAU_ASSERT(resultTy); TypeId varTy = lookupType(stat->var); - testIsSubtype(resultTy, varTy, stat->location); + testIsSubtype(*resultTy, varTy, stat->location); } void visit(AstStatFunction* stat) @@ -1242,13 +1244,14 @@ struct TypeChecker2 void visit(AstExprConstantBool* expr) { -#if defined(LUAU_ENABLE_ASSERT) + // booleans use specialized inference logic for singleton types, which can lead to real type errors here. + const TypeId bestType = expr->value ? builtinTypes->trueType : builtinTypes->falseType; const TypeId inferredType = lookupType(expr); const SubtypingResult r = subtyping->isSubtype(bestType, inferredType); - LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, inferredType)); -#endif + if (!r.isSubtype && !isErrorSuppressing(expr->location, inferredType)) + reportError(TypeMismatch{inferredType, bestType}, expr->location); } void visit(AstExprConstantNumber* expr) @@ -1264,13 +1267,14 @@ struct TypeChecker2 void visit(AstExprConstantString* expr) { -#if defined(LUAU_ENABLE_ASSERT) + // strings use specialized inference logic for singleton types, which can lead to real type errors here. + const TypeId bestType = module->internalTypes.addType(SingletonType{StringSingleton{std::string{expr->value.data, expr->value.size}}}); const TypeId inferredType = lookupType(expr); const SubtypingResult r = subtyping->isSubtype(bestType, inferredType); - LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, inferredType)); -#endif + if (!r.isSubtype && !isErrorSuppressing(expr->location, inferredType)) + reportError(TypeMismatch{inferredType, bestType}, expr->location); } void visit(AstExprLocal* expr) @@ -1280,7 +1284,9 @@ struct TypeChecker2 void visit(AstExprGlobal* expr) { - // TODO! + NotNull scope = stack.back(); + if (!scope->lookup(expr->name)) + reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location); } void visit(AstExprVarargs* expr) @@ -1534,6 +1540,24 @@ struct TypeChecker2 visitExprName(indexName->expr, indexName->location, indexName->index.value, context, builtinTypes->stringType); } + void indexExprMetatableHelper(AstExprIndexExpr* indexExpr, const MetatableType* metaTable, TypeId exprType, TypeId indexType) + { + if (auto tt = get(follow(metaTable->table)); tt && tt->indexer) + testIsSubtype(indexType, tt->indexer->indexType, indexExpr->index->location); + else if (auto mt = get(follow(metaTable->table))) + indexExprMetatableHelper(indexExpr, mt, exprType, indexType); + else if (auto tmt = get(follow(metaTable->metatable)); tmt && tmt->indexer) + testIsSubtype(indexType, tmt->indexer->indexType, indexExpr->index->location); + else if (auto mtmt = get(follow(metaTable->metatable))) + indexExprMetatableHelper(indexExpr, mtmt, exprType, indexType); + else + { + LUAU_ASSERT(tt || get(follow(metaTable->table))); + + reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); + } + } + void visit(AstExprIndexExpr* indexExpr, ValueContext context) { if (auto str = indexExpr->index->as()) @@ -1557,6 +1581,10 @@ struct TypeChecker2 else reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); } + else if (auto mt = get(exprType)) + { + return indexExprMetatableHelper(indexExpr, mt, exprType, indexType); + } else if (auto cls = get(exprType)) { if (cls->indexer) @@ -1577,6 +1605,19 @@ struct TypeChecker2 reportError(OptionalValueAccess{exprType}, indexExpr->location); } } + else if (auto exprIntersection = get(exprType)) + { + for (TypeId part : exprIntersection) + { + (void)part; + } + } + else if (get(exprType) || isErrorSuppressing(indexExpr->location, exprType)) + { + // Nothing + } + else + reportError(NotATable{exprType}, indexExpr->location); } void visit(AstExprFunction* fn) @@ -1589,7 +1630,6 @@ struct TypeChecker2 functionDeclStack.push_back(inferredFnTy); std::shared_ptr normalizedFnTy = normalizer.normalize(inferredFnTy); - const FunctionType* inferredFtv = get(normalizedFnTy->functions.parts.front()); if (!normalizedFnTy) { reportError(CodeTooComplex{}, fn->location); @@ -1684,16 +1724,23 @@ struct TypeChecker2 if (fn->returnAnnotation) visit(*fn->returnAnnotation); + // If the function type has a family annotation, we need to see if we can suggest an annotation - TypeFamilyReductionGuesser guesser{NotNull{&module->internalTypes}, builtinTypes, NotNull{&normalizer}}; - for (TypeId retTy : inferredFtv->retTypes) + if (normalizedFnTy) { - if (get(follow(retTy))) + const FunctionType* inferredFtv = get(normalizedFnTy->functions.parts.front()); + LUAU_ASSERT(inferredFtv); + + TypeFamilyReductionGuesser guesser{NotNull{&module->internalTypes}, builtinTypes, NotNull{&normalizer}}; + for (TypeId retTy : inferredFtv->retTypes) { - TypeFamilyReductionGuessResult result = guesser.guessTypeFamilyReductionForFunction(*fn, inferredFtv, retTy); - if (result.shouldRecommendAnnotation) - reportError( - ExplicitFunctionAnnotationRecommended{std::move(result.guessedFunctionAnnotations), result.guessedReturnType}, fn->location); + if (get(follow(retTy))) + { + TypeFamilyReductionGuessResult result = guesser.guessTypeFamilyReductionForFunction(*fn, inferredFtv, retTy); + if (result.shouldRecommendAnnotation) + reportError(ExplicitFunctionAnnotationRecommended{std::move(result.guessedFunctionAnnotations), result.guessedReturnType}, + fn->location); + } } } @@ -1822,7 +1869,7 @@ struct TypeChecker2 bool isStringOperation = (normLeft ? normLeft->isSubtypeOfString() : isString(leftType)) && (normRight ? normRight->isSubtypeOfString() : isString(rightType)); - + leftType = follow(leftType); if (get(leftType) || get(leftType) || get(leftType)) return leftType; else if (get(rightType) || get(rightType) || get(rightType)) @@ -2091,24 +2138,39 @@ struct TypeChecker2 TypeId annotationType = lookupAnnotation(expr->annotation); TypeId computedType = lookupType(expr->expr); - // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (subtyping->isSubtype(annotationType, computedType).isSubtype) - return; - - if (subtyping->isSubtype(computedType, annotationType).isSubtype) - return; - switch (shouldSuppressErrors(NotNull{&normalizer}, computedType).orElse(shouldSuppressErrors(NotNull{&normalizer}, annotationType))) { case ErrorSuppression::Suppress: return; case ErrorSuppression::NormalizationFailed: reportError(NormalizationTooComplex{}, expr->location); + return; case ErrorSuppression::DoNotSuppress: break; } - reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); + switch (normalizer.isInhabited(computedType)) + { + case NormalizationResult::True: + break; + case NormalizationResult::False: + return; + case NormalizationResult::HitLimits: + reportError(NormalizationTooComplex{}, expr->location); + return; + } + + switch (normalizer.isIntersectionInhabited(computedType, annotationType)) + { + case NormalizationResult::True: + return; + case NormalizationResult::False: + reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); + break; + case NormalizationResult::HitLimits: + reportError(NormalizationTooComplex{}, expr->location); + break; + } } void visit(AstExprIfElse* expr) @@ -2710,6 +2772,8 @@ struct TypeChecker2 fetch(builtinTypes->stringType); if (normValid) fetch(norm->threads); + if (normValid) + fetch(norm->buffers); if (normValid) { diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index a685c216..816cf005 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -11,12 +11,10 @@ #include "Luau/OverloadResolution.h" #include "Luau/Set.h" #include "Luau/Simplify.h" -#include "Luau/Substitution.h" #include "Luau/Subtyping.h" #include "Luau/ToString.h" #include "Luau/TxnLog.h" #include "Luau/Type.h" -#include "Luau/TypeCheckLimits.h" #include "Luau/TypeFamilyReductionGuesser.h" #include "Luau/TypeFwd.h" #include "Luau/TypeUtils.h" @@ -37,7 +35,7 @@ LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyApplicationCartesianProductLimit, 5'0 // when this value is set to a negative value, guessing will be totally disabled. LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyUseGuesserDepth, -1); -LUAU_FASTFLAG(DebugLuauLogSolver); +LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies, false); namespace Luau { @@ -182,9 +180,12 @@ struct FamilyReducer void replace(T subject, T replacement) { if (subject->owningArena != ctx.arena.get()) - ctx.ice->ice("Attempting to modify a type family instance from another arena", location); + { + result.errors.emplace_back(location, InternalError{"Attempting to modify a type family instance from another arena"}); + return; + } - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("%s -> %s\n", toString(subject, {true}).c_str(), toString(replacement, {true}).c_str()); asMutable(subject)->ty.template emplace>(replacement); @@ -206,7 +207,7 @@ struct FamilyReducer if (reduction.uninhabited || force) { - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("%s is uninhabited\n", toString(subject, {true}).c_str()); if constexpr (std::is_same_v) @@ -216,7 +217,7 @@ struct FamilyReducer } else if (!reduction.uninhabited && !force) { - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("%s is irreducible; blocked on %zu types, %zu packs\n", toString(subject, {true}).c_str(), reduction.blockedTypes.size(), reduction.blockedPacks.size()); @@ -243,7 +244,7 @@ struct FamilyReducer if (skip == SkipTestResult::Irreducible) { - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("%s is irreducible due to a dependency on %s\n", toString(subject, {true}).c_str(), toString(p, {true}).c_str()); irreducible.insert(subject); @@ -251,7 +252,7 @@ struct FamilyReducer } else if (skip == SkipTestResult::Defer) { - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("Deferring %s until %s is solved\n", toString(subject, {true}).c_str(), toString(p, {true}).c_str()); if constexpr (std::is_same_v) @@ -269,7 +270,7 @@ struct FamilyReducer if (skip == SkipTestResult::Irreducible) { - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("%s is irreducible due to a dependency on %s\n", toString(subject, {true}).c_str(), toString(p, {true}).c_str()); irreducible.insert(subject); @@ -277,7 +278,7 @@ struct FamilyReducer } else if (skip == SkipTestResult::Defer) { - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("Deferring %s until %s is solved\n", toString(subject, {true}).c_str(), toString(p, {true}).c_str()); if constexpr (std::is_same_v) @@ -297,7 +298,7 @@ struct FamilyReducer { if (shouldGuess.contains(subject)) { - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("Flagged %s for reduction with guesser.\n", toString(subject, {true}).c_str()); TypeFamilyReductionGuesser guesser{ctx.arena, ctx.builtins, ctx.normalizer}; @@ -305,14 +306,14 @@ struct FamilyReducer if (guessed) { - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("Selected %s as the guessed result type.\n", toString(*guessed, {true}).c_str()); replace(subject, *guessed); return true; } - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("Failed to produce a guess for the result of %s.\n", toString(subject, {true}).c_str()); } @@ -328,7 +329,7 @@ struct FamilyReducer if (irreducible.contains(subject)) return; - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("Trying to reduce %s\n", toString(subject, {true}).c_str()); if (const TypeFamilyInstanceType* tfit = get(subject)) @@ -337,7 +338,7 @@ struct FamilyReducer if (!testParameters(subject, tfit) && testCyclic != SkipTestResult::CyclicTypeFamily) { - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("Irreducible due to irreducible/pending and a non-cyclic family\n"); return; @@ -346,9 +347,7 @@ struct FamilyReducer if (tryGuessing(subject)) return; - TypeFamilyQueue queue{NotNull{&queuedTys}, NotNull{&queuedTps}}; - TypeFamilyReductionResult result = - tfit->family->reducer(subject, NotNull{&queue}, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); + TypeFamilyReductionResult result = tfit->family->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); handleFamilyReduction(subject, result); } } @@ -361,7 +360,7 @@ struct FamilyReducer if (irreducible.contains(subject)) return; - if (FFlag::DebugLuauLogSolver) + if (FFlag::DebugLuauLogTypeFamilies) printf("Trying to reduce %s\n", toString(subject, {true}).c_str()); if (const TypeFamilyInstanceTypePack* tfit = get(subject)) @@ -372,9 +371,7 @@ struct FamilyReducer if (tryGuessing(subject)) return; - TypeFamilyQueue queue{NotNull{&queuedTys}, NotNull{&queuedTps}}; - TypeFamilyReductionResult result = - tfit->family->reducer(subject, NotNull{&queue}, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); + TypeFamilyReductionResult result = tfit->family->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); handleFamilyReduction(subject, result); } } @@ -449,25 +446,90 @@ FamilyGraphReductionResult reduceFamilies(TypePackId entrypoint, Location locati std::move(collector.cyclicInstance), location, ctx, force); } -void TypeFamilyQueue::add(TypeId instanceTy) -{ - LUAU_ASSERT(get(instanceTy)); - queuedTys->push_back(instanceTy); -} - -void TypeFamilyQueue::add(TypePackId instanceTp) -{ - LUAU_ASSERT(get(instanceTp)); - queuedTps->push_back(instanceTp); -} - bool isPending(TypeId ty, ConstraintSolver* solver) { - return is(ty) || (solver && solver->hasUnresolvedConstraints(ty)); + return is(ty) || (solver && solver->hasUnresolvedConstraints(ty)); } -TypeFamilyReductionResult notFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +template +static std::optional> tryDistributeTypeFamilyApp(F f, TypeId instance, const std::vector& typeParams, + const std::vector& packParams, NotNull ctx, Args&&... args) +{ + // op (a | b) (c | d) ~ (op a (c | d)) | (op b (c | d)) ~ (op a c) | (op a d) | (op b c) | (op b d) + bool uninhabited = false; + std::vector blockedTypes; + std::vector results; + size_t cartesianProductSize = 1; + + const UnionType* firstUnion = nullptr; + size_t unionIndex; + + std::vector arguments = typeParams; + for (size_t i = 0; i < arguments.size(); ++i) + { + const UnionType* ut = get(follow(arguments[i])); + if (!ut) + continue; + + // We want to find the first union type in the set of arguments to distribute that one and only that one union. + // The function `f` we have is recursive, so `arguments[unionIndex]` will be updated in-place for each option in + // the union we've found in this context, so that index will no longer be a union type. Any other arguments at + // index + 1 or after will instead be distributed, if those are a union, which will be subjected to the same rules. + if (!firstUnion && ut) + { + firstUnion = ut; + unionIndex = i; + } + + cartesianProductSize *= std::distance(begin(ut), end(ut)); + + // TODO: We'd like to report that the type family application is too complex here. + if (size_t(DFInt::LuauTypeFamilyApplicationCartesianProductLimit) <= cartesianProductSize) + return {{std::nullopt, true, {}, {}}}; + } + + if (!firstUnion) + { + // If we couldn't find any union type argument, we're not distributing. + return std::nullopt; + } + + for (TypeId option : firstUnion) + { + arguments[unionIndex] = option; + + TypeFamilyReductionResult result = f(instance, arguments, packParams, ctx, args...); + blockedTypes.insert(blockedTypes.end(), result.blockedTypes.begin(), result.blockedTypes.end()); + uninhabited |= result.uninhabited; + + if (result.uninhabited || !result.result) + break; + else + results.push_back(*result.result); + } + + if (uninhabited || !blockedTypes.empty()) + return {{std::nullopt, uninhabited, blockedTypes, {}}}; + + if (!results.empty()) + { + if (results.size() == 1) + return {{results[0], false, {}, {}}}; + + TypeId resultTy = ctx->arena->addType(TypeFamilyInstanceType{ + NotNull{&builtinTypeFunctions().unionFamily}, + std::move(results), + {}, + }); + + return {{resultTy, false, {}, {}}}; + } + + return std::nullopt; +} + +TypeFamilyReductionResult notFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) { @@ -477,15 +539,21 @@ TypeFamilyReductionResult notFamilyFn(TypeId instance, NotNullbuiltins->neverType, false, {}, {}}; + if (isPending(ty, ctx->solver)) return {std::nullopt, false, {ty}, {}}; + if (auto result = tryDistributeTypeFamilyApp(notFamilyFn, instance, typeParams, packParams, ctx)) + return *result; + // `not` operates on anything and returns a `boolean` always. return {ctx->builtins->booleanType, false, {}, {}}; } -TypeFamilyReductionResult lenFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult lenFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) { @@ -495,11 +563,23 @@ TypeFamilyReductionResult lenFamilyFn(TypeId instance, NotNullbuiltins->neverType, false, {}, {}}; + // check to see if the operand type is resolved enough, and wait to reduce if not // the use of `typeFromNormal` later necessitates blocking on local types. - if (isPending(operandTy, ctx->solver) || get(operandTy)) + if (isPending(operandTy, ctx->solver)) return {std::nullopt, false, {operandTy}, {}}; + // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy, /* avoidSealingTables */ true); + if (!maybeGeneralized) + return {std::nullopt, false, {operandTy}, {}}; + operandTy = *maybeGeneralized; + } + std::shared_ptr normTy = ctx->normalizer->normalize(operandTy); NormalizationResult inhabited = ctx->normalizer->isInhabited(normTy.get()); @@ -524,6 +604,9 @@ TypeFamilyReductionResult lenFamilyFn(TypeId instance, NotNullhasTopTable() || get(normalizedOperand)) return {ctx->builtins->numberType, false, {}, {}}; + if (auto result = tryDistributeTypeFamilyApp(notFamilyFn, instance, typeParams, packParams, ctx)) + return *result; + // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. ErrorVec dummy; @@ -561,8 +644,8 @@ TypeFamilyReductionResult lenFamilyFn(TypeId instance, NotNullbuiltins->numberType, false, {}, {}}; } -TypeFamilyReductionResult unmFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult unmFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) { @@ -572,10 +655,22 @@ TypeFamilyReductionResult unmFamilyFn(TypeId instance, NotNullbuiltins->neverType, false, {}, {}}; + // check to see if the operand type is resolved enough, and wait to reduce if not if (isPending(operandTy, ctx->solver)) return {std::nullopt, false, {operandTy}, {}}; + // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy); + if (!maybeGeneralized) + return {std::nullopt, false, {operandTy}, {}}; + operandTy = *maybeGeneralized; + } + std::shared_ptr normTy = ctx->normalizer->normalize(operandTy); // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. @@ -594,6 +689,9 @@ TypeFamilyReductionResult unmFamilyFn(TypeId instance, NotNullisExactlyNumber()) return {ctx->builtins->numberType, false, {}, {}}; + if (auto result = tryDistributeTypeFamilyApp(notFamilyFn, instance, typeParams, packParams, ctx)) + return *result; + // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. ErrorVec dummy; @@ -646,7 +744,7 @@ NotNull TypeFamilyContext::pushConstraint(ConstraintV&& c) return newConstraint; } -TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx, const std::string metamethod) { if (typeParams.size() != 2 || !packParams.empty()) @@ -674,6 +772,21 @@ TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, NotNull< else if (isPending(rhsTy, ctx->solver)) return {std::nullopt, false, {rhsTy}, {}}; + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + // TODO: Normalization needs to remove cyclic type families from a `NormalizedType`. std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); std::shared_ptr normRhsTy = ctx->normalizer->normalize(rhsTy); @@ -690,67 +803,8 @@ TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, NotNull< if (normLhsTy->isExactlyNumber() && normRhsTy->isExactlyNumber()) return {ctx->builtins->numberType, false, {}, {}}; - // op (a | b) (c | d) ~ (op a (c | d)) | (op b (c | d)) ~ (op a c) | (op a d) | (op b c) | (op b d) - std::vector results; - bool uninhabited = false; - std::vector blockedTypes; - std::vector arguments = typeParams; - auto distributeFamilyApp = [&](const UnionType* ut, size_t argumentIndex) { - // Returning true here means we completed the loop without any problems. - for (TypeId option : ut) - { - arguments[argumentIndex] = option; - - TypeFamilyReductionResult result = numericBinopFamilyFn(instance, queue, arguments, packParams, ctx, metamethod); - blockedTypes.insert(blockedTypes.end(), result.blockedTypes.begin(), result.blockedTypes.end()); - uninhabited |= result.uninhabited; - - if (result.uninhabited) - return false; - else if (!result.result) - return false; - else - results.push_back(*result.result); - } - - return true; - }; - - const UnionType* lhsUnion = get(lhsTy); - const UnionType* rhsUnion = get(rhsTy); - if (lhsUnion || rhsUnion) - { - // TODO: We'd like to report that the type family application is too complex here. - size_t lhsUnionSize = lhsUnion ? std::distance(begin(lhsUnion), end(lhsUnion)) : 1; - size_t rhsUnionSize = rhsUnion ? std::distance(begin(rhsUnion), end(rhsUnion)) : 1; - if (size_t(DFInt::LuauTypeFamilyApplicationCartesianProductLimit) <= lhsUnionSize * rhsUnionSize) - return {std::nullopt, true, {}, {}}; - - if (lhsUnion && !distributeFamilyApp(lhsUnion, 0)) - return {std::nullopt, uninhabited, std::move(blockedTypes), {}}; - - if (rhsUnion && !distributeFamilyApp(rhsUnion, 1)) - return {std::nullopt, uninhabited, std::move(blockedTypes), {}}; - - if (results.empty()) - { - // If this happens, it means `distributeFamilyApp` has improperly returned `true` even - // though there exists no arm of the union that is inhabited or have a reduced type. - ctx->ice->ice("`distributeFamilyApp` failed to add any types to the results vector?"); - } - - if (results.size() == 1) - return {results[0], false, {}, {}}; - - TypeId resultTy = ctx->arena->addType(TypeFamilyInstanceType{ - NotNull{&kBuiltinTypeFamilies.unionFamily}, - std::move(results), - {}, - }); - - queue->add(resultTy); - return {resultTy, false, {}, {}}; - } + if (auto result = tryDistributeTypeFamilyApp(numericBinopFamilyFn, instance, typeParams, packParams, ctx, metamethod)) + return *result; // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. @@ -793,8 +847,8 @@ TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, NotNull< return {extracted.head.front(), false, {}, {}}; } -TypeFamilyReductionResult addFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult addFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -802,11 +856,11 @@ TypeFamilyReductionResult addFamilyFn(TypeId instance, NotNull subFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult subFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -814,11 +868,11 @@ TypeFamilyReductionResult subFamilyFn(TypeId instance, NotNull mulFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult mulFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -826,11 +880,11 @@ TypeFamilyReductionResult mulFamilyFn(TypeId instance, NotNull divFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult divFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -838,11 +892,11 @@ TypeFamilyReductionResult divFamilyFn(TypeId instance, NotNull idivFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult idivFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -850,11 +904,11 @@ TypeFamilyReductionResult idivFamilyFn(TypeId instance, NotNull powFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult powFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -862,11 +916,11 @@ TypeFamilyReductionResult powFamilyFn(TypeId instance, NotNull modFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult modFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -874,11 +928,11 @@ TypeFamilyReductionResult modFamilyFn(TypeId instance, NotNull concatFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult concatFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -889,12 +943,31 @@ TypeFamilyReductionResult concatFamilyFn(TypeId instance, NotNullbuiltins->neverType, false, {}, {}}; + // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) return {std::nullopt, false, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) return {std::nullopt, false, {rhsTy}, {}}; + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); std::shared_ptr normRhsTy = ctx->normalizer->normalize(rhsTy); @@ -914,6 +987,9 @@ TypeFamilyReductionResult concatFamilyFn(TypeId instance, NotNullisSubtypeOfString() || normLhsTy->isExactlyNumber()) && (normRhsTy->isSubtypeOfString() || normRhsTy->isExactlyNumber())) return {ctx->builtins->stringType, false, {}, {}}; + if (auto result = tryDistributeTypeFamilyApp(concatFamilyFn, instance, typeParams, packParams, ctx)) + return *result; + // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. ErrorVec dummy; @@ -963,8 +1039,8 @@ TypeFamilyReductionResult concatFamilyFn(TypeId instance, NotNullbuiltins->stringType, false, {}, {}}; } -TypeFamilyReductionResult andFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult andFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -982,13 +1058,27 @@ TypeFamilyReductionResult andFamilyFn(TypeId instance, NotNullsolver)) return {std::nullopt, false, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) return {std::nullopt, false, {rhsTy}, {}}; + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + // And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is truthy. SimplifyResult filteredLhs = simplifyIntersection(ctx->builtins, ctx->arena, lhsTy, ctx->builtins->falsyType); SimplifyResult overallResult = simplifyUnion(ctx->builtins, ctx->arena, rhsTy, filteredLhs.result); @@ -1000,8 +1090,8 @@ TypeFamilyReductionResult andFamilyFn(TypeId instance, NotNull orFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult orFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -1025,6 +1115,21 @@ TypeFamilyReductionResult orFamilyFn(TypeId instance, NotNullsolver)) return {std::nullopt, false, {rhsTy}, {}}; + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + // Or evalutes to the LHS type if the LHS is truthy, and the RHS type if LHS is falsy. SimplifyResult filteredLhs = simplifyIntersection(ctx->builtins, ctx->arena, lhsTy, ctx->builtins->truthyType); SimplifyResult overallResult = simplifyUnion(ctx->builtins, ctx->arena, rhsTy, filteredLhs.result); @@ -1036,7 +1141,7 @@ TypeFamilyReductionResult orFamilyFn(TypeId instance, NotNull comparisonFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, +static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx, const std::string metamethod) { @@ -1049,6 +1154,9 @@ static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, Not TypeId lhsTy = follow(typeParams.at(0)); TypeId rhsTy = follow(typeParams.at(1)); + if (lhsTy == instance || rhsTy == instance) + return {ctx->builtins->neverType, false, {}, {}}; + if (isPending(lhsTy, ctx->solver)) return {std::nullopt, false, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) @@ -1088,6 +1196,21 @@ static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, Not lhsTy = follow(lhsTy); rhsTy = follow(rhsTy); + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + // check to see if both operand types are resolved enough, and wait to reduce if not std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); @@ -1115,6 +1238,9 @@ static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, Not if (normLhsTy->isExactlyNumber() && normRhsTy->isExactlyNumber()) return {ctx->builtins->booleanType, false, {}, {}}; + if (auto result = tryDistributeTypeFamilyApp(comparisonFamilyFn, instance, typeParams, packParams, ctx, metamethod)) + return *result; + // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. ErrorVec dummy; @@ -1154,8 +1280,8 @@ static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, Not return {ctx->builtins->booleanType, false, {}, {}}; } -TypeFamilyReductionResult ltFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult ltFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -1163,11 +1289,11 @@ TypeFamilyReductionResult ltFamilyFn(TypeId instance, NotNull leFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult leFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -1175,11 +1301,11 @@ TypeFamilyReductionResult leFamilyFn(TypeId instance, NotNull eqFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult eqFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -1196,6 +1322,21 @@ TypeFamilyReductionResult eqFamilyFn(TypeId instance, NotNullsolver)) return {std::nullopt, false, {rhsTy}, {}}; + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); + std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); + + if (!lhsMaybeGeneralized) + return {std::nullopt, false, {lhsTy}, {}}; + else if (!rhsMaybeGeneralized) + return {std::nullopt, false, {rhsTy}, {}}; + + lhsTy = *lhsMaybeGeneralized; + rhsTy = *rhsMaybeGeneralized; + } + std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); std::shared_ptr normRhsTy = ctx->normalizer->normalize(rhsTy); NormalizationResult lhsInhabited = ctx->normalizer->isInhabited(normLhsTy.get()); @@ -1223,10 +1364,25 @@ TypeFamilyReductionResult eqFamilyFn(TypeId instance, NotNullnormalizer->isIntersectionInhabited(lhsTy, rhsTy); - if (!mmType && intersectInhabited == NormalizationResult::True) - return {ctx->builtins->booleanType, false, {}, {}}; // if it's inhabited, everything is okay! - else if (!mmType) + if (!mmType) + { + if (intersectInhabited == NormalizationResult::True) + return {ctx->builtins->booleanType, false, {}, {}}; // if it's inhabited, everything is okay! + + // we might be in a case where we still want to accept the comparison... + if (intersectInhabited == NormalizationResult::False) + { + // if they're both subtypes of `string` but have no common intersection, the comparison is allowed but always `false`. + if (normLhsTy->isSubtypeOfString() && normRhsTy->isSubtypeOfString()) + return {ctx->builtins->falseType, false, {}, {}}; + + // if they're both subtypes of `boolean` but have no common intersection, the comparison is allowed but always `false`. + if (normLhsTy->isSubtypeOfBooleans() && normRhsTy->isSubtypeOfBooleans()) + return {ctx->builtins->falseType, false, {}, {}}; + } + return {std::nullopt, true, {}, {}}; // if it's not, then this family is irreducible! + } mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) @@ -1272,12 +1428,6 @@ struct FindRefinementBlockers : TypeOnceVisitor return false; } - bool visit(TypeId ty, const LocalType&) override - { - found.insert(ty); - return false; - } - bool visit(TypeId ty, const ClassType&) override { return false; @@ -1285,8 +1435,8 @@ struct FindRefinementBlockers : TypeOnceVisitor }; -TypeFamilyReductionResult refineFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult refineFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -1303,6 +1453,21 @@ TypeFamilyReductionResult refineFamilyFn(TypeId instance, NotNullsolver)) return {std::nullopt, false, {discriminantTy}, {}}; + // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional targetMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, targetTy); + std::optional discriminantMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, discriminantTy); + + if (!targetMaybeGeneralized) + return {std::nullopt, false, {targetTy}, {}}; + else if (!discriminantMaybeGeneralized) + return {std::nullopt, false, {discriminantTy}, {}}; + + targetTy = *targetMaybeGeneralized; + discriminantTy = *discriminantMaybeGeneralized; + } + // we need a more complex check for blocking on the discriminant in particular FindRefinementBlockers frb; frb.traverse(discriminantTy); @@ -1326,6 +1491,18 @@ TypeFamilyReductionResult refineFamilyFn(TypeId instance, NotNull(follow(nt->ty))) return {targetTy, false, {}, {}}; + // If the target type is a table, then simplification already implements the logic to deal with refinements properly since the + // type of the discriminant is guaranteed to only ever be an (arbitrarily-nested) table of a single property type. + if (get(targetTy)) + { + SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, targetTy, discriminantTy); + if (!result.blockedTypes.empty()) + return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + + return {result.result, false, {}, {}}; + } + + // In the general case, we'll still use normalization though. TypeId intersection = ctx->arena->addType(IntersectionType{{targetTy, discriminantTy}}); std::shared_ptr normIntersection = ctx->normalizer->normalize(intersection); std::shared_ptr normType = ctx->normalizer->normalize(targetTy); @@ -1343,8 +1520,8 @@ TypeFamilyReductionResult refineFamilyFn(TypeId instance, NotNull singletonFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult singletonFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) { @@ -1358,10 +1535,19 @@ TypeFamilyReductionResult singletonFamilyFn(TypeId instance, NotNullsolver)) return {std::nullopt, false, {type}, {}}; + // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. + if (ctx->solver) + { + std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, type); + if (!maybeGeneralized) + return {std::nullopt, false, {type}, {}}; + type = *maybeGeneralized; + } + TypeId followed = type; // we want to follow through a negation here as well. if (auto negation = get(followed)) - followed = follow(negation->ty); + followed = follow(negation->ty); // if we have a singleton type or `nil`, which is its own singleton type... if (get(followed) || isNil(followed)) @@ -1371,8 +1557,8 @@ TypeFamilyReductionResult singletonFamilyFn(TypeId instance, NotNullbuiltins->unknownType, false, {}, {}}; } -TypeFamilyReductionResult unionFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult unionFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (!packParams.empty()) { @@ -1432,8 +1618,8 @@ TypeFamilyReductionResult unionFamilyFn(TypeId instance, NotNull intersectFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult intersectFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (!packParams.empty()) { @@ -1656,8 +1842,8 @@ TypeFamilyReductionResult keyofFamilyImpl( return {ctx->arena->addType(UnionType{singletons}), false, {}, {}}; } -TypeFamilyReductionResult keyofFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult keyofFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) { @@ -1668,8 +1854,8 @@ TypeFamilyReductionResult keyofFamilyFn(TypeId instance, NotNull rawkeyofFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, - const std::vector& packParams, NotNull ctx) +TypeFamilyReductionResult rawkeyofFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) { @@ -1680,6 +1866,228 @@ TypeFamilyReductionResult rawkeyofFamilyFn(TypeId instance, NotNull tblIndexer, DenseHashSet& result, NotNull ctx) +{ + ty = follow(ty); + + // index into tbl's properties + if (auto stringSingleton = get(get(ty))) + { + if (tblProps.find(stringSingleton->value) != tblProps.end()) + { + TypeId propTy = follow(tblProps.at(stringSingleton->value).type()); + + // property is a union type -> we need to extend our reduction type + if (auto propUnionTy = get(propTy)) + { + for (TypeId option : propUnionTy->options) + result.insert(option); + } + else // property is a singular type or intersection type -> we can simply append + result.insert(propTy); + + return true; + } + } + + // index into tbl's indexer + if (tblIndexer) + { + if (isSubtype(ty, tblIndexer->indexType, ctx->scope, ctx->builtins, *ctx->ice)) + { + TypeId idxResultTy = follow(tblIndexer->indexResultType); + + // indexResultType is a union type -> we need to extend our reduction type + if (auto idxResUnionTy = get(idxResultTy)) + { + for (TypeId option : idxResUnionTy->options) + result.insert(option); + } + else // indexResultType is a singular type or intersection type -> we can simply append + result.insert(idxResultTy); + + return true; + } + } + + return false; +} + +/* Handles recursion / metamethods of tables/classes + `isRaw` parameter indicates whether or not we should follow __index metamethods + returns false if property of `ty` could not be found */ +bool tblIndexInto(TypeId indexer, TypeId indexee, DenseHashSet& result, NotNull ctx, bool isRaw) +{ + indexer = follow(indexer); + indexee = follow(indexee); + + // we have a table type to try indexing + if (auto tableTy = get(indexee)) + { + return searchPropsAndIndexer(indexer, tableTy->props, tableTy->indexer, result, ctx); + } + + // we have a metatable type to try indexing + if (auto metatableTy = get(indexee)) + { + if (auto tableTy = get(metatableTy->table)) + { + + // try finding all properties within the current scope of the table + if (searchPropsAndIndexer(indexer, tableTy->props, tableTy->indexer, result, ctx)) + return true; + } + + // if the code reached here, it means we weren't able to find all properties -> look into __index metamethod + if (!isRaw) + { + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, indexee, "__index", Location{}); + if (mmType) + return tblIndexInto(indexer, *mmType, result, ctx, isRaw); + } + } + + return false; +} + +/* Vocabulary note: indexee refers to the type that contains the properties, + indexer refers to the type that is used to access indexee + Example: index => `Person` is the indexee and `"name"` is the indexer */ +TypeFamilyReductionResult indexFamilyImpl( + const std::vector& typeParams, const std::vector& packParams, NotNull ctx, bool isRaw) +{ + TypeId indexeeTy = follow(typeParams.at(0)); + std::shared_ptr indexeeNormTy = ctx->normalizer->normalize(indexeeTy); + + // if the indexee failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!indexeeNormTy) + return {std::nullopt, false, {}, {}}; + + // if we don't have either just tables or just classes, we've got nothing to index into + if (indexeeNormTy->hasTables() == indexeeNormTy->hasClasses()) + return {std::nullopt, true, {}, {}}; + + // we're trying to reject any type that has not normalized to a table/class or a union of tables/classes. + if (indexeeNormTy->hasTops() || indexeeNormTy->hasBooleans() || indexeeNormTy->hasErrors() || indexeeNormTy->hasNils() || + indexeeNormTy->hasNumbers() || indexeeNormTy->hasStrings() || indexeeNormTy->hasThreads() || indexeeNormTy->hasBuffers() || + indexeeNormTy->hasFunctions() || indexeeNormTy->hasTyvars()) + return {std::nullopt, true, {}, {}}; + + TypeId indexerTy = follow(typeParams.at(1)); + std::shared_ptr indexerNormTy = ctx->normalizer->normalize(indexerTy); + + // if the indexer failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!indexerNormTy) + return {std::nullopt, false, {}, {}}; + + // we're trying to reject any type that is not a string singleton or primitive (string, number, boolean, thread, nil, function, table, or buffer) + if (indexerNormTy->hasTops() || indexerNormTy->hasErrors()) + return {std::nullopt, true, {}, {}}; + + // indexer can be a union —> break them down into a vector + const std::vector* typesToFind; + const std::vector singleType{indexerTy}; + if (auto unionTy = get(indexerTy)) + typesToFind = &unionTy->options; + else + typesToFind = &singleType; + + DenseHashSet properties{{}}; // vector of types that will be returned + + if (indexeeNormTy->hasClasses()) + { + LUAU_ASSERT(!indexeeNormTy->hasTables()); + + if (isRaw) // rawget should never reduce for classes (to match the behavior of the rawget global function) + return {std::nullopt, true, {}, {}}; + + // at least one class is guaranteed to be in the iterator by .hasClasses() + for (auto classesIter = indexeeNormTy->classes.ordering.begin(); classesIter != indexeeNormTy->classes.ordering.end(); ++classesIter) + { + auto classTy = get(*classesIter); + if (!classTy) + { + LUAU_ASSERT(false); // this should not be possible according to normalization's spec + return {std::nullopt, true, {}, {}}; + } + + for (TypeId ty : *typesToFind) + { + // Search for all instances of indexer in class->props and class->indexer + if (searchPropsAndIndexer(ty, classTy->props, classTy->indexer, properties, ctx)) + continue; // Indexer was found in this class, so we can move on to the next + + // If code reaches here,that means the property not found -> check in the metatable's __index + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + std::optional mmType = findMetatableEntry(ctx->builtins, dummy, *classesIter, "__index", Location{}); + if (!mmType) // if a metatable does not exist, there is no where else to look + return {std::nullopt, true, {}, {}}; + + if (!tblIndexInto(ty, *mmType, properties, ctx, isRaw)) // if indexer is not in the metatable, we fail to reduce + return {std::nullopt, true, {}, {}}; + } + } + } + + if (indexeeNormTy->hasTables()) + { + LUAU_ASSERT(!indexeeNormTy->hasClasses()); + + // at least one table is guaranteed to be in the iterator by .hasTables() + for (auto tablesIter = indexeeNormTy->tables.begin(); tablesIter != indexeeNormTy->tables.end(); ++tablesIter) + { + for (TypeId ty : *typesToFind) + if (!tblIndexInto(ty, *tablesIter, properties, ctx, isRaw)) + return {std::nullopt, true, {}, {}}; + } + } + + // Call `follow()` on each element to resolve all Bound types before returning + std::transform(properties.begin(), properties.end(), properties.begin(), [](TypeId ty) { + return follow(ty); + }); + + // If the type being reduced to is a single type, no need to union + if (properties.size() == 1) + return {*properties.begin(), false, {}, {}}; + + return {ctx->arena->addType(UnionType{std::vector(properties.begin(), properties.end())}), false, {}, {}}; +} + +TypeFamilyReductionResult indexFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("index type family: encountered a type family instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return indexFamilyImpl(typeParams, packParams, ctx, /* isRaw */ false); +} + +TypeFamilyReductionResult rawgetFamilyFn( + TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("rawget type family: encountered a type family instance without the required argument structure"); + LUAU_ASSERT(false); + } + + return indexFamilyImpl(typeParams, packParams, ctx, /* isRaw */ true); +} + BuiltinTypeFamilies::BuiltinTypeFamilies() : notFamily{"not", notFamilyFn} , lenFamily{"len", lenFamilyFn} @@ -1703,6 +2111,8 @@ BuiltinTypeFamilies::BuiltinTypeFamilies() , intersectFamily{"intersect", intersectFamilyFn} , keyofFamily{"keyof", keyofFamilyFn} , rawkeyofFamily{"rawkeyof", rawkeyofFamilyFn} + , indexFamily{"index", indexFamilyFn} + , rawgetFamily{"rawget", rawgetFamilyFn} { } @@ -1744,6 +2154,16 @@ void BuiltinTypeFamilies::addToScope(NotNull arena, NotNull sc scope->exportedTypeBindings[keyofFamily.name] = mkUnaryTypeFamily(&keyofFamily); scope->exportedTypeBindings[rawkeyofFamily.name] = mkUnaryTypeFamily(&rawkeyofFamily); + + scope->exportedTypeBindings[indexFamily.name] = mkBinaryTypeFamily(&indexFamily); + scope->exportedTypeBindings[rawgetFamily.name] = mkBinaryTypeFamily(&rawgetFamily); +} + +const BuiltinTypeFamilies& builtinTypeFunctions() +{ + static std::unique_ptr result = std::make_unique(); + + return *result; } } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 52cc927c..00d683dd 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -33,13 +33,12 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAGVARIABLE(LuauMetatableInstantiationCloneCheck, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false) -LUAU_FASTFLAGVARIABLE(LuauForbidAliasNamedTypeof, false) LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false) -LUAU_FASTFLAG(LuauFixNormalizeCaching) +LUAU_FASTFLAGVARIABLE(LuauReusableSubstitutions, false) +LUAU_FASTFLAG(LuauDeclarationExtraPropData) namespace Luau { @@ -216,6 +215,7 @@ TypeChecker::TypeChecker(const ScopePtr& globalScope, ModuleResolver* resolver, , iceHandler(iceHandler) , unifierState(iceHandler) , normalizer(nullptr, builtinTypes, NotNull{&unifierState}) + , reusableInstantiation(TxnLog::empty(), nullptr, builtinTypes, {}, nullptr) , nilType(builtinTypes->nilType) , numberType(builtinTypes->numberType) , stringType(builtinTypes->stringType) @@ -668,7 +668,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std { if (const auto& typealias = stat->as()) { - if (typealias->name == kParseNameError || (FFlag::LuauForbidAliasNamedTypeof && typealias->name == "typeof")) + if (typealias->name == kParseNameError || typealias->name == "typeof") continue; auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings; @@ -1536,7 +1536,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& ty if (name == kParseNameError) return ControlFlow::None; - if (FFlag::LuauForbidAliasNamedTypeof && name == "typeof") + if (name == "typeof") { reportError(typealias.location, GenericError{"Type aliases cannot be named typeof"}); return ControlFlow::None; @@ -1657,7 +1657,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea // If the alias is missing a name, we can't do anything with it. Ignore it. // Also, typeof is not a valid type alias name. We will report an error for // this in check() - if (name == kParseNameError || (FFlag::LuauForbidAliasNamedTypeof && name == "typeof")) + if (name == kParseNameError || name == "typeof") return; std::optional binding; @@ -1784,12 +1784,55 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); ftv->hasSelf = true; + + if (FFlag::LuauDeclarationExtraPropData) + { + FunctionDefinition defn; + + defn.definitionModuleName = currentModule->name; + defn.definitionLocation = prop.location; + // No data is preserved for varargLocation + defn.originalNameLocation = prop.nameLocation; + + ftv->definition = defn; + } } } if (assignTo.count(propName) == 0) { - assignTo[propName] = {propTy}; + if (FFlag::LuauDeclarationExtraPropData) + assignTo[propName] = {propTy, /*deprecated*/ false, /*deprecatedSuggestion*/ "", prop.location}; + else + assignTo[propName] = {propTy}; + } + else if (FFlag::LuauDeclarationExtraPropData) + { + Luau::Property& prop = assignTo[propName]; + TypeId currentTy = prop.type(); + + // We special-case this logic to keep the intersection flat; otherwise we + // would create a ton of nested intersection types. + if (const IntersectionType* itv = get(currentTy)) + { + std::vector options = itv->parts; + options.push_back(propTy); + TypeId newItv = addType(IntersectionType{std::move(options)}); + + prop.readTy = newItv; + prop.writeTy = newItv; + } + else if (get(currentTy)) + { + TypeId intersection = addType(IntersectionType{{currentTy, propTy}}); + + prop.readTy = intersection; + prop.writeTy = intersection; + } + else + { + reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + } } else { @@ -1841,7 +1884,18 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFuncti TypePackId argPack = resolveTypePack(funScope, global.params); TypePackId retPack = resolveTypePack(funScope, global.retTypes); - TypeId fnType = addType(FunctionType{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack}); + + FunctionDefinition defn; + + if (FFlag::LuauDeclarationExtraPropData) + { + defn.definitionModuleName = currentModule->name; + defn.definitionLocation = global.location; + defn.varargLocation = global.vararg ? std::make_optional(global.varargLocation) : std::nullopt; + defn.originalNameLocation = global.nameLocation; + } + + TypeId fnType = addType(FunctionType{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack, defn}); FunctionType* ftv = getMutable(fnType); ftv->argNames.reserve(global.paramNames.size); @@ -2649,24 +2703,12 @@ static std::optional areEqComparable(NotNull arena, NotNulladdType(IntersectionType{{a, b}}); - std::shared_ptr n = normalizer->normalize(c); - if (!n) - return std::nullopt; + TypeId c = arena->addType(IntersectionType{{a, b}}); + std::shared_ptr n = normalizer->normalize(c); + if (!n) + return std::nullopt; - nr = normalizer->isInhabited(n.get()); - } - else - { - TypeId c = arena->addType(IntersectionType{{a, b}}); - const NormalizedType* n = normalizer->DEPRECATED_normalize(c); - if (!n) - return std::nullopt; - - nr = normalizer->isInhabited(n); - } + nr = normalizer->isInhabited(n.get()); switch (nr) { @@ -4879,12 +4921,27 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat if (ftv && ftv->hasNoFreeOrGenericTypes) return ty; - Instantiation instantiation{log, ¤tModule->internalTypes, builtinTypes, scope->level, /*scope*/ nullptr}; + std::optional instantiated; - if (instantiationChildLimit) - instantiation.childLimit = *instantiationChildLimit; + if (FFlag::LuauReusableSubstitutions) + { + reusableInstantiation.resetState(log, ¤tModule->internalTypes, builtinTypes, scope->level, /*scope*/ nullptr); + + if (instantiationChildLimit) + reusableInstantiation.childLimit = *instantiationChildLimit; + + instantiated = reusableInstantiation.substitute(ty); + } + else + { + Instantiation instantiation{log, ¤tModule->internalTypes, builtinTypes, scope->level, /*scope*/ nullptr}; + + if (instantiationChildLimit) + instantiation.childLimit = *instantiationChildLimit; + + instantiated = instantiation.substitute(ty); + } - std::optional instantiated = instantiation.substitute(ty); if (instantiated.has_value()) return *instantiated; else @@ -5633,8 +5690,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, TypeId instantiated = *maybeInstantiated; TypeId target = follow(instantiated); - const TableType* tfTable = FFlag::LuauMetatableInstantiationCloneCheck ? getTableType(tf.type) : nullptr; - bool needsClone = follow(tf.type) == target || (FFlag::LuauMetatableInstantiationCloneCheck && tfTable != nullptr && tfTable == getTableType(target)); + const TableType* tfTable = getTableType(tf.type); + bool needsClone = follow(tf.type) == target || (tfTable != nullptr && tfTable == getTableType(target)); bool shouldMutate = getTableType(tf.type); TableType* ttv = getMutableTableType(target); diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 588b1da1..c2512ddc 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -38,6 +38,59 @@ bool occursCheck(TypeId needle, TypeId haystack) return false; } +// FIXME: Property is quite large. +// +// Returning it on the stack like this isn't great. We'd like to just return a +// const Property*, but we mint a property of type any if the subject type is +// any. +std::optional findTableProperty(NotNull builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location) +{ + if (get(ty)) + return Property::rw(ty); + + if (const TableType* tableType = getTableType(ty)) + { + const auto& it = tableType->props.find(name); + if (it != tableType->props.end()) + return it->second; + } + + std::optional mtIndex = findMetatableEntry(builtinTypes, errors, ty, "__index", location); + int count = 0; + while (mtIndex) + { + TypeId index = follow(*mtIndex); + + if (count >= 100) + return std::nullopt; + + ++count; + + if (const auto& itt = getTableType(index)) + { + const auto& fit = itt->props.find(name); + if (fit != itt->props.end()) + return fit->second.type(); + } + else if (const auto& itf = get(index)) + { + std::optional r = first(follow(itf->retTypes)); + if (!r) + return builtinTypes->nilType; + else + return *r; + } + else if (get(index)) + return builtinTypes->anyType; + else + errors.push_back(TypeError{location, GenericError{"__index should either be a function or table. Got " + toString(index)}}); + + mtIndex = findMetatableEntry(builtinTypes, errors, *mtIndex, "__index", location); + } + + return std::nullopt; +} + std::optional findMetatableEntry( NotNull builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location) { diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 3dc274a9..1802345d 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -23,7 +23,7 @@ LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering, false) LUAU_FASTFLAGVARIABLE(LuauUnifierShouldNotCopyError, false) -LUAU_FASTFLAG(LuauFixNormalizeCaching) +LUAU_FASTFLAGVARIABLE(LuauUnifierRecursionOnRestart, false) namespace Luau { @@ -580,28 +580,14 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { if (normalize) { - if (FFlag::LuauFixNormalizeCaching) - { - // TODO: there are probably cheaper ways to check if any <: T. - std::shared_ptr superNorm = normalizer->normalize(superTy); + // TODO: there are probably cheaper ways to check if any <: T. + std::shared_ptr superNorm = normalizer->normalize(superTy); - if (!superNorm) - return reportError(location, NormalizationTooComplex{}); + if (!superNorm) + return reportError(location, NormalizationTooComplex{}); - if (!log.get(superNorm->tops)) - failure = true; - } - else - { - // TODO: there are probably cheaper ways to check if any <: T. - const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy); - - if (!superNorm) - return reportError(location, NormalizationTooComplex{}); - - if (!log.get(superNorm->tops)) - failure = true; - } + if (!log.get(superNorm->tops)) + failure = true; } else failure = true; @@ -962,30 +948,15 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp // We deal with this by type normalization. Unifier innerState = makeChildUnifier(); - if (FFlag::LuauFixNormalizeCaching) - { - std::shared_ptr subNorm = normalizer->normalize(subTy); - std::shared_ptr superNorm = normalizer->normalize(superTy); - if (!subNorm || !superNorm) - return reportError(location, NormalizationTooComplex{}); - else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - innerState.tryUnifyNormalizedTypes( - subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); - else - innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); - } + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); + if (!subNorm || !superNorm) + return reportError(location, NormalizationTooComplex{}); + else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + innerState.tryUnifyNormalizedTypes( + subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); else - { - const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy); - const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy); - if (!subNorm || !superNorm) - return reportError(location, NormalizationTooComplex{}); - else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - innerState.tryUnifyNormalizedTypes( - subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); - else - innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); - } + innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); if (!innerState.failure) log.concat(std::move(innerState.log)); @@ -999,30 +970,14 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp // It is possible that T <: A | B even though T subNorm = normalizer->normalize(subTy); - std::shared_ptr superNorm = normalizer->normalize(superTy); - if (!subNorm || !superNorm) - reportError(location, NormalizationTooComplex{}); - else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - tryUnifyNormalizedTypes( - subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); - else - tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); - } + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); + if (!subNorm || !superNorm) + reportError(location, NormalizationTooComplex{}); + else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); else - { - const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy); - const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy); - if (!subNorm || !superNorm) - reportError(location, NormalizationTooComplex{}); - else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - tryUnifyNormalizedTypes( - subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); - else - tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); - } + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); } else if (!found) { @@ -1125,24 +1080,12 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* // It is possible that A & B <: T even though A subNorm = normalizer->normalize(subTy); - std::shared_ptr superNorm = normalizer->normalize(superTy); - if (subNorm && superNorm) - tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); - else - reportError(location, NormalizationTooComplex{}); - } + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); + if (subNorm && superNorm) + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); else - { - const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy); - const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy); - if (subNorm && superNorm) - tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); - else - reportError(location, NormalizationTooComplex{}); - } + reportError(location, NormalizationTooComplex{}); return; } @@ -1192,24 +1135,12 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* // for example string? & number? <: nil. // We deal with this by type normalization. - if (FFlag::LuauFixNormalizeCaching) - { - std::shared_ptr subNorm = normalizer->normalize(subTy); - std::shared_ptr superNorm = normalizer->normalize(superTy); - if (subNorm && superNorm) - tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); - else - reportError(location, NormalizationTooComplex{}); - } + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); + if (subNorm && superNorm) + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); else - { - const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy); - const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy); - if (subNorm && superNorm) - tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); - else - reportError(location, NormalizationTooComplex{}); - } + reportError(location, NormalizationTooComplex{}); } else if (!found) { @@ -2249,7 +2180,18 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, // If one of the types stopped being a table altogether, we need to restart from the top if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) - return tryUnify(subTy, superTy, false, isIntersection); + { + if (FFlag::LuauUnifierRecursionOnRestart) + { + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + tryUnify(subTy, superTy, false, isIntersection); + return; + } + else + { + return tryUnify(subTy, superTy, false, isIntersection); + } + } // Otherwise, restart only the table unification TableType* newSuperTable = log.getMutable(superTyNew); @@ -2328,7 +2270,18 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, // If one of the types stopped being a table altogether, we need to restart from the top if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) - return tryUnify(subTy, superTy, false, isIntersection); + { + if (FFlag::LuauUnifierRecursionOnRestart) + { + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + tryUnify(subTy, superTy, false, isIntersection); + return; + } + else + { + return tryUnify(subTy, superTy, false, isIntersection); + } + } // Recursive unification can change the txn log, and invalidate the old // table. If we detect that this has happened, we start over, with the updated @@ -2712,32 +2665,16 @@ void Unifier::tryUnifyNegations(TypeId subTy, TypeId superTy) if (!log.get(subTy) && !log.get(superTy)) ice("tryUnifyNegations superTy or subTy must be a negation type"); - if (FFlag::LuauFixNormalizeCaching) - { - std::shared_ptr subNorm = normalizer->normalize(subTy); - std::shared_ptr superNorm = normalizer->normalize(superTy); - if (!subNorm || !superNorm) - return reportError(location, NormalizationTooComplex{}); + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); + if (!subNorm || !superNorm) + return reportError(location, NormalizationTooComplex{}); - // T DEPRECATED_normalize(subTy); - const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy); - if (!subNorm || !superNorm) - return reportError(location, NormalizationTooComplex{}); - - // T & queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) diff --git a/Analysis/src/Unifier2.cpp b/Analysis/src/Unifier2.cpp index 34fc6ee9..6dcd7197 100644 --- a/Analysis/src/Unifier2.cpp +++ b/Analysis/src/Unifier2.cpp @@ -158,12 +158,6 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy) if (subFree || superFree) return true; - if (auto subLocal = getMutable(subTy)) - { - subLocal->domain = mkUnion(subLocal->domain, superTy); - expandedFreeTypes[subTy].push_back(superTy); - } - auto subFn = get(subTy); auto superFn = get(superTy); if (subFn && superFn) @@ -204,25 +198,21 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy) auto subAny = get(subTy); auto superAny = get(superTy); - if (subAny && superAny) - return true; - else if (subAny && superFn) - { - // If `any` is the subtype, then we can propagate that inward. - bool argResult = unify(superFn->argTypes, builtinTypes->anyTypePack); - bool retResult = unify(builtinTypes->anyTypePack, superFn->retTypes); - return argResult && retResult; - } - else if (subFn && superAny) - { - // If `any` is the supertype, then we can propagate that inward. - bool argResult = unify(builtinTypes->anyTypePack, subFn->argTypes); - bool retResult = unify(subFn->retTypes, builtinTypes->anyTypePack); - return argResult && retResult; - } auto subTable = getMutable(subTy); auto superTable = get(superTy); + + if (subAny && superAny) + return true; + else if (subAny && superFn) + return unify(subAny, superFn); + else if (subFn && superAny) + return unify(subFn, superAny); + else if (subAny && superTable) + return unify(subAny, superTable); + else if (subTable && superAny) + return unify(subTable, superAny); + if (subTable && superTable) { // `boundTo` works like a bound type, and therefore we'd replace it @@ -451,7 +441,16 @@ bool Unifier2::unify(TableType* subTable, const TableType* superTable) * an indexer, we therefore conclude that the unsealed table has the * same indexer. */ - subTable->indexer = *superTable->indexer; + + TypeId indexType = superTable->indexer->indexType; + if (TypeId* subst = genericSubstitutions.find(indexType)) + indexType = *subst; + + TypeId indexResultType = superTable->indexer->indexResultType; + if (TypeId* subst = genericSubstitutions.find(indexResultType)) + indexResultType = *subst; + + subTable->indexer = TableIndexer{indexType, indexResultType}; } return result; @@ -462,6 +461,62 @@ bool Unifier2::unify(const MetatableType* subMetatable, const MetatableType* sup return unify(subMetatable->metatable, superMetatable->metatable) && unify(subMetatable->table, superMetatable->table); } +bool Unifier2::unify(const AnyType* subAny, const FunctionType* superFn) +{ + // If `any` is the subtype, then we can propagate that inward. + bool argResult = unify(superFn->argTypes, builtinTypes->anyTypePack); + bool retResult = unify(builtinTypes->anyTypePack, superFn->retTypes); + return argResult && retResult; +} + +bool Unifier2::unify(const FunctionType* subFn, const AnyType* superAny) +{ + // If `any` is the supertype, then we can propagate that inward. + bool argResult = unify(builtinTypes->anyTypePack, subFn->argTypes); + bool retResult = unify(subFn->retTypes, builtinTypes->anyTypePack); + return argResult && retResult; +} + +bool Unifier2::unify(const AnyType* subAny, const TableType* superTable) +{ + for (const auto& [propName, prop] : superTable->props) + { + if (prop.readTy) + unify(builtinTypes->anyType, *prop.readTy); + + if (prop.writeTy) + unify(*prop.writeTy, builtinTypes->anyType); + } + + if (superTable->indexer) + { + unify(builtinTypes->anyType, superTable->indexer->indexType); + unify(builtinTypes->anyType, superTable->indexer->indexResultType); + } + + return true; +} + +bool Unifier2::unify(const TableType* subTable, const AnyType* superAny) +{ + for (const auto& [propName, prop] : subTable->props) + { + if (prop.readTy) + unify(*prop.readTy, builtinTypes->anyType); + + if (prop.writeTy) + unify(builtinTypes->anyType, *prop.writeTy); + } + + if (subTable->indexer) + { + unify(subTable->indexer->indexType, builtinTypes->anyType); + unify(subTable->indexer->indexResultType, builtinTypes->anyType); + } + + return true; +} + // FIXME? This should probably return an ErrorVec or an optional // rather than a boolean to signal an occurs check failure. bool Unifier2::unify(TypePackId subTp, TypePackId superTp) @@ -596,6 +651,43 @@ struct FreeTypeSearcher : TypeVisitor } } + DenseHashSet seenPositive{nullptr}; + DenseHashSet seenNegative{nullptr}; + + bool seenWithPolarity(const void* ty) + { + switch (polarity) + { + case Positive: + { + if (seenPositive.contains(ty)) + return true; + + seenPositive.insert(ty); + return false; + } + case Negative: + { + if (seenNegative.contains(ty)) + return true; + + seenNegative.insert(ty); + return false; + } + case Both: + { + if (seenPositive.contains(ty) && seenNegative.contains(ty)) + return true; + + seenPositive.insert(ty); + seenNegative.insert(ty); + return false; + } + } + + return false; + } + // The keys in these maps are either TypeIds or TypePackIds. It's safe to // mix them because we only use these pointers as unique keys. We never // indirect them. @@ -604,12 +696,18 @@ struct FreeTypeSearcher : TypeVisitor bool visit(TypeId ty) override { + if (seenWithPolarity(ty)) + return false; + LUAU_ASSERT(ty); return true; } bool visit(TypeId ty, const FreeType& ft) override { + if (seenWithPolarity(ty)) + return false; + if (!subsumes(scope, ft.scope)) return true; @@ -632,6 +730,9 @@ struct FreeTypeSearcher : TypeVisitor bool visit(TypeId ty, const TableType& tt) override { + if (seenWithPolarity(ty)) + return false; + if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope)) { switch (polarity) @@ -675,6 +776,9 @@ struct FreeTypeSearcher : TypeVisitor bool visit(TypeId ty, const FunctionType& ft) override { + if (seenWithPolarity(ty)) + return false; + flip(); traverse(ft.argTypes); flip(); @@ -691,6 +795,9 @@ struct FreeTypeSearcher : TypeVisitor bool visit(TypePackId tp, const FreeTypePack& ftp) override { + if (seenWithPolarity(tp)) + return false; + if (!subsumes(scope, ftp.scope)) return true; @@ -712,315 +819,6 @@ struct FreeTypeSearcher : TypeVisitor } }; -struct MutatingGeneralizer : TypeOnceVisitor -{ - NotNull builtinTypes; - - NotNull scope; - DenseHashMap positiveTypes; - DenseHashMap negativeTypes; - std::vector generics; - std::vector genericPacks; - - bool isWithinFunction = false; - - MutatingGeneralizer(NotNull builtinTypes, NotNull scope, DenseHashMap positiveTypes, - DenseHashMap negativeTypes) - : TypeOnceVisitor(/* skipBoundTypes */ true) - , builtinTypes(builtinTypes) - , scope(scope) - , positiveTypes(std::move(positiveTypes)) - , negativeTypes(std::move(negativeTypes)) - { - } - - static void replace(DenseHashSet& seen, TypeId haystack, TypeId needle, TypeId replacement) - { - haystack = follow(haystack); - - if (seen.find(haystack)) - return; - seen.insert(haystack); - - if (UnionType* ut = getMutable(haystack)) - { - for (auto iter = ut->options.begin(); iter != ut->options.end();) - { - // FIXME: I bet this function has reentrancy problems - TypeId option = follow(*iter); - - if (option == needle && get(replacement)) - { - iter = ut->options.erase(iter); - continue; - } - - if (option == needle) - { - *iter = replacement; - iter++; - continue; - } - - // advance the iterator, nothing after this can use it. - iter++; - - if (seen.find(option)) - continue; - seen.insert(option); - - if (get(option)) - replace(seen, option, needle, haystack); - else if (get(option)) - replace(seen, option, needle, haystack); - } - - if (ut->options.size() == 1) - { - TypeId onlyType = ut->options[0]; - LUAU_ASSERT(onlyType != haystack); - emplaceType(asMutable(haystack), onlyType); - } - - return; - } - - if (IntersectionType* it = getMutable(needle)) - { - for (auto iter = it->parts.begin(); iter != it->parts.end();) - { - // FIXME: I bet this function has reentrancy problems - TypeId part = follow(*iter); - - if (part == needle && get(replacement)) - { - iter = it->parts.erase(iter); - continue; - } - - if (part == needle) - { - *iter = replacement; - iter++; - continue; - } - - // advance the iterator, nothing after this can use it. - iter++; - - if (seen.find(part)) - continue; - seen.insert(part); - - if (get(part)) - replace(seen, part, needle, haystack); - else if (get(part)) - replace(seen, part, needle, haystack); - } - - if (it->parts.size() == 1) - { - TypeId onlyType = it->parts[0]; - LUAU_ASSERT(onlyType != needle); - emplaceType(asMutable(needle), onlyType); - } - - return; - } - } - - bool visit(TypeId ty, const FunctionType& ft) override - { - const bool oldValue = isWithinFunction; - - isWithinFunction = true; - - traverse(ft.argTypes); - traverse(ft.retTypes); - - isWithinFunction = oldValue; - - return false; - } - - bool visit(TypeId ty, const FreeType&) override - { - const FreeType* ft = get(ty); - LUAU_ASSERT(ft); - - traverse(ft->lowerBound); - traverse(ft->upperBound); - - // It is possible for the above traverse() calls to cause ty to be - // transmuted. We must reacquire ft if this happens. - ty = follow(ty); - ft = get(ty); - if (!ft) - return false; - - const size_t positiveCount = getCount(positiveTypes, ty); - const size_t negativeCount = getCount(negativeTypes, ty); - - if (!positiveCount && !negativeCount) - return false; - - const bool hasLowerBound = !get(follow(ft->lowerBound)); - const bool hasUpperBound = !get(follow(ft->upperBound)); - - DenseHashSet seen{nullptr}; - seen.insert(ty); - - if (!hasLowerBound && !hasUpperBound) - { - if (!isWithinFunction || (positiveCount + negativeCount == 1)) - emplaceType(asMutable(ty), builtinTypes->unknownType); - else - { - emplaceType(asMutable(ty), scope); - generics.push_back(ty); - } - } - - // It is possible that this free type has other free types in its upper - // or lower bounds. If this is the case, we must replace those - // references with never (for the lower bound) or unknown (for the upper - // bound). - // - // If we do not do this, we get tautological bounds like a <: a <: unknown. - else if (positiveCount && !hasUpperBound) - { - TypeId lb = follow(ft->lowerBound); - if (FreeType* lowerFree = getMutable(lb); lowerFree && lowerFree->upperBound == ty) - lowerFree->upperBound = builtinTypes->unknownType; - else - { - DenseHashSet replaceSeen{nullptr}; - replace(replaceSeen, lb, ty, builtinTypes->unknownType); - } - - if (lb != ty) - emplaceType(asMutable(ty), lb); - else if (!isWithinFunction || (positiveCount + negativeCount == 1)) - emplaceType(asMutable(ty), builtinTypes->unknownType); - else - { - // if the lower bound is the type in question, we don't actually have a lower bound. - emplaceType(asMutable(ty), scope); - generics.push_back(ty); - } - } - else - { - TypeId ub = follow(ft->upperBound); - if (FreeType* upperFree = getMutable(ub); upperFree && upperFree->lowerBound == ty) - upperFree->lowerBound = builtinTypes->neverType; - else - { - DenseHashSet replaceSeen{nullptr}; - replace(replaceSeen, ub, ty, builtinTypes->neverType); - } - - if (ub != ty) - emplaceType(asMutable(ty), ub); - else if (!isWithinFunction || (positiveCount + negativeCount == 1)) - emplaceType(asMutable(ty), builtinTypes->unknownType); - else - { - // if the upper bound is the type in question, we don't actually have an upper bound. - emplaceType(asMutable(ty), scope); - generics.push_back(ty); - } - } - - return false; - } - - size_t getCount(const DenseHashMap& map, const void* ty) - { - if (const size_t* count = map.find(ty)) - return *count; - else - return 0; - } - - bool visit(TypeId ty, const TableType&) override - { - const size_t positiveCount = getCount(positiveTypes, ty); - const size_t negativeCount = getCount(negativeTypes, ty); - - // FIXME: Free tables should probably just be replaced by upper bounds on free types. - // - // eg never <: 'a <: {x: number} & {z: boolean} - - if (!positiveCount && !negativeCount) - return true; - - TableType* tt = getMutable(ty); - LUAU_ASSERT(tt); - - tt->state = TableState::Sealed; - - return true; - } - - bool visit(TypePackId tp, const FreeTypePack& ftp) override - { - if (!subsumes(scope, ftp.scope)) - return true; - - tp = follow(tp); - - const size_t positiveCount = getCount(positiveTypes, tp); - const size_t negativeCount = getCount(negativeTypes, tp); - - if (1 == positiveCount + negativeCount) - emplaceTypePack(asMutable(tp), builtinTypes->unknownTypePack); - else - { - emplaceTypePack(asMutable(tp), scope); - genericPacks.push_back(tp); - } - - return true; - } -}; - -std::optional Unifier2::generalize(TypeId ty) -{ - ty = follow(ty); - - if (ty->owningArena != arena || ty->persistent) - return ty; - - if (const FunctionType* ft = get(ty); ft && (!ft->generics.empty() || !ft->genericPacks.empty())) - return ty; - - FreeTypeSearcher fts{scope}; - fts.traverse(ty); - - MutatingGeneralizer gen{builtinTypes, scope, std::move(fts.positiveTypes), std::move(fts.negativeTypes)}; - - gen.traverse(ty); - - /* MutatingGeneralizer mutates types in place, so it is possible that ty has - * been transmuted to a BoundType. We must follow it again and verify that - * we are allowed to mutate it before we attach generics to it. - */ - ty = follow(ty); - - if (ty->owningArena != arena || ty->persistent) - return ty; - - FunctionType* ftv = getMutable(ty); - if (ftv) - { - ftv->generics = std::move(gen.generics); - ftv->genericPacks = std::move(gen.genericPacks); - } - - return ty; -} - TypeId Unifier2::mkUnion(TypeId left, TypeId right) { left = follow(left); diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 993116d6..e2ac8b7d 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -60,6 +60,8 @@ class AstStat; class AstStatBlock; class AstExpr; class AstTypePack; +class AstAttr; +class AstExprTable; struct AstLocal { @@ -172,6 +174,10 @@ public: { return nullptr; } + virtual AstAttr* asAttr() + { + return nullptr; + } template bool is() const @@ -193,6 +199,29 @@ public: Location location; }; +class AstAttr : public AstNode +{ +public: + LUAU_RTTI(AstAttr) + + enum Type + { + Checked, + Native, + }; + + AstAttr(const Location& location, Type type); + + AstAttr* asAttr() override + { + return this; + } + + void visit(AstVisitor* visitor) override; + + Type type; +}; + class AstExpr : public AstNode { public: @@ -384,13 +413,17 @@ class AstExprFunction : public AstExpr public: LUAU_RTTI(AstExprFunction) - AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, - const AstName& debugname, const std::optional& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, + AstExprFunction(const Location& location, const AstArray& attributes, const AstArray& generics, + const AstArray& genericPacks, AstLocal* self, const AstArray& args, bool vararg, + const Location& varargLocation, AstStatBlock* body, size_t functionDepth, const AstName& debugname, + const std::optional& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, const std::optional& argLocation = std::nullopt); void visit(AstVisitor* visitor) override; + bool hasNativeAttribute() const; + + AstArray attributes; AstArray generics; AstArray genericPacks; AstLocal* self; @@ -793,11 +826,12 @@ class AstStatDeclareGlobal : public AstStat public: LUAU_RTTI(AstStatDeclareGlobal) - AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type); + AstStatDeclareGlobal(const Location& location, const AstName& name, const Location& nameLocation, AstType* type); void visit(AstVisitor* visitor) override; AstName name; + Location nameLocation; AstType* type; }; @@ -806,31 +840,38 @@ class AstStatDeclareFunction : public AstStat public: LUAU_RTTI(AstStatDeclareFunction) - AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, - const AstTypeList& retTypes); + AstStatDeclareFunction(const Location& location, const AstName& name, const Location& nameLocation, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, bool vararg, + const Location& varargLocation, const AstTypeList& retTypes); - AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, - const AstTypeList& retTypes, bool checkedFunction); + AstStatDeclareFunction(const Location& location, const AstArray& attributes, const AstName& name, const Location& nameLocation, + const AstArray& generics, const AstArray& genericPacks, const AstTypeList& params, + const AstArray& paramNames, bool vararg, const Location& varargLocation, const AstTypeList& retTypes); void visit(AstVisitor* visitor) override; + bool isCheckedFunction() const; + + AstArray attributes; AstName name; + Location nameLocation; AstArray generics; AstArray genericPacks; AstTypeList params; AstArray paramNames; + bool vararg = false; + Location varargLocation; AstTypeList retTypes; - bool checkedFunction; }; struct AstDeclaredClassProp { AstName name; + Location nameLocation; AstType* ty = nullptr; bool isMethod = false; + Location location; }; enum class AstTableAccess @@ -936,17 +977,20 @@ public: AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes); - AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes, bool checkedFunction); + AstTypeFunction(const Location& location, const AstArray& attributes, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, + const AstTypeList& returnTypes); void visit(AstVisitor* visitor) override; + bool isCheckedFunction() const; + + AstArray attributes; AstArray generics; AstArray genericPacks; AstTypeList argTypes; AstArray> argNames; AstTypeList returnTypes; - bool checkedFunction; }; class AstTypeTypeof : public AstType @@ -1105,6 +1149,11 @@ public: return true; } + virtual bool visit(class AstAttr* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExpr* node) { return visit(static_cast(node)); diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index e111030d..f6ac28ad 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -87,6 +87,8 @@ struct Lexeme Comment, BlockComment, + Attribute, + BrokenString, BrokenComment, BrokenUnicode, @@ -115,14 +117,20 @@ struct Lexeme ReservedTrue, ReservedUntil, ReservedWhile, - ReservedChecked, Reserved_END }; Type type; Location location; + + // Field declared here, before the union, to ensure that Lexeme size is 32 bytes. +private: + // length is used to extract a slice from the input buffer. + // This field is only valid for certain lexeme types which don't duplicate portions of input + // but instead store a pointer to a location in the input buffer and the length of lexeme. unsigned int length; +public: union { const char* data; // String, Number, Comment @@ -135,9 +143,13 @@ struct Lexeme Lexeme(const Location& location, Type type, const char* data, size_t size); Lexeme(const Location& location, Type type, const char* name); + unsigned int getLength() const; + std::string toString() const; }; +static_assert(sizeof(Lexeme) <= 32, "Size of `Lexeme` struct should be up to 32 bytes."); + class AstNameTable { public: diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index e97df66b..5a945e26 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -82,8 +82,8 @@ private: // if exp then block {elseif exp then block} [else block] end | // for Name `=' exp `,' exp [`,' exp] do block end | // for namelist in explist do block end | - // function funcname funcbody | - // local function Name funcbody | + // [attributes] function funcname funcbody | + // [attributes] local function Name funcbody | // local namelist [`=' explist] // laststat ::= return [explist] | break AstStat* parseStat(); @@ -114,11 +114,25 @@ private: AstExpr* parseFunctionName(Location start, bool& hasself, AstName& debugname); // function funcname funcbody - AstStat* parseFunctionStat(); + LUAU_FORCEINLINE AstStat* parseFunctionStat(const AstArray& attributes = {nullptr, 0}); + + std::pair validateAttribute(const char* attributeName, const TempVector& attributes); + + // attribute ::= '@' NAME + void parseAttribute(TempVector& attribute); + + // attributes ::= {attribute} + AstArray parseAttributes(); + + // attributes local function Name funcbody + // attributes function funcname funcbody + // attributes `declare function' Name`(' [parlist] `)' [`:` Type] + // declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}' + AstStat* parseAttributeStat(); // local function Name funcbody | // local namelist [`=' explist] - AstStat* parseLocal(); + AstStat* parseLocal(const AstArray& attributes); // return [explist] AstStat* parseReturn(); @@ -130,7 +144,7 @@ private: // `declare global' Name: Type | // `declare function' Name`(' [parlist] `)' [`:` Type] - AstStat* parseDeclaration(const Location& start); + AstStat* parseDeclaration(const Location& start, const AstArray& attributes); // varlist `=' explist AstStat* parseAssignment(AstExpr* initial); @@ -143,7 +157,7 @@ private: // funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` Type] // funcbody ::= funcbodyhead block end std::pair parseFunctionBody( - bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName); + bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray& attributes); // explist ::= {exp `,'} exp void parseExprList(TempVector& result); @@ -176,10 +190,10 @@ private: AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional accessLocation); - AstTypeOrPack parseFunctionType(bool allowPack, bool isCheckedFunction = false); - AstType* parseFunctionTypeTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, - AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation, - bool isCheckedFunction = false); + AstTypeOrPack parseFunctionType(bool allowPack, const AstArray& attributes); + AstType* parseFunctionTypeTail(const Lexeme& begin, const AstArray& attributes, AstArray generics, + AstArray genericPacks, AstArray params, AstArray> paramNames, + AstTypePack* varargAnnotation); AstType* parseTableType(bool inDeclarationContext = false); AstTypeOrPack parseSimpleType(bool allowPack, bool inDeclarationContext = false); @@ -220,7 +234,7 @@ private: // asexp -> simpleexp [`::' Type] AstExpr* parseAssertionExpr(); - // simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp + // simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | [attributes] FUNCTION body | primaryexp AstExpr* parseSimpleExpr(); // args ::= `(' [explist] `)' | tableconstructor | String @@ -393,6 +407,7 @@ private: std::vector matchRecoveryStopOnToken; + std::vector scratchAttr; std::vector scratchStat; std::vector> scratchString; std::vector scratchExpr; diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h index 2f7daf2c..bd2ca86b 100644 --- a/Ast/include/Luau/TimeTrace.h +++ b/Ast/include/Luau/TimeTrace.h @@ -134,6 +134,14 @@ struct ThreadContext static constexpr size_t kEventFlushLimit = 8192; }; +using ThreadContextProvider = ThreadContext& (*)(); + +inline ThreadContextProvider& threadContextProvider() +{ + static ThreadContextProvider handler = nullptr; + return handler; +} + ThreadContext& getThreadContext(); struct Scope diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index bb82e0be..a3e53af5 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -3,6 +3,8 @@ #include "Luau/Common.h" +LUAU_FASTFLAG(LuauAttributeSyntax); +LUAU_FASTFLAG(LuauNativeAttribute); namespace Luau { @@ -16,6 +18,17 @@ static void visitTypeList(AstVisitor* visitor, const AstTypeList& list) list.tailType->visit(visitor); } +AstAttr::AstAttr(const Location& location, Type type) + : AstNode(ClassIndex(), location) + , type(type) +{ +} + +void AstAttr::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + int gAstRttiIndex = 0; AstExprGroup::AstExprGroup(const Location& location, AstExpr* expr) @@ -161,11 +174,12 @@ void AstExprIndexExpr::visit(AstVisitor* visitor) } } -AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, - const AstName& debugname, const std::optional& returnAnnotation, AstTypePack* varargAnnotation, - const std::optional& argLocation) +AstExprFunction::AstExprFunction(const Location& location, const AstArray& attributes, const AstArray& generics, + const AstArray& genericPacks, AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, + AstStatBlock* body, size_t functionDepth, const AstName& debugname, const std::optional& returnAnnotation, + AstTypePack* varargAnnotation, const std::optional& argLocation) : AstExpr(ClassIndex(), location) + , attributes(attributes) , generics(generics) , genericPacks(genericPacks) , self(self) @@ -201,6 +215,18 @@ void AstExprFunction::visit(AstVisitor* visitor) } } +bool AstExprFunction::hasNativeAttribute() const +{ + LUAU_ASSERT(FFlag::LuauNativeAttribute); + + for (const auto attribute : attributes) + { + if (attribute->type == AstAttr::Type::Native) + return true; + } + return false; +} + AstExprTable::AstExprTable(const Location& location, const AstArray& items) : AstExpr(ClassIndex(), location) , items(items) @@ -679,9 +705,10 @@ void AstStatTypeAlias::visit(AstVisitor* visitor) } } -AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type) +AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, const Location& nameLocation, AstType* type) : AstStat(ClassIndex(), location) , name(name) + , nameLocation(nameLocation) , type(type) { } @@ -692,31 +719,37 @@ void AstStatDeclareGlobal::visit(AstVisitor* visitor) type->visit(visitor); } -AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, - const AstTypeList& retTypes) +AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const Location& nameLocation, + const AstArray& generics, const AstArray& genericPacks, const AstTypeList& params, + const AstArray& paramNames, bool vararg, const Location& varargLocation, const AstTypeList& retTypes) : AstStat(ClassIndex(), location) + , attributes() , name(name) + , nameLocation(nameLocation) , generics(generics) , genericPacks(genericPacks) , params(params) , paramNames(paramNames) + , vararg(vararg) + , varargLocation(varargLocation) , retTypes(retTypes) - , checkedFunction(false) { } -AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, - const AstTypeList& retTypes, bool checkedFunction) +AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstArray& attributes, const AstName& name, + const Location& nameLocation, const AstArray& generics, const AstArray& genericPacks, + const AstTypeList& params, const AstArray& paramNames, bool vararg, const Location& varargLocation, const AstTypeList& retTypes) : AstStat(ClassIndex(), location) + , attributes(attributes) , name(name) + , nameLocation(nameLocation) , generics(generics) , genericPacks(genericPacks) , params(params) , paramNames(paramNames) + , vararg(vararg) + , varargLocation(varargLocation) , retTypes(retTypes) - , checkedFunction(checkedFunction) { } @@ -729,6 +762,19 @@ void AstStatDeclareFunction::visit(AstVisitor* visitor) } } +bool AstStatDeclareFunction::isCheckedFunction() const +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + for (const AstAttr* attr : attributes) + { + if (attr->type == AstAttr::Type::Checked) + return true; + } + + return false; +} + AstStatDeclareClass::AstStatDeclareClass(const Location& location, const AstName& name, std::optional superName, const AstArray& props, AstTableIndexer* indexer) : AstStat(ClassIndex(), location) @@ -820,25 +866,26 @@ void AstTypeTable::visit(AstVisitor* visitor) AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes) : AstType(ClassIndex(), location) + , attributes() , generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) , argNames(argNames) , returnTypes(returnTypes) - , checkedFunction(false) { LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size); } -AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes, bool checkedFunction) +AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& attributes, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, + const AstTypeList& returnTypes) : AstType(ClassIndex(), location) + , attributes(attributes) , generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) , argNames(argNames) , returnTypes(returnTypes) - , checkedFunction(checkedFunction) { LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size); } @@ -852,6 +899,19 @@ void AstTypeFunction::visit(AstVisitor* visitor) } } +bool AstTypeFunction::isCheckedFunction() const +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + for (const AstAttr* attr : attributes) + { + if (attr->type == AstAttr::Type::Checked) + return true; + } + + return false; +} + AstTypeTypeof::AstTypeTypeof(const Location& location, AstExpr* expr) : AstType(ClassIndex(), location) , expr(expr) diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 96653a56..8e9b3be9 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -8,7 +8,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false) -LUAU_FASTFLAGVARIABLE(LuauCheckedFunctionSyntax, false) +LUAU_FASTFLAGVARIABLE(LuauAttributeSyntax, false) namespace Luau { @@ -103,11 +103,19 @@ Lexeme::Lexeme(const Location& location, Type type, const char* name) , length(0) , name(name) { - LUAU_ASSERT(type == Name || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END)); + LUAU_ASSERT(type == Name || type == Attribute || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END)); +} + +unsigned int Lexeme::getLength() const +{ + LUAU_ASSERT(type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd || + type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment); + + return length; } static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", "not", "or", - "repeat", "return", "then", "true", "until", "while", "@checked"}; + "repeat", "return", "then", "true", "until", "while"}; std::string Lexeme::toString() const { @@ -192,6 +200,10 @@ std::string Lexeme::toString() const case Comment: return "comment"; + case Attribute: + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + return name ? format("'%s'", name) : "attribute"; + case BrokenString: return "malformed string"; @@ -279,7 +291,7 @@ std::pair AstNameTable::getOrAddWithType(const char* name nameData[length] = 0; const_cast(entry).value = AstName(nameData); - const_cast(entry).type = Lexeme::Name; + const_cast(entry).type = (name[0] == '@' ? Lexeme::Attribute : Lexeme::Name); return std::make_pair(entry.value, entry.type); } @@ -995,16 +1007,10 @@ Lexeme Lexer::readNext() } case '@': { - if (FFlag::LuauCheckedFunctionSyntax) + if (FFlag::LuauAttributeSyntax) { - // We're trying to lex the token @checked - LUAU_ASSERT(peekch() == '@'); - - std::pair maybeChecked = readName(); - if (maybeChecked.second != Lexeme::ReservedChecked) - return Lexeme(Location(start, position()), Lexeme::Error); - - return Lexeme(Location(start, position()), maybeChecked.second, maybeChecked.first.value); + std::pair attribute = readName(); + return Lexeme(Location(start, position()), Lexeme::Attribute, attribute.first.value); } } default: diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index a7363552..87af53cb 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -16,13 +16,24 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) // Warning: If you are introducing new syntax, ensure that it is behind a separate // flag so that we don't break production games by reverting syntax changes. // See docs/SyntaxChanges.md for an explanation. -LUAU_FASTFLAG(LuauCheckedFunctionSyntax) -LUAU_FASTFLAGVARIABLE(LuauReadWritePropertySyntax, false) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) +LUAU_FASTFLAG(LuauAttributeSyntax) +LUAU_FASTFLAGVARIABLE(LuauLeadingBarAndAmpersand2, false) +LUAU_FASTFLAGVARIABLE(LuauNativeAttribute, false) +LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr, false) +LUAU_FASTFLAGVARIABLE(LuauDeclarationExtraPropData, false) namespace Luau { +struct AttributeEntry +{ + const char* name; + AstAttr::Type type; +}; + +AttributeEntry kAttributeEntries[] = {{"@checked", AstAttr::Type::Checked}, {"@native", AstAttr::Type::Native}, {nullptr, AstAttr::Type::Checked}}; + ParseError::ParseError(const Location& location, const std::string& message) : location(location) , message(message) @@ -281,7 +292,9 @@ AstStatBlock* Parser::parseBlockNoScope() // for binding `=' exp `,' exp [`,' exp] do block end | // for namelist in explist do block end | // function funcname funcbody | +// attributes function funcname funcbody | // local function Name funcbody | +// local attributes function Name funcbody | // local namelist [`=' explist] // laststat ::= return [explist] | break AstStat* Parser::parseStat() @@ -300,13 +313,16 @@ AstStat* Parser::parseStat() case Lexeme::ReservedRepeat: return parseRepeat(); case Lexeme::ReservedFunction: - return parseFunctionStat(); + return parseFunctionStat(AstArray({nullptr, 0})); case Lexeme::ReservedLocal: - return parseLocal(); + return parseLocal(AstArray({nullptr, 0})); case Lexeme::ReservedReturn: return parseReturn(); case Lexeme::ReservedBreak: return parseBreak(); + case Lexeme::Attribute: + if (FFlag::LuauAttributeSyntax) + return parseAttributeStat(); default:; } @@ -344,7 +360,7 @@ AstStat* Parser::parseStat() if (options.allowDeclarationSyntax) { if (ident == "declare") - return parseDeclaration(expr->location); + return parseDeclaration(expr->location, AstArray({nullptr, 0})); } // skip unexpected symbol if lexer couldn't advance at all (statements are parsed in a loop) @@ -653,7 +669,7 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug } // function funcname funcbody -AstStat* Parser::parseFunctionStat() +AstStat* Parser::parseFunctionStat(const AstArray& attributes) { Location start = lexer.current().location; @@ -666,16 +682,129 @@ AstStat* Parser::parseFunctionStat() matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; - AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr).first; + AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr, attributes).first; matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; return allocator.alloc(Location(start, body->location), expr, body); } + +std::pair Parser::validateAttribute(const char* attributeName, const TempVector& attributes) +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + AstAttr::Type type; + + // check if the attribute name is valid + + bool found = false; + + for (int i = 0; kAttributeEntries[i].name; ++i) + { + found = !strcmp(attributeName, kAttributeEntries[i].name); + if (found) + { + type = kAttributeEntries[i].type; + + if (!FFlag::LuauNativeAttribute && type == AstAttr::Type::Native) + found = false; + + break; + } + } + + if (!found) + { + if (strlen(attributeName) == 1) + report(lexer.current().location, "Attribute name is missing"); + else + report(lexer.current().location, "Invalid attribute '%s'", attributeName); + } + else + { + // check that attribute is not duplicated + for (const AstAttr* attr : attributes) + { + if (attr->type == type) + { + report(lexer.current().location, "Cannot duplicate attribute '%s'", attributeName); + } + } + } + + return {found, type}; +} + +// attribute ::= '@' NAME +void Parser::parseAttribute(TempVector& attributes) +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + LUAU_ASSERT(lexer.current().type == Lexeme::Type::Attribute); + + Location loc = lexer.current().location; + + const char* name = lexer.current().name; + const auto [found, type] = validateAttribute(name, attributes); + + nextLexeme(); + + if (found) + attributes.push_back(allocator.alloc(loc, type)); +} + +// attributes ::= {attribute} +AstArray Parser::parseAttributes() +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + Lexeme::Type type = lexer.current().type; + + LUAU_ASSERT(type == Lexeme::Attribute); + + TempVector attributes(scratchAttr); + + while (lexer.current().type == Lexeme::Attribute) + parseAttribute(attributes); + + return copy(attributes); +} + +// attributes local function Name funcbody +// attributes function funcname funcbody +// attributes `declare function' Name`(' [parlist] `)' [`:` Type] +// declare Name '{' Name ':' attributes `(' [parlist] `)' [`:` Type] '}' +AstStat* Parser::parseAttributeStat() +{ + LUAU_ASSERT(FFlag::LuauAttributeSyntax); + + AstArray attributes = parseAttributes(); + + Lexeme::Type type = lexer.current().type; + + switch (type) + { + case Lexeme::Type::ReservedFunction: + return parseFunctionStat(attributes); + case Lexeme::Type::ReservedLocal: + return parseLocal(attributes); + case Lexeme::Type::Name: + if (options.allowDeclarationSyntax && !strcmp("declare", lexer.current().data)) + { + AstExpr* expr = parsePrimaryExpr(/* asStatement= */ true); + return parseDeclaration(expr->location, attributes); + } + default: + return reportStatError(lexer.current().location, {}, {}, + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got %s intead", + lexer.current().toString().c_str()); + } +} + // local function Name funcbody | // local bindinglist [`=' explist] -AstStat* Parser::parseLocal() +AstStat* Parser::parseLocal(const AstArray& attributes) { Location start = lexer.current().location; @@ -695,7 +824,7 @@ AstStat* Parser::parseLocal() matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; - auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name); + auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name, attributes); matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; @@ -705,6 +834,12 @@ AstStat* Parser::parseLocal() } else { + if (FFlag::LuauAttributeSyntax && attributes.size != 0) + { + return reportStatError(lexer.current().location, {}, {}, "Expected 'function' after local declaration with attribute, but got %s intead", + lexer.current().toString().c_str()); + } + matchRecoveryStopOnToken['=']++; TempVector names(scratchBinding); @@ -775,8 +910,16 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) AstDeclaredClassProp Parser::parseDeclaredClassMethod() { + Location start; + + if (FFlag::LuauDeclarationExtraPropData) + start = lexer.current().location; + nextLexeme(); - Location start = lexer.current().location; + + if (!FFlag::LuauDeclarationExtraPropData) + start = lexer.current().location; + Name fnName = parseName("function name"); // TODO: generic method declarations CLI-39909 @@ -801,15 +944,15 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() expectMatchAndConsume(')', matchParen); AstTypeList retTypes = parseOptionalReturnType().value_or(AstTypeList{copy(nullptr, 0), nullptr}); - Location end = lexer.current().location; + Location end = FFlag::LuauDeclarationExtraPropData ? lexer.previousLocation() : lexer.current().location; TempVector vars(scratchType); TempVector> varNames(scratchOptArgName); if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr) { - return AstDeclaredClassProp{ - fnName.name, reportTypeError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true}; + return AstDeclaredClassProp{fnName.name, FFlag::LuauDeclarationExtraPropData ? fnName.location : Location{}, + reportTypeError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true}; } // Skip the first index. @@ -829,21 +972,21 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() AstType* fnType = allocator.alloc( Location(start, end), generics, genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes); - return AstDeclaredClassProp{fnName.name, fnType, true}; + return AstDeclaredClassProp{fnName.name, FFlag::LuauDeclarationExtraPropData ? fnName.location : Location{}, fnType, true, + FFlag::LuauDeclarationExtraPropData ? Location(start, end) : Location{}}; } -AstStat* Parser::parseDeclaration(const Location& start) +AstStat* Parser::parseDeclaration(const Location& start, const AstArray& attributes) { // `declare` token is already parsed at this point + + if (FFlag::LuauAttributeSyntax && (attributes.size != 0) && (lexer.current().type != Lexeme::ReservedFunction)) + return reportStatError(lexer.current().location, {}, {}, "Expected a function type declaration after attribute, but got %s intead", + lexer.current().toString().c_str()); + if (lexer.current().type == Lexeme::ReservedFunction) { nextLexeme(); - bool checkedFunction = false; - if (FFlag::LuauCheckedFunctionSyntax && lexer.current().type == Lexeme::ReservedChecked) - { - checkedFunction = true; - nextLexeme(); - } Name globalName = parseName("global function name"); auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); @@ -881,8 +1024,12 @@ AstStat* Parser::parseDeclaration(const Location& start) if (vararg && !varargAnnotation) return reportStatError(Location(start, end), {}, {}, "All declaration parameters must be annotated"); - return allocator.alloc(Location(start, end), globalName.name, generics, genericPacks, - AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes, checkedFunction); + if (FFlag::LuauDeclarationExtraPropData) + return allocator.alloc(Location(start, end), attributes, globalName.name, globalName.location, generics, + genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), vararg, varargLocation, retTypes); + else + return allocator.alloc(Location(start, end), attributes, globalName.name, Location{}, generics, genericPacks, + AstTypeList{copy(vars), varargAnnotation}, copy(varNames), false, Location{}, retTypes); } else if (AstName(lexer.current().name) == "class") { @@ -912,19 +1059,42 @@ AstStat* Parser::parseDeclaration(const Location& start) const Lexeme begin = lexer.current(); nextLexeme(); // [ - std::optional> chars = parseCharArray(); + if (FFlag::LuauDeclarationExtraPropData) + { + const Location nameBegin = lexer.current().location; + std::optional> chars = parseCharArray(); - expectMatchAndConsume(']', begin); - expectAndConsume(':', "property type annotation"); - AstType* type = parseType(); + const Location nameEnd = lexer.previousLocation(); - // since AstName contains a char*, it can't contain null - bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); + expectMatchAndConsume(']', begin); + expectAndConsume(':', "property type annotation"); + AstType* type = parseType(); - if (chars && !containsNull) - props.push_back(AstDeclaredClassProp{AstName(chars->data), type, false}); + // since AstName contains a char*, it can't contain null + bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); + + if (chars && !containsNull) + props.push_back(AstDeclaredClassProp{ + AstName(chars->data), Location(nameBegin, nameEnd), type, false, Location(begin.location, lexer.previousLocation())}); + else + report(begin.location, "String literal contains malformed escape sequence or \\0"); + } else - report(begin.location, "String literal contains malformed escape sequence or \\0"); + { + std::optional> chars = parseCharArray(); + + expectMatchAndConsume(']', begin); + expectAndConsume(':', "property type annotation"); + AstType* type = parseType(); + + // since AstName contains a char*, it can't contain null + bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); + + if (chars && !containsNull) + props.push_back(AstDeclaredClassProp{AstName(chars->data), Location{}, type, false}); + else + report(begin.location, "String literal contains malformed escape sequence or \\0"); + } } else if (lexer.current().type == '[') { @@ -942,12 +1112,21 @@ AstStat* Parser::parseDeclaration(const Location& start) indexer = parseTableIndexer(AstTableAccess::ReadWrite, std::nullopt); } } + else if (FFlag::LuauDeclarationExtraPropData) + { + Location propStart = lexer.current().location; + Name propName = parseName("property name"); + expectAndConsume(':', "property type annotation"); + AstType* propType = parseType(); + props.push_back( + AstDeclaredClassProp{propName.name, propName.location, propType, false, Location(propStart, lexer.previousLocation())}); + } else { Name propName = parseName("property name"); expectAndConsume(':', "property type annotation"); AstType* propType = parseType(); - props.push_back(AstDeclaredClassProp{propName.name, propType, false}); + props.push_back(AstDeclaredClassProp{propName.name, Location{}, propType, false}); } } @@ -961,7 +1140,8 @@ AstStat* Parser::parseDeclaration(const Location& start) expectAndConsume(':', "global variable declaration"); AstType* type = parseType(/* in declaration context */ true); - return allocator.alloc(Location(start, type->location), globalName->name, type); + return allocator.alloc( + Location(start, type->location), globalName->name, FFlag::LuauDeclarationExtraPropData ? globalName->location : Location{}, type); } else { @@ -1036,7 +1216,7 @@ std::pair> Parser::prepareFunctionArguments(const // funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end // parlist ::= bindinglist [`,' `...'] | `...' std::pair Parser::parseFunctionBody( - bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName) + bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray& attributes) { Location start = matchFunction.location; @@ -1088,7 +1268,7 @@ std::pair Parser::parseFunctionBody( bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction); body->hasEnd = hasEnd; - return {allocator.alloc(Location(start, end), generics, genericPacks, self, vars, vararg, varargLocation, body, + return {allocator.alloc(Location(start, end), attributes, generics, genericPacks, self, vars, vararg, varargLocation, body, functionStack.size(), debugname, typelist, varargAnnotation, argLocation), funLocal}; } @@ -1297,7 +1477,7 @@ std::pair Parser::parseReturnType() return {location, AstTypeList{copy(result), varargAnnotation}}; } - AstType* tail = parseFunctionTypeTail(begin, {}, {}, copy(result), copy(resultNames), varargAnnotation); + AstType* tail = parseFunctionTypeTail(begin, {nullptr, 0}, {}, {}, copy(result), copy(resultNames), varargAnnotation); return {Location{location, tail->location}, AstTypeList{copy(&tail, 1), varargAnnotation}}; } @@ -1340,22 +1520,19 @@ AstType* Parser::parseTableType(bool inDeclarationContext) AstTableAccess access = AstTableAccess::ReadWrite; std::optional accessLocation; - if (FFlag::LuauReadWritePropertySyntax || FFlag::DebugLuauDeferredConstraintResolution) + if (lexer.current().type == Lexeme::Name && lexer.lookahead().type != ':') { - if (lexer.current().type == Lexeme::Name && lexer.lookahead().type != ':') + if (AstName(lexer.current().name) == "read") { - if (AstName(lexer.current().name) == "read") - { - accessLocation = lexer.current().location; - access = AstTableAccess::Read; - lexer.next(); - } - else if (AstName(lexer.current().name) == "write") - { - accessLocation = lexer.current().location; - access = AstTableAccess::Write; - lexer.next(); - } + accessLocation = lexer.current().location; + access = AstTableAccess::Read; + lexer.next(); + } + else if (AstName(lexer.current().name) == "write") + { + accessLocation = lexer.current().location; + access = AstTableAccess::Write; + lexer.next(); } } @@ -1439,7 +1616,7 @@ AstType* Parser::parseTableType(bool inDeclarationContext) // ReturnType ::= Type | `(' TypeList `)' // FunctionType ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstTypeOrPack Parser::parseFunctionType(bool allowPack, bool isCheckedFunction) +AstTypeOrPack Parser::parseFunctionType(bool allowPack, const AstArray& attributes) { incrementRecursionCounter("type annotation"); @@ -1487,11 +1664,12 @@ AstTypeOrPack Parser::parseFunctionType(bool allowPack, bool isCheckedFunction) AstArray> paramNames = copy(names); - return {parseFunctionTypeTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation, isCheckedFunction), {}}; + return {parseFunctionTypeTail(begin, attributes, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; } -AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, - AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation, bool isCheckedFunction) +AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, const AstArray& attributes, AstArray generics, + AstArray genericPacks, AstArray params, AstArray> paramNames, + AstTypePack* varargAnnotation) { incrementRecursionCounter("type annotation"); @@ -1516,7 +1694,7 @@ AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray( - Location(begin.location, endLocation), generics, genericPacks, paramTypes, paramNames, returnTypeList, isCheckedFunction); + Location(begin.location, endLocation), attributes, generics, genericPacks, paramTypes, paramNames, returnTypeList); } // Type ::= @@ -1528,7 +1706,11 @@ AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray parts(scratchType); - parts.push_back(type); + + if (!FFlag::LuauLeadingBarAndAmpersand2 || type != nullptr) + { + parts.push_back(type); + } incrementRecursionCounter("type annotation"); @@ -1553,6 +1735,8 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) } else if (c == '?') { + LUAU_ASSERT(parts.size() >= 1); + Location loc = lexer.current().location; nextLexeme(); @@ -1585,7 +1769,7 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) } if (parts.size() == 1) - return type; + return FFlag::LuauLeadingBarAndAmpersand2 ? parts[0] : type; if (isUnion && isIntersection) { @@ -1628,15 +1812,34 @@ AstTypeOrPack Parser::parseTypeOrPack() AstType* Parser::parseType(bool inDeclarationContext) { unsigned int oldRecursionCount = recursionCounter; - // recursion counter is incremented in parseSimpleType + // recursion counter is incremented in parseSimpleType and/or parseTypeSuffix Location begin = lexer.current().location; - AstType* type = parseSimpleType(/* allowPack= */ false, /* in declaration context */ inDeclarationContext).type; + if (FFlag::LuauLeadingBarAndAmpersand2) + { + AstType* type = nullptr; - recursionCounter = oldRecursionCount; + Lexeme::Type c = lexer.current().type; + if (c != '|' && c != '&') + { + type = parseSimpleType(/* allowPack= */ false, /* in declaration context */ inDeclarationContext).type; + recursionCounter = oldRecursionCount; + } - return parseTypeSuffix(type, begin); + AstType* typeWithSuffix = parseTypeSuffix(type, begin); + recursionCounter = oldRecursionCount; + + return typeWithSuffix; + } + else + { + AstType* type = parseSimpleType(/* allowPack= */ false, /* in declaration context */ inDeclarationContext).type; + + recursionCounter = oldRecursionCount; + + return parseTypeSuffix(type, begin); + } } // Type ::= nil | Name[`.' Name] [ `<' Type [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}' @@ -1647,7 +1850,21 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) Location start = lexer.current().location; - if (lexer.current().type == Lexeme::ReservedNil) + AstArray attributes{nullptr, 0}; + + if (lexer.current().type == Lexeme::Attribute) + { + if (!inDeclarationContext || !FFlag::LuauAttributeSyntax) + { + return {reportTypeError(start, {}, "attributes are not allowed in declaration context")}; + } + else + { + attributes = Parser::parseAttributes(); + return parseFunctionType(allowPack, attributes); + } + } + else if (lexer.current().type == Lexeme::ReservedNil) { nextLexeme(); return {allocator.alloc(start, std::nullopt, nameNil, std::nullopt, start), {}}; @@ -1735,15 +1952,9 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) { return {parseTableType(/* inDeclarationContext */ inDeclarationContext), {}}; } - else if (FFlag::LuauCheckedFunctionSyntax && inDeclarationContext && lexer.current().type == Lexeme::ReservedChecked) - { - LUAU_ASSERT(FFlag::LuauCheckedFunctionSyntax); - nextLexeme(); - return parseFunctionType(allowPack, /* isCheckedFunction */ true); - } else if (lexer.current().type == '(' || lexer.current().type == '<') { - return parseFunctionType(allowPack); + return parseFunctionType(allowPack, AstArray({nullptr, 0})); } else if (lexer.current().type == Lexeme::ReservedFunction) { @@ -2213,11 +2424,24 @@ static ConstantNumberParseResult parseDouble(double& result, const char* data) return ConstantNumberParseResult::Ok; } -// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp +// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | [attributes] FUNCTION body | primaryexp AstExpr* Parser::parseSimpleExpr() { Location start = lexer.current().location; + AstArray attributes{nullptr, 0}; + + if (FFlag::LuauAttributeSyntax && FFlag::LuauAttributeSyntaxFunExpr && lexer.current().type == Lexeme::Attribute) + { + attributes = parseAttributes(); + + if (lexer.current().type != Lexeme::ReservedFunction) + { + return reportExprError( + start, {}, "Expected 'function' declaration after attribute, but got %s intead", lexer.current().toString().c_str()); + } + } + if (lexer.current().type == Lexeme::ReservedNil) { nextLexeme(); @@ -2241,7 +2465,7 @@ AstExpr* Parser::parseSimpleExpr() Lexeme matchFunction = lexer.current(); nextLexeme(); - return parseFunctionBody(false, matchFunction, AstName(), nullptr).first; + return parseFunctionBody(false, matchFunction, AstName(), nullptr, attributes).first; } else if (lexer.current().type == Lexeme::Number) { @@ -2671,7 +2895,7 @@ std::optional> Parser::parseCharArray() LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::InterpStringSimple); - scratchData.assign(lexer.current().data, lexer.current().length); + scratchData.assign(lexer.current().data, lexer.current().getLength()); if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple) { @@ -2716,7 +2940,7 @@ AstExpr* Parser::parseInterpString() endLocation = currentLexeme.location; - scratchData.assign(currentLexeme.data, currentLexeme.length); + scratchData.assign(currentLexeme.data, currentLexeme.getLength()); if (!Lexer::fixupQuotedString(scratchData)) { @@ -2789,7 +3013,7 @@ AstExpr* Parser::parseNumber() { Location start = lexer.current().location; - scratchData.assign(lexer.current().data, lexer.current().length); + scratchData.assign(lexer.current().data, lexer.current().getLength()); // Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al if (scratchData.find('_') != std::string::npos) @@ -3144,11 +3368,11 @@ void Parser::nextLexeme() return; // Comments starting with ! are called "hot comments" and contain directives for type checking / linting / compiling - if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!') + if (lexeme.type == Lexeme::Comment && lexeme.getLength() && lexeme.data[0] == '!') { const char* text = lexeme.data; - unsigned int end = lexeme.length; + unsigned int end = lexeme.getLength(); while (end > 0 && isSpace(text[end - 1])) --end; diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index bc3f3538..cfcf9ce2 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -250,6 +250,10 @@ void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector& jumpTargets); -void analyzeBytecodeTypes(IrFunction& function); +void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index 9765035b..7dd05660 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -12,6 +12,12 @@ struct lua_State; +#if defined(__x86_64__) || defined(_M_X64) +#define CODEGEN_TARGET_X64 +#elif defined(__aarch64__) || defined(_M_ARM64) +#define CODEGEN_TARGET_A64 +#endif + namespace Luau { namespace CodeGen @@ -40,8 +46,12 @@ enum class CodeGenCompilationResult CodeGenAssemblerFinalizationFailure = 7, // Failure during assembler finalization CodeGenLoweringFailure = 8, // Lowering failed AllocationFailed = 9, // Native codegen failed due to an allocation error + + Count = 10, }; +std::string toString(const CodeGenCompilationResult& result); + struct ProtoCompilationFailure { CodeGenCompilationResult result = CodeGenCompilationResult::Success; @@ -62,6 +72,97 @@ struct CompilationResult } }; +struct IrBuilder; +struct IrOp; + +using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength); +using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); +using HostVectorNamecallHandler = bool (*)( + IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos); + +enum class HostMetamethod +{ + Add, + Sub, + Mul, + Div, + Idiv, + Mod, + Pow, + Minus, + Equal, + LessThan, + LessEqual, + Length, + Concat, +}; + +using HostUserdataOperationBytecodeType = uint8_t (*)(uint8_t type, const char* member, size_t memberLength); +using HostUserdataMetamethodBytecodeType = uint8_t (*)(uint8_t lhsTy, uint8_t rhsTy, HostMetamethod method); +using HostUserdataAccessHandler = bool (*)( + IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); +using HostUserdataMetamethodHandler = bool (*)( + IrBuilder& builder, uint8_t lhsTy, uint8_t rhsTy, int resultReg, IrOp lhs, IrOp rhs, HostMetamethod method, int pcpos); +using HostUserdataNamecallHandler = bool (*)( + IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos); + +struct HostIrHooks +{ + // Suggest result type of a vector field access + HostVectorOperationBytecodeType vectorAccessBytecodeType = nullptr; + + // Suggest result type of a vector function namecall + HostVectorOperationBytecodeType vectorNamecallBytecodeType = nullptr; + + // Handle vector value field access + // 'sourceReg' is guaranteed to be a vector + // Guards should take a VM exit to 'pcpos' + HostVectorAccessHandler vectorAccess = nullptr; + + // Handle namecall performed on a vector value + // 'sourceReg' (self argument) is guaranteed to be a vector + // All other arguments can be of any type + // Guards should take a VM exit to 'pcpos' + HostVectorNamecallHandler vectorNamecall = nullptr; + + // Suggest result type of a userdata field access + HostUserdataOperationBytecodeType userdataAccessBytecodeType = nullptr; + + // Suggest result type of a metamethod call + HostUserdataMetamethodBytecodeType userdataMetamethodBytecodeType = nullptr; + + // Suggest result type of a userdata namecall + HostUserdataOperationBytecodeType userdataNamecallBytecodeType = nullptr; + + // Handle userdata value field access + // 'sourceReg' is guaranteed to be a userdata, but tag has to be checked + // Write to 'resultReg' might invalidate 'sourceReg' + // Guards should take a VM exit to 'pcpos' + HostUserdataAccessHandler userdataAccess = nullptr; + + // Handle metamethod operation on a userdata value + // 'lhs' and 'rhs' operands can be VM registers of constants + // Operand types have to be checked and userdata operand tags have to be checked + // Write to 'resultReg' might invalidate source operands + // Guards should take a VM exit to 'pcpos' + HostUserdataMetamethodHandler userdataMetamethod = nullptr; + + // Handle namecall performed on a userdata value + // 'sourceReg' (self argument) is guaranteed to be a userdata, but tag has to be checked + // All other arguments can be of any type + // Guards should take a VM exit to 'pcpos' + HostUserdataNamecallHandler userdataNamecall = nullptr; +}; + +struct CompilationOptions +{ + unsigned int flags = 0; + HostIrHooks hooks; + + // null-terminated array of userdata types names that might have custom lowering + const char* const* userdataTypes = nullptr; +}; + struct CompilationStats { size_t bytecodeSizeBytes = 0; @@ -101,8 +202,17 @@ using UniqueSharedCodeGenContext = std::unique_ptr; // Builds target function and all inner functions -CompilationResult compile(lua_State* L, int idx, unsigned int flags = 0, CompilationStats* stats = nullptr); -CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags = 0, CompilationStats* stats = nullptr); +CompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats = nullptr); +CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats = nullptr); + +CompilationResult compile(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats = nullptr); +CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats = nullptr); using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int instpos); @@ -160,7 +279,7 @@ struct AssemblyOptions Target target = Host; - unsigned int flags = 0; + CompilationOptions compilationOptions; bool outputBinary = false; diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index 6c975e85..2077cce0 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -16,11 +16,11 @@ namespace Luau namespace CodeGen { -struct AssemblyOptions; +struct HostIrHooks; struct IrBuilder { - IrBuilder(); + IrBuilder(const HostIrHooks& hostHooks); void buildFunctionIr(Proto* proto); @@ -54,6 +54,7 @@ struct IrBuilder IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d); IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e); IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f); + IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f, IrOp g); IrOp block(IrBlockKind kind); // Requested kind can be ignored if we are in an outlined sequence IrOp blockAtInst(uint32_t index); @@ -64,13 +65,17 @@ struct IrBuilder IrOp vmExit(uint32_t pcpos); + const HostIrHooks& hostHooks; + bool inTerminatedBlock = false; bool interruptRequested = false; bool activeFastcallFallback = false; IrOp fastcallFallbackReturn; - int fastcallSkipTarget = -1; + + // Force builder to skip source commands + int cmdSkipTarget = -1; IrFunction function; diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index b00fffab..ae406bbc 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -31,7 +31,7 @@ enum // * Rn - VM stack register slot, n in 0..254 // * Kn - VM proto constant slot, n in 0..2^23-1 // * UPn - VM function upvalue slot, n in 0..199 -// * A, B, C, D, E are instruction arguments +// * A, B, C, D, E, F, G are instruction arguments enum class IrCmd : uint8_t { NOP, @@ -179,6 +179,10 @@ enum class IrCmd : uint8_t // A: double ABS_NUM, + // Get the sign of the argument (math.sign) + // A: double + SIGN_NUM, + // Add/Sub/Mul/Div/Idiv two vectors // A, B: TValue ADD_VEC, @@ -290,6 +294,11 @@ enum class IrCmd : uint8_t // C: block TRY_CALL_FASTGETTM, + // Create new tagged userdata + // A: int (size) + // B: int (tag) + NEW_USERDATA, + // Convert integer into a double number // A: int INT_TO_NUM, @@ -321,13 +330,12 @@ enum class IrCmd : uint8_t // This is used to recover after calling a variadic function ADJUST_STACK_TO_TOP, - // Execute fastcall builtin function in-place + // Execute fastcall builtin function with 1 argument in-place + // This is used for a few builtins that can have more than 1 result and cannot be represented as a regular instruction // A: unsigned int (builtin id) // B: Rn (result start) - // C: Rn (argument start) - // D: Rn or Kn or undef (optional second argument) - // E: int (argument count) - // F: int (result count) + // C: Rn (first argument) + // D: int (result count) FASTCALL, // Call the fastcall builtin function @@ -335,8 +343,9 @@ enum class IrCmd : uint8_t // B: Rn (result start) // C: Rn (argument start) // D: Rn or Kn or undef (optional second argument) - // E: int (argument count or -1 to use all arguments up to stack top) - // F: int (result count or -1 to preserve all results and adjust stack top) + // E: Rn or Kn or undef (optional third argument) + // F: int (argument count or -1 to use all arguments up to stack top) + // G: int (result count or -1 to preserve all results and adjust stack top) INVOKE_FASTCALL, // Check that fastcall builtin function invocation was successful (negative result count jumps to fallback) @@ -460,6 +469,13 @@ enum class IrCmd : uint8_t // When undef is specified instead of a block, execution is aborted on check failure CHECK_BUFFER_LEN, + // Guard against userdata tag mismatch + // A: pointer (userdata) + // B: int (tag) + // C: block/vmexit/undef + // When undef is specified instead of a block, execution is aborted on check failure + CHECK_USERDATA_TAG, + // Special operations // Check interrupt handler @@ -857,6 +873,7 @@ struct IrInst IrOp d; IrOp e; IrOp f; + IrOp g; uint32_t lastUse = 0; uint16_t useCount = 0; @@ -911,6 +928,7 @@ struct IrInstHash h = mix(h, key.d); h = mix(h, key.e); h = mix(h, key.f); + h = mix(h, key.g); // MurmurHash2 tail h ^= h >> 13; @@ -925,7 +943,7 @@ struct IrInstEq { bool operator()(const IrInst& a, const IrInst& b) const { - return a.cmd == b.cmd && a.a == b.a && a.b == b.b && a.c == b.c && a.d == b.d && a.e == b.e && a.f == b.f; + return a.cmd == b.cmd && a.a == b.a && a.b == b.b && a.c == b.c && a.d == b.d && a.e == b.e && a.f == b.f && a.g == b.g; } }; diff --git a/CodeGen/include/Luau/IrDump.h b/CodeGen/include/Luau/IrDump.h index dcca3c7b..d989a6c7 100644 --- a/CodeGen/include/Luau/IrDump.h +++ b/CodeGen/include/Luau/IrDump.h @@ -31,9 +31,11 @@ void toString(IrToStringContext& ctx, IrOp op); void toString(std::string& result, IrConst constant); -const char* getBytecodeTypeName(uint8_t type); +const char* getBytecodeTypeName_DEPRECATED(uint8_t type); +const char* getBytecodeTypeName(uint8_t type, const char* const* userdataTypes); -void toString(std::string& result, const BytecodeTypes& bcTypes); +void toString_DEPRECATED(std::string& result, const BytecodeTypes& bcTypes); +void toString(std::string& result, const BytecodeTypes& bcTypes, const char* const* userdataTypes); void toStringDetailed( IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, const IrInst& inst, uint32_t instIdx, IncludeUseInfo includeUseInfo); diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 0c8495e8..8d48780f 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -11,6 +11,7 @@ namespace CodeGen { struct IrBuilder; +enum class HostMetamethod; inline bool isJumpD(LuauOpcode op) { @@ -63,6 +64,7 @@ inline bool isFastCall(LuauOpcode op) case LOP_FASTCALL1: case LOP_FASTCALL2: case LOP_FASTCALL2K: + case LOP_FASTCALL3: return true; default: @@ -129,6 +131,7 @@ inline bool isNonTerminatingJump(IrCmd cmd) case IrCmd::CHECK_NODE_NO_NEXT: case IrCmd::CHECK_NODE_VALUE: case IrCmd::CHECK_BUFFER_LEN: + case IrCmd::CHECK_USERDATA_TAG: return true; default: break; @@ -168,6 +171,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::ROUND_NUM: case IrCmd::SQRT_NUM: case IrCmd::ABS_NUM: + case IrCmd::SIGN_NUM: case IrCmd::ADD_VEC: case IrCmd::SUB_VEC: case IrCmd::MUL_VEC: @@ -182,6 +186,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::DUP_TABLE: case IrCmd::TRY_NUM_TO_INDEX: case IrCmd::TRY_CALL_FASTGETTM: + case IrCmd::NEW_USERDATA: case IrCmd::INT_TO_NUM: case IrCmd::UINT_TO_NUM: case IrCmd::NUM_TO_INT: @@ -241,6 +246,12 @@ IrValueKind getCmdValueKind(IrCmd cmd); bool isGCO(uint8_t tag); +// Optional bit has to be cleared at call site, otherwise, this will return 'false' for 'userdata?' +bool isUserdataBytecodeType(uint8_t ty); +bool isCustomUserdataBytecodeType(uint8_t ty); + +HostMetamethod tmToHostMetamethod(int tm); + // Manually add or remove use of an operand void addUse(IrFunction& function, IrOp op); void removeUse(IrFunction& function, IrOp op); diff --git a/CodeGen/include/Luau/IrVisitUseDef.h b/CodeGen/include/Luau/IrVisitUseDef.h index 58c88661..6744bd65 100644 --- a/CodeGen/include/Luau/IrVisitUseDef.h +++ b/CodeGen/include/Luau/IrVisitUseDef.h @@ -4,7 +4,7 @@ #include "Luau/Common.h" #include "Luau/IrData.h" -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) +LUAU_FASTFLAG(LuauCodegenFastcall3) namespace Luau { @@ -112,12 +112,48 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i visitor.useRange(vmRegOp(inst.a), function.intOp(inst.b)); break; - // TODO: FASTCALL is more restrictive than INVOKE_FASTCALL; we should either determine the exact semantics, or rework it case IrCmd::FASTCALL: - case IrCmd::INVOKE_FASTCALL: - if (int count = function.intOp(inst.e); count != -1) + if (FFlag::LuauCodegenFastcall3) { - if (count >= 3) + visitor.use(inst.c); + + if (int nresults = function.intOp(inst.d); nresults != -1) + visitor.defRange(vmRegOp(inst.b), nresults); + } + else + { + if (int count = function.intOp(inst.e); count != -1) + { + if (count >= 3) + { + CODEGEN_ASSERT(inst.d.kind == IrOpKind::VmReg && vmRegOp(inst.d) == vmRegOp(inst.c) + 1); + + visitor.useRange(vmRegOp(inst.c), count); + } + else + { + if (count >= 1) + visitor.use(inst.c); + + if (count >= 2) + visitor.maybeUse(inst.d); // Argument can also be a VmConst + } + } + else + { + visitor.useVarargs(vmRegOp(inst.c)); + } + + // Multiple return sequences (count == -1) are defined by ADJUST_STACK_TO_REG + if (int count = function.intOp(inst.f); count != -1) + visitor.defRange(vmRegOp(inst.b), count); + } + break; + case IrCmd::INVOKE_FASTCALL: + if (int count = function.intOp(FFlag::LuauCodegenFastcall3 ? inst.f : inst.e); count != -1) + { + // Only LOP_FASTCALL3 lowering is allowed to have third optional argument + if (count >= 3 && (!FFlag::LuauCodegenFastcall3 || inst.e.kind == IrOpKind::Undef)) { CODEGEN_ASSERT(inst.d.kind == IrOpKind::VmReg && vmRegOp(inst.d) == vmRegOp(inst.c) + 1); @@ -130,6 +166,9 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i if (count >= 2) visitor.maybeUse(inst.d); // Argument can also be a VmConst + + if (FFlag::LuauCodegenFastcall3 && count >= 3) + visitor.maybeUse(inst.e); // Argument can also be a VmConst } } else @@ -138,7 +177,7 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i } // Multiple return sequences (count == -1) are defined by ADJUST_STACK_TO_REG - if (int count = function.intOp(inst.f); count != -1) + if (int count = function.intOp(FFlag::LuauCodegenFastcall3 ? inst.g : inst.f); count != -1) visitor.defRange(vmRegOp(inst.b), count); break; case IrCmd::FORGLOOP: @@ -188,15 +227,8 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i visitor.def(inst.b); break; case IrCmd::FALLBACK_FORGPREP: - if (FFlag::LuauCodegenRemoveDeadStores5) - { - // This instruction doesn't always redefine Rn, Rn+1, Rn+2, so we have to mark it as implicit use - visitor.useRange(vmRegOp(inst.b), 3); - } - else - { - visitor.use(inst.b); - } + // This instruction doesn't always redefine Rn, Rn+1, Rn+2, so we have to mark it as implicit use + visitor.useRange(vmRegOp(inst.b), 3); visitor.defRange(vmRegOp(inst.b), 3); break; @@ -214,12 +246,6 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i visitor.use(inst.a); break; - // After optimizations with DebugLuauAbortingChecks enabled, CHECK_TAG Rn, tag, block instructions are generated - case IrCmd::CHECK_TAG: - if (!FFlag::LuauCodegenRemoveDeadStores5) - visitor.maybeUse(inst.a); - break; - default: // All instructions which reference registers have to be handled explicitly CODEGEN_ASSERT(inst.a.kind != IrOpKind::VmReg); @@ -228,6 +254,7 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i CODEGEN_ASSERT(inst.d.kind != IrOpKind::VmReg); CODEGEN_ASSERT(inst.e.kind != IrOpKind::VmReg); CODEGEN_ASSERT(inst.f.kind != IrOpKind::VmReg); + CODEGEN_ASSERT(inst.g.kind != IrOpKind::VmReg); break; } } diff --git a/CodeGen/include/Luau/UnwindBuilder.h b/CodeGen/include/Luau/UnwindBuilder.h index 1ba377ba..03c9b56a 100644 --- a/CodeGen/include/Luau/UnwindBuilder.h +++ b/CodeGen/include/Luau/UnwindBuilder.h @@ -16,7 +16,7 @@ namespace CodeGen { // This value is used in 'finishFunction' to mark the function that spans to the end of the whole code block -static uint32_t kFullBlockFuncton = ~0u; +static uint32_t kFullBlockFunction = ~0u; class UnwindBuilder { @@ -52,11 +52,10 @@ public: virtual void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list gpr, const std::vector& simd) = 0; - virtual size_t getSize() const = 0; - virtual size_t getFunctionCount() const = 0; + virtual size_t getUnwindInfoSize(size_t blockSize) const = 0; // This will place the unwinding data at the target address and might update values of some fields - virtual void finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const = 0; + virtual size_t finalize(char* target, size_t offset, void* funcAddress, size_t blockSize) const = 0; }; } // namespace CodeGen diff --git a/CodeGen/include/Luau/UnwindBuilderDwarf2.h b/CodeGen/include/Luau/UnwindBuilderDwarf2.h index 741aaed2..1b634dec 100644 --- a/CodeGen/include/Luau/UnwindBuilderDwarf2.h +++ b/CodeGen/include/Luau/UnwindBuilderDwarf2.h @@ -33,10 +33,9 @@ public: void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list gpr, const std::vector& simd) override; - size_t getSize() const override; - size_t getFunctionCount() const override; + size_t getUnwindInfoSize(size_t blockSize = 0) const override; - void finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const override; + size_t finalize(char* target, size_t offset, void* funcAddress, size_t blockSize) const override; private: size_t beginOffset = 0; diff --git a/CodeGen/include/Luau/UnwindBuilderWin.h b/CodeGen/include/Luau/UnwindBuilderWin.h index 3a7e1b5a..bc43b94a 100644 --- a/CodeGen/include/Luau/UnwindBuilderWin.h +++ b/CodeGen/include/Luau/UnwindBuilderWin.h @@ -53,10 +53,9 @@ public: void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list gpr, const std::vector& simd) override; - size_t getSize() const override; - size_t getFunctionCount() const override; + size_t getUnwindInfoSize(size_t blockSize = 0) const override; - void finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const override; + size_t finalize(char* target, size_t offset, void* funcAddress, size_t blockSize) const override; private: size_t beginOffset = 0; diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index bed7e0e3..f999d753 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -826,7 +826,7 @@ void AssemblyBuilderX64::vcvtss2sd(OperandX64 dst, OperandX64 src1, OperandX64 s else CODEGEN_ASSERT(src2.memSize == SizeX64::dword); - placeAvx("vcvtsd2ss", dst, src1, src2, 0x5a, false, AVX_0F, AVX_F3); + placeAvx("vcvtss2sd", dst, src1, src2, 0x5a, false, AVX_0F, AVX_F3); } void AssemblyBuilderX64::vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode) diff --git a/CodeGen/src/BytecodeAnalysis.cpp b/CodeGen/src/BytecodeAnalysis.cpp index 7c39f5fc..b99d6336 100644 --- a/CodeGen/src/BytecodeAnalysis.cpp +++ b/CodeGen/src/BytecodeAnalysis.cpp @@ -2,35 +2,26 @@ #include "Luau/BytecodeAnalysis.h" #include "Luau/BytecodeUtils.h" +#include "Luau/CodeGen.h" #include "Luau/IrData.h" #include "Luau/IrUtils.h" #include "lobject.h" +#include "lstate.h" #include -LUAU_FASTFLAG(LuauCodegenDirectUserdataFlow) -LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo load changes the format used by Codegen, same flag is used -LUAU_FASTFLAGVARIABLE(LuauCodegenTypeInfo, false) // New analysis is flagged separately -LUAU_FASTFLAG(LuauTypeInfoLookupImprovement) +LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOps, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenFastcall3, false) namespace Luau { namespace CodeGen { -static bool hasTypedParameters(Proto* proto) -{ - CODEGEN_ASSERT(!FFlag::LuauLoadTypeInfo); - - return proto->typeinfo && proto->numparams != 0; -} - template static T read(uint8_t* data, size_t& offset) { - CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo); - T result; memcpy(&result, data + offset, sizeof(T)); offset += sizeof(T); @@ -40,8 +31,6 @@ static T read(uint8_t* data, size_t& offset) static uint32_t readVarInt(uint8_t* data, size_t& offset) { - CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo); - uint32_t result = 0; uint32_t shift = 0; @@ -59,25 +48,15 @@ static uint32_t readVarInt(uint8_t* data, size_t& offset) void loadBytecodeTypeInfo(IrFunction& function) { - CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo); - Proto* proto = function.proto; - if (FFlag::LuauTypeInfoLookupImprovement) - { - if (!proto) - return; - } - else - { - if (!proto || !proto->typeinfo) - return; - } + if (!proto) + return; BytecodeTypeInfo& typeInfo = function.bcTypeInfo; // If there is no typeinfo, we generate default values for arguments and upvalues - if (FFlag::LuauTypeInfoLookupImprovement && !proto->typeinfo) + if (!proto->typeinfo) { typeInfo.argumentTypes.resize(proto->numparams, LBC_TYPE_ANY); typeInfo.upvalueTypes.resize(proto->nups, LBC_TYPE_ANY); @@ -91,8 +70,6 @@ void loadBytecodeTypeInfo(IrFunction& function) uint32_t upvalCount = readVarInt(data, offset); uint32_t localCount = readVarInt(data, offset); - CODEGEN_ASSERT(upvalCount == unsigned(proto->nups)); - if (typeSize != 0) { uint8_t* types = (uint8_t*)data + offset; @@ -110,6 +87,8 @@ void loadBytecodeTypeInfo(IrFunction& function) if (upvalCount != 0) { + CODEGEN_ASSERT(upvalCount == unsigned(proto->nups)); + typeInfo.upvalueTypes.resize(upvalCount); uint8_t* types = (uint8_t*)data + offset; @@ -137,8 +116,6 @@ void loadBytecodeTypeInfo(IrFunction& function) static void prepareRegTypeInfoLookups(BytecodeTypeInfo& typeInfo) { - CODEGEN_ASSERT(FFlag::LuauTypeInfoLookupImprovement); - // Sort by register first, then by end PC std::sort(typeInfo.regTypes.begin(), typeInfo.regTypes.end(), [](const BytecodeRegTypeInfo& a, const BytecodeRegTypeInfo& b) { if (a.reg != b.reg) @@ -171,47 +148,30 @@ static void prepareRegTypeInfoLookups(BytecodeTypeInfo& typeInfo) static BytecodeRegTypeInfo* findRegType(BytecodeTypeInfo& info, uint8_t reg, int pc) { - CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo); - - if (FFlag::LuauTypeInfoLookupImprovement) - { - auto b = info.regTypes.begin() + info.regTypeOffsets[reg]; - auto e = info.regTypes.begin() + info.regTypeOffsets[reg + 1]; - - // Doen't have info - if (b == e) - return nullptr; - - // No info after the last live range - if (pc >= (e - 1)->endpc) - return nullptr; - - for (auto it = b; it != e; ++it) - { - CODEGEN_ASSERT(it->reg == reg); - - if (pc >= it->startpc && pc < it->endpc) - return &*it; - } + auto b = info.regTypes.begin() + info.regTypeOffsets[reg]; + auto e = info.regTypes.begin() + info.regTypeOffsets[reg + 1]; + // Doen't have info + if (b == e) return nullptr; - } - else - { - for (BytecodeRegTypeInfo& el : info.regTypes) - { - if (reg == el.reg && pc >= el.startpc && pc < el.endpc) - return ⪙ - } + // No info after the last live range + if (pc >= (e - 1)->endpc) return nullptr; + + for (auto it = b; it != e; ++it) + { + CODEGEN_ASSERT(it->reg == reg); + + if (pc >= it->startpc && pc < it->endpc) + return &*it; } + + return nullptr; } static void refineRegType(BytecodeTypeInfo& info, uint8_t reg, int pc, uint8_t ty) { - CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo); - if (ty != LBC_TYPE_ANY) { if (BytecodeRegTypeInfo* regType = findRegType(info, reg, pc)) @@ -220,7 +180,7 @@ static void refineRegType(BytecodeTypeInfo& info, uint8_t reg, int pc, uint8_t t if (regType->type == LBC_TYPE_ANY) regType->type = ty; } - else if (FFlag::LuauTypeInfoLookupImprovement && reg < info.argumentTypes.size()) + else if (reg < info.argumentTypes.size()) { if (info.argumentTypes[reg] == LBC_TYPE_ANY) info.argumentTypes[reg] = ty; @@ -230,8 +190,6 @@ static void refineRegType(BytecodeTypeInfo& info, uint8_t reg, int pc, uint8_t t static void refineUpvalueType(BytecodeTypeInfo& info, int up, uint8_t ty) { - CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo); - if (ty != LBC_TYPE_ANY) { if (size_t(up) < info.upvalueTypes.size()) @@ -558,6 +516,49 @@ static void applyBuiltinCall(int bfid, BytecodeTypes& types) } } +static HostMetamethod opcodeToHostMetamethod(LuauOpcode op) +{ + switch (op) + { + case LOP_ADD: + return HostMetamethod::Add; + case LOP_SUB: + return HostMetamethod::Sub; + case LOP_MUL: + return HostMetamethod::Mul; + case LOP_DIV: + return HostMetamethod::Div; + case LOP_IDIV: + return HostMetamethod::Idiv; + case LOP_MOD: + return HostMetamethod::Mod; + case LOP_POW: + return HostMetamethod::Pow; + case LOP_ADDK: + return HostMetamethod::Add; + case LOP_SUBK: + return HostMetamethod::Sub; + case LOP_MULK: + return HostMetamethod::Mul; + case LOP_DIVK: + return HostMetamethod::Div; + case LOP_IDIVK: + return HostMetamethod::Idiv; + case LOP_MODK: + return HostMetamethod::Mod; + case LOP_POWK: + return HostMetamethod::Pow; + case LOP_SUBRK: + return HostMetamethod::Sub; + case LOP_DIVRK: + return HostMetamethod::Div; + default: + CODEGEN_ASSERT(!"opcode is not assigned to a host metamethod"); + } + + return HostMetamethod::Add; +} + void buildBytecodeBlocks(IrFunction& function, const std::vector& jumpTargets) { Proto* proto = function.proto; @@ -607,15 +608,14 @@ void buildBytecodeBlocks(IrFunction& function, const std::vector& jumpT } } -void analyzeBytecodeTypes(IrFunction& function) +void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) { Proto* proto = function.proto; CODEGEN_ASSERT(proto); BytecodeTypeInfo& bcTypeInfo = function.bcTypeInfo; - if (FFlag::LuauTypeInfoLookupImprovement) - prepareRegTypeInfoLookups(bcTypeInfo); + prepareRegTypeInfoLookups(bcTypeInfo); // Setup our current knowledge of type tags based on arguments uint8_t regTags[256]; @@ -631,48 +631,31 @@ void analyzeBytecodeTypes(IrFunction& function) // At the block start, reset or knowledge to the starting state // In the future we might be able to propagate some info between the blocks as well - if (FFlag::LuauLoadTypeInfo) + for (size_t i = 0; i < bcTypeInfo.argumentTypes.size(); i++) { - for (size_t i = 0; i < bcTypeInfo.argumentTypes.size(); i++) - { - uint8_t et = bcTypeInfo.argumentTypes[i]; + uint8_t et = bcTypeInfo.argumentTypes[i]; - // TODO: if argument is optional, this might force a VM exit unnecessarily - regTags[i] = et & ~LBC_TYPE_OPTIONAL_BIT; - } - } - else - { - if (hasTypedParameters(proto)) - { - for (int i = 0; i < proto->numparams; ++i) - { - uint8_t et = proto->typeinfo[2 + i]; - - // TODO: if argument is optional, this might force a VM exit unnecessarily - regTags[i] = et & ~LBC_TYPE_OPTIONAL_BIT; - } - } + // TODO: if argument is optional, this might force a VM exit unnecessarily + regTags[i] = et & ~LBC_TYPE_OPTIONAL_BIT; } for (int i = proto->numparams; i < proto->maxstacksize; ++i) regTags[i] = LBC_TYPE_ANY; + LuauBytecodeType knownNextCallResult = LBC_TYPE_ANY; + for (int i = block.startpc; i <= block.finishpc;) { const Instruction* pc = &proto->code[i]; LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc)); - if (FFlag::LuauCodegenTypeInfo) + // Assign known register types from local type information + // TODO: this is an expensive walk for each instruction + // TODO: it's best to lookup when register is actually used in the instruction + for (BytecodeRegTypeInfo& el : bcTypeInfo.regTypes) { - // Assign known register types from local type information - // TODO: this is an expensive walk for each instruction - // TODO: it's best to lookup when register is actually used in the instruction - for (BytecodeRegTypeInfo& el : bcTypeInfo.regTypes) - { - if (el.type != LBC_TYPE_ANY && i >= el.startpc && i < el.endpc) - regTags[el.reg] = el.type; - } + if (el.type != LBC_TYPE_ANY && i >= el.startpc && i < el.endpc) + regTags[el.reg] = el.type; } BytecodeTypes& bcType = function.bcTypes[i]; @@ -694,8 +677,7 @@ void analyzeBytecodeTypes(IrFunction& function) regTags[ra] = LBC_TYPE_BOOLEAN; bcType.result = regTags[ra]; - if (FFlag::LuauCodegenTypeInfo) - refineRegType(bcTypeInfo, ra, i, bcType.result); + refineRegType(bcTypeInfo, ra, i, bcType.result); break; } case LOP_LOADN: @@ -704,8 +686,7 @@ void analyzeBytecodeTypes(IrFunction& function) regTags[ra] = LBC_TYPE_NUMBER; bcType.result = regTags[ra]; - if (FFlag::LuauCodegenTypeInfo) - refineRegType(bcTypeInfo, ra, i, bcType.result); + refineRegType(bcTypeInfo, ra, i, bcType.result); break; } case LOP_LOADK: @@ -716,8 +697,7 @@ void analyzeBytecodeTypes(IrFunction& function) regTags[ra] = bcType.a; bcType.result = regTags[ra]; - if (FFlag::LuauCodegenTypeInfo) - refineRegType(bcTypeInfo, ra, i, bcType.result); + refineRegType(bcTypeInfo, ra, i, bcType.result); break; } case LOP_LOADKX: @@ -728,8 +708,7 @@ void analyzeBytecodeTypes(IrFunction& function) regTags[ra] = bcType.a; bcType.result = regTags[ra]; - if (FFlag::LuauCodegenTypeInfo) - refineRegType(bcTypeInfo, ra, i, bcType.result); + refineRegType(bcTypeInfo, ra, i, bcType.result); break; } case LOP_MOVE: @@ -740,8 +719,7 @@ void analyzeBytecodeTypes(IrFunction& function) regTags[ra] = regTags[rb]; bcType.result = regTags[ra]; - if (FFlag::LuauCodegenTypeInfo) - refineRegType(bcTypeInfo, ra, i, bcType.result); + refineRegType(bcTypeInfo, ra, i, bcType.result); break; } case LOP_GETTABLE: @@ -771,10 +749,51 @@ void analyzeBytecodeTypes(IrFunction& function) regTags[ra] = LBC_TYPE_ANY; - // Assuming that vector component is being indexed - // TODO: check what key is used - if (bcType.a == LBC_TYPE_VECTOR) - regTags[ra] = LBC_TYPE_NUMBER; + if (FFlag::LuauCodegenUserdataOps) + { + TString* str = gco2ts(function.proto->k[kc].value.gc); + const char* field = getstr(str); + + if (bcType.a == LBC_TYPE_VECTOR) + { + if (str->len == 1) + { + // Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z" + char ch = field[0] | ' '; + + if (ch == 'x' || ch == 'y' || ch == 'z') + regTags[ra] = LBC_TYPE_NUMBER; + } + + if (regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) + regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len); + } + else if (isCustomUserdataBytecodeType(bcType.a)) + { + if (regTags[ra] == LBC_TYPE_ANY && hostHooks.userdataAccessBytecodeType) + regTags[ra] = hostHooks.userdataAccessBytecodeType(bcType.a, field, str->len); + } + } + else + { + if (bcType.a == LBC_TYPE_VECTOR) + { + TString* str = gco2ts(function.proto->k[kc].value.gc); + const char* field = getstr(str); + + if (str->len == 1) + { + // Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z" + char ch = field[0] | ' '; + + if (ch == 'x' || ch == 'y' || ch == 'z') + regTags[ra] = LBC_TYPE_NUMBER; + } + + if (regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType) + regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len); + } + } bcType.result = regTags[ra]; break; @@ -810,6 +829,9 @@ void analyzeBytecodeTypes(IrFunction& function) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -839,6 +861,11 @@ void analyzeBytecodeTypes(IrFunction& function) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; } + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + { + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); + } bcType.result = regTags[ra]; break; @@ -857,6 +884,9 @@ void analyzeBytecodeTypes(IrFunction& function) if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER) regTags[ra] = LBC_TYPE_NUMBER; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -877,6 +907,9 @@ void analyzeBytecodeTypes(IrFunction& function) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -906,6 +939,11 @@ void analyzeBytecodeTypes(IrFunction& function) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; } + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + { + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); + } bcType.result = regTags[ra]; break; @@ -924,6 +962,9 @@ void analyzeBytecodeTypes(IrFunction& function) if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER) regTags[ra] = LBC_TYPE_NUMBER; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -943,6 +984,9 @@ void analyzeBytecodeTypes(IrFunction& function) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; break; @@ -970,6 +1014,11 @@ void analyzeBytecodeTypes(IrFunction& function) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; } + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + { + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); + } bcType.result = regTags[ra]; break; @@ -998,6 +1047,8 @@ void analyzeBytecodeTypes(IrFunction& function) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; + else if (FFlag::LuauCodegenUserdataOps && hostHooks.userdataMetamethodBytecodeType && isCustomUserdataBytecodeType(bcType.a)) + regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, LBC_TYPE_ANY, HostMetamethod::Minus); bcType.result = regTags[ra]; break; @@ -1036,8 +1087,7 @@ void analyzeBytecodeTypes(IrFunction& function) regTags[ra + 3] = bcType.c; regTags[ra] = bcType.result; - if (FFlag::LuauCodegenTypeInfo) - refineRegType(bcTypeInfo, ra, i, bcType.result); + refineRegType(bcTypeInfo, ra, i, bcType.result); break; } case LOP_FASTCALL1: @@ -1055,8 +1105,7 @@ void analyzeBytecodeTypes(IrFunction& function) regTags[LUAU_INSN_B(*pc)] = bcType.a; regTags[ra] = bcType.result; - if (FFlag::LuauCodegenTypeInfo) - refineRegType(bcTypeInfo, ra, i, bcType.result); + refineRegType(bcTypeInfo, ra, i, bcType.result); break; } case LOP_FASTCALL2: @@ -1074,8 +1123,29 @@ void analyzeBytecodeTypes(IrFunction& function) regTags[int(pc[1])] = bcType.b; regTags[ra] = bcType.result; - if (FFlag::LuauCodegenTypeInfo) - refineRegType(bcTypeInfo, ra, i, bcType.result); + refineRegType(bcTypeInfo, ra, i, bcType.result); + break; + } + case LOP_FASTCALL3: + { + CODEGEN_ASSERT(FFlag::LuauCodegenFastcall3); + + int bfid = LUAU_INSN_A(*pc); + int skip = LUAU_INSN_C(*pc); + int aux = pc[1]; + + Instruction call = pc[skip + 1]; + CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + int ra = LUAU_INSN_A(call); + + applyBuiltinCall(bfid, bcType); + + regTags[LUAU_INSN_B(*pc)] = bcType.a; + regTags[aux & 0xff] = bcType.b; + regTags[(aux >> 8) & 0xff] = bcType.c; + regTags[ra] = bcType.result; + + refineRegType(bcTypeInfo, ra, i, bcType.result); break; } case LOP_FORNPREP: @@ -1086,12 +1156,9 @@ void analyzeBytecodeTypes(IrFunction& function) regTags[ra + 1] = LBC_TYPE_NUMBER; regTags[ra + 2] = LBC_TYPE_NUMBER; - if (FFlag::LuauCodegenTypeInfo) - { - refineRegType(bcTypeInfo, ra, i, regTags[ra]); - refineRegType(bcTypeInfo, ra + 1, i, regTags[ra + 1]); - refineRegType(bcTypeInfo, ra + 2, i, regTags[ra + 2]); - } + refineRegType(bcTypeInfo, ra, i, regTags[ra]); + refineRegType(bcTypeInfo, ra + 1, i, regTags[ra + 1]); + refineRegType(bcTypeInfo, ra + 2, i, regTags[ra + 2]); break; } case LOP_FORNLOOP: @@ -1121,61 +1188,88 @@ void analyzeBytecodeTypes(IrFunction& function) } case LOP_NAMECALL: { - if (FFlag::LuauCodegenDirectUserdataFlow) + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + uint32_t kc = pc[1]; + + bcType.a = regTags[rb]; + bcType.b = getBytecodeConstantTag(proto, kc); + + // While namecall might result in a callable table, we assume the function fast path + regTags[ra] = LBC_TYPE_FUNCTION; + + // Namecall places source register into target + 1 + regTags[ra + 1] = bcType.a; + + bcType.result = LBC_TYPE_FUNCTION; + + if (FFlag::LuauCodegenUserdataOps) { - int ra = LUAU_INSN_A(*pc); - int rb = LUAU_INSN_B(*pc); - uint32_t kc = pc[1]; + TString* str = gco2ts(function.proto->k[kc].value.gc); + const char* field = getstr(str); - bcType.a = regTags[rb]; - bcType.b = getBytecodeConstantTag(proto, kc); - - // While namecall might result in a callable table, we assume the function fast path - regTags[ra] = LBC_TYPE_FUNCTION; - - // Namecall places source register into target + 1 - regTags[ra + 1] = bcType.a; - - bcType.result = LBC_TYPE_FUNCTION; + if (bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) + knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len)); + else if (isCustomUserdataBytecodeType(bcType.a) && hostHooks.userdataNamecallBytecodeType) + knownNextCallResult = LuauBytecodeType(hostHooks.userdataNamecallBytecodeType(bcType.a, field, str->len)); } + else + { + if (bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType) + { + TString* str = gco2ts(function.proto->k[kc].value.gc); + const char* field = getstr(str); + + knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len)); + } + } + break; + } + case LOP_CALL: + { + int ra = LUAU_INSN_A(*pc); + + if (knownNextCallResult != LBC_TYPE_ANY) + { + bcType.result = knownNextCallResult; + + knownNextCallResult = LBC_TYPE_ANY; + + regTags[ra] = bcType.result; + } + + refineRegType(bcTypeInfo, ra, i, bcType.result); break; } case LOP_GETUPVAL: { - if (FFlag::LuauCodegenTypeInfo) + int ra = LUAU_INSN_A(*pc); + int up = LUAU_INSN_B(*pc); + + bcType.a = LBC_TYPE_ANY; + + if (size_t(up) < bcTypeInfo.upvalueTypes.size()) { - int ra = LUAU_INSN_A(*pc); - int up = LUAU_INSN_B(*pc); + uint8_t et = bcTypeInfo.upvalueTypes[up]; - bcType.a = LBC_TYPE_ANY; - - if (size_t(up) < bcTypeInfo.upvalueTypes.size()) - { - uint8_t et = bcTypeInfo.upvalueTypes[up]; - - // TODO: if argument is optional, this might force a VM exit unnecessarily - bcType.a = et & ~LBC_TYPE_OPTIONAL_BIT; - } - - regTags[ra] = bcType.a; - bcType.result = regTags[ra]; + // TODO: if argument is optional, this might force a VM exit unnecessarily + bcType.a = et & ~LBC_TYPE_OPTIONAL_BIT; } + + regTags[ra] = bcType.a; + bcType.result = regTags[ra]; break; } case LOP_SETUPVAL: { - if (FFlag::LuauCodegenTypeInfo) - { - int ra = LUAU_INSN_A(*pc); - int up = LUAU_INSN_B(*pc); + int ra = LUAU_INSN_A(*pc); + int up = LUAU_INSN_B(*pc); - refineUpvalueType(bcTypeInfo, up, regTags[ra]); - } + refineUpvalueType(bcTypeInfo, up, regTags[ra]); break; } case LOP_GETGLOBAL: case LOP_SETGLOBAL: - case LOP_CALL: case LOP_RETURN: case LOP_JUMP: case LOP_JUMPBACK: diff --git a/CodeGen/src/BytecodeSummary.cpp b/CodeGen/src/BytecodeSummary.cpp index 0089f592..d0d71504 100644 --- a/CodeGen/src/BytecodeSummary.cpp +++ b/CodeGen/src/BytecodeSummary.cpp @@ -8,6 +8,8 @@ #include "lobject.h" #include "lstate.h" +LUAU_FASTFLAG(LuauNativeAttribute) + namespace Luau { namespace CodeGen @@ -56,7 +58,10 @@ std::vector summarizeBytecode(lua_State* L, int idx, un Proto* root = clvalue(func)->l.p; std::vector protos; - gatherFunctions(protos, root, CodeGen_ColdFunctions); + if (FFlag::LuauNativeAttribute) + gatherFunctions(protos, root, CodeGen_ColdFunctions, root->flags & LPF_NATIVE_FUNCTION); + else + gatherFunctions_DEPRECATED(protos, root, CodeGen_ColdFunctions); std::vector summaries; summaries.reserve(protos.size()); diff --git a/CodeGen/src/CodeBlockUnwind.cpp b/CodeGen/src/CodeBlockUnwind.cpp index ca1a489e..cb2d693a 100644 --- a/CodeGen/src/CodeBlockUnwind.cpp +++ b/CodeGen/src/CodeBlockUnwind.cpp @@ -7,7 +7,7 @@ #include #include -#if defined(_WIN32) && defined(_M_X64) +#if defined(_WIN32) && defined(CODEGEN_TARGET_X64) #ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN @@ -26,7 +26,7 @@ extern "C" void __deregister_frame(const void*) __attribute__((weak)); extern "C" void __unw_add_dynamic_fde() __attribute__((weak)); #endif -#if defined(__APPLE__) && defined(__aarch64__) +#if defined(__APPLE__) && defined(CODEGEN_TARGET_A64) #include #include #include @@ -48,7 +48,7 @@ namespace Luau namespace CodeGen { -#if defined(__APPLE__) && defined(__aarch64__) +#if defined(__APPLE__) && defined(CODEGEN_TARGET_A64) static int findDynamicUnwindSections(uintptr_t addr, unw_dynamic_unwind_sections_t* info) { // Define a minimal mach header for JIT'd code. @@ -102,17 +102,17 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz UnwindBuilder* unwind = (UnwindBuilder*)context; // All unwinding related data is placed together at the start of the block - size_t unwindSize = unwind->getSize(); + size_t unwindSize = unwind->getUnwindInfoSize(blockSize); unwindSize = (unwindSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1); // Match code allocator alignment CODEGEN_ASSERT(blockSize >= unwindSize); char* unwindData = (char*)block; - unwind->finalize(unwindData, unwindSize, block, blockSize); + [[maybe_unused]] size_t functionCount = unwind->finalize(unwindData, unwindSize, block, blockSize); -#if defined(_WIN32) && defined(_M_X64) +#if defined(_WIN32) && defined(CODEGEN_TARGET_X64) #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) - if (!RtlAddFunctionTable((RUNTIME_FUNCTION*)block, uint32_t(unwind->getFunctionCount()), uintptr_t(block))) + if (!RtlAddFunctionTable((RUNTIME_FUNCTION*)block, uint32_t(functionCount), uintptr_t(block))) { CODEGEN_ASSERT(!"Failed to allocate function table"); return nullptr; @@ -126,7 +126,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz visitFdeEntries(unwindData, __register_frame); #endif -#if defined(__APPLE__) && defined(__aarch64__) +#if defined(__APPLE__) && defined(CODEGEN_TARGET_A64) // Starting from macOS 14, we need to register unwind section callback to state that our ABI doesn't require pointer authentication // This might conflict with other JITs that do the same; unfortunately this is the best we can do for now. static unw_add_find_dynamic_unwind_sections_t unw_add_find_dynamic_unwind_sections = @@ -141,7 +141,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz void destroyBlockUnwindInfo(void* context, void* unwindData) { -#if defined(_WIN32) && defined(_M_X64) +#if defined(_WIN32) && defined(CODEGEN_TARGET_X64) #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) if (!RtlDeleteFunctionTable((RUNTIME_FUNCTION*)unwindData)) @@ -161,12 +161,12 @@ void destroyBlockUnwindInfo(void* context, void* unwindData) bool isUnwindSupported() { -#if defined(_WIN32) && defined(_M_X64) +#if defined(_WIN32) && defined(CODEGEN_TARGET_X64) return true; #elif defined(__ANDROID__) // Current unwind information is not compatible with Android return false; -#elif defined(__APPLE__) && defined(__aarch64__) +#elif defined(__APPLE__) && defined(CODEGEN_TARGET_A64) char ver[256]; size_t verLength = sizeof(ver); // libunwind on macOS 12 and earlier (which maps to osrelease 21) assumes JIT frames use pointer authentication without a way to override that diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 9ef9980a..694a9f7e 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -27,7 +27,7 @@ #include #include -#if defined(__x86_64__) || defined(_M_X64) +#if defined(CODEGEN_TARGET_X64) #ifdef _MSC_VER #include // __cpuid #else @@ -35,7 +35,7 @@ #endif #endif -#if defined(__aarch64__) +#if defined(CODEGEN_TARGET_A64) #ifdef __APPLE__ #include #endif @@ -58,186 +58,41 @@ LUAU_FASTINTVARIABLE(CodegenHeuristicsBlockLimit, 32'768) // 32 K // Current value is based on some member variables being limited to 16 bits LUAU_FASTINTVARIABLE(CodegenHeuristicsBlockInstructionLimit, 65'536) // 64 K -LUAU_FASTFLAG(LuauCodegenContext) - namespace Luau { namespace CodeGen { -static const Instruction kCodeEntryInsn = LOP_NATIVECALL; - -void* gPerfLogContext = nullptr; -PerfLogFn gPerfLogFn = nullptr; - -struct OldNativeProto +std::string toString(const CodeGenCompilationResult& result) { - Proto* p; - void* execdata; - uintptr_t exectarget; -}; - -// Additional data attached to Proto::execdata -// Guaranteed to be aligned to 16 bytes -struct ExtraExecData -{ - size_t execDataSize; - size_t codeSize; -}; - -static int alignTo(int value, int align) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - CODEGEN_ASSERT(align > 0 && (align & (align - 1)) == 0); - return (value + (align - 1)) & ~(align - 1); -} - -// Returns the size of execdata required to store all code offsets and ExtraExecData structure at proper alignment -// Always a multiple of 4 bytes -static int calculateExecDataSize(Proto* proto) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - int size = proto->sizecode * sizeof(uint32_t); - - size = alignTo(size, 16); - size += sizeof(ExtraExecData); - - return size; -} - -// Returns pointer to the ExtraExecData inside the Proto::execdata -// Even though 'execdata' is a field in Proto, we require it to support cases where it's not attached to Proto during construction -ExtraExecData* getExtraExecData(Proto* proto, void* execdata) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - int size = proto->sizecode * sizeof(uint32_t); - - size = alignTo(size, 16); - - return reinterpret_cast(reinterpret_cast(execdata) + size); -} - -static OldNativeProto createOldNativeProto(Proto* proto, const IrBuilder& ir) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - - int execDataSize = calculateExecDataSize(proto); - CODEGEN_ASSERT(execDataSize % 4 == 0); - - uint32_t* execData = new uint32_t[execDataSize / 4]; - uint32_t instTarget = ir.function.entryLocation; - - for (int i = 0; i < proto->sizecode; i++) + switch (result) { - CODEGEN_ASSERT(ir.function.bcMapping[i].asmLocation >= instTarget); - - execData[i] = ir.function.bcMapping[i].asmLocation - instTarget; + case CodeGenCompilationResult::Success: + return "Success"; + case CodeGenCompilationResult::NothingToCompile: + return "NothingToCompile"; + case CodeGenCompilationResult::NotNativeModule: + return "NotNativeModule"; + case CodeGenCompilationResult::CodeGenNotInitialized: + return "CodeGenNotInitialized"; + case CodeGenCompilationResult::CodeGenOverflowInstructionLimit: + return "CodeGenOverflowInstructionLimit"; + case CodeGenCompilationResult::CodeGenOverflowBlockLimit: + return "CodeGenOverflowBlockLimit"; + case CodeGenCompilationResult::CodeGenOverflowBlockInstructionLimit: + return "CodeGenOverflowBlockInstructionLimit"; + case CodeGenCompilationResult::CodeGenAssemblerFinalizationFailure: + return "CodeGenAssemblerFinalizationFailure"; + case CodeGenCompilationResult::CodeGenLoweringFailure: + return "CodeGenLoweringFailure"; + case CodeGenCompilationResult::AllocationFailed: + return "AllocationFailed"; + case CodeGenCompilationResult::Count: + return "Count"; } - // Set first instruction offset to 0 so that entering this function still executes any generated entry code. - execData[0] = 0; - - ExtraExecData* extra = getExtraExecData(proto, execData); - memset(extra, 0, sizeof(ExtraExecData)); - - extra->execDataSize = execDataSize; - - // entry target will be relocated when assembly is finalized - return {proto, execData, instTarget}; -} - -static void destroyExecData(void* execdata) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - - delete[] static_cast(execdata); -} - -static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - CODEGEN_ASSERT(p->source); - - const char* source = getstr(p->source); - source = (source[0] == '=' || source[0] == '@') ? source + 1 : "[string]"; - - char name[256]; - snprintf(name, sizeof(name), " %s:%d %s", source, p->linedefined, p->debugname ? getstr(p->debugname) : ""); - - if (gPerfLogFn) - gPerfLogFn(gPerfLogContext, addr, size, name); -} - -template -static std::optional createNativeFunction( - AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, uint32_t& totalIrInstCount, CodeGenCompilationResult& result) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - - IrBuilder ir; - ir.buildFunctionIr(proto); - - unsigned instCount = unsigned(ir.function.instructions.size()); - - if (totalIrInstCount + instCount >= unsigned(FInt::CodegenHeuristicsInstructionLimit.value)) - { - result = CodeGenCompilationResult::CodeGenOverflowInstructionLimit; - return std::nullopt; - } - totalIrInstCount += instCount; - - if (!lowerFunction(ir, build, helpers, proto, {}, /* stats */ nullptr, result)) - return std::nullopt; - - return createOldNativeProto(proto, ir); -} - -static NativeState* getNativeState(lua_State* L) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - - return static_cast(L->global->ecb.context); -} - -static void onCloseState(lua_State* L) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - - delete getNativeState(L); - L->global->ecb = lua_ExecutionCallbacks(); -} - -static void onDestroyFunction(lua_State* L, Proto* proto) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - - destroyExecData(proto->execdata); - proto->execdata = nullptr; - proto->exectarget = 0; - proto->codeentry = proto->code; -} - -static int onEnter(lua_State* L, Proto* proto) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - - NativeState* data = getNativeState(L); - - CODEGEN_ASSERT(proto->execdata); - CODEGEN_ASSERT(L->ci->savedpc >= proto->code && L->ci->savedpc < proto->code + proto->sizecode); - - uintptr_t target = proto->exectarget + static_cast(proto->execdata)[L->ci->savedpc - proto->code]; - - // Returns 1 to finish the function in the VM - return GateFn(data->context.gateEntry)(L, proto, target, &data->context); -} - -// used to disable native execution, unconditionally -static int onEnterDisabled(lua_State* L, Proto* proto) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - - return 1; + CODEGEN_ASSERT(false); + return ""; } void onDisable(lua_State* L, Proto* proto) @@ -279,18 +134,7 @@ void onDisable(lua_State* L, Proto* proto) }); } -static size_t getMemorySize(lua_State* L, Proto* proto) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - ExtraExecData* extra = getExtraExecData(proto, proto->execdata); - - // While execDataSize is exactly the size of the allocation we made and hold for 'execdata' field, the code size is approximate - // This is because code+data page is shared and owned by all Proto from a single module and each one can keep the whole region alive - // So individual Proto being freed by GC will not reflect memory use by native code correctly - return extra->execDataSize + extra->codeSize; -} - -#if defined(__aarch64__) +#if defined(CODEGEN_TARGET_A64) unsigned int getCpuFeaturesA64() { unsigned int result = 0; @@ -326,7 +170,7 @@ bool isSupported() return false; #endif -#if defined(__x86_64__) || defined(_M_X64) +#if defined(CODEGEN_TARGET_X64) int cpuinfo[4] = {}; #ifdef _MSC_VER __cpuid(cpuinfo, 1); @@ -341,273 +185,12 @@ bool isSupported() return false; return true; -#elif defined(__aarch64__) +#elif defined(CODEGEN_TARGET_A64) return true; #else return false; #endif } -static void create_OLD(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext) -{ - CODEGEN_ASSERT(!FFlag::LuauCodegenContext); - CODEGEN_ASSERT(isSupported()); - - std::unique_ptr data = std::make_unique(allocationCallback, allocationCallbackContext); - -#if defined(_WIN32) - data->unwindBuilder = std::make_unique(); -#else - data->unwindBuilder = std::make_unique(); -#endif - - data->codeAllocator.context = data->unwindBuilder.get(); - data->codeAllocator.createBlockUnwindInfo = createBlockUnwindInfo; - data->codeAllocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; - - initFunctions(*data); - -#if defined(__x86_64__) || defined(_M_X64) - if (!X64::initHeaderFunctions(*data)) - return; -#elif defined(__aarch64__) - if (!A64::initHeaderFunctions(*data)) - return; -#endif - - if (gPerfLogFn) - gPerfLogFn(gPerfLogContext, uintptr_t(data->context.gateEntry), 4096, ""); - - lua_ExecutionCallbacks* ecb = &L->global->ecb; - - ecb->context = data.release(); - ecb->close = onCloseState; - ecb->destroy = onDestroyFunction; - ecb->enter = onEnter; - ecb->disable = onDisable; - ecb->getmemorysize = getMemorySize; -} - -void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext) -{ - if (FFlag::LuauCodegenContext) - { - create_NEW(L, allocationCallback, allocationCallbackContext); - } - else - { - create_OLD(L, allocationCallback, allocationCallbackContext); - } -} - -void create(lua_State* L) -{ - if (FFlag::LuauCodegenContext) - { - create_NEW(L); - } - else - { - create(L, nullptr, nullptr); - } -} - -void create(lua_State* L, SharedCodeGenContext* codeGenContext) -{ - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - - create_NEW(L, codeGenContext); -} - -[[nodiscard]] bool isNativeExecutionEnabled(lua_State* L) -{ - if (FFlag::LuauCodegenContext) - { - return isNativeExecutionEnabled_NEW(L); - } - else - { - return getNativeState(L) ? (L->global->ecb.enter == onEnter) : false; - } -} - -void setNativeExecutionEnabled(lua_State* L, bool enabled) -{ - if (FFlag::LuauCodegenContext) - { - setNativeExecutionEnabled_NEW(L, enabled); - } - else - { - if (getNativeState(L)) - L->global->ecb.enter = enabled ? onEnter : onEnterDisabled; - } -} - -static CompilationResult compile_OLD(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) -{ - CompilationResult compilationResult; - - CODEGEN_ASSERT(lua_isLfunction(L, idx)); - const TValue* func = luaA_toobject(L, idx); - - Proto* root = clvalue(func)->l.p; - - if ((flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) - { - compilationResult.result = CodeGenCompilationResult::NotNativeModule; - return compilationResult; - } - - // If initialization has failed, do not compile any functions - NativeState* data = getNativeState(L); - if (!data) - { - compilationResult.result = CodeGenCompilationResult::CodeGenNotInitialized; - return compilationResult; - } - - std::vector protos; - gatherFunctions(protos, root, flags); - - // Skip protos that have been compiled during previous invocations of CodeGen::compile - protos.erase(std::remove_if(protos.begin(), protos.end(), - [](Proto* p) { - return p == nullptr || p->execdata != nullptr; - }), - protos.end()); - - if (protos.empty()) - { - compilationResult.result = CodeGenCompilationResult::NothingToCompile; - return compilationResult; - } - - if (stats != nullptr) - stats->functionsTotal = uint32_t(protos.size()); - -#if defined(__aarch64__) - static unsigned int cpuFeatures = getCpuFeaturesA64(); - A64::AssemblyBuilderA64 build(/* logText= */ false, cpuFeatures); -#else - X64::AssemblyBuilderX64 build(/* logText= */ false); -#endif - - ModuleHelpers helpers; -#if defined(__aarch64__) - A64::assembleHelpers(build, helpers); -#else - X64::assembleHelpers(build, helpers); -#endif - - std::vector results; - results.reserve(protos.size()); - - uint32_t totalIrInstCount = 0; - - for (Proto* p : protos) - { - CodeGenCompilationResult protoResult = CodeGenCompilationResult::Success; - - if (std::optional np = createNativeFunction(build, helpers, p, totalIrInstCount, protoResult)) - results.push_back(*np); - else - compilationResult.protoFailures.push_back({protoResult, p->debugname ? getstr(p->debugname) : "", p->linedefined}); - } - - // Very large modules might result in overflowing a jump offset; in this case we currently abandon the entire module - if (!build.finalize()) - { - for (OldNativeProto result : results) - destroyExecData(result.execdata); - - compilationResult.result = CodeGenCompilationResult::CodeGenAssemblerFinalizationFailure; - return compilationResult; - } - - // If no functions were assembled, we don't need to allocate/copy executable pages for helpers - if (results.empty()) - return compilationResult; - - uint8_t* nativeData = nullptr; - size_t sizeNativeData = 0; - uint8_t* codeStart = nullptr; - if (!data->codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast(build.code.data()), - int(build.code.size() * sizeof(build.code[0])), nativeData, sizeNativeData, codeStart)) - { - for (OldNativeProto result : results) - destroyExecData(result.execdata); - - compilationResult.result = CodeGenCompilationResult::AllocationFailed; - return compilationResult; - } - - if (gPerfLogFn && results.size() > 0) - gPerfLogFn(gPerfLogContext, uintptr_t(codeStart), uint32_t(results[0].exectarget), ""); - - for (size_t i = 0; i < results.size(); ++i) - { - uint32_t begin = uint32_t(results[i].exectarget); - uint32_t end = i + 1 < results.size() ? uint32_t(results[i + 1].exectarget) : uint32_t(build.code.size() * sizeof(build.code[0])); - CODEGEN_ASSERT(begin < end); - - if (gPerfLogFn) - logPerfFunction(results[i].p, uintptr_t(codeStart) + begin, end - begin); - - ExtraExecData* extra = getExtraExecData(results[i].p, results[i].execdata); - extra->codeSize = end - begin; - } - - for (const OldNativeProto& result : results) - { - // the memory is now managed by VM and will be freed via onDestroyFunction - result.p->execdata = result.execdata; - result.p->exectarget = uintptr_t(codeStart) + result.exectarget; - result.p->codeentry = &kCodeEntryInsn; - } - - if (stats != nullptr) - { - for (const OldNativeProto& result : results) - { - stats->bytecodeSizeBytes += result.p->sizecode * sizeof(Instruction); - - // Account for the native -> bytecode instruction offsets mapping: - stats->nativeMetadataSizeBytes += result.p->sizecode * sizeof(uint32_t); - } - - stats->functionsCompiled += uint32_t(results.size()); - stats->nativeCodeSizeBytes += build.code.size(); - stats->nativeDataSizeBytes += build.data.size(); - } - - return compilationResult; -} - -CompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) -{ - if (FFlag::LuauCodegenContext) - { - return compile_NEW(L, idx, flags, stats); - } - else - { - return compile_OLD(L, idx, flags, stats); - } -} - -CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats) -{ - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - - return compile_NEW(moduleId, L, idx, flags, stats); -} - -void setPerfLog(void* context, PerfLogFn logFn) -{ - gPerfLogContext = context; - gPerfLogFn = logFn; -} - } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index a18278c9..06f64955 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -253,44 +253,11 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde // Our entry function is special, it spans the whole remaining code area unwind.startFunction(); unwind.prologueA64(prologueSize, kStackSize, {x29, x30, x19, x20, x21, x22, x23, x24, x25}); - unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFuncton); + unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFunction); return locations; } -bool initHeaderFunctions(NativeState& data) -{ - AssemblyBuilderA64 build(/* logText= */ false); - UnwindBuilder& unwind = *data.unwindBuilder.get(); - - unwind.startInfo(UnwindBuilder::A64); - - EntryLocations entryLocations = buildEntryFunction(build, unwind); - - build.finalize(); - - unwind.finishInfo(); - - CODEGEN_ASSERT(build.data.empty()); - - uint8_t* codeStart = nullptr; - if (!data.codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast(build.code.data()), - int(build.code.size() * sizeof(build.code[0])), data.gateData, data.gateDataSize, codeStart)) - { - CODEGEN_ASSERT(!"Failed to create entry function"); - return false; - } - - // Set the offset at the begining so that functions in new blocks will not overlay the locations - // specified by the unwind information of the entry function - unwind.setBeginOffset(build.getLabelOffset(entryLocations.prologueEnd)); - - data.context.gateEntry = codeStart + build.getLabelOffset(entryLocations.start); - data.context.gateExit = codeStart + build.getLabelOffset(entryLocations.epilogueStart); - - return true; -} - bool initHeaderFunctions(BaseCodeGenContext& codeGenContext) { AssemblyBuilderA64 build(/* logText= */ false); diff --git a/CodeGen/src/CodeGenA64.h b/CodeGen/src/CodeGenA64.h index 24fedd9a..2633f5ba 100644 --- a/CodeGen/src/CodeGenA64.h +++ b/CodeGen/src/CodeGenA64.h @@ -7,7 +7,6 @@ namespace CodeGen { class BaseCodeGenContext; -struct NativeState; struct ModuleHelpers; namespace A64 @@ -15,7 +14,6 @@ namespace A64 class AssemblyBuilderA64; -bool initHeaderFunctions(NativeState& data); bool initHeaderFunctions(BaseCodeGenContext& codeGenContext); void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers); diff --git a/CodeGen/src/CodeGenAssembly.cpp b/CodeGen/src/CodeGenAssembly.cpp index 96c73ce2..de8dcecf 100644 --- a/CodeGen/src/CodeGenAssembly.cpp +++ b/CodeGen/src/CodeGenAssembly.cpp @@ -12,13 +12,50 @@ #include "lapi.h" -LUAU_FASTFLAG(LuauCodegenTypeInfo) +LUAU_FASTFLAG(LuauLoadUserdataInfo) +LUAU_FASTFLAG(LuauNativeAttribute) namespace Luau { namespace CodeGen { +static const LocVar* tryFindLocal(const Proto* proto, int reg, int pcpos) +{ + for (int i = 0; i < proto->sizelocvars; i++) + { + const LocVar& local = proto->locvars[i]; + + if (reg == local.reg && pcpos >= local.startpc && pcpos < local.endpc) + return &local; + } + + return nullptr; +} + +const char* tryFindLocalName(const Proto* proto, int reg, int pcpos) +{ + const LocVar* var = tryFindLocal(proto, reg, pcpos); + + if (var && var->varname) + return getstr(var->varname); + + return nullptr; +} + +const char* tryFindUpvalueName(const Proto* proto, int upval) +{ + if (proto->upvalues) + { + CODEGEN_ASSERT(upval < proto->sizeupvalues); + + if (proto->upvalues[upval]) + return getstr(proto->upvalues[upval]); + } + + return nullptr; +} + template static void logFunctionHeader(AssemblyBuilder& build, Proto* proto) { @@ -29,10 +66,8 @@ static void logFunctionHeader(AssemblyBuilder& build, Proto* proto) for (int i = 0; i < proto->numparams; i++) { - LocVar* var = proto->locvars ? &proto->locvars[proto->sizelocvars - proto->numparams + i] : nullptr; - - if (var && var->varname) - build.logAppend("%s%s", i == 0 ? "" : ", ", getstr(var->varname)); + if (const char* name = tryFindLocalName(proto, i, 0)) + build.logAppend("%s%s", i == 0 ? "" : ", ", name); else build.logAppend("%s$arg%d", i == 0 ? "" : ", ", i); } @@ -49,9 +84,9 @@ static void logFunctionHeader(AssemblyBuilder& build, Proto* proto) } template -static void logFunctionTypes(AssemblyBuilder& build, const IrFunction& function) +static void logFunctionTypes_DEPRECATED(AssemblyBuilder& build, const IrFunction& function) { - CODEGEN_ASSERT(FFlag::LuauCodegenTypeInfo); + CODEGEN_ASSERT(!FFlag::LuauLoadUserdataInfo); const BytecodeTypeInfo& typeInfo = function.bcTypeInfo; @@ -60,7 +95,12 @@ static void logFunctionTypes(AssemblyBuilder& build, const IrFunction& function) uint8_t ty = typeInfo.argumentTypes[i]; if (ty != LBC_TYPE_ANY) - build.logAppend("; R%d: %s [argument]\n", int(i), getBytecodeTypeName(ty)); + { + if (const char* name = tryFindLocalName(function.proto, int(i), 0)) + build.logAppend("; R%d: %s [argument '%s']\n", int(i), getBytecodeTypeName_DEPRECATED(ty), name); + else + build.logAppend("; R%d: %s [argument]\n", int(i), getBytecodeTypeName_DEPRECATED(ty)); + } } for (size_t i = 0; i < typeInfo.upvalueTypes.size(); i++) @@ -68,12 +108,73 @@ static void logFunctionTypes(AssemblyBuilder& build, const IrFunction& function) uint8_t ty = typeInfo.upvalueTypes[i]; if (ty != LBC_TYPE_ANY) - build.logAppend("; U%d: %s\n", int(i), getBytecodeTypeName(ty)); + { + if (const char* name = tryFindUpvalueName(function.proto, int(i))) + build.logAppend("; U%d: %s ['%s']\n", int(i), getBytecodeTypeName_DEPRECATED(ty), name); + else + build.logAppend("; U%d: %s\n", int(i), getBytecodeTypeName_DEPRECATED(ty)); + } } for (const BytecodeRegTypeInfo& el : typeInfo.regTypes) { - build.logAppend("; R%d: %s from %d to %d\n", el.reg, getBytecodeTypeName(el.type), el.startpc, el.endpc); + // Using last active position as the PC because 'startpc' for type info is before local is initialized + if (const char* name = tryFindLocalName(function.proto, el.reg, el.endpc - 1)) + build.logAppend("; R%d: %s from %d to %d [local '%s']\n", el.reg, getBytecodeTypeName_DEPRECATED(el.type), el.startpc, el.endpc, name); + else + build.logAppend("; R%d: %s from %d to %d\n", el.reg, getBytecodeTypeName_DEPRECATED(el.type), el.startpc, el.endpc); + } +} + +template +static void logFunctionTypes(AssemblyBuilder& build, const IrFunction& function, const char* const* userdataTypes) +{ + CODEGEN_ASSERT(FFlag::LuauLoadUserdataInfo); + + const BytecodeTypeInfo& typeInfo = function.bcTypeInfo; + + for (size_t i = 0; i < typeInfo.argumentTypes.size(); i++) + { + uint8_t ty = typeInfo.argumentTypes[i]; + + const char* type = getBytecodeTypeName(ty, userdataTypes); + const char* optional = (ty & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : ""; + + if (ty != LBC_TYPE_ANY) + { + if (const char* name = tryFindLocalName(function.proto, int(i), 0)) + build.logAppend("; R%d: %s%s [argument '%s']\n", int(i), type, optional, name); + else + build.logAppend("; R%d: %s%s [argument]\n", int(i), type, optional); + } + } + + for (size_t i = 0; i < typeInfo.upvalueTypes.size(); i++) + { + uint8_t ty = typeInfo.upvalueTypes[i]; + + const char* type = getBytecodeTypeName(ty, userdataTypes); + const char* optional = (ty & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : ""; + + if (ty != LBC_TYPE_ANY) + { + if (const char* name = tryFindUpvalueName(function.proto, int(i))) + build.logAppend("; U%d: %s%s ['%s']\n", int(i), type, optional, name); + else + build.logAppend("; U%d: %s%s\n", int(i), type, optional); + } + } + + for (const BytecodeRegTypeInfo& el : typeInfo.regTypes) + { + const char* type = getBytecodeTypeName(el.type, userdataTypes); + const char* optional = (el.type & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : ""; + + // Using last active position as the PC because 'startpc' for type info is before local is initialized + if (const char* name = tryFindLocalName(function.proto, el.reg, el.endpc - 1)) + build.logAppend("; R%d: %s%s from %d to %d [local '%s']\n", el.reg, type, optional, el.startpc, el.endpc, name); + else + build.logAppend("; R%d: %s%s from %d to %d\n", el.reg, type, optional, el.startpc, el.endpc); } } @@ -93,11 +194,14 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A { Proto* root = clvalue(func)->l.p; - if ((options.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) + if ((options.compilationOptions.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) return std::string(); std::vector protos; - gatherFunctions(protos, root, options.flags); + if (FFlag::LuauNativeAttribute) + gatherFunctions(protos, root, options.compilationOptions.flags, root->flags & LPF_NATIVE_FUNCTION); + else + gatherFunctions_DEPRECATED(protos, root, options.compilationOptions.flags); protos.erase(std::remove_if(protos.begin(), protos.end(), [](Proto* p) { @@ -125,7 +229,7 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A for (Proto* p : protos) { - IrBuilder ir; + IrBuilder ir(options.compilationOptions.hooks); ir.buildFunctionIr(p); unsigned asmSize = build.getCodeSize(); unsigned asmCount = build.getInstructionCount(); @@ -133,8 +237,13 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A if (options.includeAssembly || options.includeIr) logFunctionHeader(build, p); - if (FFlag::LuauCodegenTypeInfo && options.includeIrTypes) - logFunctionTypes(build, ir.function); + if (options.includeIrTypes) + { + if (FFlag::LuauLoadUserdataInfo) + logFunctionTypes(build, ir.function, options.compilationOptions.userdataTypes); + else + logFunctionTypes_DEPRECATED(build, ir.function); + } CodeGenCompilationResult result = CodeGenCompilationResult::Success; @@ -189,7 +298,7 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A return build.text; } -#if defined(__aarch64__) +#if defined(CODEGEN_TARGET_A64) unsigned int getCpuFeaturesA64(); #endif @@ -202,7 +311,7 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options, Lowering { case AssemblyOptions::Host: { -#if defined(__aarch64__) +#if defined(CODEGEN_TARGET_A64) static unsigned int cpuFeatures = getCpuFeaturesA64(); A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, cpuFeatures); #else diff --git a/CodeGen/src/CodeGenContext.cpp b/CodeGen/src/CodeGenContext.cpp index d9e3c4b3..a31a08ba 100644 --- a/CodeGen/src/CodeGenContext.cpp +++ b/CodeGen/src/CodeGenContext.cpp @@ -12,12 +12,9 @@ #include "lapi.h" - -LUAU_FASTFLAGVARIABLE(LuauCodegenContext, false) -LUAU_FASTFLAGVARIABLE(LuauCodegenCheckNullContext, false) - -LUAU_FASTINT(LuauCodeGenBlockSize) -LUAU_FASTINT(LuauCodeGenMaxTotalSize) +LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024) +LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024) +LUAU_FASTFLAG(LuauNativeAttribute) namespace Luau { @@ -27,14 +24,19 @@ namespace CodeGen static const Instruction kCodeEntryInsn = LOP_NATIVECALL; // From CodeGen.cpp -extern void* gPerfLogContext; -extern PerfLogFn gPerfLogFn; +static void* gPerfLogContext = nullptr; +static PerfLogFn gPerfLogFn = nullptr; unsigned int getCpuFeaturesA64(); +void setPerfLog(void* context, PerfLogFn logFn) +{ + gPerfLogContext = context; + gPerfLogFn = logFn; +} + static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); CODEGEN_ASSERT(p->source); const char* source = getstr(p->source); @@ -50,8 +52,6 @@ static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size) static void logPerfFunctions( const std::vector& moduleProtos, const uint8_t* nativeModuleBaseAddress, const std::vector& nativeProtos) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - if (gPerfLogFn == nullptr) return; @@ -83,8 +83,6 @@ static void logPerfFunctions( template [[nodiscard]] static uint32_t bindNativeProtos(const std::vector& moduleProtos, NativeProtosVector& nativeProtos) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - uint32_t protosBound = 0; auto protoIt = moduleProtos.begin(); @@ -125,7 +123,6 @@ template BaseCodeGenContext::BaseCodeGenContext(size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) : codeAllocator{blockSize, maxTotalSize, allocationCallback, allocationCallbackContext} { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); CODEGEN_ASSERT(isSupported()); #if defined(_WIN32) @@ -143,12 +140,10 @@ BaseCodeGenContext::BaseCodeGenContext(size_t blockSize, size_t maxTotalSize, Al [[nodiscard]] bool BaseCodeGenContext::initHeaderFunctions() { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - -#if defined(__x86_64__) || defined(_M_X64) +#if defined(CODEGEN_TARGET_X64) if (!X64::initHeaderFunctions(*this)) return false; -#elif defined(__aarch64__) +#elif defined(CODEGEN_TARGET_A64) if (!A64::initHeaderFunctions(*this)) return false; #endif @@ -164,13 +159,10 @@ StandaloneCodeGenContext::StandaloneCodeGenContext( size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) : BaseCodeGenContext{blockSize, maxTotalSize, allocationCallback, allocationCallbackContext} { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); } [[nodiscard]] std::optional StandaloneCodeGenContext::tryBindExistingModule(const ModuleId&, const std::vector&) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - // The StandaloneCodeGenContext does not support sharing of native code return {}; } @@ -178,8 +170,6 @@ StandaloneCodeGenContext::StandaloneCodeGenContext( [[nodiscard]] ModuleBindResult StandaloneCodeGenContext::bindModule(const std::optional&, const std::vector& moduleProtos, std::vector nativeProtos, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - uint8_t* nativeData = nullptr; size_t sizeNativeData = 0; uint8_t* codeStart = nullptr; @@ -205,8 +195,6 @@ StandaloneCodeGenContext::StandaloneCodeGenContext( void StandaloneCodeGenContext::onCloseState() noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - // The StandaloneCodeGenContext is owned by the one VM that owns it, so when // that VM is destroyed, we destroy *this as well: delete this; @@ -214,8 +202,6 @@ void StandaloneCodeGenContext::onCloseState() noexcept void StandaloneCodeGenContext::onDestroyFunction(void* execdata) noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - destroyNativeProtoExecData(static_cast(execdata)); } @@ -225,14 +211,11 @@ SharedCodeGenContext::SharedCodeGenContext( : BaseCodeGenContext{blockSize, maxTotalSize, allocationCallback, allocationCallbackContext} , sharedAllocator{&codeAllocator} { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); } [[nodiscard]] std::optional SharedCodeGenContext::tryBindExistingModule( const ModuleId& moduleId, const std::vector& moduleProtos) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - NativeModuleRef nativeModule = sharedAllocator.tryGetNativeModule(moduleId); if (nativeModule.empty()) { @@ -249,8 +232,6 @@ SharedCodeGenContext::SharedCodeGenContext( [[nodiscard]] ModuleBindResult SharedCodeGenContext::bindModule(const std::optional& moduleId, const std::vector& moduleProtos, std::vector nativeProtos, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - const std::pair insertionResult = [&]() -> std::pair { if (moduleId.has_value()) { @@ -279,8 +260,6 @@ SharedCodeGenContext::SharedCodeGenContext( void SharedCodeGenContext::onCloseState() noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - // The lifetime of the SharedCodeGenContext is managed separately from the // VMs that use it. When a VM is destroyed, we don't need to do anything // here. @@ -288,23 +267,17 @@ void SharedCodeGenContext::onCloseState() noexcept void SharedCodeGenContext::onDestroyFunction(void* execdata) noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - getNativeProtoExecDataHeader(static_cast(execdata)).nativeModule->release(); } [[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext() { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - return createSharedCodeGenContext(size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), nullptr, nullptr); } [[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext(AllocationCallback* allocationCallback, void* allocationCallbackContext) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - return createSharedCodeGenContext( size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext); } @@ -312,8 +285,6 @@ void SharedCodeGenContext::onDestroyFunction(void* execdata) noexcept [[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext( size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - UniqueSharedCodeGenContext codeGenContext{new SharedCodeGenContext{blockSize, maxTotalSize, nullptr, nullptr}}; if (!codeGenContext->initHeaderFunctions()) @@ -324,38 +295,28 @@ void SharedCodeGenContext::onDestroyFunction(void* execdata) noexcept void destroySharedCodeGenContext(const SharedCodeGenContext* codeGenContext) noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - delete codeGenContext; } void SharedCodeGenContextDeleter::operator()(const SharedCodeGenContext* codeGenContext) const noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - destroySharedCodeGenContext(codeGenContext); } [[nodiscard]] static BaseCodeGenContext* getCodeGenContext(lua_State* L) noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - return static_cast(L->global->ecb.context); } static void onCloseState(lua_State* L) noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - getCodeGenContext(L)->onCloseState(); L->global->ecb = lua_ExecutionCallbacks{}; } static void onDestroyFunction(lua_State* L, Proto* proto) noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - getCodeGenContext(L)->onDestroyFunction(proto->execdata); proto->execdata = nullptr; proto->exectarget = 0; @@ -364,8 +325,6 @@ static void onDestroyFunction(lua_State* L, Proto* proto) noexcept static int onEnter(lua_State* L, Proto* proto) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - BaseCodeGenContext* codeGenContext = getCodeGenContext(L); CODEGEN_ASSERT(proto->execdata); @@ -379,8 +338,6 @@ static int onEnter(lua_State* L, Proto* proto) static int onEnterDisabled(lua_State* L, Proto* proto) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - return 1; } @@ -389,8 +346,6 @@ void onDisable(lua_State* L, Proto* proto); static size_t getMemorySize(lua_State* L, Proto* proto) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - const NativeProtoExecDataHeader& execDataHeader = getNativeProtoExecDataHeader(static_cast(proto->execdata)); const size_t execDataSize = sizeof(NativeProtoExecDataHeader) + execDataHeader.bytecodeInstructionCount * sizeof(Instruction); @@ -403,8 +358,7 @@ static size_t getMemorySize(lua_State* L, Proto* proto) static void initializeExecutionCallbacks(lua_State* L, BaseCodeGenContext* codeGenContext) noexcept { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - CODEGEN_ASSERT(!FFlag::LuauCodegenCheckNullContext || codeGenContext != nullptr); + CODEGEN_ASSERT(codeGenContext != nullptr); lua_ExecutionCallbacks* ecb = &L->global->ecb; @@ -416,24 +370,18 @@ static void initializeExecutionCallbacks(lua_State* L, BaseCodeGenContext* codeG ecb->getmemorysize = getMemorySize; } -void create_NEW(lua_State* L) +void create(lua_State* L) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - - return create_NEW(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), nullptr, nullptr); + return create(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), nullptr, nullptr); } -void create_NEW(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext) +void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - - return create_NEW(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext); + return create(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext); } -void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) +void create(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - std::unique_ptr codeGenContext = std::make_unique(blockSize, maxTotalSize, allocationCallback, allocationCallbackContext); @@ -443,17 +391,13 @@ void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationC initializeExecutionCallbacks(L, codeGenContext.release()); } -void create_NEW(lua_State* L, SharedCodeGenContext* codeGenContext) +void create(lua_State* L, SharedCodeGenContext* codeGenContext) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - initializeExecutionCallbacks(L, codeGenContext); } [[nodiscard]] static NativeProtoExecDataPtr createNativeProtoExecData(Proto* proto, const IrBuilder& ir) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - NativeProtoExecDataPtr nativeExecData = createNativeProtoExecData(proto->sizecode); uint32_t instTarget = ir.function.entryLocation; @@ -478,12 +422,10 @@ void create_NEW(lua_State* L, SharedCodeGenContext* codeGenContext) } template -[[nodiscard]] static NativeProtoExecDataPtr createNativeFunction( - AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, uint32_t& totalIrInstCount, CodeGenCompilationResult& result) +[[nodiscard]] static NativeProtoExecDataPtr createNativeFunction(AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, + uint32_t& totalIrInstCount, const HostIrHooks& hooks, CodeGenCompilationResult& result) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - - IrBuilder ir; + IrBuilder ir(hooks); ir.buildFunctionIr(proto); unsigned instCount = unsigned(ir.function.instructions.size()); @@ -505,15 +447,14 @@ template } [[nodiscard]] static CompilationResult compileInternal( - const std::optional& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats) + const std::optional& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); CODEGEN_ASSERT(lua_isLfunction(L, idx)); const TValue* func = luaA_toobject(L, idx); Proto* root = clvalue(func)->l.p; - if ((flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) + if ((options.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0 && (root->flags & LPF_NATIVE_FUNCTION) == 0) return CompilationResult{CodeGenCompilationResult::NotNativeModule}; BaseCodeGenContext* codeGenContext = getCodeGenContext(L); @@ -521,7 +462,10 @@ template return CompilationResult{CodeGenCompilationResult::CodeGenNotInitialized}; std::vector protos; - gatherFunctions(protos, root, flags); + if (FFlag::LuauNativeAttribute) + gatherFunctions(protos, root, options.flags, root->flags & LPF_NATIVE_FUNCTION); + else + gatherFunctions_DEPRECATED(protos, root, options.flags); // Skip protos that have been compiled during previous invocations of CodeGen::compile protos.erase(std::remove_if(protos.begin(), protos.end(), @@ -547,7 +491,7 @@ template } } -#if defined(__aarch64__) +#if defined(CODEGEN_TARGET_A64) static unsigned int cpuFeatures = getCpuFeaturesA64(); A64::AssemblyBuilderA64 build(/* logText= */ false, cpuFeatures); #else @@ -555,7 +499,7 @@ template #endif ModuleHelpers helpers; -#if defined(__aarch64__) +#if defined(CODEGEN_TARGET_A64) A64::assembleHelpers(build, helpers); #else X64::assembleHelpers(build, helpers); @@ -572,7 +516,7 @@ template { CodeGenCompilationResult protoResult = CodeGenCompilationResult::Success; - NativeProtoExecDataPtr nativeExecData = createNativeFunction(build, helpers, protos[i], totalIrInstCount, protoResult); + NativeProtoExecDataPtr nativeExecData = createNativeFunction(build, helpers, protos[i], totalIrInstCount, options.hooks, protoResult); if (nativeExecData != nullptr) { nativeProtos.push_back(std::move(nativeExecData)); @@ -639,34 +583,60 @@ template return compilationResult; } -CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats) +CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - - return compileInternal(moduleId, L, idx, flags, stats); + return compileInternal(moduleId, L, idx, options, stats); } -CompilationResult compile_NEW(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) +CompilationResult compile(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - - return compileInternal({}, L, idx, flags, stats); + return compileInternal({}, L, idx, options, stats); } -[[nodiscard]] bool isNativeExecutionEnabled_NEW(lua_State* L) +CompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); + return compileInternal({}, L, idx, CompilationOptions{flags}, stats); +} +CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats) +{ + return compileInternal(moduleId, L, idx, CompilationOptions{flags}, stats); +} + +[[nodiscard]] bool isNativeExecutionEnabled(lua_State* L) +{ return getCodeGenContext(L) != nullptr && L->global->ecb.enter == onEnter; } -void setNativeExecutionEnabled_NEW(lua_State* L, bool enabled) +void setNativeExecutionEnabled(lua_State* L, bool enabled) { - CODEGEN_ASSERT(FFlag::LuauCodegenContext); - if (getCodeGenContext(L) != nullptr) L->global->ecb.enter = enabled ? onEnter : onEnterDisabled; } +static uint8_t userdataRemapperWrap(lua_State* L, const char* str, size_t len) +{ + if (BaseCodeGenContext* codegenCtx = getCodeGenContext(L)) + { + uint8_t index = codegenCtx->userdataRemapper(codegenCtx->userdataRemappingContext, str, len); + + if (index < (LBC_TYPE_TAGGED_USERDATA_END - LBC_TYPE_TAGGED_USERDATA_BASE)) + return LBC_TYPE_TAGGED_USERDATA_BASE + index; + } + + return LBC_TYPE_USERDATA; +} + +void setUserdataRemapper(lua_State* L, void* context, UserdataRemapperCallback cb) +{ + if (BaseCodeGenContext* codegenCtx = getCodeGenContext(L)) + { + codegenCtx->userdataRemappingContext = context; + codegenCtx->userdataRemapper = cb; + + L->global->ecb.gettypemapping = cb ? userdataRemapperWrap : nullptr; + } +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGenContext.h b/CodeGen/src/CodeGenContext.h index ca338da5..43099a9b 100644 --- a/CodeGen/src/CodeGenContext.h +++ b/CodeGen/src/CodeGenContext.h @@ -50,6 +50,9 @@ public: uint8_t* gateData = nullptr; size_t gateDataSize = 0; + void* userdataRemappingContext = nullptr; + UserdataRemapperCallback* userdataRemapper = nullptr; + NativeContext context; }; @@ -88,33 +91,5 @@ private: SharedCodeAllocator sharedAllocator; }; - -// The following will become the public interface, and can be moved into -// CodeGen.h after the shared allocator work is complete. When the old -// implementation is removed, the _NEW suffix can be dropped from these -// functions. - -// Initializes native code-gen on the provided Luau VM, using a VM-specific -// code-gen context and either the default allocator parameters or custom -// allocator parameters. -void create_NEW(lua_State* L); -void create_NEW(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext); -void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext); - -// Initializes native code-gen on the provided Luau VM, using the provided -// SharedCodeGenContext. Note that after this function is called, the -// SharedCodeGenContext must not be destroyed until after the Luau VM L is -// destroyed via lua_close. -void create_NEW(lua_State* L, SharedCodeGenContext* codeGenContext); - -CompilationResult compile_NEW(lua_State* L, int idx, unsigned int flags, CompilationStats* stats); -CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats); - -// Returns true if native execution is currently enabled for this VM -[[nodiscard]] bool isNativeExecutionEnabled_NEW(lua_State* L); - -// Enables or disables native excution for this VM -void setNativeExecutionEnabled_NEW(lua_State* L, bool enabled); - } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGenLower.h b/CodeGen/src/CodeGenLower.h index efd1034d..e7701361 100644 --- a/CodeGen/src/CodeGenLower.h +++ b/CodeGen/src/CodeGenLower.h @@ -27,14 +27,15 @@ LUAU_FASTFLAG(DebugCodegenSkipNumbering) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTINT(CodegenHeuristicsBlockLimit) LUAU_FASTINT(CodegenHeuristicsBlockInstructionLimit) -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) +LUAU_FASTFLAG(LuauLoadUserdataInfo) +LUAU_FASTFLAG(LuauNativeAttribute) namespace Luau { namespace CodeGen { -inline void gatherFunctions(std::vector& results, Proto* proto, unsigned int flags) +inline void gatherFunctions_DEPRECATED(std::vector& results, Proto* proto, unsigned int flags) { if (results.size() <= size_t(proto->bytecodeid)) results.resize(proto->bytecodeid + 1); @@ -49,7 +50,36 @@ inline void gatherFunctions(std::vector& results, Proto* proto, unsigned // Recursively traverse child protos even if we aren't compiling this one for (int i = 0; i < proto->sizep; i++) - gatherFunctions(results, proto->p[i], flags); + gatherFunctions_DEPRECATED(results, proto->p[i], flags); +} + +inline void gatherFunctionsHelper( + std::vector& results, Proto* proto, const unsigned int flags, const bool hasNativeFunctions, const bool root) +{ + if (results.size() <= size_t(proto->bytecodeid)) + results.resize(proto->bytecodeid + 1); + + // Skip protos that we've already compiled in this run: this happens because at -O2, inlined functions get their protos reused + if (results[proto->bytecodeid]) + return; + + // if native module, compile cold functions if requested + // if not native module, compile function if it has native attribute and is not root + bool shouldGather = hasNativeFunctions ? (!root && (proto->flags & LPF_NATIVE_FUNCTION) != 0) + : ((proto->flags & LPF_NATIVE_COLD) == 0 || (flags & CodeGen_ColdFunctions) != 0); + + if (shouldGather) + results[proto->bytecodeid] = proto; + + // Recursively traverse child protos even if we aren't compiling this one + for (int i = 0; i < proto->sizep; i++) + gatherFunctionsHelper(results, proto->p[i], flags, hasNativeFunctions, false); +} + +inline void gatherFunctions(std::vector& results, Proto* root, const unsigned int flags, const bool hasNativeFunctions = false) +{ + LUAU_ASSERT(FFlag::LuauNativeAttribute); + gatherFunctionsHelper(results, root, flags, hasNativeFunctions, true); } inline unsigned getInstructionCount(const std::vector& instructions, IrCmd cmd) @@ -149,7 +179,11 @@ inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& if (bcTypes.result != LBC_TYPE_ANY || bcTypes.a != LBC_TYPE_ANY || bcTypes.b != LBC_TYPE_ANY || bcTypes.c != LBC_TYPE_ANY) { - toString(ctx.result, bcTypes); + if (FFlag::LuauLoadUserdataInfo) + toString(ctx.result, bcTypes, options.compilationOptions.userdataTypes); + else + toString_DEPRECATED(ctx.result, bcTypes); + build.logAppend("\n"); } } @@ -312,8 +346,7 @@ inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& } } - if (FFlag::LuauCodegenRemoveDeadStores5) - markDeadStoresInBlockChains(ir); + markDeadStoresInBlockChains(ir); } std::vector sortedBlocks = getSortedBlockOrder(ir.function); diff --git a/CodeGen/src/CodeGenUtils.cpp b/CodeGen/src/CodeGenUtils.cpp index 973829ca..ad231e76 100644 --- a/CodeGen/src/CodeGenUtils.cpp +++ b/CodeGen/src/CodeGenUtils.cpp @@ -14,6 +14,7 @@ #include "lstate.h" #include "lstring.h" #include "ltable.h" +#include "ludata.h" #include @@ -219,6 +220,20 @@ void callEpilogC(lua_State* L, int nresults, int n) L->top = (nresults == LUA_MULTRET) ? res : cip->top; } +Udata* newUserdata(lua_State* L, size_t s, int tag) +{ + Udata* u = luaU_newudata(L, s, tag); + + if (Table* h = L->global->udatamt[tag]) + { + u->metatable = h; + + luaC_objbarrier(L, u, h); + } + + return u; +} + // Extracted as-is from lvmexecute.cpp with the exception of control flow (reentry) and removed interrupts/savedpc Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults) { diff --git a/CodeGen/src/CodeGenUtils.h b/CodeGen/src/CodeGenUtils.h index 515a81f0..15d4c95d 100644 --- a/CodeGen/src/CodeGenUtils.h +++ b/CodeGen/src/CodeGenUtils.h @@ -17,6 +17,8 @@ void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc); Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults); void callEpilogC(lua_State* L, int nresults, int n); +Udata* newUserdata(lua_State* L, size_t s, int tag); + #define CALL_FALLBACK_YIELD 1 Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults); diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index 5e450c9a..b8df3774 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -181,44 +181,11 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde build.ret(); // Our entry function is special, it spans the whole remaining code area - unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFuncton); + unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFunction); return locations; } -bool initHeaderFunctions(NativeState& data) -{ - AssemblyBuilderX64 build(/* logText= */ false); - UnwindBuilder& unwind = *data.unwindBuilder.get(); - - unwind.startInfo(UnwindBuilder::X64); - - EntryLocations entryLocations = buildEntryFunction(build, unwind); - - build.finalize(); - - unwind.finishInfo(); - - CODEGEN_ASSERT(build.data.empty()); - - uint8_t* codeStart = nullptr; - if (!data.codeAllocator.allocate( - build.data.data(), int(build.data.size()), build.code.data(), int(build.code.size()), data.gateData, data.gateDataSize, codeStart)) - { - CODEGEN_ASSERT(!"Failed to create entry function"); - return false; - } - - // Set the offset at the begining so that functions in new blocks will not overlay the locations - // specified by the unwind information of the entry function - unwind.setBeginOffset(build.getLabelOffset(entryLocations.prologueEnd)); - - data.context.gateEntry = codeStart + build.getLabelOffset(entryLocations.start); - data.context.gateExit = codeStart + build.getLabelOffset(entryLocations.epilogueStart); - - return true; -} - bool initHeaderFunctions(BaseCodeGenContext& codeGenContext) { AssemblyBuilderX64 build(/* logText= */ false); diff --git a/CodeGen/src/CodeGenX64.h b/CodeGen/src/CodeGenX64.h index eb6ab81c..ce360b23 100644 --- a/CodeGen/src/CodeGenX64.h +++ b/CodeGen/src/CodeGenX64.h @@ -7,7 +7,6 @@ namespace CodeGen { class BaseCodeGenContext; -struct NativeState; struct ModuleHelpers; namespace X64 @@ -15,7 +14,6 @@ namespace X64 class AssemblyBuilderX64; -bool initHeaderFunctions(NativeState& data); bool initHeaderFunctions(BaseCodeGenContext& codeGenContext); void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers); diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index 96d22e13..15aab4b6 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -12,7 +12,7 @@ #include "lstate.h" -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) +LUAU_FASTFLAG(LuauCodegenMathSign) namespace Luau { @@ -29,17 +29,13 @@ static void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_frexp)]); build.vmovsd(luauRegValue(ra), xmm0); - - if (FFlag::LuauCodegenRemoveDeadStores5) - build.mov(luauRegTag(ra), LUA_TNUMBER); + build.mov(luauRegTag(ra), LUA_TNUMBER); if (nresults > 1) { build.vcvtsi2sd(xmm0, xmm0, dword[sTemporarySlot + 0]); build.vmovsd(luauRegValue(ra + 1), xmm0); - - if (FFlag::LuauCodegenRemoveDeadStores5) - build.mov(luauRegTag(ra + 1), LUA_TNUMBER); + build.mov(luauRegTag(ra + 1), LUA_TNUMBER); } } @@ -52,21 +48,19 @@ static void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, build.vmovsd(xmm1, qword[sTemporarySlot + 0]); build.vmovsd(luauRegValue(ra), xmm1); - - if (FFlag::LuauCodegenRemoveDeadStores5) - build.mov(luauRegTag(ra), LUA_TNUMBER); + build.mov(luauRegTag(ra), LUA_TNUMBER); if (nresults > 1) { build.vmovsd(luauRegValue(ra + 1), xmm0); - - if (FFlag::LuauCodegenRemoveDeadStores5) - build.mov(luauRegTag(ra + 1), LUA_TNUMBER); + build.mov(luauRegTag(ra + 1), LUA_TNUMBER); } } static void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg) { + CODEGEN_ASSERT(!FFlag::LuauCodegenMathSign); + ScopedRegX64 tmp0{regs, SizeX64::xmmword}; ScopedRegX64 tmp1{regs, SizeX64::xmmword}; ScopedRegX64 tmp2{regs, SizeX64::xmmword}; @@ -90,23 +84,22 @@ static void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, build.vblendvpd(tmp0.reg, tmp2.reg, build.f64x2(1, 1), tmp0.reg); build.vmovsd(luauRegValue(ra), tmp0.reg); - - if (FFlag::LuauCodegenRemoveDeadStores5) - build.mov(luauRegTag(ra), LUA_TNUMBER); + build.mov(luauRegTag(ra), LUA_TNUMBER); } -void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, OperandX64 arg2, int nparams, int nresults) +void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, int nresults) { switch (bfid) { case LBF_MATH_FREXP: - CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); + CODEGEN_ASSERT(nresults == 1 || nresults == 2); return emitBuiltinMathFrexp(regs, build, ra, arg, nresults); case LBF_MATH_MODF: - CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); + CODEGEN_ASSERT(nresults == 1 || nresults == 2); return emitBuiltinMathModf(regs, build, ra, arg, nresults); case LBF_MATH_SIGN: - CODEGEN_ASSERT(nparams == 1 && nresults == 1); + CODEGEN_ASSERT(!FFlag::LuauCodegenMathSign); + CODEGEN_ASSERT(nresults == 1); return emitBuiltinMathSign(regs, build, ra, arg); default: CODEGEN_ASSERT(!"Missing x64 lowering"); diff --git a/CodeGen/src/EmitBuiltinsX64.h b/CodeGen/src/EmitBuiltinsX64.h index cd8b5251..72a1ad15 100644 --- a/CodeGen/src/EmitBuiltinsX64.h +++ b/CodeGen/src/EmitBuiltinsX64.h @@ -16,7 +16,7 @@ class AssemblyBuilderX64; struct OperandX64; struct IrRegAllocX64; -void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, OperandX64 arg2, int nparams, int nresults); +void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, int nresults); } // namespace X64 } // namespace CodeGen diff --git a/CodeGen/src/EmitCommonA64.h b/CodeGen/src/EmitCommonA64.h index 894570d9..d61fd2a7 100644 --- a/CodeGen/src/EmitCommonA64.h +++ b/CodeGen/src/EmitCommonA64.h @@ -22,8 +22,6 @@ namespace Luau namespace CodeGen { -struct NativeState; - namespace A64 { diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index c8d1e75a..79562b88 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -155,8 +155,37 @@ void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, Ope callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); callWrap.addArgument(SizeX64::qword, b); callWrap.addArgument(SizeX64::qword, c); - callWrap.addArgument(SizeX64::dword, tm); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]); + + switch (tm) + { + case TM_ADD: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithadd)]); + break; + case TM_SUB: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithsub)]); + break; + case TM_MUL: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithmul)]); + break; + case TM_DIV: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithdiv)]); + break; + case TM_IDIV: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithidiv)]); + break; + case TM_MOD: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithmod)]); + break; + case TM_POW: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithpow)]); + break; + case TM_UNM: + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarithunm)]); + break; + default: + CODEGEN_ASSERT(!"Invalid doarith helper operation tag"); + break; + } emitUpdateBase(build); } diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index c29479e1..f88944e5 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -26,7 +26,6 @@ namespace CodeGen { enum class IrCondition : uint8_t; -struct NativeState; struct IrOp; namespace X64 diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index 30ed42a0..f78823df 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -13,6 +13,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauCodegenInstG, false) + namespace Luau { namespace CodeGen @@ -52,6 +54,9 @@ void updateUseCounts(IrFunction& function) checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } } @@ -95,6 +100,9 @@ void updateLastUseLocations(IrFunction& function, const std::vector& s checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } } } @@ -128,6 +136,12 @@ uint32_t getNextInstUse(IrFunction& function, uint32_t targetInstIdx, uint32_t s if (inst.f.kind == IrOpKind::Inst && inst.f.index == targetInstIdx) return i; + + if (FFlag::LuauCodegenInstG) + { + if (inst.g.kind == IrOpKind::Inst && inst.g.index == targetInstIdx) + return i; + } } // There must be a next use since there is the last use location @@ -165,6 +179,9 @@ std::pair getLiveInOutValueCount(IrFunction& function, IrBlo checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } return std::make_pair(liveIns, liveOuts); @@ -488,6 +505,9 @@ static void computeCfgBlockEdges(IrFunction& function) checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } } diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 7d285aaf..1f4342f6 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -13,8 +13,9 @@ #include -LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo load changes the format used by Codegen, same flag is used -LUAU_FASTFLAG(LuauTypeInfoLookupImprovement) +LUAU_FASTFLAG(LuauLoadUserdataInfo) +LUAU_FASTFLAG(LuauCodegenInstG) +LUAU_FASTFLAG(LuauCodegenFastcall3) namespace Luau { @@ -23,120 +24,25 @@ namespace CodeGen constexpr unsigned kNoAssociatedBlockIndex = ~0u; -IrBuilder::IrBuilder() - : constantMap({IrConstKind::Tag, ~0ull}) +IrBuilder::IrBuilder(const HostIrHooks& hostHooks) + : hostHooks(hostHooks) + , constantMap({IrConstKind::Tag, ~0ull}) { } -static bool hasTypedParameters_DEPRECATED(Proto* proto) -{ - CODEGEN_ASSERT(!FFlag::LuauLoadTypeInfo); - - return proto->typeinfo && proto->numparams != 0; -} - -static void buildArgumentTypeChecks_DEPRECATED(IrBuilder& build, Proto* proto) -{ - CODEGEN_ASSERT(!FFlag::LuauLoadTypeInfo); - CODEGEN_ASSERT(hasTypedParameters_DEPRECATED(proto)); - - for (int i = 0; i < proto->numparams; ++i) - { - uint8_t et = proto->typeinfo[2 + i]; - - uint8_t tag = et & ~LBC_TYPE_OPTIONAL_BIT; - uint8_t optional = et & LBC_TYPE_OPTIONAL_BIT; - - if (tag == LBC_TYPE_ANY) - continue; - - IrOp load = build.inst(IrCmd::LOAD_TAG, build.vmReg(i)); - - IrOp nextCheck; - if (optional) - { - nextCheck = build.block(IrBlockKind::Internal); - IrOp fallbackCheck = build.block(IrBlockKind::Internal); - - build.inst(IrCmd::JUMP_EQ_TAG, load, build.constTag(LUA_TNIL), nextCheck, fallbackCheck); - - build.beginBlock(fallbackCheck); - } - - switch (tag) - { - case LBC_TYPE_NIL: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNIL), build.vmExit(kVmExitEntryGuardPc)); - break; - case LBC_TYPE_BOOLEAN: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBOOLEAN), build.vmExit(kVmExitEntryGuardPc)); - break; - case LBC_TYPE_NUMBER: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNUMBER), build.vmExit(kVmExitEntryGuardPc)); - break; - case LBC_TYPE_STRING: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TSTRING), build.vmExit(kVmExitEntryGuardPc)); - break; - case LBC_TYPE_TABLE: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTABLE), build.vmExit(kVmExitEntryGuardPc)); - break; - case LBC_TYPE_FUNCTION: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TFUNCTION), build.vmExit(kVmExitEntryGuardPc)); - break; - case LBC_TYPE_THREAD: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTHREAD), build.vmExit(kVmExitEntryGuardPc)); - break; - case LBC_TYPE_USERDATA: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TUSERDATA), build.vmExit(kVmExitEntryGuardPc)); - break; - case LBC_TYPE_VECTOR: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TVECTOR), build.vmExit(kVmExitEntryGuardPc)); - break; - case LBC_TYPE_BUFFER: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBUFFER), build.vmExit(kVmExitEntryGuardPc)); - break; - } - - if (optional) - { - build.inst(IrCmd::JUMP, nextCheck); - build.beginBlock(nextCheck); - } - } - - // If the last argument is optional, we can skip creating a new internal block since one will already have been created. - if (!(proto->typeinfo[2 + proto->numparams - 1] & LBC_TYPE_OPTIONAL_BIT)) - { - IrOp next = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP, next); - - build.beginBlock(next); - } -} static bool hasTypedParameters(const BytecodeTypeInfo& typeInfo) { - CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo); - - if (FFlag::LuauTypeInfoLookupImprovement) + for (auto el : typeInfo.argumentTypes) { - for (auto el : typeInfo.argumentTypes) - { - if (el != LBC_TYPE_ANY) - return true; - } + if (el != LBC_TYPE_ANY) + return true; + } - return false; - } - else - { - return !typeInfo.argumentTypes.empty(); - } + return false; } static void buildArgumentTypeChecks(IrBuilder& build) { - CODEGEN_ASSERT(FFlag::LuauLoadTypeInfo); - const BytecodeTypeInfo& typeInfo = build.function.bcTypeInfo; CODEGEN_ASSERT(hasTypedParameters(typeInfo)); @@ -195,6 +101,19 @@ static void buildArgumentTypeChecks(IrBuilder& build) case LBC_TYPE_BUFFER: build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBUFFER), build.vmExit(kVmExitEntryGuardPc)); break; + default: + if (FFlag::LuauLoadUserdataInfo) + { + if (tag >= LBC_TYPE_TAGGED_USERDATA_BASE && tag < LBC_TYPE_TAGGED_USERDATA_END) + { + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TUSERDATA), build.vmExit(kVmExitEntryGuardPc)); + } + else + { + CODEGEN_ASSERT(!"unknown argument type tag"); + } + } + break; } if (optional) @@ -219,18 +138,17 @@ void IrBuilder::buildFunctionIr(Proto* proto) function.proto = proto; function.variadic = proto->is_vararg != 0; - if (FFlag::LuauLoadTypeInfo) - loadBytecodeTypeInfo(function); + loadBytecodeTypeInfo(function); // Reserve entry block - bool generateTypeChecks = FFlag::LuauLoadTypeInfo ? hasTypedParameters(function.bcTypeInfo) : hasTypedParameters_DEPRECATED(proto); + bool generateTypeChecks = hasTypedParameters(function.bcTypeInfo); IrOp entry = generateTypeChecks ? block(IrBlockKind::Internal) : IrOp{}; // Rebuild original control flow blocks rebuildBytecodeBasicBlocks(proto); // Infer register tags in bytecode - analyzeBytecodeTypes(function); + analyzeBytecodeTypes(function, hostHooks); function.bcMapping.resize(proto->sizecode, {~0u, ~0u}); @@ -238,10 +156,7 @@ void IrBuilder::buildFunctionIr(Proto* proto) { beginBlock(entry); - if (FFlag::LuauLoadTypeInfo) - buildArgumentTypeChecks(*this); - else - buildArgumentTypeChecks_DEPRECATED(*this, proto); + buildArgumentTypeChecks(*this); inst(IrCmd::JUMP, blockAtInst(0)); } @@ -283,10 +198,10 @@ void IrBuilder::buildFunctionIr(Proto* proto) translateInst(op, pc, i); - if (fastcallSkipTarget != -1) + if (cmdSkipTarget != -1) { - nexti = fastcallSkipTarget; - fastcallSkipTarget = -1; + nexti = cmdSkipTarget; + cmdSkipTarget = -1; } } @@ -535,16 +450,21 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstCloseUpvals(*this, pc); break; case LOP_FASTCALL: - handleFastcallFallback(translateFastCallN(*this, pc, i, false, 0, {}), pc, i); + handleFastcallFallback(translateFastCallN(*this, pc, i, false, 0, {}, {}), pc, i); break; case LOP_FASTCALL1: - handleFastcallFallback(translateFastCallN(*this, pc, i, true, 1, undef()), pc, i); + handleFastcallFallback(translateFastCallN(*this, pc, i, true, 1, undef(), undef()), pc, i); break; case LOP_FASTCALL2: - handleFastcallFallback(translateFastCallN(*this, pc, i, true, 2, vmReg(pc[1])), pc, i); + handleFastcallFallback(translateFastCallN(*this, pc, i, true, 2, vmReg(pc[1]), undef()), pc, i); break; case LOP_FASTCALL2K: - handleFastcallFallback(translateFastCallN(*this, pc, i, true, 2, vmConst(pc[1])), pc, i); + handleFastcallFallback(translateFastCallN(*this, pc, i, true, 2, vmConst(pc[1]), undef()), pc, i); + break; + case LOP_FASTCALL3: + CODEGEN_ASSERT(FFlag::LuauCodegenFastcall3); + + handleFastcallFallback(translateFastCallN(*this, pc, i, true, 3, vmReg(pc[1] & 0xff), vmReg((pc[1] >> 8) & 0xff)), pc, i); break; case LOP_FORNPREP: translateInstForNPrep(*this, pc, i); @@ -613,7 +533,8 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstCapture(*this, pc, i); break; case LOP_NAMECALL: - translateInstNamecall(*this, pc, i); + if (translateInstNamecall(*this, pc, i)) + cmdSkipTarget = i + 3; break; case LOP_PREPVARARGS: inst(IrCmd::FALLBACK_PREPVARARGS, constUint(i), constInt(LUAU_INSN_A(*pc))); @@ -654,7 +575,7 @@ void IrBuilder::handleFastcallFallback(IrOp fallbackOrUndef, const Instruction* } else { - fastcallSkipTarget = i + skip + 2; + cmdSkipTarget = i + skip + 2; } } @@ -725,6 +646,9 @@ void IrBuilder::clone(const IrBlock& source, bool removeCurrentTerminator) redirect(clone.e); redirect(clone.f); + if (FFlag::LuauCodegenInstG) + redirect(clone.g); + addUse(function, clone.a); addUse(function, clone.b); addUse(function, clone.c); @@ -732,11 +656,17 @@ void IrBuilder::clone(const IrBlock& source, bool removeCurrentTerminator) addUse(function, clone.e); addUse(function, clone.f); + if (FFlag::LuauCodegenInstG) + addUse(function, clone.g); + // Instructions that referenced the original will have to be adjusted to use the clone instRedir[index] = uint32_t(function.instructions.size()); // Reconstruct the fresh clone - inst(clone.cmd, clone.a, clone.b, clone.c, clone.d, clone.e, clone.f); + if (FFlag::LuauCodegenInstG) + inst(clone.cmd, clone.a, clone.b, clone.c, clone.d, clone.e, clone.f, clone.g); + else + inst(clone.cmd, clone.a, clone.b, clone.c, clone.d, clone.e, clone.f); } } @@ -834,8 +764,33 @@ IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e) IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f) { + if (FFlag::LuauCodegenInstG) + { + return inst(cmd, a, b, c, d, e, f, {}); + } + else + { + uint32_t index = uint32_t(function.instructions.size()); + function.instructions.push_back({cmd, a, b, c, d, e, f}); + + CODEGEN_ASSERT(!inTerminatedBlock); + + if (isBlockTerminator(cmd)) + { + function.blocks[activeBlockIdx].finish = index; + inTerminatedBlock = true; + } + + return {IrOpKind::Inst, index}; + } +} + +IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f, IrOp g) +{ + CODEGEN_ASSERT(FFlag::LuauCodegenInstG); + uint32_t index = uint32_t(function.instructions.size()); - function.instructions.push_back({cmd, a, b, c, d, e, f}); + function.instructions.push_back({cmd, a, b, c, d, e, f, g}); CODEGEN_ASSERT(!inTerminatedBlock); diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 48a50ecb..c4114d89 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -7,6 +7,9 @@ #include +LUAU_FASTFLAG(LuauLoadUserdataInfo) +LUAU_FASTFLAG(LuauCodegenInstG) + namespace Luau { namespace CodeGen @@ -151,6 +154,8 @@ const char* getCmdName(IrCmd cmd) return "SQRT_NUM"; case IrCmd::ABS_NUM: return "ABS_NUM"; + case IrCmd::SIGN_NUM: + return "SIGN_NUM"; case IrCmd::ADD_VEC: return "ADD_VEC"; case IrCmd::SUB_VEC: @@ -197,6 +202,8 @@ const char* getCmdName(IrCmd cmd) return "TRY_NUM_TO_INDEX"; case IrCmd::TRY_CALL_FASTGETTM: return "TRY_CALL_FASTGETTM"; + case IrCmd::NEW_USERDATA: + return "NEW_USERDATA"; case IrCmd::INT_TO_NUM: return "INT_TO_NUM"; case IrCmd::UINT_TO_NUM: @@ -255,6 +262,8 @@ const char* getCmdName(IrCmd cmd) return "CHECK_NODE_VALUE"; case IrCmd::CHECK_BUFFER_LEN: return "CHECK_BUFFER_LEN"; + case IrCmd::CHECK_USERDATA_TAG: + return "CHECK_USERDATA_TAG"; case IrCmd::INTERRUPT: return "INTERRUPT"; case IrCmd::CHECK_GC: @@ -411,6 +420,9 @@ void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index) checkOp(inst.d, ", "); checkOp(inst.e, ", "); checkOp(inst.f, ", "); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g, ", "); } void toString(IrToStringContext& ctx, const IrBlock& block, uint32_t index) @@ -480,8 +492,10 @@ void toString(std::string& result, IrConst constant) } } -const char* getBytecodeTypeName(uint8_t type) +const char* getBytecodeTypeName_DEPRECATED(uint8_t type) { + CODEGEN_ASSERT(!FFlag::LuauLoadUserdataInfo); + switch (type & ~LBC_TYPE_OPTIONAL_BIT) { case LBC_TYPE_NIL: @@ -512,13 +526,78 @@ const char* getBytecodeTypeName(uint8_t type) return nullptr; } -void toString(std::string& result, const BytecodeTypes& bcTypes) +const char* getBytecodeTypeName(uint8_t type, const char* const* userdataTypes) { + CODEGEN_ASSERT(FFlag::LuauLoadUserdataInfo); + + // Optional bit should be handled externally + type = type & ~LBC_TYPE_OPTIONAL_BIT; + + if (type >= LBC_TYPE_TAGGED_USERDATA_BASE && type < LBC_TYPE_TAGGED_USERDATA_END) + { + if (userdataTypes) + return userdataTypes[type - LBC_TYPE_TAGGED_USERDATA_BASE]; + + return "userdata"; + } + + switch (type) + { + case LBC_TYPE_NIL: + return "nil"; + case LBC_TYPE_BOOLEAN: + return "boolean"; + case LBC_TYPE_NUMBER: + return "number"; + case LBC_TYPE_STRING: + return "string"; + case LBC_TYPE_TABLE: + return "table"; + case LBC_TYPE_FUNCTION: + return "function"; + case LBC_TYPE_THREAD: + return "thread"; + case LBC_TYPE_USERDATA: + return "userdata"; + case LBC_TYPE_VECTOR: + return "vector"; + case LBC_TYPE_BUFFER: + return "buffer"; + case LBC_TYPE_ANY: + return "any"; + } + + CODEGEN_ASSERT(!"Unhandled type in getBytecodeTypeName"); + return nullptr; +} + +void toString_DEPRECATED(std::string& result, const BytecodeTypes& bcTypes) +{ + CODEGEN_ASSERT(!FFlag::LuauLoadUserdataInfo); + if (bcTypes.c != LBC_TYPE_ANY) - append(result, "%s <- %s, %s, %s", getBytecodeTypeName(bcTypes.result), getBytecodeTypeName(bcTypes.a), getBytecodeTypeName(bcTypes.b), - getBytecodeTypeName(bcTypes.c)); + append(result, "%s <- %s, %s, %s", getBytecodeTypeName_DEPRECATED(bcTypes.result), getBytecodeTypeName_DEPRECATED(bcTypes.a), + getBytecodeTypeName_DEPRECATED(bcTypes.b), getBytecodeTypeName_DEPRECATED(bcTypes.c)); else - append(result, "%s <- %s, %s", getBytecodeTypeName(bcTypes.result), getBytecodeTypeName(bcTypes.a), getBytecodeTypeName(bcTypes.b)); + append(result, "%s <- %s, %s", getBytecodeTypeName_DEPRECATED(bcTypes.result), getBytecodeTypeName_DEPRECATED(bcTypes.a), + getBytecodeTypeName_DEPRECATED(bcTypes.b)); +} + +void toString(std::string& result, const BytecodeTypes& bcTypes, const char* const* userdataTypes) +{ + CODEGEN_ASSERT(FFlag::LuauLoadUserdataInfo); + + append(result, "%s%s", getBytecodeTypeName(bcTypes.result, userdataTypes), (bcTypes.result & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : ""); + append(result, " <- "); + append(result, "%s%s", getBytecodeTypeName(bcTypes.a, userdataTypes), (bcTypes.a & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : ""); + append(result, ", "); + append(result, "%s%s", getBytecodeTypeName(bcTypes.b, userdataTypes), (bcTypes.b & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : ""); + + if (bcTypes.c != LBC_TYPE_ANY) + { + append(result, ", "); + append(result, "%s%s", getBytecodeTypeName(bcTypes.c, userdataTypes), (bcTypes.c & LBC_TYPE_OPTIONAL_BIT) != 0 ? "?" : ""); + } } static void appendBlockSet(IrToStringContext& ctx, BlockIteratorWrapper blocks) @@ -583,6 +662,8 @@ static RegisterSet getJumpTargetExtraLiveIn(IrToStringContext& ctx, const IrBloc op = inst.e; else if (inst.f.kind == IrOpKind::Block) op = inst.f; + else if (FFlag::LuauCodegenInstG && inst.g.kind == IrOpKind::Block) + op = inst.g; if (op.kind == IrOpKind::Block && op.index < ctx.cfg.in.size()) { @@ -867,6 +948,9 @@ std::string toDot(const IrFunction& function, bool includeInst) checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } } diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index f35a15fa..ef51a4b1 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -11,7 +11,11 @@ #include "lstate.h" #include "lgc.h" -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) +LUAU_FASTFLAG(LuauCodegenUserdataOps) +LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataAlloc, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenUserdataOpsFixA64, false) +LUAU_FASTFLAG(LuauCodegenFastcall3) +LUAU_FASTFLAG(LuauCodegenMathSign) namespace Luau { @@ -193,78 +197,51 @@ static void emitInvokeLibm1P(AssemblyBuilderA64& build, size_t func, int arg) build.blr(x1); } -static bool emitBuiltin( - AssemblyBuilderA64& build, IrFunction& function, IrRegAllocA64& regs, int bfid, int res, int arg, IrOp args, int nparams, int nresults) +static bool emitBuiltin(AssemblyBuilderA64& build, IrFunction& function, IrRegAllocA64& regs, int bfid, int res, int arg, int nresults) { switch (bfid) { case LBF_MATH_FREXP: { - if (FFlag::LuauCodegenRemoveDeadStores5) - { - CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); - emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg); - build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); + CODEGEN_ASSERT(nresults == 1 || nresults == 2); + emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg); + build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); - RegisterA64 temp = regs.allocTemp(KindA64::w); - build.mov(temp, LUA_TNUMBER); - build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt))); + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.mov(temp, LUA_TNUMBER); + build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt))); - if (nresults == 2) - { - build.ldr(w0, sTemporary); - build.scvtf(d1, w0); - build.str(d1, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n))); - build.str(temp, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, tt))); - } - } - else + if (nresults == 2) { - CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); - emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg); - build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); - if (nresults == 2) - { - build.ldr(w0, sTemporary); - build.scvtf(d1, w0); - build.str(d1, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n))); - } + build.ldr(w0, sTemporary); + build.scvtf(d1, w0); + build.str(d1, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n))); + build.str(temp, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, tt))); } return true; } case LBF_MATH_MODF: { - if (FFlag::LuauCodegenRemoveDeadStores5) - { - CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); - emitInvokeLibm1P(build, offsetof(NativeContext, libm_modf), arg); - build.ldr(d1, sTemporary); - build.str(d1, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); + CODEGEN_ASSERT(nresults == 1 || nresults == 2); + emitInvokeLibm1P(build, offsetof(NativeContext, libm_modf), arg); + build.ldr(d1, sTemporary); + build.str(d1, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); - RegisterA64 temp = regs.allocTemp(KindA64::w); - build.mov(temp, LUA_TNUMBER); - build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt))); + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.mov(temp, LUA_TNUMBER); + build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt))); - if (nresults == 2) - { - build.str(d0, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n))); - build.str(temp, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, tt))); - } - } - else + if (nresults == 2) { - CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); - emitInvokeLibm1P(build, offsetof(NativeContext, libm_modf), arg); - build.ldr(d1, sTemporary); - build.str(d1, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); - if (nresults == 2) - build.str(d0, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n))); + build.str(d0, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, value.n))); + build.str(temp, mem(rBase, (res + 1) * sizeof(TValue) + offsetof(TValue, tt))); } return true; } case LBF_MATH_SIGN: { - CODEGEN_ASSERT(nparams == 1 && nresults == 1); + CODEGEN_ASSERT(!FFlag::LuauCodegenMathSign); + CODEGEN_ASSERT(nresults == 1); build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); build.fcmpz(d0); build.fmov(d0, 0.0); @@ -274,12 +251,10 @@ static bool emitBuiltin( build.fcsel(d0, d1, d0, getConditionFP(IrCondition::Less)); build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); - if (FFlag::LuauCodegenRemoveDeadStores5) - { - RegisterA64 temp = regs.allocTemp(KindA64::w); - build.mov(temp, LUA_TNUMBER); - build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt))); - } + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.mov(temp, LUA_TNUMBER); + build.str(temp, mem(rBase, res * sizeof(TValue) + offsetof(TValue, tt))); + return true; } @@ -723,6 +698,24 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.fabs(inst.regA64, temp); break; } + case IrCmd::SIGN_NUM: + { + CODEGEN_ASSERT(FFlag::LuauCodegenMathSign); + + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + + RegisterA64 temp = tempDouble(inst.a); + RegisterA64 temp0 = regs.allocTemp(KindA64::d); + RegisterA64 temp1 = regs.allocTemp(KindA64::d); + + build.fcmpz(temp); + build.fmov(temp0, 0.0); + build.fmov(temp1, 1.0); + build.fcsel(inst.regA64, temp1, temp0, getConditionFP(IrCondition::Greater)); + build.fmov(temp1, -1.0); + build.fcsel(inst.regA64, temp1, inst.regA64, getConditionFP(IrCondition::Less)); + break; + } case IrCmd::ADD_VEC: { inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b}); @@ -1082,6 +1075,19 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) inst.regA64 = regs.takeReg(x0, index); break; } + case IrCmd::NEW_USERDATA: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc); + + regs.spill(build, index); + build.mov(x0, rState); + build.mov(x1, intOp(inst.a)); + build.mov(x2, intOp(inst.b)); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, newUserdata))); + build.blr(x3); + inst.regA64 = regs.takeReg(x0, index); + break; + } case IrCmd::INT_TO_NUM: { inst.regA64 = regs.allocReg(KindA64::d, index); @@ -1188,34 +1194,88 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) } case IrCmd::FASTCALL: regs.spill(build, index); - error |= !emitBuiltin(build, function, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), inst.d, intOp(inst.e), intOp(inst.f)); + + if (FFlag::LuauCodegenFastcall3) + error |= !emitBuiltin(build, function, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d)); + else + error |= !emitBuiltin(build, function, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.f)); + break; case IrCmd::INVOKE_FASTCALL: { - regs.spill(build, index); - build.mov(x0, rState); - build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); - build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); - build.mov(w3, intOp(inst.f)); // nresults - - if (inst.d.kind == IrOpKind::VmReg) - build.add(x4, rBase, uint16_t(vmRegOp(inst.d) * sizeof(TValue))); - else if (inst.d.kind == IrOpKind::VmConst) - emitAddOffset(build, x4, rConstants, vmConstOp(inst.d) * sizeof(TValue)); - else - CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef); - - // nparams - if (intOp(inst.e) == LUA_MULTRET) + if (FFlag::LuauCodegenFastcall3) { - // L->top - (ra + 1) - build.ldr(x5, mem(rState, offsetof(lua_State, top))); - build.sub(x5, x5, rBase); - build.sub(x5, x5, uint16_t((vmRegOp(inst.b) + 1) * sizeof(TValue))); - build.lsr(x5, x5, kTValueSizeLog2); + // We might need a temporary and we have to preserve it over the spill + RegisterA64 temp = regs.allocTemp(KindA64::q); + regs.spill(build, index, {temp}); + + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); + build.mov(w3, intOp(inst.g)); // nresults + + // 'E' argument can only be produced by LOP_FASTCALL3 lowering + if (inst.e.kind != IrOpKind::Undef) + { + CODEGEN_ASSERT(intOp(inst.f) == 3); + + build.ldr(x4, mem(rState, offsetof(lua_State, top))); + + build.ldr(temp, mem(rBase, vmRegOp(inst.d) * sizeof(TValue))); + build.str(temp, mem(x4, 0)); + + build.ldr(temp, mem(rBase, vmRegOp(inst.e) * sizeof(TValue))); + build.str(temp, mem(x4, sizeof(TValue))); + } + else + { + if (inst.d.kind == IrOpKind::VmReg) + build.add(x4, rBase, uint16_t(vmRegOp(inst.d) * sizeof(TValue))); + else if (inst.d.kind == IrOpKind::VmConst) + emitAddOffset(build, x4, rConstants, vmConstOp(inst.d) * sizeof(TValue)); + else + CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef); + } + + // nparams + if (intOp(inst.f) == LUA_MULTRET) + { + // L->top - (ra + 1) + build.ldr(x5, mem(rState, offsetof(lua_State, top))); + build.sub(x5, x5, rBase); + build.sub(x5, x5, uint16_t((vmRegOp(inst.b) + 1) * sizeof(TValue))); + build.lsr(x5, x5, kTValueSizeLog2); + } + else + build.mov(w5, intOp(inst.f)); } else - build.mov(w5, intOp(inst.e)); + { + regs.spill(build, index); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); + build.mov(w3, intOp(inst.f)); // nresults + + if (inst.d.kind == IrOpKind::VmReg) + build.add(x4, rBase, uint16_t(vmRegOp(inst.d) * sizeof(TValue))); + else if (inst.d.kind == IrOpKind::VmConst) + emitAddOffset(build, x4, rConstants, vmConstOp(inst.d) * sizeof(TValue)); + else + CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef); + + // nparams + if (intOp(inst.e) == LUA_MULTRET) + { + // L->top - (ra + 1) + build.ldr(x5, mem(rState, offsetof(lua_State, top))); + build.sub(x5, x5, rBase); + build.sub(x5, x5, uint16_t((vmRegOp(inst.b) + 1) * sizeof(TValue))); + build.lsr(x5, x5, kTValueSizeLog2); + } + else + build.mov(w5, intOp(inst.e)); + } build.ldr(x6, mem(rNativeContext, offsetof(NativeContext, luauF_table) + uintOp(inst.a) * sizeof(luau_FastFunction))); build.blr(x6); @@ -1242,9 +1302,38 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) else build.add(x3, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); - build.mov(w4, TMS(intOp(inst.d))); - build.ldr(x5, mem(rNativeContext, offsetof(NativeContext, luaV_doarith))); - build.blr(x5); + switch (TMS(intOp(inst.d))) + { + case TM_ADD: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithadd))); + break; + case TM_SUB: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithsub))); + break; + case TM_MUL: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithmul))); + break; + case TM_DIV: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithdiv))); + break; + case TM_IDIV: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithidiv))); + break; + case TM_MOD: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithmod))); + break; + case TM_POW: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithpow))); + break; + case TM_UNM: + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_doarithunm))); + break; + default: + CODEGEN_ASSERT(!"Invalid doarith helper operation tag"); + break; + } + + build.blr(x4); emitUpdateBase(build); break; @@ -1388,35 +1477,14 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) Label fresh; // used when guard aborts execution or jumps to a VM exit Label& fail = getTargetLabel(inst.c, fresh); - if (FFlag::LuauCodegenRemoveDeadStores5) + if (tagOp(inst.b) == 0) { - if (tagOp(inst.b) == 0) - { - build.cbnz(regOp(inst.a), fail); - } - else - { - build.cmp(regOp(inst.a), tagOp(inst.b)); - build.b(ConditionA64::NotEqual, fail); - } + build.cbnz(regOp(inst.a), fail); } else { - // To support DebugLuauAbortingChecks, CHECK_TAG with VmReg has to be handled - RegisterA64 tag = inst.a.kind == IrOpKind::VmReg ? regs.allocTemp(KindA64::w) : regOp(inst.a); - - if (inst.a.kind == IrOpKind::VmReg) - build.ldr(tag, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, tt))); - - if (tagOp(inst.b) == 0) - { - build.cbnz(tag, fail); - } - else - { - build.cmp(tag, tagOp(inst.b)); - build.b(ConditionA64::NotEqual, fail); - } + build.cmp(regOp(inst.a), tagOp(inst.b)); + build.b(ConditionA64::NotEqual, fail); } finalizeTargetLabel(inst.c, fresh); @@ -1638,6 +1706,24 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) finalizeTargetLabel(inst.d, fresh); break; } + case IrCmd::CHECK_USERDATA_TAG: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps); + + Label fresh; // used when guard aborts execution or jumps to a VM exit + Label& fail = getTargetLabel(inst.c, fresh); + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.ldrb(temp, mem(regOp(inst.a), offsetof(Udata, tag))); + + if (FFlag::LuauCodegenUserdataOpsFixA64) + build.cmp(temp, intOp(inst.b)); + else + build.cmp(temp, tagOp(inst.b)); + + build.b(ConditionA64::NotEqual, fail); + finalizeTargetLabel(inst.c, fresh); + break; + } case IrCmd::INTERRUPT: { regs.spill(build, index); @@ -2269,7 +2355,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READI8: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldrsb(inst.regA64, addr); break; @@ -2278,7 +2364,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READU8: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldrb(inst.regA64, addr); break; @@ -2287,7 +2373,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_WRITEI8: { RegisterA64 temp = tempInt(inst.c); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.strb(temp, addr); break; @@ -2296,7 +2382,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READI16: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldrsh(inst.regA64, addr); break; @@ -2305,7 +2391,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READU16: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldrh(inst.regA64, addr); break; @@ -2314,7 +2400,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_WRITEI16: { RegisterA64 temp = tempInt(inst.c); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.strh(temp, addr); break; @@ -2323,7 +2409,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READI32: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.b}); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldr(inst.regA64, addr); break; @@ -2332,7 +2418,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_WRITEI32: { RegisterA64 temp = tempInt(inst.c); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.str(temp, addr); break; @@ -2342,7 +2428,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { inst.regA64 = regs.allocReg(KindA64::d, index); RegisterA64 temp = castReg(KindA64::s, inst.regA64); // safe to alias a fresh register - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldr(temp, addr); build.fcvt(inst.regA64, temp); @@ -2353,7 +2439,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { RegisterA64 temp1 = tempDouble(inst.c); RegisterA64 temp2 = regs.allocTemp(KindA64::s); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.fcvt(temp2, temp1); build.str(temp2, addr); @@ -2363,7 +2449,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READF64: { inst.regA64 = regs.allocReg(KindA64::d, index); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c)); build.ldr(inst.regA64, addr); break; @@ -2372,7 +2458,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_WRITEF64: { RegisterA64 temp = tempDouble(inst.c); - AddressA64 addr = tempAddrBuffer(inst.a, inst.b); + AddressA64 addr = tempAddrBuffer(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d)); build.str(temp, addr); break; @@ -2600,32 +2686,68 @@ AddressA64 IrLoweringA64::tempAddr(IrOp op, int offset) } } -AddressA64 IrLoweringA64::tempAddrBuffer(IrOp bufferOp, IrOp indexOp) +AddressA64 IrLoweringA64::tempAddrBuffer(IrOp bufferOp, IrOp indexOp, uint8_t tag) { - if (indexOp.kind == IrOpKind::Inst) + if (FFlag::LuauCodegenUserdataOps) { - RegisterA64 temp = regs.allocTemp(KindA64::x); - build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw - return mem(temp, offsetof(Buffer, data)); - } - else if (indexOp.kind == IrOpKind::Constant) - { - // Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled encoding - if (unsigned(intOp(indexOp)) + offsetof(Buffer, data) <= 255) - return mem(regOp(bufferOp), int(intOp(indexOp) + offsetof(Buffer, data))); + CODEGEN_ASSERT(tag == LUA_TUSERDATA || tag == LUA_TBUFFER); + int dataOffset = tag == LUA_TBUFFER ? offsetof(Buffer, data) : offsetof(Udata, data); - // indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset - if (intOp(indexOp) < 0) - return mem(regOp(bufferOp), offsetof(Buffer, data)); + if (indexOp.kind == IrOpKind::Inst) + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw + return mem(temp, dataOffset); + } + else if (indexOp.kind == IrOpKind::Constant) + { + // Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled + // encoding + if (unsigned(intOp(indexOp)) + dataOffset <= 255) + return mem(regOp(bufferOp), int(intOp(indexOp) + dataOffset)); - RegisterA64 temp = regs.allocTemp(KindA64::x); - emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp))); - return mem(temp, offsetof(Buffer, data)); + // indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset + if (intOp(indexOp) < 0) + return mem(regOp(bufferOp), dataOffset); + + RegisterA64 temp = regs.allocTemp(KindA64::x); + emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp))); + return mem(temp, dataOffset); + } + else + { + CODEGEN_ASSERT(!"Unsupported instruction form"); + return noreg; + } } else { - CODEGEN_ASSERT(!"Unsupported instruction form"); - return noreg; + if (indexOp.kind == IrOpKind::Inst) + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + build.add(temp, regOp(bufferOp), regOp(indexOp)); // implicit uxtw + return mem(temp, offsetof(Buffer, data)); + } + else if (indexOp.kind == IrOpKind::Constant) + { + // Since the resulting address may be used to load any size, including 1 byte, from an unaligned offset, we are limited by unscaled + // encoding + if (unsigned(intOp(indexOp)) + offsetof(Buffer, data) <= 255) + return mem(regOp(bufferOp), int(intOp(indexOp) + offsetof(Buffer, data))); + + // indexOp can only be negative in dead code (since offsets are checked); this avoids assertion in emitAddOffset + if (intOp(indexOp) < 0) + return mem(regOp(bufferOp), offsetof(Buffer, data)); + + RegisterA64 temp = regs.allocTemp(KindA64::x); + emitAddOffset(build, temp, regOp(bufferOp), size_t(intOp(indexOp))); + return mem(temp, offsetof(Buffer, data)); + } + else + { + CODEGEN_ASSERT(!"Unsupported instruction form"); + return noreg; + } } } diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index 5fb7f2b8..5f13f58e 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -44,7 +44,7 @@ struct IrLoweringA64 RegisterA64 tempInt(IrOp op); RegisterA64 tempUint(IrOp op); AddressA64 tempAddr(IrOp op, int offset); - AddressA64 tempAddrBuffer(IrOp bufferOp, IrOp indexOp); + AddressA64 tempAddrBuffer(IrOp bufferOp, IrOp indexOp, uint8_t tag); // May emit restore instructions RegisterA64 regOp(IrOp op); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 66609cb7..f372a7ec 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -15,6 +15,11 @@ #include "lstate.h" #include "lgc.h" +LUAU_FASTFLAG(LuauCodegenUserdataOps) +LUAU_FASTFLAG(LuauCodegenUserdataAlloc) +LUAU_FASTFLAG(LuauCodegenFastcall3) +LUAU_FASTFLAG(LuauCodegenMathSign) + namespace Luau { namespace CodeGen @@ -586,6 +591,33 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.vandpd(inst.regX64, inst.regX64, build.i64(~(1LL << 63))); break; + case IrCmd::SIGN_NUM: + { + CODEGEN_ASSERT(FFlag::LuauCodegenMathSign); + + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a}); + + ScopedRegX64 tmp0{regs, SizeX64::xmmword}; + ScopedRegX64 tmp1{regs, SizeX64::xmmword}; + ScopedRegX64 tmp2{regs, SizeX64::xmmword}; + + build.vxorpd(tmp0.reg, tmp0.reg, tmp0.reg); + + // Set tmp1 to -1 if arg < 0, else 0 + build.vcmpltsd(tmp1.reg, regOp(inst.a), tmp0.reg); + build.vmovsd(tmp2.reg, build.f64(-1)); + build.vandpd(tmp1.reg, tmp1.reg, tmp2.reg); + + // Set mask bit to 1 if 0 < arg, else 0 + build.vcmpltsd(inst.regX64, tmp0.reg, regOp(inst.a)); + + // Result = (mask-bit == 1) ? 1.0 : tmp1 + // If arg < 0 then tmp1 is -1 and mask-bit is 0, result is -1 + // If arg == 0 then tmp1 is 0 and mask-bit is 0, result is 0 + // If arg > 0 then tmp1 is 0 and mask-bit is 1, result is 1 + build.vblendvpd(inst.regX64, tmp1.reg, build.f64x2(1, 1), inst.regX64); + break; + } case IrCmd::ADD_VEC: { inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); @@ -905,6 +937,18 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) inst.regX64 = regs.takeReg(rax, index); break; } + case IrCmd::NEW_USERDATA: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc); + + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, intOp(inst.a)); + callWrap.addArgument(SizeX64::dword, intOp(inst.b)); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, newUserdata)]); + inst.regX64 = regs.takeReg(rax, index); + break; + } case IrCmd::INT_TO_NUM: inst.regX64 = regs.allocReg(SizeX64::xmmword, index); @@ -993,9 +1037,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::FASTCALL: { - OperandX64 arg2 = inst.d.kind != IrOpKind::Undef ? memRegDoubleOp(inst.d) : OperandX64{0}; - - emitBuiltin(regs, build, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), arg2, intOp(inst.e), intOp(inst.f)); + if (FFlag::LuauCodegenFastcall3) + emitBuiltin(regs, build, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d)); + else + emitBuiltin(regs, build, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.f)); break; } case IrCmd::INVOKE_FASTCALL: @@ -1003,25 +1048,49 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) unsigned bfid = uintOp(inst.a); OperandX64 args = 0; + ScopedRegX64 argsAlt{regs}; - if (inst.d.kind == IrOpKind::VmReg) - args = luauRegAddress(vmRegOp(inst.d)); - else if (inst.d.kind == IrOpKind::VmConst) - args = luauConstantAddress(vmConstOp(inst.d)); + // 'E' argument can only be produced by LOP_FASTCALL3 + if (FFlag::LuauCodegenFastcall3 && inst.e.kind != IrOpKind::Undef) + { + CODEGEN_ASSERT(intOp(inst.f) == 3); + + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + argsAlt.alloc(SizeX64::qword); + + build.mov(argsAlt.reg, qword[rState + offsetof(lua_State, top)]); + + build.vmovups(tmp.reg, luauReg(vmRegOp(inst.d))); + build.vmovups(xmmword[argsAlt.reg], tmp.reg); + + build.vmovups(tmp.reg, luauReg(vmRegOp(inst.e))); + build.vmovups(xmmword[argsAlt.reg + sizeof(TValue)], tmp.reg); + } else - CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef); + { + if (inst.d.kind == IrOpKind::VmReg) + args = luauRegAddress(vmRegOp(inst.d)); + else if (inst.d.kind == IrOpKind::VmConst) + args = luauConstantAddress(vmConstOp(inst.d)); + else + CODEGEN_ASSERT(inst.d.kind == IrOpKind::Undef); + } int ra = vmRegOp(inst.b); int arg = vmRegOp(inst.c); - int nparams = intOp(inst.e); - int nresults = intOp(inst.f); + int nparams = intOp(FFlag::LuauCodegenFastcall3 ? inst.f : inst.e); + int nresults = intOp(FFlag::LuauCodegenFastcall3 ? inst.g : inst.f); IrCallWrapperX64 callWrap(regs, build, index); callWrap.addArgument(SizeX64::qword, rState); callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); callWrap.addArgument(SizeX64::qword, luauRegAddress(arg)); callWrap.addArgument(SizeX64::dword, nresults); - callWrap.addArgument(SizeX64::qword, args); + + if (FFlag::LuauCodegenFastcall3 && inst.e.kind != IrOpKind::Undef) + callWrap.addArgument(SizeX64::qword, argsAlt); + else + callWrap.addArgument(SizeX64::qword, args); if (nparams == LUA_MULTRET) { @@ -1350,6 +1419,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) } break; } + case IrCmd::CHECK_USERDATA_TAG: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps); + + build.cmp(byte[regOp(inst.a) + offsetof(Udata, tag)], intOp(inst.b)); + jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.c, next); + break; + } case IrCmd::INTERRUPT: { unsigned pcpos = uintOp(inst.a); @@ -1895,71 +1972,71 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::BUFFER_READI8: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.movsx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b)]); + build.movsx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_READU8: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.movzx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b)]); + build.movzx(inst.regX64, byte[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEI8: { OperandX64 value = inst.c.kind == IrOpKind::Inst ? byteReg(regOp(inst.c)) : OperandX64(int8_t(intOp(inst.c))); - build.mov(byte[bufferAddrOp(inst.a, inst.b)], value); + build.mov(byte[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value); break; } case IrCmd::BUFFER_READI16: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.movsx(inst.regX64, word[bufferAddrOp(inst.a, inst.b)]); + build.movsx(inst.regX64, word[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_READU16: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.movzx(inst.regX64, word[bufferAddrOp(inst.a, inst.b)]); + build.movzx(inst.regX64, word[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEI16: { OperandX64 value = inst.c.kind == IrOpKind::Inst ? wordReg(regOp(inst.c)) : OperandX64(int16_t(intOp(inst.c))); - build.mov(word[bufferAddrOp(inst.a, inst.b)], value); + build.mov(word[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value); break; } case IrCmd::BUFFER_READI32: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a, inst.b}); - build.mov(inst.regX64, dword[bufferAddrOp(inst.a, inst.b)]); + build.mov(inst.regX64, dword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEI32: { OperandX64 value = inst.c.kind == IrOpKind::Inst ? regOp(inst.c) : OperandX64(intOp(inst.c)); - build.mov(dword[bufferAddrOp(inst.a, inst.b)], value); + build.mov(dword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], value); break; } case IrCmd::BUFFER_READF32: inst.regX64 = regs.allocReg(SizeX64::xmmword, index); - build.vcvtss2sd(inst.regX64, inst.regX64, dword[bufferAddrOp(inst.a, inst.b)]); + build.vcvtss2sd(inst.regX64, inst.regX64, dword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEF32: - storeDoubleAsFloat(dword[bufferAddrOp(inst.a, inst.b)], inst.c); + storeDoubleAsFloat(dword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], inst.c); break; case IrCmd::BUFFER_READF64: inst.regX64 = regs.allocReg(SizeX64::xmmword, index); - build.vmovsd(inst.regX64, qword[bufferAddrOp(inst.a, inst.b)]); + build.vmovsd(inst.regX64, qword[bufferAddrOp(inst.a, inst.b, inst.c.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.c))]); break; case IrCmd::BUFFER_WRITEF64: @@ -1967,11 +2044,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { ScopedRegX64 tmp{regs, SizeX64::xmmword}; build.vmovsd(tmp.reg, build.f64(doubleOp(inst.c))); - build.vmovsd(qword[bufferAddrOp(inst.a, inst.b)], tmp.reg); + build.vmovsd(qword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], tmp.reg); } else if (inst.c.kind == IrOpKind::Inst) { - build.vmovsd(qword[bufferAddrOp(inst.a, inst.b)], regOp(inst.c)); + build.vmovsd(qword[bufferAddrOp(inst.a, inst.b, inst.d.kind == IrOpKind::None ? LUA_TBUFFER : tagOp(inst.d))], regOp(inst.c)); } else { @@ -2190,12 +2267,25 @@ RegisterX64 IrLoweringX64::regOp(IrOp op) return inst.regX64; } -OperandX64 IrLoweringX64::bufferAddrOp(IrOp bufferOp, IrOp indexOp) +OperandX64 IrLoweringX64::bufferAddrOp(IrOp bufferOp, IrOp indexOp, uint8_t tag) { - if (indexOp.kind == IrOpKind::Inst) - return regOp(bufferOp) + qwordReg(regOp(indexOp)) + offsetof(Buffer, data); - else if (indexOp.kind == IrOpKind::Constant) - return regOp(bufferOp) + intOp(indexOp) + offsetof(Buffer, data); + if (FFlag::LuauCodegenUserdataOps) + { + CODEGEN_ASSERT(tag == LUA_TUSERDATA || tag == LUA_TBUFFER); + int dataOffset = tag == LUA_TBUFFER ? offsetof(Buffer, data) : offsetof(Udata, data); + + if (indexOp.kind == IrOpKind::Inst) + return regOp(bufferOp) + qwordReg(regOp(indexOp)) + dataOffset; + else if (indexOp.kind == IrOpKind::Constant) + return regOp(bufferOp) + intOp(indexOp) + dataOffset; + } + else + { + if (indexOp.kind == IrOpKind::Inst) + return regOp(bufferOp) + qwordReg(regOp(indexOp)) + offsetof(Buffer, data); + else if (indexOp.kind == IrOpKind::Constant) + return regOp(bufferOp) + intOp(indexOp) + offsetof(Buffer, data); + } CODEGEN_ASSERT(!"Unsupported instruction form"); return noreg; diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index 5fb7b0fa..8fb311ea 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -50,7 +50,7 @@ struct IrLoweringX64 OperandX64 memRegUintOp(IrOp op); OperandX64 memRegTagOp(IrOp op); RegisterX64 regOp(IrOp op); - OperandX64 bufferAddrOp(IrOp bufferOp, IrOp indexOp); + OperandX64 bufferAddrOp(IrOp bufferOp, IrOp indexOp, uint8_t tag); RegisterX64 vecOp(IrOp op, ScopedRegX64& tmp); IrConst constOp(IrOp op) const; diff --git a/CodeGen/src/IrRegAllocA64.cpp b/CodeGen/src/IrRegAllocA64.cpp index 24b0b285..af63a2fc 100644 --- a/CodeGen/src/IrRegAllocA64.cpp +++ b/CodeGen/src/IrRegAllocA64.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAGVARIABLE(DebugCodegenChaosA64, false) +LUAU_FASTFLAG(LuauCodegenInstG) namespace Luau { @@ -256,6 +257,9 @@ void IrRegAllocA64::freeLastUseRegs(const IrInst& inst, uint32_t index) checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } void IrRegAllocA64::freeTempRegs() diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index 2b5da623..60326074 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -6,6 +6,8 @@ #include "EmitCommonX64.h" +LUAU_FASTFLAG(LuauCodegenInstG) + namespace Luau { namespace CodeGen @@ -181,6 +183,9 @@ void IrRegAllocX64::freeLastUseRegs(const IrInst& inst, uint32_t instIdx) checkOp(inst.d); checkOp(inst.e); checkOp(inst.f); + + if (FFlag::LuauCodegenInstG) + checkOp(inst.g); } bool IrRegAllocX64::isLastUseReg(const IrInst& target, uint32_t instIdx) const diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index bec5deea..f6a77f21 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -8,7 +8,8 @@ #include -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) +LUAU_FASTFLAG(LuauCodegenFastcall3) +LUAU_FASTFLAGVARIABLE(LuauCodegenMathSign, false) // TODO: when nresults is less than our actual result count, we can skip computing/writing unused results @@ -42,23 +43,23 @@ static IrOp builtinLoadDouble(IrBuilder& build, IrOp arg) static BuiltinImplResult translateBuiltinNumberToNumber( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { + CODEGEN_ASSERT(!FFlag::LuauCodegenMathSign); + if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; builtinCheckDouble(build, build.vmReg(arg), pcpos); - build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(1)); - if (!FFlag::LuauCodegenRemoveDeadStores5) - { - if (ra != arg) - build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - } + if (FFlag::LuauCodegenFastcall3) + build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), build.constInt(1)); + else + build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(1)); return {BuiltinImplType::Full, 1}; } static BuiltinImplResult translateBuiltinNumberToNumberLibm( - IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, int nresults, int pcpos) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -109,17 +110,12 @@ static BuiltinImplResult translateBuiltinNumberTo2Number( return {BuiltinImplType::None, -1}; builtinCheckDouble(build, build.vmReg(arg), pcpos); - build.inst( - IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(nresults == 1 ? 1 : 2)); - if (!FFlag::LuauCodegenRemoveDeadStores5) - { - if (ra != arg) - build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - - if (nresults != 1) - build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TNUMBER)); - } + if (FFlag::LuauCodegenFastcall3) + build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), build.constInt(nresults == 1 ? 1 : 2)); + else + build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), build.undef(), build.constInt(1), + build.constInt(nresults == 1 ? 1 : 2)); return {BuiltinImplType::Full, 2}; } @@ -198,7 +194,8 @@ static BuiltinImplResult translateBuiltinMathLog(IrBuilder& build, int nparams, return {BuiltinImplType::Full, 1}; } -static BuiltinImplResult translateBuiltinMathMinMax(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) +static BuiltinImplResult translateBuiltinMathMinMax( + IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, IrOp args, IrOp arg3, int nresults, int pcpos) { if (nparams < 2 || nparams > kMinMaxUnrolledParams || nresults > 1) return {BuiltinImplType::None, -1}; @@ -206,7 +203,10 @@ static BuiltinImplResult translateBuiltinMathMinMax(IrBuilder& build, IrCmd cmd, builtinCheckDouble(build, build.vmReg(arg), pcpos); builtinCheckDouble(build, args, pcpos); - for (int i = 3; i <= nparams; ++i) + if (FFlag::LuauCodegenFastcall3 && nparams >= 3) + builtinCheckDouble(build, arg3, pcpos); + + for (int i = (FFlag::LuauCodegenFastcall3 ? 4 : 3); i <= nparams; ++i) builtinCheckDouble(build, build.vmReg(vmRegOp(args) + (i - 2)), pcpos); IrOp varg1 = builtinLoadDouble(build, build.vmReg(arg)); @@ -214,7 +214,13 @@ static BuiltinImplResult translateBuiltinMathMinMax(IrBuilder& build, IrCmd cmd, IrOp res = build.inst(cmd, varg2, varg1); // Swapped arguments are required for consistency with VM builtins - for (int i = 3; i <= nparams; ++i) + if (FFlag::LuauCodegenFastcall3 && nparams >= 3) + { + IrOp arg = builtinLoadDouble(build, arg3); + res = build.inst(cmd, arg, res); + } + + for (int i = (FFlag::LuauCodegenFastcall3 ? 4 : 3); i <= nparams; ++i) { IrOp arg = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + (i - 2))); res = build.inst(cmd, arg, res); @@ -228,7 +234,8 @@ static BuiltinImplResult translateBuiltinMathMinMax(IrBuilder& build, IrCmd cmd, return {BuiltinImplType::Full, 1}; } -static BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback, int pcpos) +static BuiltinImplResult translateBuiltinMathClamp( + IrBuilder& build, int nparams, int ra, int arg, IrOp args, IrOp arg3, int nresults, IrOp fallback, int pcpos) { if (nparams < 3 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -239,10 +246,10 @@ static BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams builtinCheckDouble(build, build.vmReg(arg), pcpos); builtinCheckDouble(build, args, pcpos); - builtinCheckDouble(build, build.vmReg(vmRegOp(args) + 1), pcpos); + builtinCheckDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(vmRegOp(args) + 1), pcpos); IrOp min = builtinLoadDouble(build, args); - IrOp max = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + 1)); + IrOp max = builtinLoadDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(vmRegOp(args) + 1)); build.inst(IrCmd::JUMP_CMP_NUM, min, max, build.cond(IrCondition::NotLessEqual), fallback, block); build.beginBlock(block); @@ -305,7 +312,7 @@ static BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, i } static BuiltinImplResult translateBuiltinBit32BinaryOp( - IrBuilder& build, IrCmd cmd, bool btest, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) + IrBuilder& build, IrCmd cmd, bool btest, int nparams, int ra, int arg, IrOp args, IrOp arg3, int nresults, int pcpos) { if (nparams < 2 || nparams > kBit32BinaryOpUnrolledParams || nresults > 1) return {BuiltinImplType::None, -1}; @@ -313,7 +320,10 @@ static BuiltinImplResult translateBuiltinBit32BinaryOp( builtinCheckDouble(build, build.vmReg(arg), pcpos); builtinCheckDouble(build, args, pcpos); - for (int i = 3; i <= nparams; ++i) + if (FFlag::LuauCodegenFastcall3 && nparams >= 3) + builtinCheckDouble(build, arg3, pcpos); + + for (int i = (FFlag::LuauCodegenFastcall3 ? 4 : 3); i <= nparams; ++i) builtinCheckDouble(build, build.vmReg(vmRegOp(args) + (i - 2)), pcpos); IrOp va = builtinLoadDouble(build, build.vmReg(arg)); @@ -324,7 +334,15 @@ static BuiltinImplResult translateBuiltinBit32BinaryOp( IrOp res = build.inst(cmd, vaui, vbui); - for (int i = 3; i <= nparams; ++i) + if (FFlag::LuauCodegenFastcall3 && nparams >= 3) + { + IrOp vc = builtinLoadDouble(build, arg3); + IrOp arg = build.inst(IrCmd::NUM_TO_UINT, vc); + + res = build.inst(cmd, res, arg); + } + + for (int i = (FFlag::LuauCodegenFastcall3 ? 4 : 3); i <= nparams; ++i) { IrOp vc = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + (i - 2))); IrOp arg = build.inst(IrCmd::NUM_TO_UINT, vc); @@ -449,7 +467,7 @@ static BuiltinImplResult translateBuiltinBit32Rotate(IrBuilder& build, IrCmd cmd } static BuiltinImplResult translateBuiltinBit32Extract( - IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback, int pcpos) + IrBuilder& build, int nparams, int ra, int arg, IrOp args, IrOp arg3, int nresults, IrOp fallback, int pcpos) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -497,8 +515,8 @@ static BuiltinImplResult translateBuiltinBit32Extract( { IrOp f = build.inst(IrCmd::NUM_TO_INT, vb); - builtinCheckDouble(build, build.vmReg(args.index + 1), pcpos); - IrOp vc = builtinLoadDouble(build, build.vmReg(args.index + 1)); + builtinCheckDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(args.index + 1), pcpos); + IrOp vc = builtinLoadDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(args.index + 1)); IrOp w = build.inst(IrCmd::NUM_TO_INT, vc); IrOp block1 = build.block(IrBlockKind::Internal); @@ -587,18 +605,18 @@ static BuiltinImplResult translateBuiltinBit32Unary(IrBuilder& build, IrCmd cmd, } static BuiltinImplResult translateBuiltinBit32Replace( - IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback, int pcpos) + IrBuilder& build, int nparams, int ra, int arg, IrOp args, IrOp arg3, int nresults, IrOp fallback, int pcpos) { if (nparams < 3 || nresults > 1) return {BuiltinImplType::None, -1}; builtinCheckDouble(build, build.vmReg(arg), pcpos); builtinCheckDouble(build, args, pcpos); - builtinCheckDouble(build, build.vmReg(args.index + 1), pcpos); + builtinCheckDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(args.index + 1), pcpos); IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp vb = builtinLoadDouble(build, args); - IrOp vc = builtinLoadDouble(build, build.vmReg(args.index + 1)); + IrOp vc = builtinLoadDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(args.index + 1)); IrOp n = build.inst(IrCmd::NUM_TO_UINT, va); IrOp v = build.inst(IrCmd::NUM_TO_UINT, vb); @@ -623,8 +641,8 @@ static BuiltinImplResult translateBuiltinBit32Replace( } else { - builtinCheckDouble(build, build.vmReg(args.index + 2), pcpos); - IrOp vd = builtinLoadDouble(build, build.vmReg(args.index + 2)); + builtinCheckDouble(build, FFlag::LuauCodegenFastcall3 ? build.vmReg(vmRegOp(args) + 2) : build.vmReg(args.index + 2), pcpos); + IrOp vd = builtinLoadDouble(build, FFlag::LuauCodegenFastcall3 ? build.vmReg(vmRegOp(args) + 2) : build.vmReg(args.index + 2)); IrOp w = build.inst(IrCmd::NUM_TO_INT, vd); IrOp block1 = build.block(IrBlockKind::Internal); @@ -661,7 +679,7 @@ static BuiltinImplResult translateBuiltinBit32Replace( return {BuiltinImplType::UsesFallback, 1}; } -static BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) +static BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, int ra, int arg, IrOp args, IrOp arg3, int nresults, int pcpos) { if (nparams < 3 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -670,11 +688,11 @@ static BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, i builtinCheckDouble(build, build.vmReg(arg), pcpos); builtinCheckDouble(build, args, pcpos); - builtinCheckDouble(build, build.vmReg(vmRegOp(args) + 1), pcpos); + builtinCheckDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(vmRegOp(args) + 1), pcpos); IrOp x = builtinLoadDouble(build, build.vmReg(arg)); IrOp y = builtinLoadDouble(build, args); - IrOp z = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + 1)); + IrOp z = builtinLoadDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(vmRegOp(args) + 1)); build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), x, y, z); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); @@ -736,13 +754,14 @@ static BuiltinImplResult translateBuiltinStringLen(IrBuilder& build, int nparams return {BuiltinImplType::Full, 1}; } -static void translateBufferArgsAndCheckBounds(IrBuilder& build, int nparams, int arg, IrOp args, int size, int pcpos, IrOp& buf, IrOp& intIndex) +static void translateBufferArgsAndCheckBounds( + IrBuilder& build, int nparams, int arg, IrOp args, IrOp arg3, int size, int pcpos, IrOp& buf, IrOp& intIndex) { build.loadAndCheckTag(build.vmReg(arg), LUA_TBUFFER, build.vmExit(pcpos)); builtinCheckDouble(build, args, pcpos); if (nparams == 3) - builtinCheckDouble(build, build.vmReg(vmRegOp(args) + 1), pcpos); + builtinCheckDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(vmRegOp(args) + 1), pcpos); buf = build.inst(IrCmd::LOAD_POINTER, build.vmReg(arg)); @@ -753,13 +772,13 @@ static void translateBufferArgsAndCheckBounds(IrBuilder& build, int nparams, int } static BuiltinImplResult translateBuiltinBufferRead( - IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos, IrCmd readCmd, int size, IrCmd convCmd) + IrBuilder& build, int nparams, int ra, int arg, IrOp args, IrOp arg3, int nresults, int pcpos, IrCmd readCmd, int size, IrCmd convCmd) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; IrOp buf, intIndex; - translateBufferArgsAndCheckBounds(build, nparams, arg, args, size, pcpos, buf, intIndex); + translateBufferArgsAndCheckBounds(build, nparams, arg, args, arg3, size, pcpos, buf, intIndex); IrOp result = build.inst(readCmd, buf, intIndex); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), convCmd == IrCmd::NOP ? result : build.inst(convCmd, result)); @@ -769,21 +788,22 @@ static BuiltinImplResult translateBuiltinBufferRead( } static BuiltinImplResult translateBuiltinBufferWrite( - IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos, IrCmd writeCmd, int size, IrCmd convCmd) + IrBuilder& build, int nparams, int ra, int arg, IrOp args, IrOp arg3, int nresults, int pcpos, IrCmd writeCmd, int size, IrCmd convCmd) { if (nparams < 3 || nresults > 0) return {BuiltinImplType::None, -1}; IrOp buf, intIndex; - translateBufferArgsAndCheckBounds(build, nparams, arg, args, size, pcpos, buf, intIndex); + translateBufferArgsAndCheckBounds(build, nparams, arg, args, arg3, size, pcpos, buf, intIndex); - IrOp numValue = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + 1)); + IrOp numValue = builtinLoadDouble(build, FFlag::LuauCodegenFastcall3 ? arg3 : build.vmReg(vmRegOp(args) + 1)); build.inst(writeCmd, buf, intIndex, convCmd == IrCmd::NOP ? numValue : build.inst(convCmd, numValue)); return {BuiltinImplType::Full, 0}; } -BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback, int pcpos) +BuiltinImplResult translateBuiltin( + IrBuilder& build, int bfid, int ra, int arg, IrOp args, IrOp arg3, int nparams, int nresults, IrOp fallback, int pcpos) { // Builtins are not allowed to handle variadic arguments if (nparams == LUA_MULTRET) @@ -800,11 +820,11 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_MATH_LOG: return translateBuiltinMathLog(build, nparams, ra, arg, args, nresults, pcpos); case LBF_MATH_MIN: - return translateBuiltinMathMinMax(build, IrCmd::MIN_NUM, nparams, ra, arg, args, nresults, pcpos); + return translateBuiltinMathMinMax(build, IrCmd::MIN_NUM, nparams, ra, arg, args, arg3, nresults, pcpos); case LBF_MATH_MAX: - return translateBuiltinMathMinMax(build, IrCmd::MAX_NUM, nparams, ra, arg, args, nresults, pcpos); + return translateBuiltinMathMinMax(build, IrCmd::MAX_NUM, nparams, ra, arg, args, arg3, nresults, pcpos); case LBF_MATH_CLAMP: - return translateBuiltinMathClamp(build, nparams, ra, arg, args, nresults, fallback, pcpos); + return translateBuiltinMathClamp(build, nparams, ra, arg, args, arg3, nresults, fallback, pcpos); case LBF_MATH_FLOOR: return translateBuiltinMathUnary(build, IrCmd::FLOOR_NUM, nparams, ra, arg, nresults, pcpos); case LBF_MATH_CEIL: @@ -826,9 +846,12 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_MATH_TAN: case LBF_MATH_TANH: case LBF_MATH_LOG10: - return translateBuiltinNumberToNumberLibm(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); + return translateBuiltinNumberToNumberLibm(build, LuauBuiltinFunction(bfid), nparams, ra, arg, nresults, pcpos); case LBF_MATH_SIGN: - return translateBuiltinNumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); + if (FFlag::LuauCodegenMathSign) + return translateBuiltinMathUnary(build, IrCmd::SIGN_NUM, nparams, ra, arg, nresults, pcpos); + else + return translateBuiltinNumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); case LBF_MATH_POW: case LBF_MATH_FMOD: case LBF_MATH_ATAN2: @@ -838,13 +861,13 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_MATH_MODF: return translateBuiltinNumberTo2Number(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, pcpos); case LBF_BIT32_BAND: - return translateBuiltinBit32BinaryOp(build, IrCmd::BITAND_UINT, /* btest= */ false, nparams, ra, arg, args, nresults, pcpos); + return translateBuiltinBit32BinaryOp(build, IrCmd::BITAND_UINT, /* btest= */ false, nparams, ra, arg, args, arg3, nresults, pcpos); case LBF_BIT32_BOR: - return translateBuiltinBit32BinaryOp(build, IrCmd::BITOR_UINT, /* btest= */ false, nparams, ra, arg, args, nresults, pcpos); + return translateBuiltinBit32BinaryOp(build, IrCmd::BITOR_UINT, /* btest= */ false, nparams, ra, arg, args, arg3, nresults, pcpos); case LBF_BIT32_BXOR: - return translateBuiltinBit32BinaryOp(build, IrCmd::BITXOR_UINT, /* btest= */ false, nparams, ra, arg, args, nresults, pcpos); + return translateBuiltinBit32BinaryOp(build, IrCmd::BITXOR_UINT, /* btest= */ false, nparams, ra, arg, args, arg3, nresults, pcpos); case LBF_BIT32_BTEST: - return translateBuiltinBit32BinaryOp(build, IrCmd::BITAND_UINT, /* btest= */ true, nparams, ra, arg, args, nresults, pcpos); + return translateBuiltinBit32BinaryOp(build, IrCmd::BITAND_UINT, /* btest= */ true, nparams, ra, arg, args, arg3, nresults, pcpos); case LBF_BIT32_BNOT: return translateBuiltinBit32Bnot(build, nparams, ra, arg, args, nresults, pcpos); case LBF_BIT32_LSHIFT: @@ -858,7 +881,7 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_BIT32_RROTATE: return translateBuiltinBit32Rotate(build, IrCmd::BITRROTATE_UINT, nparams, ra, arg, args, nresults, pcpos); case LBF_BIT32_EXTRACT: - return translateBuiltinBit32Extract(build, nparams, ra, arg, args, nresults, fallback, pcpos); + return translateBuiltinBit32Extract(build, nparams, ra, arg, args, arg3, nresults, fallback, pcpos); case LBF_BIT32_EXTRACTK: return translateBuiltinBit32ExtractK(build, nparams, ra, arg, args, nresults, pcpos); case LBF_BIT32_COUNTLZ: @@ -866,13 +889,13 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_BIT32_COUNTRZ: return translateBuiltinBit32Unary(build, IrCmd::BITCOUNTRZ_UINT, nparams, ra, arg, args, nresults, pcpos); case LBF_BIT32_REPLACE: - return translateBuiltinBit32Replace(build, nparams, ra, arg, args, nresults, fallback, pcpos); + return translateBuiltinBit32Replace(build, nparams, ra, arg, args, arg3, nresults, fallback, pcpos); case LBF_TYPE: return translateBuiltinType(build, nparams, ra, arg, args, nresults); case LBF_TYPEOF: return translateBuiltinTypeof(build, nparams, ra, arg, args, nresults); case LBF_VECTOR: - return translateBuiltinVector(build, nparams, ra, arg, args, nresults, pcpos); + return translateBuiltinVector(build, nparams, ra, arg, args, arg3, nresults, pcpos); case LBF_TABLE_INSERT: return translateBuiltinTableInsert(build, nparams, ra, arg, args, nresults, pcpos); case LBF_STRING_LEN: @@ -880,31 +903,31 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_BIT32_BYTESWAP: return translateBuiltinBit32Unary(build, IrCmd::BYTESWAP_UINT, nparams, ra, arg, args, nresults, pcpos); case LBF_BUFFER_READI8: - return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READI8, 1, IrCmd::INT_TO_NUM); + return translateBuiltinBufferRead(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_READI8, 1, IrCmd::INT_TO_NUM); case LBF_BUFFER_READU8: - return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READU8, 1, IrCmd::INT_TO_NUM); + return translateBuiltinBufferRead(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_READU8, 1, IrCmd::INT_TO_NUM); case LBF_BUFFER_WRITEU8: - return translateBuiltinBufferWrite(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_WRITEI8, 1, IrCmd::NUM_TO_UINT); + return translateBuiltinBufferWrite(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_WRITEI8, 1, IrCmd::NUM_TO_UINT); case LBF_BUFFER_READI16: - return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READI16, 2, IrCmd::INT_TO_NUM); + return translateBuiltinBufferRead(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_READI16, 2, IrCmd::INT_TO_NUM); case LBF_BUFFER_READU16: - return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READU16, 2, IrCmd::INT_TO_NUM); + return translateBuiltinBufferRead(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_READU16, 2, IrCmd::INT_TO_NUM); case LBF_BUFFER_WRITEU16: - return translateBuiltinBufferWrite(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_WRITEI16, 2, IrCmd::NUM_TO_UINT); + return translateBuiltinBufferWrite(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_WRITEI16, 2, IrCmd::NUM_TO_UINT); case LBF_BUFFER_READI32: - return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READI32, 4, IrCmd::INT_TO_NUM); + return translateBuiltinBufferRead(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_READI32, 4, IrCmd::INT_TO_NUM); case LBF_BUFFER_READU32: - return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READI32, 4, IrCmd::UINT_TO_NUM); + return translateBuiltinBufferRead(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_READI32, 4, IrCmd::UINT_TO_NUM); case LBF_BUFFER_WRITEU32: - return translateBuiltinBufferWrite(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_WRITEI32, 4, IrCmd::NUM_TO_UINT); + return translateBuiltinBufferWrite(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_WRITEI32, 4, IrCmd::NUM_TO_UINT); case LBF_BUFFER_READF32: - return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READF32, 4, IrCmd::NOP); + return translateBuiltinBufferRead(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_READF32, 4, IrCmd::NOP); case LBF_BUFFER_WRITEF32: - return translateBuiltinBufferWrite(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_WRITEF32, 4, IrCmd::NOP); + return translateBuiltinBufferWrite(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_WRITEF32, 4, IrCmd::NOP); case LBF_BUFFER_READF64: - return translateBuiltinBufferRead(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_READF64, 8, IrCmd::NOP); + return translateBuiltinBufferRead(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_READF64, 8, IrCmd::NOP); case LBF_BUFFER_WRITEF64: - return translateBuiltinBufferWrite(build, nparams, ra, arg, args, nresults, pcpos, IrCmd::BUFFER_WRITEF64, 8, IrCmd::NOP); + return translateBuiltinBufferWrite(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_WRITEF64, 8, IrCmd::NOP); default: return {BuiltinImplType::None, -1}; } diff --git a/CodeGen/src/IrTranslateBuiltins.h b/CodeGen/src/IrTranslateBuiltins.h index 8ae64b94..54a05aba 100644 --- a/CodeGen/src/IrTranslateBuiltins.h +++ b/CodeGen/src/IrTranslateBuiltins.h @@ -22,7 +22,8 @@ struct BuiltinImplResult int actualResultCount; }; -BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback, int pcpos); +BuiltinImplResult translateBuiltin( + IrBuilder& build, int bfid, int ra, int arg, IrOp args, IrOp arg3, int nparams, int nresults, IrOp fallback, int pcpos); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 20150f9a..db867fc9 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -3,6 +3,7 @@ #include "Luau/Bytecode.h" #include "Luau/BytecodeUtils.h" +#include "Luau/CodeGen.h" #include "Luau/IrBuilder.h" #include "Luau/IrUtils.h" @@ -12,7 +13,8 @@ #include "lstate.h" #include "ltm.h" -LUAU_FASTFLAGVARIABLE(LuauCodegenDirectUserdataFlow, false) +LUAU_FASTFLAG(LuauCodegenUserdataOps) +LUAU_FASTFLAG(LuauCodegenFastcall3) namespace Luau { @@ -442,6 +444,17 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, return; } + if (FFlag::LuauCodegenUserdataOps && (isUserdataBytecodeType(bcTypes.a) || isUserdataBytecodeType(bcTypes.b))) + { + if (build.hostHooks.userdataMetamethod && + build.hostHooks.userdataMetamethod(build, bcTypes.a, bcTypes.b, ra, opb, opc, tmToHostMetamethod(tm), pcpos)) + return; + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::DO_ARITH, build.vmReg(ra), opb, opc, build.constInt(tm)); + return; + } + IrOp fallback; // fast-path: number @@ -583,6 +596,17 @@ void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos) return; } + if (FFlag::LuauCodegenUserdataOps && isUserdataBytecodeType(bcTypes.a)) + { + if (build.hostHooks.userdataMetamethod && + build.hostHooks.userdataMetamethod(build, bcTypes.a, bcTypes.b, ra, build.vmReg(rb), {}, tmToHostMetamethod(TM_UNM), pcpos)) + return; + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::DO_ARITH, build.vmReg(ra), build.vmReg(rb), build.vmReg(rb), build.constInt(TM_UNM)); + return; + } + IrOp fallback; IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); @@ -604,8 +628,17 @@ void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos) FallbackStreamScope scope(build, fallback, next); build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); - build.inst( - IrCmd::DO_ARITH, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.constInt(TM_UNM)); + + if (FFlag::LuauCodegenUserdataOps) + { + build.inst(IrCmd::DO_ARITH, build.vmReg(ra), build.vmReg(rb), build.vmReg(rb), build.constInt(TM_UNM)); + } + else + { + build.inst( + IrCmd::DO_ARITH, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_B(*pc)), build.constInt(TM_UNM)); + } + build.inst(IrCmd::JUMP, next); } } @@ -617,6 +650,17 @@ void translateInstLength(IrBuilder& build, const Instruction* pc, int pcpos) int ra = LUAU_INSN_A(*pc); int rb = LUAU_INSN_B(*pc); + if (FFlag::LuauCodegenUserdataOps && isUserdataBytecodeType(bcTypes.a)) + { + if (build.hostHooks.userdataMetamethod && + build.hostHooks.userdataMetamethod(build, bcTypes.a, bcTypes.b, ra, build.vmReg(rb), {}, tmToHostMetamethod(TM_LEN), pcpos)) + return; + + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::DO_LEN, build.vmReg(ra), build.vmReg(rb)); + return; + } + IrOp fallback = build.block(IrBlockKind::Fallback); IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); @@ -636,7 +680,12 @@ void translateInstLength(IrBuilder& build, const Instruction* pc, int pcpos) FallbackStreamScope scope(build, fallback, next); build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); - build.inst(IrCmd::DO_LEN, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc))); + + if (FFlag::LuauCodegenUserdataOps) + build.inst(IrCmd::DO_LEN, build.vmReg(ra), build.vmReg(rb)); + else + build.inst(IrCmd::DO_LEN, build.vmReg(LUAU_INSN_A(*pc)), build.vmReg(LUAU_INSN_B(*pc))); + build.inst(IrCmd::JUMP, next); } @@ -693,7 +742,7 @@ void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc) build.inst(IrCmd::CLOSE_UPVALS, build.vmReg(ra)); } -IrOp translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs) +IrOp translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp customArg3) { LuauOpcode opcode = LuauOpcode(LUAU_INSN_OP(*pc)); int bfid = LUAU_INSN_A(*pc); @@ -719,13 +768,15 @@ IrOp translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool builtinArgs = build.constDouble(protok.value.n); } + IrOp builtinArg3 = FFlag::LuauCodegenFastcall3 ? (customParams ? customArg3 : build.vmReg(ra + 3)) : IrOp{}; + IrOp fallback = build.block(IrBlockKind::Fallback); // In unsafe environment, instead of retrying fastcall at 'pcpos' we side-exit directly to fallback sequence build.inst(IrCmd::CHECK_SAFE_ENV, build.vmExit(pcpos + getOpLength(opcode))); - BuiltinImplResult br = - translateBuiltin(build, LuauBuiltinFunction(bfid), ra, arg, builtinArgs, nparams, nresults, fallback, pcpos + getOpLength(opcode)); + BuiltinImplResult br = translateBuiltin( + build, LuauBuiltinFunction(bfid), ra, arg, builtinArgs, builtinArg3, nparams, nresults, fallback, pcpos + getOpLength(opcode)); if (br.type != BuiltinImplType::None) { @@ -742,6 +793,22 @@ IrOp translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool return build.undef(); } } + else if (FFlag::LuauCodegenFastcall3) + { + IrOp arg3 = customParams ? customArg3 : build.undef(); + + // TODO: we can skip saving pc for some well-behaved builtins which we didn't inline + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + getOpLength(opcode))); + + IrOp res = build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, arg3, build.constInt(nparams), + build.constInt(nresults)); + build.inst(IrCmd::CHECK_FASTCALL_RES, res, fallback); + + if (nresults == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(ra), res); + else if (nparams == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_TOP); + } else { // TODO: we can skip saving pc for some well-behaved builtins which we didn't inline @@ -1197,19 +1264,19 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) TString* str = gco2ts(build.function.proto->k[aux].value.gc); const char* field = getstr(str); - if (*field == 'X' || *field == 'x') + if (str->len == 1 && (*field == 'X' || *field == 'x')) { IrOp value = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(rb), build.constInt(0)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); } - else if (*field == 'Y' || *field == 'y') + else if (str->len == 1 && (*field == 'Y' || *field == 'y')) { IrOp value = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(rb), build.constInt(4)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); } - else if (*field == 'Z' || *field == 'z') + else if (str->len == 1 && (*field == 'Z' || *field == 'z')) { IrOp value = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(rb), build.constInt(8)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); @@ -1217,16 +1284,28 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) } else { + if (build.hostHooks.vectorAccess && build.hostHooks.vectorAccess(build, field, str->len, ra, rb, pcpos)) + return; + build.inst(IrCmd::FALLBACK_GETTABLEKS, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); } return; } - if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA) + if (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA) { build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TUSERDATA), build.vmExit(pcpos)); + if (FFlag::LuauCodegenUserdataOps && build.hostHooks.userdataAccess) + { + TString* str = gco2ts(build.function.proto->k[aux].value.gc); + const char* field = getstr(str); + + if (build.hostHooks.userdataAccess(build, bcTypes.a, field, str->len, ra, rb, pcpos)) + return; + } + build.inst(IrCmd::FALLBACK_GETTABLEKS, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); return; } @@ -1261,7 +1340,7 @@ void translateInstSetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); - if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA) + if (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA) { build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TUSERDATA), build.vmExit(pcpos)); @@ -1375,7 +1454,7 @@ void translateInstCapture(IrBuilder& build, const Instruction* pc, int pcpos) } } -void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) +bool translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) { int ra = LUAU_INSN_A(*pc); int rb = LUAU_INSN_B(*pc); @@ -1383,20 +1462,52 @@ void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) BytecodeTypes bcTypes = build.function.getBytecodeTypesAt(pcpos); - if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_VECTOR) + if (bcTypes.a == LBC_TYPE_VECTOR) { build.loadAndCheckTag(build.vmReg(rb), LUA_TVECTOR, build.vmExit(pcpos)); + if (build.hostHooks.vectorNamecall) + { + Instruction call = pc[2]; + CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + int callra = LUAU_INSN_A(call); + int nparams = LUAU_INSN_B(call) - 1; + int nresults = LUAU_INSN_C(call) - 1; + + TString* str = gco2ts(build.function.proto->k[aux].value.gc); + const char* field = getstr(str); + + if (build.hostHooks.vectorNamecall(build, field, str->len, callra, rb, nparams, nresults, pcpos)) + return true; + } + build.inst(IrCmd::FALLBACK_NAMECALL, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); - return; + return false; } - if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA) + if (FFlag::LuauCodegenUserdataOps ? isUserdataBytecodeType(bcTypes.a) : bcTypes.a == LBC_TYPE_USERDATA) { build.loadAndCheckTag(build.vmReg(rb), LUA_TUSERDATA, build.vmExit(pcpos)); + if (FFlag::LuauCodegenUserdataOps && build.hostHooks.userdataNamecall) + { + Instruction call = pc[2]; + CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + int callra = LUAU_INSN_A(call); + int nparams = LUAU_INSN_B(call) - 1; + int nresults = LUAU_INSN_C(call) - 1; + + TString* str = gco2ts(build.function.proto->k[aux].value.gc); + const char* field = getstr(str); + + if (build.hostHooks.userdataNamecall(build, bcTypes.a, field, str->len, callra, rb, nparams, nresults, pcpos)) + return true; + } + build.inst(IrCmd::FALLBACK_NAMECALL, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); - return; + return false; } IrOp next = build.blockAtInst(pcpos + getOpLength(LOP_NAMECALL)); @@ -1404,8 +1515,7 @@ void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) IrOp firstFastPathSuccess = build.block(IrBlockKind::Internal); IrOp secondFastPath = build.block(IrBlockKind::Internal); - build.loadAndCheckTag( - build.vmReg(rb), LUA_TTABLE, FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_TABLE ? build.vmExit(pcpos) : fallback); + build.loadAndCheckTag(build.vmReg(rb), LUA_TTABLE, bcTypes.a == LBC_TYPE_TABLE ? build.vmExit(pcpos) : fallback); IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); CODEGEN_ASSERT(build.function.proto); @@ -1450,6 +1560,8 @@ void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) build.inst(IrCmd::JUMP, next); build.beginBlock(next); + + return false; } void translateInstAndX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c) diff --git a/CodeGen/src/IrTranslation.h b/CodeGen/src/IrTranslation.h index b1f1e28b..8b514cc1 100644 --- a/CodeGen/src/IrTranslation.h +++ b/CodeGen/src/IrTranslation.h @@ -44,7 +44,8 @@ void translateInstDupTable(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstGetUpval(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstSetUpval(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc); -IrOp translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs); +IrOp translateFastCallN( + IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp customArg3); void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpos); @@ -61,7 +62,7 @@ void translateInstGetGlobal(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstSetGlobal(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstConcat(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstCapture(IrBuilder& build, const Instruction* pc, int pcpos); -void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos); +bool translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstAndX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c); void translateInstOrX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c); void translateInstNewClosure(IrBuilder& build, const Instruction* pc, int pcpos); diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index caa6b178..129945df 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -12,6 +12,8 @@ #include #include +LUAU_FASTFLAG(LuauCodegenInstG) + namespace Luau { namespace CodeGen @@ -67,6 +69,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::ROUND_NUM: case IrCmd::SQRT_NUM: case IrCmd::ABS_NUM: + case IrCmd::SIGN_NUM: return IrValueKind::Double; case IrCmd::ADD_VEC: case IrCmd::SUB_VEC: @@ -99,6 +102,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::TRY_NUM_TO_INDEX: return IrValueKind::Int; case IrCmd::TRY_CALL_FASTGETTM: + case IrCmd::NEW_USERDATA: return IrValueKind::Pointer; case IrCmd::INT_TO_NUM: case IrCmd::UINT_TO_NUM: @@ -135,6 +139,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::CHECK_NODE_NO_NEXT: case IrCmd::CHECK_NODE_VALUE: case IrCmd::CHECK_BUFFER_LEN: + case IrCmd::CHECK_USERDATA_TAG: case IrCmd::INTERRUPT: case IrCmd::CHECK_GC: case IrCmd::BARRIER_OBJ: @@ -252,6 +257,54 @@ bool isGCO(uint8_t tag) return tag >= LUA_TSTRING; } +bool isUserdataBytecodeType(uint8_t ty) +{ + return ty == LBC_TYPE_USERDATA || isCustomUserdataBytecodeType(ty); +} + +bool isCustomUserdataBytecodeType(uint8_t ty) +{ + return ty >= LBC_TYPE_TAGGED_USERDATA_BASE && ty < LBC_TYPE_TAGGED_USERDATA_END; +} + +HostMetamethod tmToHostMetamethod(int tm) +{ + switch (TMS(tm)) + { + case TM_ADD: + return HostMetamethod::Add; + case TM_SUB: + return HostMetamethod::Sub; + case TM_MUL: + return HostMetamethod::Mul; + case TM_DIV: + return HostMetamethod::Div; + case TM_IDIV: + return HostMetamethod::Idiv; + case TM_MOD: + return HostMetamethod::Mod; + case TM_POW: + return HostMetamethod::Pow; + case TM_UNM: + return HostMetamethod::Minus; + case TM_EQ: + return HostMetamethod::Equal; + case TM_LT: + return HostMetamethod::LessThan; + case TM_LE: + return HostMetamethod::LessEqual; + case TM_LEN: + return HostMetamethod::Length; + case TM_CONCAT: + return HostMetamethod::Concat; + default: + CODEGEN_ASSERT(!"invalid tag method for host"); + break; + } + + return HostMetamethod::Add; +} + void kill(IrFunction& function, IrInst& inst) { CODEGEN_ASSERT(inst.useCount == 0); @@ -265,12 +318,18 @@ void kill(IrFunction& function, IrInst& inst) removeUse(function, inst.e); removeUse(function, inst.f); + if (FFlag::LuauCodegenInstG) + removeUse(function, inst.g); + inst.a = {}; inst.b = {}; inst.c = {}; inst.d = {}; inst.e = {}; inst.f = {}; + + if (FFlag::LuauCodegenInstG) + inst.g = {}; } void kill(IrFunction& function, uint32_t start, uint32_t end) @@ -320,6 +379,9 @@ void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst repl addUse(function, replacement.e); addUse(function, replacement.f); + if (FFlag::LuauCodegenInstG) + addUse(function, replacement.g); + // An extra reference is added so block will not remove itself block.useCount++; @@ -342,6 +404,9 @@ void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst repl removeUse(function, inst.e); removeUse(function, inst.f); + if (FFlag::LuauCodegenInstG) + removeUse(function, inst.g); + // Inherit existing use count (last use is skipped as it will be defined later) replacement.useCount = inst.useCount; @@ -367,12 +432,18 @@ void substitute(IrFunction& function, IrInst& inst, IrOp replacement) removeUse(function, inst.e); removeUse(function, inst.f); + if (FFlag::LuauCodegenInstG) + removeUse(function, inst.g); + inst.a = replacement; inst.b = {}; inst.c = {}; inst.d = {}; inst.e = {}; inst.f = {}; + + if (FFlag::LuauCodegenInstG) + inst.g = {}; } void applySubstitutions(IrFunction& function, IrOp& op) @@ -416,6 +487,9 @@ void applySubstitutions(IrFunction& function, IrInst& inst) applySubstitutions(function, inst.d); applySubstitutions(function, inst.e); applySubstitutions(function, inst.f); + + if (FFlag::LuauCodegenInstG) + applySubstitutions(function, inst.g); } bool compare(double a, double b, IrCondition cond) @@ -585,6 +659,14 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 if (inst.a.kind == IrOpKind::Constant) substitute(function, inst, build.constDouble(fabs(function.doubleOp(inst.a)))); break; + case IrCmd::SIGN_NUM: + if (inst.a.kind == IrOpKind::Constant) + { + double v = function.doubleOp(inst.a); + + substitute(function, inst, build.constDouble(v > 0.0 ? 1.0 : v < 0.0 ? -1.0 : 0.0)); + } + break; case IrCmd::NOT_ANY: if (inst.a.kind == IrOpKind::Constant) { diff --git a/CodeGen/src/IrValueLocationTracking.cpp b/CodeGen/src/IrValueLocationTracking.cpp index 3dc72610..0224b49b 100644 --- a/CodeGen/src/IrValueLocationTracking.cpp +++ b/CodeGen/src/IrValueLocationTracking.cpp @@ -3,6 +3,8 @@ #include "Luau/IrUtils.h" +LUAU_FASTFLAG(LuauCodegenFastcall3) + namespace Luau { namespace CodeGen @@ -44,11 +46,11 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) invalidateRestoreVmRegs(vmRegOp(inst.a), -1); break; case IrCmd::FASTCALL: - invalidateRestoreVmRegs(vmRegOp(inst.b), function.intOp(inst.f)); + invalidateRestoreVmRegs(vmRegOp(inst.b), function.intOp(FFlag::LuauCodegenFastcall3 ? inst.d : inst.f)); break; case IrCmd::INVOKE_FASTCALL: // Multiple return sequences (count == -1) are defined by ADJUST_STACK_TO_REG - if (int count = function.intOp(inst.f); count != -1) + if (int count = function.intOp(FFlag::LuauCodegenFastcall3 ? inst.g : inst.f); count != -1) invalidateRestoreVmRegs(vmRegOp(inst.b), count); break; case IrCmd::DO_ARITH: @@ -119,7 +121,7 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) break; // These instructions read VmReg only after optimizeMemoryOperandsX64 - case IrCmd::CHECK_TAG: // TODO: remove with FFlagLuauCodegenRemoveDeadStores5 + case IrCmd::CHECK_TAG: case IrCmd::CHECK_TRUTHY: case IrCmd::ADD_NUM: case IrCmd::SUB_NUM: @@ -146,6 +148,7 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) CODEGEN_ASSERT(inst.d.kind != IrOpKind::VmReg); CODEGEN_ASSERT(inst.e.kind != IrOpKind::VmReg); CODEGEN_ASSERT(inst.f.kind != IrOpKind::VmReg); + CODEGEN_ASSERT(inst.g.kind != IrOpKind::VmReg); break; } } diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index 5f6df4b6..7aa35f23 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -14,104 +14,13 @@ #include #include -LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024) -LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024) +LUAU_FASTFLAG(LuauCodegenUserdataAlloc) namespace Luau { namespace CodeGen { -NativeState::NativeState() - : NativeState(nullptr, nullptr) -{ -} - -NativeState::NativeState(AllocationCallback* allocationCallback, void* allocationCallbackContext) - : codeAllocator{size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext} -{ -} - -NativeState::~NativeState() = default; - -void initFunctions(NativeState& data) -{ - static_assert(sizeof(data.context.luauF_table) == sizeof(luauF_table), "fastcall tables are not of the same length"); - memcpy(data.context.luauF_table, luauF_table, sizeof(luauF_table)); - - data.context.luaV_lessthan = luaV_lessthan; - data.context.luaV_lessequal = luaV_lessequal; - data.context.luaV_equalval = luaV_equalval; - data.context.luaV_doarith = luaV_doarith; - data.context.luaV_dolen = luaV_dolen; - data.context.luaV_gettable = luaV_gettable; - data.context.luaV_settable = luaV_settable; - data.context.luaV_getimport = luaV_getimport; - data.context.luaV_concat = luaV_concat; - - data.context.luaH_getn = luaH_getn; - data.context.luaH_new = luaH_new; - data.context.luaH_clone = luaH_clone; - data.context.luaH_resizearray = luaH_resizearray; - data.context.luaH_setnum = luaH_setnum; - - data.context.luaC_barriertable = luaC_barriertable; - data.context.luaC_barrierf = luaC_barrierf; - data.context.luaC_barrierback = luaC_barrierback; - data.context.luaC_step = luaC_step; - - data.context.luaF_close = luaF_close; - data.context.luaF_findupval = luaF_findupval; - data.context.luaF_newLclosure = luaF_newLclosure; - - data.context.luaT_gettm = luaT_gettm; - data.context.luaT_objtypenamestr = luaT_objtypenamestr; - - data.context.libm_exp = exp; - data.context.libm_pow = pow; - data.context.libm_fmod = fmod; - data.context.libm_log = log; - data.context.libm_log2 = log2; - data.context.libm_log10 = log10; - data.context.libm_ldexp = ldexp; - data.context.libm_round = round; - data.context.libm_frexp = frexp; - data.context.libm_modf = modf; - - data.context.libm_asin = asin; - data.context.libm_sin = sin; - data.context.libm_sinh = sinh; - data.context.libm_acos = acos; - data.context.libm_cos = cos; - data.context.libm_cosh = cosh; - data.context.libm_atan = atan; - data.context.libm_atan2 = atan2; - data.context.libm_tan = tan; - data.context.libm_tanh = tanh; - - data.context.forgLoopTableIter = forgLoopTableIter; - data.context.forgLoopNodeIter = forgLoopNodeIter; - data.context.forgLoopNonTableFallback = forgLoopNonTableFallback; - data.context.forgPrepXnextFallback = forgPrepXnextFallback; - data.context.callProlog = callProlog; - data.context.callEpilogC = callEpilogC; - - data.context.callFallback = callFallback; - - data.context.executeGETGLOBAL = executeGETGLOBAL; - data.context.executeSETGLOBAL = executeSETGLOBAL; - data.context.executeGETTABLEKS = executeGETTABLEKS; - data.context.executeSETTABLEKS = executeSETTABLEKS; - - data.context.executeNAMECALL = executeNAMECALL; - data.context.executeFORGPREP = executeFORGPREP; - data.context.executeGETVARARGSMultRet = executeGETVARARGSMultRet; - data.context.executeGETVARARGSConst = executeGETVARARGSConst; - data.context.executeDUPCLOSURE = executeDUPCLOSURE; - data.context.executePREPVARARGS = executePREPVARARGS; - data.context.executeSETLIST = executeSETLIST; -} - void initFunctions(NativeContext& context) { static_assert(sizeof(context.luauF_table) == sizeof(luauF_table), "fastcall tables are not of the same length"); @@ -120,7 +29,16 @@ void initFunctions(NativeContext& context) context.luaV_lessthan = luaV_lessthan; context.luaV_lessequal = luaV_lessequal; context.luaV_equalval = luaV_equalval; - context.luaV_doarith = luaV_doarith; + + context.luaV_doarithadd = luaV_doarithimpl; + context.luaV_doarithsub = luaV_doarithimpl; + context.luaV_doarithmul = luaV_doarithimpl; + context.luaV_doarithdiv = luaV_doarithimpl; + context.luaV_doarithidiv = luaV_doarithimpl; + context.luaV_doarithmod = luaV_doarithimpl; + context.luaV_doarithpow = luaV_doarithimpl; + context.luaV_doarithunm = luaV_doarithimpl; + context.luaV_dolen = luaV_dolen; context.luaV_gettable = luaV_gettable; context.luaV_settable = luaV_settable; @@ -174,6 +92,9 @@ void initFunctions(NativeContext& context) context.callProlog = callProlog; context.callEpilogC = callEpilogC; + if (FFlag::LuauCodegenUserdataAlloc) + context.newUserdata = newUserdata; + context.callFallback = callFallback; context.executeGETGLOBAL = executeGETGLOBAL; diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index 3e7c85e9..941db252 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -33,7 +33,14 @@ struct NativeContext int (*luaV_lessthan)(lua_State* L, const TValue* l, const TValue* r) = nullptr; int (*luaV_lessequal)(lua_State* L, const TValue* l, const TValue* r) = nullptr; int (*luaV_equalval)(lua_State* L, const TValue* t1, const TValue* t2) = nullptr; - void (*luaV_doarith)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TMS op) = nullptr; + void (*luaV_doarithadd)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; + void (*luaV_doarithsub)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; + void (*luaV_doarithmul)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; + void (*luaV_doarithdiv)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; + void (*luaV_doarithidiv)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; + void (*luaV_doarithmod)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; + void (*luaV_doarithpow)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; + void (*luaV_doarithunm)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) = nullptr; void (*luaV_dolen)(lua_State* L, StkId ra, const TValue* rb) = nullptr; void (*luaV_gettable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr; void (*luaV_settable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr; @@ -86,6 +93,7 @@ struct NativeContext void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr; Closure* (*callProlog)(lua_State* L, TValue* ra, StkId argtop, int nresults) = nullptr; void (*callEpilogC)(lua_State* L, int nresults, int n) = nullptr; + Udata* (*newUserdata)(lua_State* L, size_t s, int tag) = nullptr; Closure* (*callFallback)(lua_State* L, StkId ra, StkId argtop, int nresults) = nullptr; @@ -108,22 +116,6 @@ struct NativeContext using GateFn = int (*)(lua_State*, Proto*, uintptr_t, NativeContext*); -struct NativeState -{ - NativeState(); - NativeState(AllocationCallback* allocationCallback, void* allocationCallbackContext); - ~NativeState(); - - CodeAllocator codeAllocator; - std::unique_ptr unwindBuilder; - - uint8_t* gateData = nullptr; - size_t gateDataSize = 0; - - NativeContext context; -}; - -void initFunctions(NativeState& data); void initFunctions(NativeContext& context); } // namespace CodeGen diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index eae0baa3..ac90f8e5 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -16,9 +16,12 @@ LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) +LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false) -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) -LUAU_FASTFLAGVARIABLE(LuauCodegenLoadPropCheckRegLinkInTv, false) +LUAU_FASTFLAG(LuauCodegenUserdataOps) +LUAU_FASTFLAG(LuauCodegenUserdataAlloc) +LUAU_FASTFLAG(LuauCodegenFastcall3) +LUAU_FASTFLAG(LuauCodegenMathSign) namespace Luau { @@ -200,6 +203,11 @@ struct ConstPropState checkBufferLenCache.clear(); } + void invalidateUserdataData() + { + useradataTagCache.clear(); + } + void invalidateHeap() { for (int i = 0; i <= maxReg; ++i) @@ -417,6 +425,9 @@ struct ConstPropState invalidateValuePropagation(); invalidateHeapTableData(); invalidateHeapBufferData(); + + if (FFlag::LuauCodegenUserdataOps) + invalidateUserdataData(); } IrFunction& function; @@ -446,6 +457,9 @@ struct ConstPropState std::vector checkArraySizeCache; // Additionally, fallback block argument might be different std::vector checkBufferLenCache; // Additionally, fallback block argument might be different + + // Userdata tag cache can point to both NEW_USERDATA and CHECK_USERDATA_TAG instructions + std::vector useradataTagCache; // Additionally, fallback block argument might be different }; static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid, uint32_t firstReturnReg, int nresults) @@ -607,16 +621,9 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& std::tie(activeLoadCmd, activeLoadValue) = state.getPreviousVersionedLoadForTag(value, source); if (state.tryGetTag(source) == value) - { - if (FFlag::DebugLuauAbortingChecks && !FFlag::LuauCodegenRemoveDeadStores5) - replace(function, block, index, {IrCmd::CHECK_TAG, inst.a, inst.b, build.undef()}); - else - kill(function, inst); - } + kill(function, inst); else - { state.saveTag(source, value); - } } else { @@ -739,7 +746,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& // If we know the tag, we can try extracting the value from a register used by LOAD_TVALUE // To do that, we have to ensure that the register link of the source value is still valid - if (tag != 0xff && (!FFlag::LuauCodegenLoadPropCheckRegLinkInTv || state.tryGetRegLink(inst.b) != nullptr)) + if (tag != 0xff && state.tryGetRegLink(inst.b) != nullptr) { if (IrInst* arg = function.asInstOp(inst.b); arg && arg->cmd == IrCmd::LOAD_TVALUE && arg->a.kind == IrOpKind::VmReg) { @@ -751,7 +758,18 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& } // If we have constant tag and value, replace TValue store with tag/value pair store - if (tag != 0xff && value.kind != IrOpKind::None && (tag == LUA_TBOOLEAN || tag == LUA_TNUMBER || isGCO(tag))) + bool canSplitTvalueStore = false; + + if (tag == LUA_TBOOLEAN && + (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Int))) + canSplitTvalueStore = true; + else if (tag == LUA_TNUMBER && + (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double))) + canSplitTvalueStore = true; + else if (tag != 0xff && isGCO(tag) && value.kind == IrOpKind::Inst) + canSplitTvalueStore = true; + + if (canSplitTvalueStore) { replace(function, block, index, {IrCmd::STORE_SPLIT_TVALUE, inst.a, build.constTag(tag), value, inst.c}); @@ -1031,6 +1049,37 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.checkBufferLenCache.push_back(index); break; } + case IrCmd::CHECK_USERDATA_TAG: + { + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps); + + for (uint32_t prevIdx : state.useradataTagCache) + { + IrInst& prev = function.instructions[prevIdx]; + + if (prev.cmd == IrCmd::CHECK_USERDATA_TAG) + { + if (prev.a != inst.a || prev.b != inst.b) + continue; + } + else if (FFlag::LuauCodegenUserdataAlloc && prev.cmd == IrCmd::NEW_USERDATA) + { + if (inst.a.kind != IrOpKind::Inst || prevIdx != inst.a.index || prev.b != inst.b) + continue; + } + + if (FFlag::DebugLuauAbortingChecks) + replace(function, inst.c, build.undef()); + else + kill(function, inst); + + return; // Break out from both the loop and the switch + } + + if (int(state.useradataTagCache.size()) < FInt::LuauCodeGenReuseUdataTagLimit) + state.useradataTagCache.push_back(index); + break; + } case IrCmd::BUFFER_READI8: case IrCmd::BUFFER_READU8: case IrCmd::BUFFER_WRITEI8: @@ -1075,39 +1124,34 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::FASTCALL: { - if (FFlag::LuauCodegenRemoveDeadStores5) + LuauBuiltinFunction bfid = LuauBuiltinFunction(function.uintOp(inst.a)); + int firstReturnReg = vmRegOp(inst.b); + int nresults = function.intOp(FFlag::LuauCodegenFastcall3 ? inst.d : inst.f); + + // TODO: FASTCALL is more restrictive than INVOKE_FASTCALL; we should either determine the exact semantics, or rework it + handleBuiltinEffects(state, bfid, firstReturnReg, nresults); + + switch (bfid) { - LuauBuiltinFunction bfid = LuauBuiltinFunction(function.uintOp(inst.a)); - int firstReturnReg = vmRegOp(inst.b); - int nresults = function.intOp(inst.f); + case LBF_MATH_MODF: + case LBF_MATH_FREXP: + state.updateTag(IrOp{IrOpKind::VmReg, uint8_t(firstReturnReg)}, LUA_TNUMBER); - // TODO: FASTCALL is more restrictive than INVOKE_FASTCALL; we should either determine the exact semantics, or rework it - handleBuiltinEffects(state, bfid, firstReturnReg, nresults); - - switch (bfid) - { - case LBF_MATH_MODF: - case LBF_MATH_FREXP: - state.updateTag(IrOp{IrOpKind::VmReg, uint8_t(firstReturnReg)}, LUA_TNUMBER); - - if (nresults > 1) - state.updateTag(IrOp{IrOpKind::VmReg, uint8_t(firstReturnReg + 1)}, LUA_TNUMBER); - break; - case LBF_MATH_SIGN: - state.updateTag(IrOp{IrOpKind::VmReg, uint8_t(firstReturnReg)}, LUA_TNUMBER); - break; - default: - break; - } - } - else - { - handleBuiltinEffects(state, LuauBuiltinFunction(function.uintOp(inst.a)), vmRegOp(inst.b), function.intOp(inst.f)); + if (nresults > 1) + state.updateTag(IrOp{IrOpKind::VmReg, uint8_t(firstReturnReg + 1)}, LUA_TNUMBER); + break; + case LBF_MATH_SIGN: + CODEGEN_ASSERT(!FFlag::LuauCodegenMathSign); + state.updateTag(IrOp{IrOpKind::VmReg, uint8_t(firstReturnReg)}, LUA_TNUMBER); + break; + default: + break; } break; } case IrCmd::INVOKE_FASTCALL: - handleBuiltinEffects(state, LuauBuiltinFunction(function.uintOp(inst.a)), vmRegOp(inst.b), function.intOp(inst.f)); + handleBuiltinEffects( + state, LuauBuiltinFunction(function.uintOp(inst.a)), vmRegOp(inst.b), function.intOp(FFlag::LuauCodegenFastcall3 ? inst.g : inst.f)); break; // These instructions don't have an effect on register/memory state we are tracking @@ -1163,6 +1207,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::ROUND_NUM: case IrCmd::SQRT_NUM: case IrCmd::ABS_NUM: + case IrCmd::SIGN_NUM: case IrCmd::NOT_ANY: state.substituteOrRecord(inst, index); break; @@ -1198,6 +1243,12 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& break; case IrCmd::TRY_CALL_FASTGETTM: break; + case IrCmd::NEW_USERDATA: + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataAlloc); + + if (int(state.useradataTagCache.size()) < FInt::LuauCodeGenReuseUdataTagLimit) + state.useradataTagCache.push_back(index); + break; case IrCmd::INT_TO_NUM: case IrCmd::UINT_TO_NUM: state.substituteOrRecord(inst, index); @@ -1482,6 +1533,9 @@ static void constPropInBlockChain(IrBuilder& build, std::vector& visite state.invalidateHeapTableData(); state.invalidateHeapBufferData(); + if (FFlag::LuauCodegenUserdataOps) + state.invalidateUserdataData(); + // Blocks in a chain are guaranteed to follow each other // We force that by giving all blocks the same sorting key, but consecutive chain keys block->sortkey = startSortkey; diff --git a/CodeGen/src/OptimizeDeadStore.cpp b/CodeGen/src/OptimizeDeadStore.cpp index 6c1d6aff..9fa6f062 100644 --- a/CodeGen/src/OptimizeDeadStore.cpp +++ b/CodeGen/src/OptimizeDeadStore.cpp @@ -9,7 +9,7 @@ #include "lobject.h" -LUAU_FASTFLAGVARIABLE(LuauCodegenRemoveDeadStores5, false) +LUAU_FASTFLAG(LuauCodegenUserdataOps) // TODO: optimization can be improved by knowing which registers are live in at each VM exit @@ -595,6 +595,11 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, case IrCmd::CHECK_BUFFER_LEN: state.checkLiveIns(inst.d); break; + case IrCmd::CHECK_USERDATA_TAG: + CODEGEN_ASSERT(FFlag::LuauCodegenUserdataOps); + + state.checkLiveIns(inst.c); + break; case IrCmd::JUMP: // Ideally, we would be able to remove stores to registers that are not live out from a block diff --git a/CodeGen/src/UnwindBuilderDwarf2.cpp b/CodeGen/src/UnwindBuilderDwarf2.cpp index b1522e7b..2f090b52 100644 --- a/CodeGen/src/UnwindBuilderDwarf2.cpp +++ b/CodeGen/src/UnwindBuilderDwarf2.cpp @@ -202,7 +202,7 @@ void UnwindBuilderDwarf2::finishInfo() // Terminate section pos = writeu32(pos, 0); - CODEGEN_ASSERT(getSize() <= kRawDataLimit); + CODEGEN_ASSERT(getUnwindInfoSize() <= kRawDataLimit); } void UnwindBuilderDwarf2::prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list regs) @@ -271,19 +271,14 @@ void UnwindBuilderDwarf2::prologueX64(uint32_t prologueSize, uint32_t stackSize, CODEGEN_ASSERT(prologueOffset == prologueSize); } -size_t UnwindBuilderDwarf2::getSize() const +size_t UnwindBuilderDwarf2::getUnwindInfoSize(size_t blockSize) const { return size_t(pos - rawData); } -size_t UnwindBuilderDwarf2::getFunctionCount() const +size_t UnwindBuilderDwarf2::finalize(char* target, size_t offset, void* funcAddress, size_t blockSize) const { - return unwindFunctions.size(); -} - -void UnwindBuilderDwarf2::finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const -{ - memcpy(target, rawData, getSize()); + memcpy(target, rawData, getUnwindInfoSize()); for (const UnwindFunctionDwarf2& func : unwindFunctions) { @@ -291,11 +286,13 @@ void UnwindBuilderDwarf2::finalize(char* target, size_t offset, void* funcAddres writeu64(fdeEntry + kFdeInitialLocationOffset, uintptr_t(funcAddress) + offset + func.beginOffset); - if (func.endOffset == kFullBlockFuncton) - writeu64(fdeEntry + kFdeAddressRangeOffset, funcSize - offset); + if (func.endOffset == kFullBlockFunction) + writeu64(fdeEntry + kFdeAddressRangeOffset, blockSize - offset); else writeu64(fdeEntry + kFdeAddressRangeOffset, func.endOffset - func.beginOffset); } + + return unwindFunctions.size(); } } // namespace CodeGen diff --git a/CodeGen/src/UnwindBuilderWin.cpp b/CodeGen/src/UnwindBuilderWin.cpp index 498470bd..2bcc0321 100644 --- a/CodeGen/src/UnwindBuilderWin.cpp +++ b/CodeGen/src/UnwindBuilderWin.cpp @@ -194,17 +194,12 @@ void UnwindBuilderWin::prologueX64(uint32_t prologueSize, uint32_t stackSize, bo this->prologSize = prologueSize; } -size_t UnwindBuilderWin::getSize() const +size_t UnwindBuilderWin::getUnwindInfoSize(size_t blockSize) const { return sizeof(UnwindFunctionWin) * unwindFunctions.size() + size_t(rawDataPos - rawData); } -size_t UnwindBuilderWin::getFunctionCount() const -{ - return unwindFunctions.size(); -} - -void UnwindBuilderWin::finalize(char* target, size_t offset, void* funcAddress, size_t funcSize) const +size_t UnwindBuilderWin::finalize(char* target, size_t offset, void* funcAddress, size_t blockSize) const { // Copy adjusted function information for (UnwindFunctionWin func : unwindFunctions) @@ -213,8 +208,8 @@ void UnwindBuilderWin::finalize(char* target, size_t offset, void* funcAddress, func.beginOffset += uint32_t(offset); // Whole block is a part of a 'single function' - if (func.endOffset == kFullBlockFuncton) - func.endOffset = uint32_t(funcSize); + if (func.endOffset == kFullBlockFunction) + func.endOffset = uint32_t(blockSize); else func.endOffset += uint32_t(offset); @@ -226,6 +221,8 @@ void UnwindBuilderWin::finalize(char* target, size_t offset, void* funcAddress, // Copy unwind codes memcpy(target, rawData, size_t(rawDataPos - rawData)); + + return unwindFunctions.size(); } } // namespace CodeGen diff --git a/CodeGen/src/lcodegen.cpp b/CodeGen/src/lcodegen.cpp index 0795cd48..1ad685a1 100644 --- a/CodeGen/src/lcodegen.cpp +++ b/CodeGen/src/lcodegen.cpp @@ -17,5 +17,6 @@ void luau_codegen_create(lua_State* L) void luau_codegen_compile(lua_State* L, int idx) { - Luau::CodeGen::compile(L, idx); + Luau::CodeGen::CompilationOptions options; + Luau::CodeGen::compile(L, idx, options); } diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 7012d820..604b8b86 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -46,6 +46,7 @@ // Version 3: Adds FORGPREP/JUMPXEQK* and enhances AUX encoding for FORGLOOP. Removes FORGLOOP_NEXT/INEXT and JUMPIFEQK/JUMPIFNOTEQK. Currently supported. // Version 4: Adds Proto::flags, typeinfo, and floor division opcodes IDIV/IDIVK. Currently supported. // Version 5: Adds SUBRK/DIVRK and vector constants. Currently supported. +// Version 6: Adds FASTCALL3. Currently supported. // # Bytecode type information history // Version 1: (from bytecode version 4) Type information for function signature. Currently supported. @@ -299,8 +300,13 @@ enum LuauOpcode // A: target register (see FORGLOOP for register layout) LOP_FORGPREP_INEXT, - // removed in v3 - LOP_DEP_FORGLOOP_INEXT, + // FASTCALL3: perform a fast call of a built-in function using 3 register arguments + // A: builtin function id (see LuauBuiltinFunction) + // B: source argument register + // C: jump offset to get to following CALL + // AUX: source register 2 in least-significant byte + // AUX: source register 3 in second least-significant byte + LOP_FASTCALL3, // FORGPREP_NEXT: prepare FORGLOOP with 2 output variables (no AUX encoding), assuming generator is luaB_next, and jump to FORGLOOP // A: target register (see FORGLOOP for register layout) @@ -434,11 +440,12 @@ enum LuauBytecodeTag { // Bytecode version; runtime supports [MIN, MAX], compiler emits TARGET by default but may emit a higher version when flags are enabled LBC_VERSION_MIN = 3, - LBC_VERSION_MAX = 5, + LBC_VERSION_MAX = 6, LBC_VERSION_TARGET = 5, // Type encoding version - LBC_TYPE_VERSION_DEPRECATED = 1, - LBC_TYPE_VERSION = 2, + LBC_TYPE_VERSION_MIN = 1, + LBC_TYPE_VERSION_MAX = 3, + LBC_TYPE_VERSION_TARGET = 3, // Types of constant table entries LBC_CONSTANT_NIL = 0, LBC_CONSTANT_BOOLEAN, @@ -465,6 +472,10 @@ enum LuauBytecodeType LBC_TYPE_BUFFER, LBC_TYPE_ANY = 15, + + LBC_TYPE_TAGGED_USERDATA_BASE = 64, + LBC_TYPE_TAGGED_USERDATA_END = 64 + 32, + LBC_TYPE_OPTIONAL_BIT = 1 << 7, LBC_TYPE_INVALID = 256, @@ -606,4 +617,6 @@ enum LuauProtoFlag LPF_NATIVE_MODULE = 1 << 0, // used to tag individual protos as not profitable to compile natively LPF_NATIVE_COLD = 1 << 1, + // used to tag main proto for modules that have at least one function with native attribute + LPF_NATIVE_FUNCTION = 1 << 2, }; diff --git a/Common/include/Luau/BytecodeUtils.h b/Common/include/Luau/BytecodeUtils.h index 957c804c..6f110311 100644 --- a/Common/include/Luau/BytecodeUtils.h +++ b/Common/include/Luau/BytecodeUtils.h @@ -28,6 +28,7 @@ inline int getOpLength(LuauOpcode op) case LOP_LOADKX: case LOP_FASTCALL2: case LOP_FASTCALL2K: + case LOP_FASTCALL3: case LOP_JUMPXEQKNIL: case LOP_JUMPXEQKB: case LOP_JUMPXEQKN: diff --git a/Common/include/Luau/DenseHash.h b/Common/include/Luau/DenseHash.h index 507a9c48..39e50f92 100644 --- a/Common/include/Luau/DenseHash.h +++ b/Common/include/Luau/DenseHash.h @@ -120,12 +120,12 @@ public: return *this; } - void clear() + void clear(size_t thresholdToDestroy = 32) { if (count == 0) return; - if (capacity > 32) + if (capacity > thresholdToDestroy) { destroy(); } @@ -583,9 +583,9 @@ public: { } - void clear() + void clear(size_t thresholdToDestroy = 32) { - impl.clear(); + impl.clear(thresholdToDestroy); } // Note: this reference is invalidated by any insert operation (i.e. operator[]) diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index 7f0115bb..59d30d62 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -79,6 +79,9 @@ public: void pushLocalTypeInfo(LuauBytecodeType type, uint8_t reg, uint32_t startpc, uint32_t endpc); void pushUpvalTypeInfo(LuauBytecodeType type); + uint32_t addUserdataType(const char* name); + void useUserdataType(uint32_t index); + void setDebugFunctionName(StringRef name); void setDebugFunctionLineDefined(int line); void setDebugLine(int line); @@ -229,6 +232,13 @@ private: LuauBytecodeType type; }; + struct UserdataType + { + std::string name; + uint32_t nameRef = 0; + bool used = false; + }; + struct Jump { uint32_t source; @@ -277,6 +287,8 @@ private: std::vector typedLocals; std::vector typedUpvals; + std::vector userdataTypes; + DenseHashMap stringTable; std::vector debugStrings; @@ -308,6 +320,8 @@ private: int32_t addConstant(const ConstantKey& key, const Constant& value); unsigned int addStringTableEntry(StringRef value); + + const char* tryGetUserdataTypeName(LuauBytecodeType type) const; }; } // namespace Luau diff --git a/Compiler/include/Luau/Compiler.h b/Compiler/include/Luau/Compiler.h index 698a50c4..119e0aa2 100644 --- a/Compiler/include/Luau/Compiler.h +++ b/Compiler/include/Luau/Compiler.h @@ -46,6 +46,9 @@ struct CompileOptions // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these const char* const* mutableGlobals = nullptr; + + // null-terminated array of userdata types that will be included in the type information + const char* const* userdataTypes = nullptr; }; class CompileError : public std::exception diff --git a/Compiler/include/luacode.h b/Compiler/include/luacode.h index a470319d..1440a699 100644 --- a/Compiler/include/luacode.h +++ b/Compiler/include/luacode.h @@ -42,6 +42,9 @@ struct lua_CompileOptions // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these const char* const* mutableGlobals; + + // null-terminated array of userdata types that will be included in the type information + const char* const* userdataTypes; }; // compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index 2b09b7e0..c576e3a4 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -454,7 +454,7 @@ BuiltinInfo getBuiltinInfo(int bfid) case LBF_BUFFER_WRITEF32: case LBF_BUFFER_WRITEF64: return {3, 0, BuiltinInfo::Flag_NoneSafe}; - }; + } LUAU_UNREACHABLE(); } diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 5386a528..f68884c5 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -7,9 +7,8 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCompileNoJumpLineRetarget, false) -LUAU_FASTFLAG(LuauCompileRepeatUntilSkippedLocals) -LUAU_FASTFLAGVARIABLE(LuauCompileTypeInfo, false) +LUAU_FASTFLAG(LuauCompileUserdataInfo) +LUAU_FASTFLAG(LuauCompileFastcall3) namespace Luau { @@ -114,6 +113,7 @@ inline bool isFastCall(LuauOpcode op) case LOP_FASTCALL1: case LOP_FASTCALL2: case LOP_FASTCALL2K: + case LOP_FASTCALL3: return true; default: @@ -282,11 +282,8 @@ void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues, uin debugLocals.clear(); debugUpvals.clear(); - if (FFlag::LuauCompileTypeInfo) - { - typedLocals.clear(); - typedUpvals.clear(); - } + typedLocals.clear(); + typedUpvals.clear(); constantMap.clear(); tableShapeMap.clear(); @@ -335,6 +332,18 @@ unsigned int BytecodeBuilder::addStringTableEntry(StringRef value) return index; } +const char* BytecodeBuilder::tryGetUserdataTypeName(LuauBytecodeType type) const +{ + LUAU_ASSERT(FFlag::LuauCompileUserdataInfo); + + unsigned index = unsigned((type & ~LBC_TYPE_OPTIONAL_BIT) - LBC_TYPE_TAGGED_USERDATA_BASE); + + if (index < userdataTypes.size()) + return userdataTypes[index].name.c_str(); + + return nullptr; +} + int32_t BytecodeBuilder::addConstantNil() { Constant c = {Constant::Type_Nil}; @@ -546,8 +555,6 @@ void BytecodeBuilder::setFunctionTypeInfo(std::string value) void BytecodeBuilder::pushLocalTypeInfo(LuauBytecodeType type, uint8_t reg, uint32_t startpc, uint32_t endpc) { - LUAU_ASSERT(FFlag::LuauCompileTypeInfo); - TypedLocal local; local.type = type; local.reg = reg; @@ -559,14 +566,31 @@ void BytecodeBuilder::pushLocalTypeInfo(LuauBytecodeType type, uint8_t reg, uint void BytecodeBuilder::pushUpvalTypeInfo(LuauBytecodeType type) { - LUAU_ASSERT(FFlag::LuauCompileTypeInfo); - TypedUpval upval; upval.type = type; typedUpvals.push_back(upval); } +uint32_t BytecodeBuilder::addUserdataType(const char* name) +{ + LUAU_ASSERT(FFlag::LuauCompileUserdataInfo); + + UserdataType ty; + + ty.name = name; + + userdataTypes.push_back(std::move(ty)); + return uint32_t(userdataTypes.size() - 1); +} + +void BytecodeBuilder::useUserdataType(uint32_t index) +{ + LUAU_ASSERT(FFlag::LuauCompileUserdataInfo); + + userdataTypes[index].used = true; +} + void BytecodeBuilder::setDebugFunctionName(StringRef name) { unsigned int index = addStringTableEntry(name); @@ -648,6 +672,15 @@ void BytecodeBuilder::finalize() { LUAU_ASSERT(bytecode.empty()); + if (FFlag::LuauCompileUserdataInfo) + { + for (auto& ty : userdataTypes) + { + if (ty.used) + ty.nameRef = addStringTableEntry(StringRef({ty.name.c_str(), ty.name.length()})); + } + } + // preallocate space for bytecode blob size_t capacity = 16; @@ -666,10 +699,24 @@ void BytecodeBuilder::finalize() bytecode = char(version); uint8_t typesversion = getTypeEncodingVersion(); + LUAU_ASSERT(typesversion >= LBC_TYPE_VERSION_MIN && typesversion <= LBC_TYPE_VERSION_MAX); writeByte(bytecode, typesversion); writeStringTable(bytecode); + if (FFlag::LuauCompileUserdataInfo) + { + // Write the mapping between used type name indices and their name + for (uint32_t i = 0; i < uint32_t(userdataTypes.size()); i++) + { + writeByte(bytecode, i + 1); + writeVarInt(bytecode, userdataTypes[i].nameRef); + } + + // 0 marks the end of the mapping + writeByte(bytecode, 0); + } + writeVarInt(bytecode, uint32_t(functions.size())); for (const Function& func : functions) @@ -692,42 +739,34 @@ void BytecodeBuilder::writeFunction(std::string& ss, uint32_t id, uint8_t flags) writeByte(ss, flags); - if (FFlag::LuauCompileTypeInfo) + if (!func.typeinfo.empty() || !typedUpvals.empty() || !typedLocals.empty()) { - if (!func.typeinfo.empty() || !typedUpvals.empty() || !typedLocals.empty()) + // collect type info into a temporary string to know the overall size of type data + tempTypeInfo.clear(); + writeVarInt(tempTypeInfo, uint32_t(func.typeinfo.size())); + writeVarInt(tempTypeInfo, uint32_t(typedUpvals.size())); + writeVarInt(tempTypeInfo, uint32_t(typedLocals.size())); + + tempTypeInfo.append(func.typeinfo); + + for (const TypedUpval& l : typedUpvals) + writeByte(tempTypeInfo, l.type); + + for (const TypedLocal& l : typedLocals) { - // collect type info into a temporary string to know the overall size of type data - tempTypeInfo.clear(); - writeVarInt(tempTypeInfo, uint32_t(func.typeinfo.size())); - writeVarInt(tempTypeInfo, uint32_t(typedUpvals.size())); - writeVarInt(tempTypeInfo, uint32_t(typedLocals.size())); - - tempTypeInfo.append(func.typeinfo); - - for (const TypedUpval& l : typedUpvals) - writeByte(tempTypeInfo, l.type); - - for (const TypedLocal& l : typedLocals) - { - writeByte(tempTypeInfo, l.type); - writeByte(tempTypeInfo, l.reg); - writeVarInt(tempTypeInfo, l.startpc); - LUAU_ASSERT(l.endpc >= l.startpc); - writeVarInt(tempTypeInfo, l.endpc - l.startpc); - } - - writeVarInt(ss, uint32_t(tempTypeInfo.size())); - ss.append(tempTypeInfo); - } - else - { - writeVarInt(ss, 0); + writeByte(tempTypeInfo, l.type); + writeByte(tempTypeInfo, l.reg); + writeVarInt(tempTypeInfo, l.startpc); + LUAU_ASSERT(l.endpc >= l.startpc); + writeVarInt(tempTypeInfo, l.endpc - l.startpc); } + + writeVarInt(ss, uint32_t(tempTypeInfo.size())); + ss.append(tempTypeInfo); } else { - writeVarInt(ss, uint32_t(func.typeinfo.size())); - ss.append(func.typeinfo); + writeVarInt(ss, 0); } // instructions @@ -1036,11 +1075,6 @@ void BytecodeBuilder::foldJumps() if (LUAU_INSN_OP(jumpInsn) == LOP_JUMP && LUAU_INSN_OP(targetInsn) == LOP_RETURN) { insns[jumpLabel] = targetInsn; - - if (!FFlag::LuauCompileNoJumpLineRetarget) - { - lines[jumpLabel] = lines[targetLabel]; - } } else if (int16_t(offset) == offset) { @@ -1193,12 +1227,18 @@ std::string BytecodeBuilder::getError(const std::string& message) uint8_t BytecodeBuilder::getVersion() { // This function usually returns LBC_VERSION_TARGET but may sometimes return a higher number (within LBC_VERSION_MIN/MAX) under fast flags + if (FFlag::LuauCompileFastcall3) + return 6; + return LBC_VERSION_TARGET; } uint8_t BytecodeBuilder::getTypeEncodingVersion() { - return FFlag::LuauCompileTypeInfo ? LBC_TYPE_VERSION : LBC_TYPE_VERSION_DEPRECATED; + if (FFlag::LuauCompileUserdataInfo) + return LBC_TYPE_VERSION_TARGET; + + return 2; } #ifdef LUAU_ASSERTENABLED @@ -1570,6 +1610,16 @@ void BytecodeBuilder::validateInstructions() const VCONSTANY(insns[i + 1]); break; + case LOP_FASTCALL3: + LUAU_ASSERT(FFlag::LuauCompileFastcall3); + + VREG(LUAU_INSN_B(insn)); + VJUMP(LUAU_INSN_C(insn)); + LUAU_ASSERT(LUAU_INSN_OP(insns[i + 1 + LUAU_INSN_C(insn)]) == LOP_CALL); + VREG(insns[i + 1] & 0xff); + VREG((insns[i + 1] >> 8) & 0xff); + break; + case LOP_COVERAGE: break; @@ -1666,7 +1716,7 @@ void BytecodeBuilder::validateVariadic() const if (LUAU_INSN_B(insn) == 0) { - // consumer instruction ens a variadic sequence + // consumer instruction ends a variadic sequence LUAU_ASSERT(variadicSeq); variadicSeq = false; } @@ -2184,6 +2234,13 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, code++; break; + case LOP_FASTCALL3: + LUAU_ASSERT(FFlag::LuauCompileFastcall3); + + formatAppend(result, "FASTCALL3 %d R%d R%d R%d L%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), *code & 0xff, (*code >> 8) & 0xff, targetLabel); + code++; + break; + case LOP_COVERAGE: formatAppend(result, "COVERAGE\n"); break; @@ -2275,7 +2332,7 @@ std::string BytecodeBuilder::dumpCurrentFunction(std::vector& dumpinstoffs) { const DebugLocal& l = debugLocals[i]; - if (FFlag::LuauCompileRepeatUntilSkippedLocals && l.startpc == l.endpc) + if (l.startpc == l.endpc) { LUAU_ASSERT(l.startpc < lines.size()); @@ -2295,12 +2352,48 @@ std::string BytecodeBuilder::dumpCurrentFunction(std::vector& dumpinstoffs) } } - if (FFlag::LuauCompileTypeInfo) + if (dumpFlags & Dump_Types) { - if (dumpFlags & Dump_Types) - { - const std::string& typeinfo = functions.back().typeinfo; + const std::string& typeinfo = functions.back().typeinfo; + if (FFlag::LuauCompileUserdataInfo) + { + // Arguments start from third byte in function typeinfo string + for (uint8_t i = 2; i < typeinfo.size(); ++i) + { + uint8_t et = typeinfo[i]; + + const char* userdata = tryGetUserdataTypeName(LuauBytecodeType(et)); + const char* name = userdata ? userdata : getBaseTypeString(et); + const char* optional = (et & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; + + formatAppend(result, "R%d: %s%s [argument]\n", i - 2, name, optional); + } + + for (size_t i = 0; i < typedUpvals.size(); ++i) + { + const TypedUpval& l = typedUpvals[i]; + + const char* userdata = tryGetUserdataTypeName(l.type); + const char* name = userdata ? userdata : getBaseTypeString(l.type); + const char* optional = (l.type & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; + + formatAppend(result, "U%d: %s%s\n", int(i), name, optional); + } + + for (size_t i = 0; i < typedLocals.size(); ++i) + { + const TypedLocal& l = typedLocals[i]; + + const char* userdata = tryGetUserdataTypeName(l.type); + const char* name = userdata ? userdata : getBaseTypeString(l.type); + const char* optional = (l.type & LBC_TYPE_OPTIONAL_BIT) ? "?" : ""; + + formatAppend(result, "R%d: %s%s from %d to %d\n", l.reg, name, optional, l.startpc, l.endpc); + } + } + else + { // Arguments start from third byte in function typeinfo string for (uint8_t i = 2; i < typeinfo.size(); ++i) { diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index df096d3a..98520a7f 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -26,9 +26,10 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(LuauCompileRepeatUntilSkippedLocals, false) -LUAU_FASTFLAG(LuauCompileTypeInfo) -LUAU_FASTFLAGVARIABLE(LuauTypeInfoLookupImprovement, false) +LUAU_FASTFLAGVARIABLE(LuauCompileUserdataInfo, false) +LUAU_FASTFLAGVARIABLE(LuauCompileFastcall3, false) + +LUAU_FASTFLAG(LuauNativeAttribute) namespace Luau { @@ -106,8 +107,11 @@ struct Compiler , locstants(nullptr) , tableShapes(nullptr) , builtins(nullptr) + , userdataTypes(AstName()) , functionTypes(nullptr) , localTypes(nullptr) + , exprTypes(nullptr) + , builtinTypes(options.vectorType) { // preallocate some buffers that are very likely to grow anyway; this works around std::vector's inefficient growth policy for small arrays localStack.reserve(16); @@ -192,7 +196,7 @@ struct Compiler return node->as(); } - uint32_t compileFunction(AstExprFunction* func, uint8_t protoflags) + uint32_t compileFunction(AstExprFunction* func, uint8_t& protoflags) { LUAU_TIMETRACE_SCOPE("Compiler::compileFunction", "Compiler"); @@ -209,13 +213,6 @@ struct Compiler setDebugLine(func); - if (!FFlag::LuauCompileTypeInfo) - { - // note: we move types out of typeMap which is safe because compileFunction is only called once per function - if (std::string* funcType = functionTypes.find(func)) - bytecode.setFunctionTypeInfo(std::move(*funcType)); - } - if (func->vararg) bytecode.emitABC(LOP_PREPVARARGS, uint8_t(self + func->args.size), 0, 0); @@ -227,8 +224,7 @@ struct Compiler for (size_t i = 0; i < func->args.size; ++i) pushLocal(func->args.data[i], uint8_t(args + self + i), kDefaultAllocPc); - if (FFlag::LuauCompileTypeInfo) - argCount = localStack.size(); + argCount = localStack.size(); AstStatBlock* stat = func->body; @@ -260,7 +256,7 @@ struct Compiler bytecode.pushDebugUpval(sref(l->name)); } - if (FFlag::LuauCompileTypeInfo && options.typeInfoLevel >= 1) + if (options.typeInfoLevel >= 1) { for (AstLocal* l : upvals) { @@ -283,17 +279,17 @@ struct Compiler if (bytecode.getInstructionCount() > kMaxInstructionCount) CompileError::raise(func->location, "Exceeded function instruction limit; split the function into parts to compile"); - if (FFlag::LuauCompileTypeInfo) - { - // note: we move types out of typeMap which is safe because compileFunction is only called once per function - if (std::string* funcType = functionTypes.find(func)) - bytecode.setFunctionTypeInfo(std::move(*funcType)); - } + // note: we move types out of typeMap which is safe because compileFunction is only called once per function + if (std::string* funcType = functionTypes.find(func)) + bytecode.setFunctionTypeInfo(std::move(*funcType)); // top-level code only executes once so it can be marked as cold if it has no loops; code with loops might be profitable to compile natively if (func->functionDepth == 0 && !hasLoops) protoflags |= LPF_NATIVE_COLD; + if (FFlag::LuauNativeAttribute && func->hasNativeAttribute()) + protoflags |= LPF_NATIVE_FUNCTION; + bytecode.endFunction(uint8_t(stackSize), uint8_t(upvals.size()), protoflags); Function& f = functions[func]; @@ -319,8 +315,7 @@ struct Compiler upvals.clear(); // note: instead of std::move above, we copy & clear to preserve capacity for future pushes stackSize = 0; - if (FFlag::LuauCompileTypeInfo) - argCount = 0; + argCount = 0; hasLoops = false; @@ -465,10 +460,32 @@ struct Compiler { LUAU_ASSERT(!expr->self); LUAU_ASSERT(expr->args.size >= 1); - LUAU_ASSERT(expr->args.size <= 2 || (bfid == LBF_BIT32_EXTRACTK && expr->args.size == 3)); + + if (FFlag::LuauCompileFastcall3) + LUAU_ASSERT(expr->args.size <= 3); + else + LUAU_ASSERT(expr->args.size <= 2 || (bfid == LBF_BIT32_EXTRACTK && expr->args.size == 3)); + LUAU_ASSERT(bfid == LBF_BIT32_EXTRACTK ? bfK >= 0 : bfK < 0); - LuauOpcode opc = expr->args.size == 1 ? LOP_FASTCALL1 : (bfK >= 0 || isConstant(expr->args.data[1])) ? LOP_FASTCALL2K : LOP_FASTCALL2; + LuauOpcode opc = LOP_NOP; + + if (FFlag::LuauCompileFastcall3) + { + if (expr->args.size == 1) + opc = LOP_FASTCALL1; + else if (bfK >= 0 || (expr->args.size == 2 && isConstant(expr->args.data[1]))) + opc = LOP_FASTCALL2K; + else if (expr->args.size == 2) + opc = LOP_FASTCALL2; + else + opc = LOP_FASTCALL3; + } + else + { + opc = expr->args.size == 1 ? LOP_FASTCALL1 + : (bfK >= 0 || (expr->args.size == 2 && isConstant(expr->args.data[1]))) ? LOP_FASTCALL2K : LOP_FASTCALL2; + } uint32_t args[3] = {}; @@ -496,8 +513,16 @@ struct Compiler size_t fastcallLabel = bytecode.emitLabel(); bytecode.emitABC(opc, uint8_t(bfid), uint8_t(args[0]), 0); - if (opc != LOP_FASTCALL1) + + if (FFlag::LuauCompileFastcall3 && opc == LOP_FASTCALL3) + { + LUAU_ASSERT(bfK < 0); + bytecode.emitAux(args[1] | (args[2] << 8)); + } + else if (opc != LOP_FASTCALL1) + { bytecode.emitAux(bfK >= 0 ? bfK : args[1]); + } // Set up a traditional Lua stack for the subsequent LOP_CALL. // Note, as with other instructions that immediately follow FASTCALL, these are normally not executed and are used as a fallback for @@ -620,7 +645,7 @@ struct Compiler // if the last argument can return multiple values, we need to compute all of them into the remaining arguments unsigned int tail = unsigned(func->args.size - expr->args.size) + 1; uint8_t reg = allocReg(arg, tail); - uint32_t allocpc = FFlag::LuauCompileTypeInfo ? bytecode.getDebugPC() : kDefaultAllocPc; + uint32_t allocpc = bytecode.getDebugPC(); if (AstExprCall* expr = arg->as()) compileExprCall(expr, reg, tail, /* targetTop= */ true); @@ -630,12 +655,7 @@ struct Compiler LUAU_ASSERT(!"Unexpected expression type"); for (size_t j = i; j < func->args.size; ++j) - { - if (FFlag::LuauCompileTypeInfo) - args.push_back({func->args.data[j], uint8_t(reg + (j - i)), {Constant::Type_Unknown}, allocpc}); - else - args.push_back({func->args.data[j], uint8_t(reg + (j - i))}); - } + args.push_back({func->args.data[j], uint8_t(reg + (j - i)), {Constant::Type_Unknown}, allocpc}); // all remaining function arguments have been allocated and assigned to break; @@ -644,17 +664,14 @@ struct Compiler { // if the argument is mutated, we need to allocate a fresh register even if it's a constant uint8_t reg = allocReg(arg, 1); - uint32_t allocpc = FFlag::LuauCompileTypeInfo ? bytecode.getDebugPC() : kDefaultAllocPc; + uint32_t allocpc = bytecode.getDebugPC(); if (arg) compileExprTemp(arg, reg); else bytecode.emitABC(LOP_LOADNIL, reg, 0, 0); - if (FFlag::LuauCompileTypeInfo) - args.push_back({var, reg, {Constant::Type_Unknown}, allocpc}); - else - args.push_back({var, reg}); + args.push_back({var, reg, {Constant::Type_Unknown}, allocpc}); } else if (arg == nullptr) { @@ -674,22 +691,16 @@ struct Compiler // if the argument is a local that isn't mutated, we will simply reuse the existing register if (int reg = le ? getExprLocalReg(le) : -1; reg >= 0 && (!lv || !lv->written)) { - if (FFlag::LuauTypeInfoLookupImprovement) - args.push_back({var, uint8_t(reg), {Constant::Type_Unknown}, kDefaultAllocPc}); - else - args.push_back({var, uint8_t(reg)}); + args.push_back({var, uint8_t(reg), {Constant::Type_Unknown}, kDefaultAllocPc}); } else { uint8_t temp = allocReg(arg, 1); - uint32_t allocpc = FFlag::LuauCompileTypeInfo ? bytecode.getDebugPC() : kDefaultAllocPc; + uint32_t allocpc = bytecode.getDebugPC(); compileExprTemp(arg, temp); - if (FFlag::LuauCompileTypeInfo) - args.push_back({var, temp, {Constant::Type_Unknown}, allocpc}); - else - args.push_back({var, temp}); + args.push_back({var, temp, {Constant::Type_Unknown}, allocpc}); } } } @@ -703,16 +714,9 @@ struct Compiler for (InlineArg& arg : args) { if (arg.value.type == Constant::Type_Unknown) - { - if (FFlag::LuauCompileTypeInfo) - pushLocal(arg.local, arg.reg, arg.allocpc); - else - pushLocal(arg.local, arg.reg, kDefaultAllocPc); - } + pushLocal(arg.local, arg.reg, arg.allocpc); else - { locstants[arg.local] = arg.value; - } } // the inline frame will be used to compile return statements as well as to reject recursive inlining attempts @@ -852,11 +856,28 @@ struct Compiler } } - // Optimization: for 1/2 argument fast calls use specialized opcodes - if (bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2) + unsigned maxFastcallArgs = 2; + + // Fastcall with 3 arguments is only used if it can help save one or more move instructions + if (FFlag::LuauCompileFastcall3 && bfid >= 0 && expr->args.size == 3) + { + for (size_t i = 0; i < expr->args.size; ++i) + { + if (int reg = getExprLocalReg(expr->args.data[i]); reg >= 0) + { + maxFastcallArgs = 3; + break; + } + } + } + + // Optimization: for 1/2/3 argument fast calls use specialized opcodes + if (bfid >= 0 && expr->args.size >= 1 && expr->args.size <= (FFlag::LuauCompileFastcall3 ? maxFastcallArgs : 2u)) { if (!isExprMultRet(expr->args.data[expr->args.size - 1])) + { return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); + } else if (options.optimizationLevel >= 2) { // when a builtin is none-safe with matching arity, even if the last expression returns 0 or >1 arguments, @@ -916,6 +937,8 @@ struct Compiler bytecode.emitABC(LOP_NAMECALL, regs, selfreg, uint8_t(BytecodeBuilder::getStringHash(iname))); bytecode.emitAux(cid); + + hintTemporaryExprRegType(fi->expr, selfreg, LBC_TYPE_TABLE, /* instLength */ 2); } else if (bfid >= 0) { @@ -1570,6 +1593,8 @@ struct Compiler uint8_t rl = compileExprAuto(expr->left, rs); bytecode.emitABC(getBinaryOpArith(expr->op, /* k= */ true), target, rl, uint8_t(rc)); + + hintTemporaryExprRegType(expr->left, rl, LBC_TYPE_NUMBER, /* instLength */ 1); } else { @@ -1583,6 +1608,8 @@ struct Compiler LuauOpcode op = (expr->op == AstExprBinary::Sub) ? LOP_SUBRK : LOP_DIVRK; bytecode.emitABC(op, target, uint8_t(lc), uint8_t(rr)); + + hintTemporaryExprRegType(expr->right, rr, LBC_TYPE_NUMBER, /* instLength */ 1); return; } } @@ -1591,6 +1618,9 @@ struct Compiler uint8_t rr = compileExprAuto(expr->right, rs); bytecode.emitABC(getBinaryOpArith(expr->op), target, rl, rr); + + hintTemporaryExprRegType(expr->left, rl, LBC_TYPE_NUMBER, /* instLength */ 1); + hintTemporaryExprRegType(expr->right, rr, LBC_TYPE_NUMBER, /* instLength */ 1); } } break; @@ -2030,6 +2060,8 @@ struct Compiler bytecode.emitABC(LOP_GETTABLEKS, target, reg, uint8_t(BytecodeBuilder::getStringHash(iname))); bytecode.emitAux(cid); + + hintTemporaryExprRegType(expr->expr, reg, LBC_TYPE_TABLE, /* instLength */ 2); } void compileExprIndexExpr(AstExprIndexExpr* expr, uint8_t target) @@ -2750,16 +2782,14 @@ struct Compiler { validateContinueUntil(loops.back().continueUsed, stat->condition, body, i + 1); continueValidated = true; - - if (FFlag::LuauCompileRepeatUntilSkippedLocals) - conditionLocals = localStack.size(); + conditionLocals = localStack.size(); } } // if continue was used, some locals might not have had their initialization completed // the lifetime of these locals has to end before the condition is executed // because referencing skipped locals is not possible from the condition, this earlier closure doesn't affect upvalues - if (FFlag::LuauCompileRepeatUntilSkippedLocals && continueValidated) + if (continueValidated) { // if continueValidated is set, it means we have visited at least one body node and size > 0 setDebugLineEnd(body->body.data[body->body.size - 1]); @@ -2915,7 +2945,7 @@ struct Compiler // note: allocReg in this case allocates into parent block register - note that we don't have RegScope here uint8_t vars = allocReg(stat, unsigned(stat->vars.size)); - uint32_t allocpc = FFlag::LuauCompileTypeInfo ? bytecode.getDebugPC() : kDefaultAllocPc; + uint32_t allocpc = bytecode.getDebugPC(); compileExprListTemp(stat->values, vars, uint8_t(stat->vars.size), /* targetTop= */ true); @@ -3047,7 +3077,7 @@ struct Compiler // this makes sure the code inside the loop can't interfere with the iteration process (other than modifying the table we're iterating // through) uint8_t varreg = regs + 2; - uint32_t varregallocpc = FFlag::LuauCompileTypeInfo ? bytecode.getDebugPC() : kDefaultAllocPc; + uint32_t varregallocpc = bytecode.getDebugPC(); if (Variable* il = variables.find(stat->var); il && il->written) varreg = allocReg(stat, 1); @@ -3114,7 +3144,7 @@ struct Compiler // note that we reserve at least 2 variables; this allows our fast path to assume that we need 2 variables instead of 1 or 2 uint8_t vars = allocReg(stat, std::max(unsigned(stat->vars.size), 2u)); LUAU_ASSERT(vars == regs + 3); - uint32_t varsallocpc = FFlag::LuauCompileTypeInfo ? bytecode.getDebugPC() : kDefaultAllocPc; + uint32_t varsallocpc = bytecode.getDebugPC(); LuauOpcode skipOp = LOP_FORGPREP; @@ -3410,6 +3440,11 @@ struct Compiler uint8_t rr = compileExprAuto(stat->value, rs); bytecode.emitABC(getBinaryOpArith(stat->op), target, target, rr); + + if (var.kind != LValue::Kind_Local) + hintTemporaryRegType(stat->var, target, LBC_TYPE_NUMBER, /* instLength */ 1); + + hintTemporaryExprRegType(stat->value, rr, LBC_TYPE_NUMBER, /* instLength */ 1); } } break; @@ -3643,9 +3678,7 @@ struct Compiler l.reg = reg; l.allocated = true; l.debugpc = bytecode.getDebugPC(); - - if (FFlag::LuauCompileTypeInfo) - l.allocpc = allocpc == kDefaultAllocPc ? l.debugpc : allocpc; + l.allocpc = allocpc == kDefaultAllocPc ? l.debugpc : allocpc; } bool areLocalsCaptured(size_t start) @@ -3708,7 +3741,7 @@ struct Compiler bytecode.pushDebugLocal(sref(localStack[i]->name), l->reg, l->debugpc, debugpc); } - if (FFlag::LuauCompileTypeInfo && options.typeInfoLevel >= 1 && i >= argCount) + if (options.typeInfoLevel >= 1 && i >= argCount) { uint32_t debugpc = bytecode.getDebugPC(); LuauBytecodeType ty = LBC_TYPE_ANY; @@ -3794,6 +3827,23 @@ struct Compiler return !node->is() && !node->is(); } + void hintTemporaryRegType(AstExpr* expr, int reg, LuauBytecodeType expectedType, int instLength) + { + // If we know the type of a temporary and it's not the type that would be expected by codegen, provide a hint + if (LuauBytecodeType* ty = exprTypes.find(expr)) + { + if (*ty != expectedType) + bytecode.pushLocalTypeInfo(*ty, reg, bytecode.getDebugPC() - instLength, bytecode.getDebugPC()); + } + } + + void hintTemporaryExprRegType(AstExpr* expr, int reg, LuauBytecodeType expectedType, int instLength) + { + // If we allocated a temporary register for the operation argument, try hinting its type + if (!getExprLocal(expr)) + hintTemporaryRegType(expr, reg, expectedType, instLength); + } + struct FenvVisitor : AstVisitor { bool& getfenvUsed; @@ -3818,13 +3868,12 @@ struct Compiler struct FunctionVisitor : AstVisitor { - Compiler* self; std::vector& functions; bool hasTypes = false; + bool hasNativeFunction = false; - FunctionVisitor(Compiler* self, std::vector& functions) - : self(self) - , functions(functions) + FunctionVisitor(std::vector& functions) + : functions(functions) { // preallocate the result; this works around std::vector's inefficient growth policy for small arrays functions.reserve(16); @@ -3840,6 +3889,9 @@ struct Compiler // this makes sure all functions that are used when compiling this one have been already added to the vector functions.push_back(node); + if (FFlag::LuauNativeAttribute && !hasNativeFunction && node->hasNativeAttribute()) + hasNativeFunction = true; + return false; } }; @@ -4044,8 +4096,12 @@ struct Compiler DenseHashMap locstants; DenseHashMap tableShapes; DenseHashMap builtins; + DenseHashMap userdataTypes; DenseHashMap functionTypes; DenseHashMap localTypes; + DenseHashMap exprTypes; + + BuiltinTypes builtinTypes; const DenseHashMap* builtinsFold = nullptr; bool builtinsFoldMathK = false; @@ -4068,6 +4124,12 @@ struct Compiler std::vector> interpStrings; }; +static void setCompileOptionsForNativeCompilation(CompileOptions& options) +{ + options.optimizationLevel = 2; // note: this might be removed in the future in favor of --!optimize + options.typeInfoLevel = 1; +} + void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, const AstNameTable& names, const CompileOptions& inputOptions) { LUAU_TIMETRACE_SCOPE("compileOrThrow", "Compiler"); @@ -4086,15 +4148,21 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c if (hc.header && hc.content == "native") { mainFlags |= LPF_NATIVE_MODULE; - options.optimizationLevel = 2; // note: this might be removed in the future in favor of --!optimize - - if (FFlag::LuauCompileTypeInfo) - options.typeInfoLevel = 1; + setCompileOptionsForNativeCompilation(options); } } AstStatBlock* root = parseResult.root; + // gathers all functions with the invariant that all function references are to functions earlier in the list + // for example, function foo() return function() end end will result in two vector entries, [0] = anonymous and [1] = foo + std::vector functions; + Compiler::FunctionVisitor functionVisitor(functions); + root->visit(&functionVisitor); + + if (functionVisitor.hasNativeFunction) + setCompileOptionsForNativeCompilation(options); + Compiler compiler(bytecode, options); // since access to some global objects may result in values that change over time, we block imports from non-readonly tables @@ -4131,28 +4199,40 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c predictTableShapes(compiler.tableShapes, root); } - // gathers all functions with the invariant that all function references are to functions earlier in the list - // for example, function foo() return function() end end will result in two vector entries, [0] = anonymous and [1] = foo - std::vector functions; - Compiler::FunctionVisitor functionVisitor(&compiler, functions); - root->visit(&functionVisitor); + if (FFlag::LuauCompileUserdataInfo) + { + if (const char* const* ptr = options.userdataTypes) + { + for (; *ptr; ++ptr) + { + // Type will only resolve to an AstName if it is actually mentioned in the source + if (AstName name = names.get(*ptr); name.value) + compiler.userdataTypes[name] = bytecode.addUserdataType(name.value); + } + + if (uintptr_t(ptr - options.userdataTypes) > (LBC_TYPE_TAGGED_USERDATA_END - LBC_TYPE_TAGGED_USERDATA_BASE)) + CompileError::raise(root->location, "Exceeded userdata type limit in the compilation options"); + } + } // computes type information for all functions based on type annotations - if (FFlag::LuauCompileTypeInfo) - { - if (options.typeInfoLevel >= 1) - buildTypeMap(compiler.functionTypes, compiler.localTypes, root, options.vectorType); - } - else - { - if (functionVisitor.hasTypes) - buildTypeMap(compiler.functionTypes, compiler.localTypes, root, options.vectorType); - } + if (options.typeInfoLevel >= 1) + buildTypeMap(compiler.functionTypes, compiler.localTypes, compiler.exprTypes, root, options.vectorType, compiler.userdataTypes, + compiler.builtinTypes, compiler.builtins, compiler.globals, bytecode); for (AstExprFunction* expr : functions) - compiler.compileFunction(expr, 0); + { + uint8_t protoflags = 0; + compiler.compileFunction(expr, protoflags); - AstExprFunction main(root->location, /*generics= */ AstArray(), /*genericPacks= */ AstArray(), + // If a function has native attribute and the whole module is not native, we set LPF_NATIVE_FUNCTION flag + // This ensures that LPF_NATIVE_MODULE and LPF_NATIVE_FUNCTION are exclusive. + if (FFlag::LuauNativeAttribute && (protoflags & LPF_NATIVE_FUNCTION) && !(mainFlags & LPF_NATIVE_MODULE)) + mainFlags |= LPF_NATIVE_FUNCTION; + } + + AstExprFunction main(root->location, /*attributes=*/AstArray({nullptr, 0}), /*generics= */ AstArray(), + /*genericPacks= */ AstArray(), /* self= */ nullptr, AstArray(), /* vararg= */ true, /* varargLocation= */ Luau::Location(), root, /* functionDepth= */ 0, /* debugname= */ AstName()); uint32_t mainid = compiler.compileFunction(&main, mainFlags); diff --git a/Compiler/src/Types.cpp b/Compiler/src/Types.cpp index d05a7ba2..447b51d3 100644 --- a/Compiler/src/Types.cpp +++ b/Compiler/src/Types.cpp @@ -3,7 +3,7 @@ #include "Luau/BytecodeBuilder.h" -LUAU_FASTFLAG(LuauCompileTypeInfo) +LUAU_FASTFLAG(LuauCompileUserdataInfo) namespace Luau { @@ -37,10 +37,11 @@ static LuauBytecodeType getPrimitiveType(AstName name) return LBC_TYPE_INVALID; } -static LuauBytecodeType getType(AstType* ty, const AstArray& generics, const DenseHashMap& typeAliases, - bool resolveAliases, const char* vectorType) +static LuauBytecodeType getType(const AstType* ty, const AstArray& generics, + const DenseHashMap& typeAliases, bool resolveAliases, const char* vectorType, + const DenseHashMap& userdataTypes, BytecodeBuilder& bytecode) { - if (AstTypeReference* ref = ty->as()) + if (const AstTypeReference* ref = ty->as()) { if (ref->prefix) return LBC_TYPE_ANY; @@ -49,7 +50,7 @@ static LuauBytecodeType getType(AstType* ty, const AstArray& gen { // note: we only resolve aliases to the depth of 1 to avoid dealing with recursive aliases if (resolveAliases) - return getType((*alias)->type, (*alias)->generics, typeAliases, /* resolveAliases= */ false, vectorType); + return getType((*alias)->type, (*alias)->generics, typeAliases, /* resolveAliases= */ false, vectorType, userdataTypes, bytecode); else return LBC_TYPE_ANY; } @@ -63,25 +64,34 @@ static LuauBytecodeType getType(AstType* ty, const AstArray& gen if (LuauBytecodeType prim = getPrimitiveType(ref->name); prim != LBC_TYPE_INVALID) return prim; + if (FFlag::LuauCompileUserdataInfo) + { + if (const uint8_t* userdataIndex = userdataTypes.find(ref->name)) + { + bytecode.useUserdataType(*userdataIndex); + return LuauBytecodeType(LBC_TYPE_TAGGED_USERDATA_BASE + *userdataIndex); + } + } + // not primitive or alias or generic => host-provided, we assume userdata for now return LBC_TYPE_USERDATA; } - else if (AstTypeTable* table = ty->as()) + else if (const AstTypeTable* table = ty->as()) { return LBC_TYPE_TABLE; } - else if (AstTypeFunction* func = ty->as()) + else if (const AstTypeFunction* func = ty->as()) { return LBC_TYPE_FUNCTION; } - else if (AstTypeUnion* un = ty->as()) + else if (const AstTypeUnion* un = ty->as()) { bool optional = false; LuauBytecodeType type = LBC_TYPE_INVALID; for (AstType* ty : un->types) { - LuauBytecodeType et = getType(ty, generics, typeAliases, resolveAliases, vectorType); + LuauBytecodeType et = getType(ty, generics, typeAliases, resolveAliases, vectorType, userdataTypes, bytecode); if (et == LBC_TYPE_NIL) { @@ -104,7 +114,7 @@ static LuauBytecodeType getType(AstType* ty, const AstArray& gen return LuauBytecodeType(type | (optional && (type != LBC_TYPE_ANY) ? LBC_TYPE_OPTIONAL_BIT : 0)); } - else if (AstTypeIntersection* inter = ty->as()) + else if (const AstTypeIntersection* inter = ty->as()) { return LBC_TYPE_ANY; } @@ -112,7 +122,8 @@ static LuauBytecodeType getType(AstType* ty, const AstArray& gen return LBC_TYPE_ANY; } -static std::string getFunctionType(const AstExprFunction* func, const DenseHashMap& typeAliases, const char* vectorType) +static std::string getFunctionType(const AstExprFunction* func, const DenseHashMap& typeAliases, const char* vectorType, + const DenseHashMap& userdataTypes, BytecodeBuilder& bytecode) { bool self = func->self != 0; @@ -129,7 +140,8 @@ static std::string getFunctionType(const AstExprFunction* func, const DenseHashM for (AstLocal* arg : func->args) { LuauBytecodeType ty = - arg->annotation ? getType(arg->annotation, func->generics, typeAliases, /* resolveAliases= */ true, vectorType) : LBC_TYPE_ANY; + arg->annotation ? getType(arg->annotation, func->generics, typeAliases, /* resolveAliases= */ true, vectorType, userdataTypes, bytecode) + : LBC_TYPE_ANY; if (ty != LBC_TYPE_ANY) haveNonAnyParam = true; @@ -144,21 +156,47 @@ static std::string getFunctionType(const AstExprFunction* func, const DenseHashM return typeInfo; } +static bool isMatchingGlobal(const DenseHashMap& globals, AstExpr* node, const char* name) +{ + if (AstExprGlobal* expr = node->as()) + return Compile::getGlobalState(globals, expr->name) == Compile::Global::Default && expr->name == name; + + return false; +} + struct TypeMapVisitor : AstVisitor { DenseHashMap& functionTypes; DenseHashMap& localTypes; + DenseHashMap& exprTypes; const char* vectorType; + const DenseHashMap& userdataTypes; + const BuiltinTypes& builtinTypes; + const DenseHashMap& builtinCalls; + const DenseHashMap& globals; + BytecodeBuilder& bytecode; DenseHashMap typeAliases; std::vector> typeAliasStack; + DenseHashMap resolvedLocals; + DenseHashMap resolvedExprs; - TypeMapVisitor( - DenseHashMap& functionTypes, DenseHashMap& localTypes, const char* vectorType) + TypeMapVisitor(DenseHashMap& functionTypes, DenseHashMap& localTypes, + DenseHashMap& exprTypes, const char* vectorType, const DenseHashMap& userdataTypes, + const BuiltinTypes& builtinTypes, const DenseHashMap& builtinCalls, const DenseHashMap& globals, + BytecodeBuilder& bytecode) : functionTypes(functionTypes) , localTypes(localTypes) + , exprTypes(exprTypes) , vectorType(vectorType) + , userdataTypes(userdataTypes) + , builtinTypes(builtinTypes) + , builtinCalls(builtinCalls) + , globals(globals) + , bytecode(bytecode) , typeAliases(AstName()) + , resolvedLocals(nullptr) + , resolvedExprs(nullptr) { } @@ -189,6 +227,56 @@ struct TypeMapVisitor : AstVisitor } } + const AstType* resolveAliases(const AstType* ty) + { + if (const AstTypeReference* ref = ty->as()) + { + if (ref->prefix) + return ty; + + if (AstStatTypeAlias* const* alias = typeAliases.find(ref->name); alias && *alias) + return (*alias)->type; + } + + return ty; + } + + const AstTableIndexer* tryGetTableIndexer(AstExpr* expr) + { + if (const AstType** typePtr = resolvedExprs.find(expr)) + { + if (const AstTypeTable* tableTy = (*typePtr)->as()) + return tableTy->indexer; + } + + return nullptr; + } + + LuauBytecodeType recordResolvedType(AstExpr* expr, const AstType* ty) + { + ty = resolveAliases(ty); + + resolvedExprs[expr] = ty; + + LuauBytecodeType bty = getType(ty, {}, typeAliases, /* resolveAliases= */ true, vectorType, userdataTypes, bytecode); + exprTypes[expr] = bty; + return bty; + } + + LuauBytecodeType recordResolvedType(AstLocal* local, const AstType* ty) + { + ty = resolveAliases(ty); + + resolvedLocals[local] = ty; + + LuauBytecodeType bty = getType(ty, {}, typeAliases, /* resolveAliases= */ true, vectorType, userdataTypes, bytecode); + + if (bty != LBC_TYPE_ANY) + localTypes[local] = bty; + + return bty; + } + bool visit(AstStatBlock* node) override { size_t aliasStackTop = pushTypeAliases(node); @@ -216,39 +304,402 @@ struct TypeMapVisitor : AstVisitor return false; } + // for...in statement can contain type annotations on locals (we might even infer some for ipairs/pairs/generalized iteration) + bool visit(AstStatForIn* node) override + { + for (AstExpr* expr : node->values) + expr->visit(this); + + // This is similar to how Compiler matches builtin iteration, but we also handle generalized iteration case + if (node->vars.size == 2 && node->values.size == 1) + { + if (AstExprCall* call = node->values.data[0]->as(); call && call->args.size == 1) + { + AstExpr* func = call->func; + AstExpr* arg = call->args.data[0]; + + if (isMatchingGlobal(globals, func, "ipairs")) + { + if (const AstTableIndexer* indexer = tryGetTableIndexer(arg)) + { + recordResolvedType(node->vars.data[0], &builtinTypes.numberType); + recordResolvedType(node->vars.data[1], indexer->resultType); + } + } + else if (isMatchingGlobal(globals, func, "pairs")) + { + if (const AstTableIndexer* indexer = tryGetTableIndexer(arg)) + { + recordResolvedType(node->vars.data[0], indexer->indexType); + recordResolvedType(node->vars.data[1], indexer->resultType); + } + } + } + else if (const AstTableIndexer* indexer = tryGetTableIndexer(node->values.data[0])) + { + recordResolvedType(node->vars.data[0], indexer->indexType); + recordResolvedType(node->vars.data[1], indexer->resultType); + } + } + + for (size_t i = 0; i < node->vars.size; i++) + { + AstLocal* var = node->vars.data[i]; + + if (AstType* annotation = var->annotation) + recordResolvedType(var, annotation); + } + + node->body->visit(this); + + return false; + } + bool visit(AstExprFunction* node) override { - std::string type = getFunctionType(node, typeAliases, vectorType); + std::string type = getFunctionType(node, typeAliases, vectorType, userdataTypes, bytecode); if (!type.empty()) functionTypes[node] = std::move(type); - return true; + return true; // Let generic visitor step into all expressions } bool visit(AstExprLocal* node) override { - if (FFlag::LuauCompileTypeInfo) + AstLocal* local = node->local; + + if (AstType* annotation = local->annotation) { - AstLocal* local = node->local; + LuauBytecodeType ty = recordResolvedType(node, annotation); - if (AstType* annotation = local->annotation) + if (ty != LBC_TYPE_ANY) + localTypes[local] = ty; + } + else if (const AstType** typePtr = resolvedLocals.find(local)) + { + localTypes[local] = recordResolvedType(node, *typePtr); + } + + return false; + } + + bool visit(AstStatLocal* node) override + { + for (AstExpr* expr : node->values) + expr->visit(this); + + for (size_t i = 0; i < node->vars.size; i++) + { + AstLocal* var = node->vars.data[i]; + + // Propagate from the value that's being assigned + // This simple propagation doesn't handle type packs in tail position + if (var->annotation == nullptr) { - LuauBytecodeType ty = getType(annotation, {}, typeAliases, /* resolveAliases= */ true, vectorType); - - if (ty != LBC_TYPE_ANY) - localTypes[local] = ty; + if (i < node->values.size) + { + if (const AstType** typePtr = resolvedExprs.find(node->values.data[i])) + resolvedLocals[var] = *typePtr; + } } } - return true; + return false; } + + bool visit(AstExprIndexExpr* node) override + { + node->expr->visit(this); + node->index->visit(this); + + if (const AstTableIndexer* indexer = tryGetTableIndexer(node->expr)) + recordResolvedType(node, indexer->resultType); + + return false; + } + + bool visit(AstExprIndexName* node) override + { + node->expr->visit(this); + + if (const AstType** typePtr = resolvedExprs.find(node->expr)) + { + if (const AstTypeTable* tableTy = (*typePtr)->as()) + { + for (const AstTableProp& prop : tableTy->props) + { + if (prop.name == node->index) + { + recordResolvedType(node, prop.type); + return false; + } + } + } + } + + if (LuauBytecodeType* typeBcPtr = exprTypes.find(node->expr)) + { + if (*typeBcPtr == LBC_TYPE_VECTOR) + { + if (node->index == "X" || node->index == "Y" || node->index == "Z") + recordResolvedType(node, &builtinTypes.numberType); + } + } + + return false; + } + + bool visit(AstExprUnary* node) override + { + node->expr->visit(this); + + switch (node->op) + { + case AstExprUnary::Not: + recordResolvedType(node, &builtinTypes.booleanType); + break; + case AstExprUnary::Minus: + { + const AstType** typePtr = resolvedExprs.find(node->expr); + LuauBytecodeType* bcTypePtr = exprTypes.find(node->expr); + + if (!typePtr || !bcTypePtr) + return false; + + if (*bcTypePtr == LBC_TYPE_VECTOR) + recordResolvedType(node, *typePtr); + else if (*bcTypePtr == LBC_TYPE_NUMBER) + recordResolvedType(node, *typePtr); + + break; + } + case AstExprUnary::Len: + recordResolvedType(node, &builtinTypes.numberType); + break; + } + + return false; + } + + bool visit(AstExprBinary* node) override + { + node->left->visit(this); + node->right->visit(this); + + // Comparisons result in a boolean + if (node->op == AstExprBinary::CompareNe || node->op == AstExprBinary::CompareEq || node->op == AstExprBinary::CompareLt || + node->op == AstExprBinary::CompareLe || node->op == AstExprBinary::CompareGt || node->op == AstExprBinary::CompareGe) + { + recordResolvedType(node, &builtinTypes.booleanType); + return false; + } + + if (node->op == AstExprBinary::Concat || node->op == AstExprBinary::And || node->op == AstExprBinary::Or) + return false; + + const AstType** leftTypePtr = resolvedExprs.find(node->left); + LuauBytecodeType* leftBcTypePtr = exprTypes.find(node->left); + + if (!leftTypePtr || !leftBcTypePtr) + return false; + + const AstType** rightTypePtr = resolvedExprs.find(node->right); + LuauBytecodeType* rightBcTypePtr = exprTypes.find(node->right); + + if (!rightTypePtr || !rightBcTypePtr) + return false; + + if (*leftBcTypePtr == LBC_TYPE_VECTOR) + recordResolvedType(node, *leftTypePtr); + else if (*rightBcTypePtr == LBC_TYPE_VECTOR) + recordResolvedType(node, *rightTypePtr); + else if (*leftBcTypePtr == LBC_TYPE_NUMBER && *rightBcTypePtr == LBC_TYPE_NUMBER) + recordResolvedType(node, *leftTypePtr); + + return false; + } + + bool visit(AstExprGroup* node) override + { + node->expr->visit(this); + + if (const AstType** typePtr = resolvedExprs.find(node->expr)) + recordResolvedType(node, *typePtr); + + return false; + } + + bool visit(AstExprTypeAssertion* node) override + { + node->expr->visit(this); + + recordResolvedType(node, node->annotation); + + return false; + } + + bool visit(AstExprConstantBool* node) override + { + recordResolvedType(node, &builtinTypes.booleanType); + + return false; + } + + bool visit(AstExprConstantNumber* node) override + { + recordResolvedType(node, &builtinTypes.numberType); + + return false; + } + + bool visit(AstExprConstantString* node) override + { + recordResolvedType(node, &builtinTypes.stringType); + + return false; + } + + bool visit(AstExprInterpString* node) override + { + recordResolvedType(node, &builtinTypes.stringType); + + return false; + } + + bool visit(AstExprIfElse* node) override + { + node->condition->visit(this); + node->trueExpr->visit(this); + node->falseExpr->visit(this); + + const AstType** trueTypePtr = resolvedExprs.find(node->trueExpr); + LuauBytecodeType* trueBcTypePtr = exprTypes.find(node->trueExpr); + LuauBytecodeType* falseBcTypePtr = exprTypes.find(node->falseExpr); + + // Optimistic check that both expressions are of the same kind, as AstType* cannot be compared + if (trueTypePtr && trueBcTypePtr && falseBcTypePtr && *trueBcTypePtr == *falseBcTypePtr) + recordResolvedType(node, *trueTypePtr); + + return false; + } + + bool visit(AstExprCall* node) override + { + if (const int* bfid = builtinCalls.find(node)) + { + switch (LuauBuiltinFunction(*bfid)) + { + case LBF_NONE: + case LBF_ASSERT: + case LBF_RAWSET: + case LBF_RAWGET: + case LBF_TABLE_INSERT: + case LBF_TABLE_UNPACK: + case LBF_SELECT_VARARG: + case LBF_GETMETATABLE: + case LBF_SETMETATABLE: + case LBF_BUFFER_WRITEU8: + case LBF_BUFFER_WRITEU16: + case LBF_BUFFER_WRITEU32: + case LBF_BUFFER_WRITEF32: + case LBF_BUFFER_WRITEF64: + break; + case LBF_MATH_ABS: + case LBF_MATH_ACOS: + case LBF_MATH_ASIN: + case LBF_MATH_ATAN2: + case LBF_MATH_ATAN: + case LBF_MATH_CEIL: + case LBF_MATH_COSH: + case LBF_MATH_COS: + case LBF_MATH_DEG: + case LBF_MATH_EXP: + case LBF_MATH_FLOOR: + case LBF_MATH_FMOD: + case LBF_MATH_FREXP: + case LBF_MATH_LDEXP: + case LBF_MATH_LOG10: + case LBF_MATH_LOG: + case LBF_MATH_MAX: + case LBF_MATH_MIN: + case LBF_MATH_MODF: + case LBF_MATH_POW: + case LBF_MATH_RAD: + case LBF_MATH_SINH: + case LBF_MATH_SIN: + case LBF_MATH_SQRT: + case LBF_MATH_TANH: + case LBF_MATH_TAN: + case LBF_BIT32_ARSHIFT: + case LBF_BIT32_BAND: + case LBF_BIT32_BNOT: + case LBF_BIT32_BOR: + case LBF_BIT32_BXOR: + case LBF_BIT32_BTEST: + case LBF_BIT32_EXTRACT: + case LBF_BIT32_LROTATE: + case LBF_BIT32_LSHIFT: + case LBF_BIT32_REPLACE: + case LBF_BIT32_RROTATE: + case LBF_BIT32_RSHIFT: + case LBF_STRING_BYTE: + case LBF_STRING_LEN: + case LBF_MATH_CLAMP: + case LBF_MATH_SIGN: + case LBF_MATH_ROUND: + case LBF_BIT32_COUNTLZ: + case LBF_BIT32_COUNTRZ: + case LBF_RAWLEN: + case LBF_BIT32_EXTRACTK: + case LBF_TONUMBER: + case LBF_BIT32_BYTESWAP: + case LBF_BUFFER_READI8: + case LBF_BUFFER_READU8: + case LBF_BUFFER_READI16: + case LBF_BUFFER_READU16: + case LBF_BUFFER_READI32: + case LBF_BUFFER_READU32: + case LBF_BUFFER_READF32: + case LBF_BUFFER_READF64: + recordResolvedType(node, &builtinTypes.numberType); + break; + + case LBF_TYPE: + case LBF_STRING_CHAR: + case LBF_TYPEOF: + case LBF_STRING_SUB: + case LBF_TOSTRING: + recordResolvedType(node, &builtinTypes.stringType); + break; + + case LBF_RAWEQUAL: + recordResolvedType(node, &builtinTypes.booleanType); + break; + + case LBF_VECTOR: + recordResolvedType(node, &builtinTypes.vectorType); + break; + } + } + + return true; // Let generic visitor step into all expressions + } + + // AstExpr classes that are not covered: + // * AstExprConstantNil is not resolved to 'nil' because that doesn't help codegen operations and often used as an initializer before real value + // * AstExprGlobal is not supported as we don't have info on globals + // * AstExprVarargs cannot be resolved to a testable type + // * AstExprTable cannot be reconstructed into a specific AstTypeTable and table annotations don't really help codegen + // * AstExprCall is very complex (especially if builtins and registered globals are included), will be extended in the future }; -void buildTypeMap(DenseHashMap& functionTypes, DenseHashMap& localTypes, AstNode* root, - const char* vectorType) +void buildTypeMap(DenseHashMap& functionTypes, DenseHashMap& localTypes, + DenseHashMap& exprTypes, AstNode* root, const char* vectorType, const DenseHashMap& userdataTypes, + const BuiltinTypes& builtinTypes, const DenseHashMap& builtinCalls, const DenseHashMap& globals, + BytecodeBuilder& bytecode) { - TypeMapVisitor visitor(functionTypes, localTypes, vectorType); + TypeMapVisitor visitor(functionTypes, localTypes, exprTypes, vectorType, userdataTypes, builtinTypes, builtinCalls, globals, bytecode); root->visit(&visitor); } diff --git a/Compiler/src/Types.h b/Compiler/src/Types.h index de11fde9..bd12ea77 100644 --- a/Compiler/src/Types.h +++ b/Compiler/src/Types.h @@ -4,13 +4,31 @@ #include "Luau/Ast.h" #include "Luau/Bytecode.h" #include "Luau/DenseHash.h" +#include "ValueTracking.h" #include namespace Luau { +class BytecodeBuilder; -void buildTypeMap(DenseHashMap& functionTypes, DenseHashMap& localTypes, AstNode* root, - const char* vectorType); +struct BuiltinTypes +{ + BuiltinTypes(const char* vectorType) + : vectorType{{}, std::nullopt, AstName{vectorType}, std::nullopt, {}} + { + } + + // AstName use here will not match the AstNameTable, but the was we use them here always force a full string compare + AstTypeReference booleanType{{}, std::nullopt, AstName{"boolean"}, std::nullopt, {}}; + AstTypeReference numberType{{}, std::nullopt, AstName{"number"}, std::nullopt, {}}; + AstTypeReference stringType{{}, std::nullopt, AstName{"string"}, std::nullopt, {}}; + AstTypeReference vectorType; +}; + +void buildTypeMap(DenseHashMap& functionTypes, DenseHashMap& localTypes, + DenseHashMap& exprTypes, AstNode* root, const char* vectorType, const DenseHashMap& userdataTypes, + const BuiltinTypes& builtinTypes, const DenseHashMap& builtinCalls, const DenseHashMap& globals, + BytecodeBuilder& bytecode); } // namespace Luau diff --git a/Config/include/Luau/LinterConfig.h b/Config/include/Luau/LinterConfig.h index a598a3df..3a68c0d7 100644 --- a/Config/include/Luau/LinterConfig.h +++ b/Config/include/Luau/LinterConfig.h @@ -49,6 +49,7 @@ struct LintWarning Code_CommentDirective = 26, Code_IntegerParsing = 27, Code_ComparisonPrecedence = 28, + Code_RedundantNativeAttribute = 29, Code__Count }; @@ -115,6 +116,7 @@ static const char* kWarningNames[] = { "CommentDirective", "IntegerParsing", "ComparisonPrecedence", + "RedundantNativeAttribute", }; // clang-format on diff --git a/Config/src/Config.cpp b/Config/src/Config.cpp index 693e0f87..5fba9fa3 100644 --- a/Config/src/Config.cpp +++ b/Config/src/Config.cpp @@ -195,7 +195,7 @@ static Error parseJson(const std::string& contents, Action action) } else if (lexer.current().type == Lexeme::QuotedString) { - std::string value(lexer.current().data, lexer.current().length); + std::string value(lexer.current().data, lexer.current().getLength()); next(lexer); if (Error err = action(keys, value)) @@ -232,7 +232,7 @@ static Error parseJson(const std::string& contents, Action action) } else if (lexer.current().type == Lexeme::QuotedString) { - std::string key(lexer.current().data, lexer.current().length); + std::string key(lexer.current().data, lexer.current().getLength()); next(lexer); keys.push_back(key); @@ -250,7 +250,7 @@ static Error parseJson(const std::string& contents, Action action) lexer.current().type == Lexeme::ReservedFalse) { std::string value = lexer.current().type == Lexeme::QuotedString - ? std::string(lexer.current().data, lexer.current().length) + ? std::string(lexer.current().data, lexer.current().getLength()) : (lexer.current().type == Lexeme::ReservedTrue ? "true" : "false"); next(lexer); diff --git a/SECURITY.md b/SECURITY.md index 48a6ccc4..ca3f5923 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,6 +1,6 @@ # Security Guarantees -Luau provides a safe sandbox that scripts can not escape from, short of vulnerabilities in custom C functions exposed by the host. This includes the virtual machine and builtin libraries. Notably this currently does *not* include the work-in-progress native code generation facilities. +Luau provides a safe sandbox that scripts can not escape from, short of vulnerabilities in custom C functions exposed by the host. This includes the virtual machine, builtin libraries and native code generation facilities. Any source code can not result in memory safety errors or crashes during its compilation or execution. Violations of memory safety are considered vulnerabilities. diff --git a/Sources.cmake b/Sources.cmake index 6adbf283..4c5504b6 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -181,6 +181,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Error.h Analysis/include/Luau/FileResolver.h Analysis/include/Luau/Frontend.h + Analysis/include/Luau/Generalization.h Analysis/include/Luau/GlobalTypes.h Analysis/include/Luau/InsertionOrderedMap.h Analysis/include/Luau/Instantiation.h @@ -251,6 +252,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/EmbeddedBuiltinDefinitions.cpp Analysis/src/Error.cpp Analysis/src/Frontend.cpp + Analysis/src/Generalization.cpp Analysis/src/GlobalTypes.cpp Analysis/src/Instantiation.cpp Analysis/src/Instantiation2.cpp @@ -266,7 +268,6 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Refinement.cpp Analysis/src/RequireTracer.cpp Analysis/src/Scope.cpp - Analysis/src/Set.cpp Analysis/src/Simplify.cpp Analysis/src/Substitution.cpp Analysis/src/Subtyping.cpp @@ -421,6 +422,7 @@ if(TARGET Luau.UnitTest) tests/Fixture.cpp tests/Fixture.h tests/Frontend.test.cpp + tests/Generalization.test.cpp tests/InsertionOrderedMap.test.cpp tests/Instantiation2.test.cpp tests/IostreamOptional.h @@ -494,6 +496,7 @@ if(TARGET Luau.Conformance) target_sources(Luau.Conformance PRIVATE tests/RegisterCallbacks.h tests/RegisterCallbacks.cpp + tests/ConformanceIrHooks.h tests/Conformance.test.cpp tests/IrLowering.test.cpp tests/SharedCodeAllocator.test.cpp diff --git a/VM/include/lua.h b/VM/include/lua.h index 4876b933..4ee9306e 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -324,6 +324,10 @@ typedef void (*lua_Destructor)(lua_State* L, void* userdata); LUA_API void lua_setuserdatadtor(lua_State* L, int tag, lua_Destructor dtor); LUA_API lua_Destructor lua_getuserdatadtor(lua_State* L, int tag); +// alternative access for metatables already registered with luaL_newmetatable +LUA_API void lua_setuserdatametatable(lua_State* L, int tag, int idx); +LUA_API void lua_getuserdatametatable(lua_State* L, int tag); + LUA_API void lua_setlightuserdataname(lua_State* L, int tag, const char* name); LUA_API const char* lua_getlightuserdataname(lua_State* L, int tag); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 58c767f1..87f85af8 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -1427,6 +1427,33 @@ lua_Destructor lua_getuserdatadtor(lua_State* L, int tag) return L->global->udatagc[tag]; } +void lua_setuserdatametatable(lua_State* L, int tag, int idx) +{ + api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); + api_check(L, !L->global->udatamt[tag]); // reassignment not supported + StkId o = index2addr(L, idx); + api_check(L, ttistable(o)); + L->global->udatamt[tag] = hvalue(o); + L->top--; +} + +void lua_getuserdatametatable(lua_State* L, int tag) +{ + api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); + luaC_threadbarrier(L); + + if (Table* h = L->global->udatamt[tag]) + { + sethvalue(L, L->top, h); + } + else + { + setnilvalue(L->top); + } + + api_incr_top(L); +} + void lua_setlightuserdataname(lua_State* L, int tag, const char* name) { api_check(L, unsigned(tag) < LUA_LUTAG_LIMIT); diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 7122b035..07cc117e 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -330,12 +330,16 @@ l_noret luaG_runerrorL(lua_State* L, const char* fmt, ...) vsnprintf(result, sizeof(result), fmt, argp); va_end(argp); + lua_rawcheckstack(L, 1); + pusherror(L, result); luaD_throw(L, LUA_ERRRUN); } void luaG_pusherror(lua_State* L, const char* error) { + lua_rawcheckstack(L, 1); + pusherror(L, error); } diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index b33fe9dd..2a1e45c4 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -6,8 +6,6 @@ #include "lmem.h" #include "lgc.h" -LUAU_FASTFLAGVARIABLE(LuauLoadTypeInfo, false) - Proto* luaF_newproto(lua_State* L) { Proto* f = luaM_newgco(L, Proto, sizeof(Proto), L->activememcat); @@ -52,9 +50,7 @@ Proto* luaF_newproto(lua_State* L) f->linegaplog2 = 0; f->linedefined = 0; f->bytecodeid = 0; - - if (FFlag::LuauLoadTypeInfo) - f->sizetypeinfo = 0; + f->sizetypeinfo = 0; return f; } @@ -178,16 +174,8 @@ void luaF_freeproto(lua_State* L, Proto* f, lua_Page* page) if (f->execdata) L->global->ecb.destroy(L, f); - if (FFlag::LuauLoadTypeInfo) - { - if (f->typeinfo) - luaM_freearray(L, f->typeinfo, f->sizetypeinfo, uint8_t, f->memcat); - } - else - { - if (f->typeinfo) - luaM_freearray(L, f->typeinfo, f->numparams + 2, uint8_t, f->memcat); - } + if (f->typeinfo) + luaM_freearray(L, f->typeinfo, f->sizetypeinfo, uint8_t, f->memcat); luaM_freegco(L, f, sizeof(Proto), f->memcat, page); } diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index f8389422..4473f04f 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -14,8 +14,6 @@ #include -LUAU_FASTFLAG(LuauLoadTypeInfo) - /* * Luau uses an incremental non-generational non-moving mark&sweep garbage collector. * @@ -507,16 +505,8 @@ static size_t propagatemark(global_State* g) g->gray = p->gclist; traverseproto(g, p); - if (FFlag::LuauLoadTypeInfo) - { - return sizeof(Proto) + sizeof(Instruction) * p->sizecode + sizeof(Proto*) * p->sizep + sizeof(TValue) * p->sizek + p->sizelineinfo + - sizeof(LocVar) * p->sizelocvars + sizeof(TString*) * p->sizeupvalues + p->sizetypeinfo; - } - else - { - return sizeof(Proto) + sizeof(Instruction) * p->sizecode + sizeof(Proto*) * p->sizep + sizeof(TValue) * p->sizek + p->sizelineinfo + - sizeof(LocVar) * p->sizelocvars + sizeof(TString*) * p->sizeupvalues; - } + return sizeof(Proto) + sizeof(Instruction) * p->sizecode + sizeof(Proto*) * p->sizep + sizeof(TValue) * p->sizek + p->sizelineinfo + + sizeof(LocVar) * p->sizelocvars + sizeof(TString*) * p->sizeupvalues + p->sizetypeinfo; } default: LUAU_ASSERT(0); diff --git a/VM/src/lgc.h b/VM/src/lgc.h index ba433c67..010d7e86 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -23,11 +23,10 @@ #define GCSsweep 4 /* -** macro to tell when main invariant (white objects cannot point to black -** ones) must be kept. During a collection, the sweep -** phase may break the invariant, as objects turned white may point to -** still-black objects. The invariant is restored when sweep ends and -** all objects are white again. +** The main invariant of the garbage collector, while marking objects, +** is that a black object can never point to a white one. This invariant +** is not being enforced during a sweep phase, and is restored when sweep +** ends. */ #define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain || (g)->gcstate == GCSatomic) diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 3de18cf9..5ff5de72 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -53,6 +53,10 @@ * for each block size there's a page free list that contains pages that have at least one free block * (global_State::freegcopages). This free list is used to make sure object allocation is O(1). * + * When LUAU_ASSERTENABLED is enabled, all non-GCO pages are also linked in a list (global_State::allpages). + * Because this list is not strictly required for runtime operations, it is only tracked for the purposes of + * debugging. While overhead of linking those pages together is very small, unnecessary operations are avoided. + * * Compared to GCOs, regular allocations have two important differences: they can be freed in isolation, * and they don't start with a GC header. Because of this, each allocation is prefixed with block metadata, * which contains the pointer to the page for allocated blocks, and the pointer to the next free block @@ -120,11 +124,16 @@ static_assert(offsetof(Udata, data) == ABISWITCH(16, 16, 12), "size mismatch for static_assert(sizeof(Table) == ABISWITCH(48, 32, 32), "size mismatch for table header"); static_assert(offsetof(Buffer, data) == ABISWITCH(8, 8, 8), "size mismatch for buffer header"); -LUAU_FASTFLAGVARIABLE(LuauExtendedSizeClasses, false) - const size_t kSizeClasses = LUA_SIZECLASSES; -const size_t kMaxSmallSize_DEPRECATED = 512; // TODO: remove with FFlagLuauExtendedSizeClasses + +// Controls the number of entries in SizeClassConfig and define the maximum possible paged allocation size +// Modifications require updates the SizeClassConfig initialization const size_t kMaxSmallSize = 1024; + +// Effective limit on object size to use paged allocation +// Can be modified without additional changes to code, provided it is smaller or equal to kMaxSmallSize +const size_t kMaxSmallSizeUsed = 1024; + const size_t kLargePageThreshold = 512; // larger pages are used for objects larger than this size to fit more of them into a page // constant factor to reduce our page sizes by, to increase the chances that pages we allocate will @@ -183,13 +192,18 @@ struct SizeClassConfig const SizeClassConfig kSizeClassConfig; // size class for a block of size sz; returns -1 for size=0 because empty allocations take no space -#define sizeclass(sz) \ - (size_t((sz)-1) < (FFlag::LuauExtendedSizeClasses ? kMaxSmallSize : kMaxSmallSize_DEPRECATED) ? kSizeClassConfig.classForSize[sz] : -1) +#define sizeclass(sz) (size_t((sz)-1) < kMaxSmallSizeUsed ? kSizeClassConfig.classForSize[sz] : -1) // metadata for a block is stored in the first pointer of the block #define metadata(block) (*(void**)(block)) #define freegcolink(block) (*(void**)((char*)block + kGCOLinkOffset)) +#if defined(LUAU_ASSERTENABLED) +#define debugpageset(x) (x) +#else +#define debugpageset(x) NULL +#endif + struct lua_Page { // list of pages with free blocks @@ -265,34 +279,18 @@ static lua_Page* newpage(lua_State* L, lua_Page** pageset, int pageSize, int blo // if it is inlined, then the compiler may determine those functions are "too big" to be profitably inlined, which results in reduced performance LUAU_NOINLINE static lua_Page* newclasspage(lua_State* L, lua_Page** freepageset, lua_Page** pageset, uint8_t sizeClass, bool storeMetadata) { - if (FFlag::LuauExtendedSizeClasses) - { - int sizeOfClass = kSizeClassConfig.sizeOfClass[sizeClass]; - int pageSize = sizeOfClass > int(kLargePageThreshold) ? kLargePageSize : kSmallPageSize; - int blockSize = sizeOfClass + (storeMetadata ? kBlockHeader : 0); - int blockCount = (pageSize - offsetof(lua_Page, data)) / blockSize; + int sizeOfClass = kSizeClassConfig.sizeOfClass[sizeClass]; + int pageSize = sizeOfClass > int(kLargePageThreshold) ? kLargePageSize : kSmallPageSize; + int blockSize = sizeOfClass + (storeMetadata ? kBlockHeader : 0); + int blockCount = (pageSize - offsetof(lua_Page, data)) / blockSize; - lua_Page* page = newpage(L, pageset, pageSize, blockSize, blockCount); + lua_Page* page = newpage(L, pageset, pageSize, blockSize, blockCount); - // prepend a page to page freelist (which is empty because we only ever allocate a new page when it is!) - LUAU_ASSERT(!freepageset[sizeClass]); - freepageset[sizeClass] = page; + // prepend a page to page freelist (which is empty because we only ever allocate a new page when it is!) + LUAU_ASSERT(!freepageset[sizeClass]); + freepageset[sizeClass] = page; - return page; - } - else - { - int blockSize = kSizeClassConfig.sizeOfClass[sizeClass] + (storeMetadata ? kBlockHeader : 0); - int blockCount = (kSmallPageSize - offsetof(lua_Page, data)) / blockSize; - - lua_Page* page = newpage(L, pageset, kSmallPageSize, blockSize, blockCount); - - // prepend a page to page freelist (which is empty because we only ever allocate a new page when it is!) - LUAU_ASSERT(!freepageset[sizeClass]); - freepageset[sizeClass] = page; - - return page; - } + return page; } static void freepage(lua_State* L, lua_Page** pageset, lua_Page* page) @@ -336,7 +334,7 @@ static void* newblock(lua_State* L, int sizeClass) // slow path: no page in the freelist, allocate a new one if (!page) - page = newclasspage(L, g->freepages, NULL, sizeClass, true); + page = newclasspage(L, g->freepages, debugpageset(&g->allpages), sizeClass, true); LUAU_ASSERT(!page->prev); LUAU_ASSERT(page->freeList || page->freeNext >= 0); @@ -457,7 +455,7 @@ static void freeblock(lua_State* L, int sizeClass, void* block) // if it's the last block in the page, we don't need the page if (page->busyBlocks == 0) - freeclasspage(L, g->freepages, NULL, page, sizeClass); + freeclasspage(L, g->freepages, debugpageset(&g->allpages), page, sizeClass); } static void freegcoblock(lua_State* L, int sizeClass, void* block, lua_Page* page) diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index 858f61a3..6b7a9aa0 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -204,12 +204,16 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) g->freepages[i] = NULL; g->freegcopages[i] = NULL; } + g->allpages = NULL; g->allgcopages = NULL; g->sweepgcopage = NULL; for (i = 0; i < LUA_T_COUNT; i++) g->mt[i] = NULL; for (i = 0; i < LUA_UTAG_LIMIT; i++) + { g->udatagc[i] = NULL; + g->udatamt[i] = NULL; + } for (i = 0; i < LUA_LUTAG_LIMIT; i++) g->lightuserdataname[i] = NULL; for (i = 0; i < LUA_MEMORY_CATEGORIES; i++) diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 97546511..f8caa69b 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -156,6 +156,7 @@ struct lua_ExecutionCallbacks int (*enter)(lua_State* L, Proto* proto); // called when function is about to start/resume (when execdata is present), return 0 to exit VM void (*disable)(lua_State* L, Proto* proto); // called when function has to be switched from native to bytecode in the debugger size_t (*getmemorysize)(lua_State* L, Proto* proto); // called to request the size of memory associated with native part of the Proto + uint8_t (*gettypemapping)(lua_State* L, const char* str, size_t len); // called to get the userdata type index }; /* @@ -188,7 +189,8 @@ typedef struct global_State struct lua_Page* freepages[LUA_SIZECLASSES]; // free page linked list for each size class for non-collectable objects struct lua_Page* freegcopages[LUA_SIZECLASSES]; // free page linked list for each size class for collectable objects - struct lua_Page* allgcopages; // page linked list with all pages for all classes + struct lua_Page* allpages; // page linked list with all pages for all non-collectable object classes (available with LUAU_ASSERTENABLED) + struct lua_Page* allgcopages; // page linked list with all pages for all collectable object classes struct lua_Page* sweepgcopage; // position of the sweep in `allgcopages' size_t memcatbytes[LUA_MEMORY_CATEGORIES]; // total amount of memory used by each memory category @@ -215,6 +217,7 @@ typedef struct global_State lua_ExecutionCallbacks ecb; void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory + Table* udatamt[LUA_LUTAG_LIMIT]; // metatables for tagged userdata TString* lightuserdataname[LUA_LUTAG_LIMIT]; // names for tagged lightuserdata diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 27c08f11..75d9f400 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -11,8 +11,6 @@ #include "ldebug.h" #include "lvm.h" -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauFastCrossTableMove, false) - static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -54,17 +52,28 @@ static int maxn(lua_State* L) { double max = 0; luaL_checktype(L, 1, LUA_TTABLE); - lua_pushnil(L); // first key - while (lua_next(L, 1)) + + Table* t = hvalue(L->base); + + for (int i = 0; i < t->sizearray; i++) { - lua_pop(L, 1); // remove value - if (lua_type(L, -1) == LUA_TNUMBER) + if (!ttisnil(&t->array[i])) + max = i + 1; + } + + for (int i = 0; i < sizenode(t); i++) + { + LuaNode* n = gnode(t, i); + + if (!ttisnil(gval(n)) && ttisnumber(gkey(n))) { - double v = lua_tonumber(L, -1); + double v = nvalue(gkey(n)); + if (v > max) max = v; } } + lua_pushnumber(L, max); return 1; } @@ -115,68 +124,6 @@ static void moveelements(lua_State* L, int srct, int dstt, int f, int e, int t) luaC_barrierfast(L, dst); } - else if (DFFlag::LuauFastCrossTableMove && dst != src) - { - // compute the array slice we have to copy over - int slicestart = f < 1 ? 0 : (f > src->sizearray ? src->sizearray : f - 1); - int sliceend = e < 1 ? 0 : (e > src->sizearray ? src->sizearray : e); - LUAU_ASSERT(slicestart <= sliceend); - - int slicecount = sliceend - slicestart; - - if (slicecount > 0) - { - // array slice starting from INT_MIN is impossible, so we don't have to worry about int overflow - int dstslicestart = f < 1 ? -f + 1 : 0; - - // copy over the slice - for (int i = 0; i < slicecount; ++i) - { - lua_rawgeti(L, srct, slicestart + i + 1); - lua_rawseti(L, dstt, dstslicestart + t + i); - } - } - - // copy the remaining elements that could be in the hash part - int hashpartsize = sizenode(src); - - // select the strategy with the least amount of steps - if (n <= hashpartsize) - { - for (int i = 0; i < n; ++i) - { - // skip array slice elements that were already copied over - if (cast_to(unsigned int, f + i - 1) < cast_to(unsigned int, src->sizearray)) - continue; - - lua_rawgeti(L, srct, f + i); - lua_rawseti(L, dstt, t + i); - } - } - else - { - // source and destination tables are different, so we can iterate over source hash part directly - int i = hashpartsize; - - while (i--) - { - LuaNode* node = gnode(src, i); - if (ttisnumber(gkey(node))) - { - double n = nvalue(gkey(node)); - - int k; - luai_num2int(k, n); - - if (luai_numeq(cast_num(k), n) && k >= f && k <= e) - { - lua_rawgeti(L, srct, k); - lua_rawseti(L, dstt, t - f + k); - } - } - } - } - } else { if (t > e || t <= f || dst != src) @@ -282,31 +229,42 @@ static int tmove(lua_State* L) return 1; } -static void addfield(lua_State* L, luaL_Strbuf* b, int i) +static void addfield(lua_State* L, luaL_Strbuf* b, int i, Table* t) { - int tt = lua_rawgeti(L, 1, i); - if (tt != LUA_TSTRING && tt != LUA_TNUMBER) - luaL_error(L, "invalid value (%s) at index %d in table for 'concat'", luaL_typename(L, -1), i); - luaL_addvalue(b); + if (t && unsigned(i - 1) < unsigned(t->sizearray) && ttisstring(&t->array[i - 1])) + { + TString* ts = tsvalue(&t->array[i - 1]); + luaL_addlstring(b, getstr(ts), ts->len); + } + else + { + int tt = lua_rawgeti(L, 1, i); + if (tt != LUA_TSTRING && tt != LUA_TNUMBER) + luaL_error(L, "invalid value (%s) at index %d in table for 'concat'", luaL_typename(L, -1), i); + luaL_addvalue(b); + } } static int tconcat(lua_State* L) { - luaL_Strbuf b; size_t lsep; - int i, last; const char* sep = luaL_optlstring(L, 2, "", &lsep); luaL_checktype(L, 1, LUA_TTABLE); - i = luaL_optinteger(L, 3, 1); - last = luaL_opt(L, luaL_checkinteger, 4, lua_objlen(L, 1)); + int i = luaL_optinteger(L, 3, 1); + int last = luaL_opt(L, luaL_checkinteger, 4, lua_objlen(L, 1)); + + Table* t = hvalue(L->base); + + luaL_Strbuf b; luaL_buffinit(L, &b); for (; i < last; i++) { - addfield(L, &b, i); - luaL_addlstring(&b, sep, lsep); + addfield(L, &b, i, t); + if (lsep != 0) + luaL_addlstring(&b, sep, lsep); } if (i == last) // add last value (if interval was not empty) - addfield(L, &b, i); + addfield(L, &b, i, t); luaL_pushresult(&b); return 1; } diff --git a/VM/src/lvm.h b/VM/src/lvm.h index 5ec7bc16..0b8690be 100644 --- a/VM/src/lvm.h +++ b/VM/src/lvm.h @@ -15,7 +15,10 @@ LUAI_FUNC int luaV_strcmp(const TString* ls, const TString* rs); LUAI_FUNC int luaV_lessthan(lua_State* L, const TValue* l, const TValue* r); LUAI_FUNC int luaV_lessequal(lua_State* L, const TValue* l, const TValue* r); LUAI_FUNC int luaV_equalval(lua_State* L, const TValue* t1, const TValue* t2); -LUAI_FUNC void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TMS op); + +template +void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); + LUAI_FUNC void luaV_dolen(lua_State* L, StkId ra, const TValue* rb); LUAI_FUNC const TValue* luaV_tonumber(const TValue* obj, TValue* n); LUAI_FUNC const float* luaV_tovector(const TValue* obj); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 74e30c94..fb253c6a 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -96,7 +96,7 @@ VM_DISPATCH_OP(LOP_POWK), VM_DISPATCH_OP(LOP_AND), VM_DISPATCH_OP(LOP_OR), VM_DISPATCH_OP(LOP_ANDK), VM_DISPATCH_OP(LOP_ORK), \ VM_DISPATCH_OP(LOP_CONCAT), VM_DISPATCH_OP(LOP_NOT), VM_DISPATCH_OP(LOP_MINUS), VM_DISPATCH_OP(LOP_LENGTH), VM_DISPATCH_OP(LOP_NEWTABLE), \ VM_DISPATCH_OP(LOP_DUPTABLE), VM_DISPATCH_OP(LOP_SETLIST), VM_DISPATCH_OP(LOP_FORNPREP), VM_DISPATCH_OP(LOP_FORNLOOP), \ - VM_DISPATCH_OP(LOP_FORGLOOP), VM_DISPATCH_OP(LOP_FORGPREP_INEXT), VM_DISPATCH_OP(LOP_DEP_FORGLOOP_INEXT), VM_DISPATCH_OP(LOP_FORGPREP_NEXT), \ + VM_DISPATCH_OP(LOP_FORGLOOP), VM_DISPATCH_OP(LOP_FORGPREP_INEXT), VM_DISPATCH_OP(LOP_FASTCALL3), VM_DISPATCH_OP(LOP_FORGPREP_NEXT), \ VM_DISPATCH_OP(LOP_NATIVECALL), VM_DISPATCH_OP(LOP_GETVARARGS), VM_DISPATCH_OP(LOP_DUPCLOSURE), VM_DISPATCH_OP(LOP_PREPVARARGS), \ VM_DISPATCH_OP(LOP_LOADKX), VM_DISPATCH_OP(LOP_JUMPX), VM_DISPATCH_OP(LOP_FASTCALL), VM_DISPATCH_OP(LOP_COVERAGE), \ VM_DISPATCH_OP(LOP_CAPTURE), VM_DISPATCH_OP(LOP_SUBRK), VM_DISPATCH_OP(LOP_DIVRK), VM_DISPATCH_OP(LOP_FASTCALL1), \ @@ -1487,7 +1487,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_ADD)); + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); VM_NEXT(); } } @@ -1533,7 +1533,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_SUB)); + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); VM_NEXT(); } } @@ -1594,7 +1594,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_MUL)); + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); VM_NEXT(); } } @@ -1655,7 +1655,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_DIV)); + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); VM_NEXT(); } } @@ -1703,7 +1703,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_IDIV)); + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); VM_NEXT(); } } @@ -1727,7 +1727,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_MOD)); + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); VM_NEXT(); } } @@ -1748,7 +1748,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_POW)); + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rc)); VM_NEXT(); } } @@ -1769,7 +1769,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_ADD)); + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); VM_NEXT(); } } @@ -1790,7 +1790,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_SUB)); + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); VM_NEXT(); } } @@ -1835,7 +1835,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_MUL)); + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); VM_NEXT(); } } @@ -1881,7 +1881,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_DIV)); + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); VM_NEXT(); } } @@ -1928,7 +1928,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_IDIV)); + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); VM_NEXT(); } } @@ -1952,7 +1952,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_MOD)); + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); VM_NEXT(); } } @@ -1979,7 +1979,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_POW)); + VM_PROTECT(luaV_doarithimpl(L, ra, rb, kv)); VM_NEXT(); } } @@ -2092,7 +2092,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, rb, rb, TM_UNM)); + VM_PROTECT(luaV_doarithimpl(L, ra, rb, rb)); VM_NEXT(); } } @@ -2432,12 +2432,6 @@ reentry: VM_NEXT(); } - VM_CASE(LOP_DEP_FORGLOOP_INEXT) - { - LUAU_ASSERT(!"Unsupported deprecated opcode"); - LUAU_UNREACHABLE(); - } - VM_CASE(LOP_FORGPREP_NEXT) { Instruction insn = *pc++; @@ -2711,7 +2705,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, kv, rc, TM_SUB)); + VM_PROTECT(luaV_doarithimpl(L, ra, kv, rc)); VM_NEXT(); } } @@ -2739,7 +2733,7 @@ reentry: else { // slow-path, may invoke C/Lua via metamethods - VM_PROTECT(luaV_doarith(L, ra, kv, rc, TM_DIV)); + VM_PROTECT(luaV_doarithimpl(L, ra, kv, rc)); VM_NEXT(); } } @@ -2892,6 +2886,60 @@ reentry: } } + VM_CASE(LOP_FASTCALL3) + { + Instruction insn = *pc++; + int bfid = LUAU_INSN_A(insn); + int skip = LUAU_INSN_C(insn) - 1; + uint32_t aux = *pc++; + TValue* arg1 = VM_REG(LUAU_INSN_B(insn)); + TValue* arg2 = VM_REG(aux & 0xff); + TValue* arg3 = VM_REG((aux >> 8) & 0xff); + + LUAU_ASSERT(unsigned(pc - cl->l.p->code + skip) < unsigned(cl->l.p->sizecode)); + + Instruction call = pc[skip]; + LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + StkId ra = VM_REG(LUAU_INSN_A(call)); + + int nparams = 3; + int nresults = LUAU_INSN_C(call) - 1; + + luau_FastFunction f = luauF_table[bfid]; + LUAU_ASSERT(f); + + if (cl->env->safeenv) + { + VM_PROTECT_PC(); // f may fail due to OOM + + setobj2s(L, L->top, arg2); + setobj2s(L, L->top + 1, arg3); + + int n = f(L, ra, arg1, nresults, L->top, nparams); + + if (n >= 0) + { + if (nresults == LUA_MULTRET) + L->top = ra + n; + + pc += skip + 1; // skip instructions that compute function as well as CALL + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + else + { + // continue execution through the fallback code + VM_NEXT(); + } + } + else + { + // continue execution through the fallback code + VM_NEXT(); + } + } + VM_CASE(LOP_BREAK) { LUAU_ASSERT(cl->l.p->debuginsn); diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index f13c0f21..112a7197 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,7 +13,7 @@ #include -LUAU_FASTFLAG(LuauLoadTypeInfo) +LUAU_FASTFLAGVARIABLE(LuauLoadUserdataInfo, false) // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens template @@ -187,6 +187,65 @@ static void resolveImportSafe(lua_State* L, Table* env, TValue* k, uint32_t id) } } +static void remapUserdataTypes(char* data, size_t size, uint8_t* userdataRemapping, uint32_t count) +{ + LUAU_ASSERT(FFlag::LuauLoadUserdataInfo); + + size_t offset = 0; + + uint32_t typeSize = readVarInt(data, size, offset); + uint32_t upvalCount = readVarInt(data, size, offset); + uint32_t localCount = readVarInt(data, size, offset); + + if (typeSize != 0) + { + uint8_t* types = (uint8_t*)data + offset; + + // Skip two bytes of function type introduction + for (uint32_t i = 2; i < typeSize; i++) + { + uint32_t index = uint32_t(types[i] - LBC_TYPE_TAGGED_USERDATA_BASE); + + if (index < count) + types[i] = userdataRemapping[index]; + } + + offset += typeSize; + } + + if (upvalCount != 0) + { + uint8_t* types = (uint8_t*)data + offset; + + for (uint32_t i = 0; i < upvalCount; i++) + { + uint32_t index = uint32_t(types[i] - LBC_TYPE_TAGGED_USERDATA_BASE); + + if (index < count) + types[i] = userdataRemapping[index]; + } + + offset += upvalCount; + } + + if (localCount != 0) + { + for (uint32_t i = 0; i < localCount; i++) + { + uint32_t index = uint32_t(data[offset] - LBC_TYPE_TAGGED_USERDATA_BASE); + + if (index < count) + data[offset] = userdataRemapping[index]; + + offset += 2; + readVarInt(data, size, offset); + readVarInt(data, size, offset); + } + } + + LUAU_ASSERT(offset == size); +} + int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env) { size_t offset = 0; @@ -227,6 +286,18 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size if (version >= 4) { typesversion = read(data, size, offset); + + if (FFlag::LuauLoadUserdataInfo) + { + if (typesversion < LBC_TYPE_VERSION_MIN || typesversion > LBC_TYPE_VERSION_MAX) + { + char chunkbuf[LUA_IDSIZE]; + const char* chunkid = luaO_chunkid(chunkbuf, sizeof(chunkbuf), chunkname, strlen(chunkname)); + lua_pushfstring(L, "%s: bytecode type version mismatch (expected [%d..%d], got %d)", chunkid, LBC_TYPE_VERSION_MIN, + LBC_TYPE_VERSION_MAX, typesversion); + return 1; + } + } } // string table @@ -241,6 +312,31 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size offset += length; } + // userdata type remapping table + // for unknown userdata types, the entry will remap to common 'userdata' type + const uint32_t userdataTypeLimit = LBC_TYPE_TAGGED_USERDATA_END - LBC_TYPE_TAGGED_USERDATA_BASE; + uint8_t userdataRemapping[userdataTypeLimit]; + + if (FFlag::LuauLoadUserdataInfo && typesversion == 3) + { + memset(userdataRemapping, LBC_TYPE_USERDATA, userdataTypeLimit); + + uint8_t index = read(data, size, offset); + + while (index != 0) + { + TString* name = readString(strings, data, size, offset); + + if (uint32_t(index - 1) < userdataTypeLimit) + { + if (auto cb = L->global->ecb.gettypemapping) + userdataRemapping[index - 1] = cb(L, getstr(name), name->len); + } + + index = read(data, size, offset); + } + } + // proto table unsigned int protoCount = readVarInt(data, size, offset); TempBuffer protos(L, protoCount); @@ -260,65 +356,11 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size { p->flags = read(data, size, offset); - if (FFlag::LuauLoadTypeInfo) - { - if (typesversion == 1) - { - uint32_t typesize = readVarInt(data, size, offset); - - if (typesize) - { - uint8_t* types = (uint8_t*)data + offset; - - LUAU_ASSERT(typesize == unsigned(2 + p->numparams)); - LUAU_ASSERT(types[0] == LBC_TYPE_FUNCTION); - LUAU_ASSERT(types[1] == p->numparams); - - // transform v1 into v2 format - int headersize = typesize > 127 ? 4 : 3; - - p->typeinfo = luaM_newarray(L, headersize + typesize, uint8_t, p->memcat); - p->sizetypeinfo = headersize + typesize; - - if (headersize == 4) - { - p->typeinfo[0] = (typesize & 127) | (1 << 7); - p->typeinfo[1] = typesize >> 7; - p->typeinfo[2] = 0; - p->typeinfo[3] = 0; - } - else - { - p->typeinfo[0] = uint8_t(typesize); - p->typeinfo[1] = 0; - p->typeinfo[2] = 0; - } - - memcpy(p->typeinfo + headersize, types, typesize); - } - - offset += typesize; - } - else if (typesversion == 2) - { - uint32_t typesize = readVarInt(data, size, offset); - - if (typesize) - { - uint8_t* types = (uint8_t*)data + offset; - - p->typeinfo = luaM_newarray(L, typesize, uint8_t, p->memcat); - p->sizetypeinfo = typesize; - memcpy(p->typeinfo, types, typesize); - offset += typesize; - } - } - } - else + if (typesversion == 1) { uint32_t typesize = readVarInt(data, size, offset); - if (typesize && typesversion == LBC_TYPE_VERSION_DEPRECATED) + if (typesize) { uint8_t* types = (uint8_t*)data + offset; @@ -326,12 +368,50 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size LUAU_ASSERT(types[0] == LBC_TYPE_FUNCTION); LUAU_ASSERT(types[1] == p->numparams); - p->typeinfo = luaM_newarray(L, typesize, uint8_t, p->memcat); - memcpy(p->typeinfo, types, typesize); + // transform v1 into v2 format + int headersize = typesize > 127 ? 4 : 3; + + p->typeinfo = luaM_newarray(L, headersize + typesize, uint8_t, p->memcat); + p->sizetypeinfo = headersize + typesize; + + if (headersize == 4) + { + p->typeinfo[0] = (typesize & 127) | (1 << 7); + p->typeinfo[1] = typesize >> 7; + p->typeinfo[2] = 0; + p->typeinfo[3] = 0; + } + else + { + p->typeinfo[0] = uint8_t(typesize); + p->typeinfo[1] = 0; + p->typeinfo[2] = 0; + } + + memcpy(p->typeinfo + headersize, types, typesize); } offset += typesize; } + else if (typesversion == 2 || (FFlag::LuauLoadUserdataInfo && typesversion == 3)) + { + uint32_t typesize = readVarInt(data, size, offset); + + if (typesize) + { + uint8_t* types = (uint8_t*)data + offset; + + p->typeinfo = luaM_newarray(L, typesize, uint8_t, p->memcat); + p->sizetypeinfo = typesize; + memcpy(p->typeinfo, types, typesize); + offset += typesize; + + if (FFlag::LuauLoadUserdataInfo && typesversion == 3) + { + remapUserdataTypes((char*)(uint8_t*)p->typeinfo, p->sizetypeinfo, userdataRemapping, userdataTypeLimit); + } + } + } } const int sizecode = readVarInt(data, size, offset); diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 4db8bba7..41990742 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -373,13 +373,102 @@ void luaV_concat(lua_State* L, int total, int last) } while (total > 1); // repeat until only 1 result left } -void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TMS op) +template +void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc) { TValue tempb, tempc; const TValue *b, *c; + + // vector operations that we support: + // v+v v-v -v (add/sub/neg) + // v*v s*v v*s (mul) + // v/v s/v v/s (div) + // v//v s//v v//s (floor div) + const float* vb = ttisvector(rb) ? vvalue(rb) : nullptr; + const float* vc = ttisvector(rc) ? vvalue(rc) : nullptr; + + if (vb && vc) + { + switch (op) + { + case TM_ADD: + setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2], vb[3] + vc[3]); + return; + case TM_SUB: + setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2], vb[3] - vc[3]); + return; + case TM_MUL: + setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2], vb[3] * vc[3]); + return; + case TM_DIV: + setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2], vb[3] / vc[3]); + return; + case TM_IDIV: + setvvalue(ra, float(luai_numidiv(vb[0], vc[0])), float(luai_numidiv(vb[1], vc[1])), float(luai_numidiv(vb[2], vc[2])), + float(luai_numidiv(vb[3], vc[3]))); + return; + case TM_UNM: + setvvalue(ra, -vb[0], -vb[1], -vb[2], -vb[3]); + return; + default: + break; + } + } + else if (vb) + { + c = ttisnumber(rc) ? rc : luaV_tonumber(rc, &tempc); + + if (c) + { + float nc = cast_to(float, nvalue(c)); + + switch (op) + { + case TM_MUL: + setvvalue(ra, vb[0] * nc, vb[1] * nc, vb[2] * nc, vb[3] * nc); + return; + case TM_DIV: + setvvalue(ra, vb[0] / nc, vb[1] / nc, vb[2] / nc, vb[3] / nc); + return; + case TM_IDIV: + setvvalue(ra, float(luai_numidiv(vb[0], nc)), float(luai_numidiv(vb[1], nc)), float(luai_numidiv(vb[2], nc)), + float(luai_numidiv(vb[3], nc))); + return; + default: + break; + } + } + } + else if (vc) + { + b = ttisnumber(rb) ? rb : luaV_tonumber(rb, &tempb); + + if (b) + { + float nb = cast_to(float, nvalue(b)); + + switch (op) + { + case TM_MUL: + setvvalue(ra, nb * vc[0], nb * vc[1], nb * vc[2], nb * vc[3]); + return; + case TM_DIV: + setvvalue(ra, nb / vc[0], nb / vc[1], nb / vc[2], nb / vc[3]); + return; + case TM_IDIV: + setvvalue(ra, float(luai_numidiv(nb, vc[0])), float(luai_numidiv(nb, vc[1])), float(luai_numidiv(nb, vc[2])), + float(luai_numidiv(nb, vc[3]))); + return; + default: + break; + } + } + } + if ((b = luaV_tonumber(rb, &tempb)) != NULL && (c = luaV_tonumber(rc, &tempc)) != NULL) { double nb = nvalue(b), nc = nvalue(c); + switch (op) { case TM_ADD: @@ -413,93 +502,6 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM } else { - // vector operations that we support: - // v+v v-v -v (add/sub/neg) - // v*v s*v v*s (mul) - // v/v s/v v/s (div) - // v//v s//v v//s (floor div) - - const float* vb = luaV_tovector(rb); - const float* vc = luaV_tovector(rc); - - if (vb && vc) - { - switch (op) - { - case TM_ADD: - setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2], vb[3] + vc[3]); - return; - case TM_SUB: - setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2], vb[3] - vc[3]); - return; - case TM_MUL: - setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2], vb[3] * vc[3]); - return; - case TM_DIV: - setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2], vb[3] / vc[3]); - return; - case TM_IDIV: - setvvalue(ra, float(luai_numidiv(vb[0], vc[0])), float(luai_numidiv(vb[1], vc[1])), float(luai_numidiv(vb[2], vc[2])), - float(luai_numidiv(vb[3], vc[3]))); - return; - case TM_UNM: - setvvalue(ra, -vb[0], -vb[1], -vb[2], -vb[3]); - return; - default: - break; - } - } - else if (vb) - { - c = luaV_tonumber(rc, &tempc); - - if (c) - { - float nc = cast_to(float, nvalue(c)); - - switch (op) - { - case TM_MUL: - setvvalue(ra, vb[0] * nc, vb[1] * nc, vb[2] * nc, vb[3] * nc); - return; - case TM_DIV: - setvvalue(ra, vb[0] / nc, vb[1] / nc, vb[2] / nc, vb[3] / nc); - return; - case TM_IDIV: - setvvalue(ra, float(luai_numidiv(vb[0], nc)), float(luai_numidiv(vb[1], nc)), float(luai_numidiv(vb[2], nc)), - float(luai_numidiv(vb[3], nc))); - return; - default: - break; - } - } - } - else if (vc) - { - b = luaV_tonumber(rb, &tempb); - - if (b) - { - float nb = cast_to(float, nvalue(b)); - - switch (op) - { - case TM_MUL: - setvvalue(ra, nb * vc[0], nb * vc[1], nb * vc[2], nb * vc[3]); - return; - case TM_DIV: - setvvalue(ra, nb / vc[0], nb / vc[1], nb / vc[2], nb / vc[3]); - return; - case TM_IDIV: - setvvalue(ra, float(luai_numidiv(nb, vc[0])), float(luai_numidiv(nb, vc[1])), float(luai_numidiv(nb, vc[2])), - float(luai_numidiv(nb, vc[3]))); - return; - default: - break; - } - } - } - if (!call_binTM(L, rb, rc, ra, op)) { luaG_aritherror(L, rb, rc, op); @@ -507,6 +509,16 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM } } +// instantiate private template implementation for external callers +template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); +template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); +template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); +template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); +template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); +template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); +template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); +template void luaV_doarithimpl(lua_State* L, StkId ra, const TValue* rb, const TValue* rc); + void luaV_dolen(lua_State* L, StkId ra, const TValue* rb) { const TValue* tm = NULL; diff --git a/fuzz/CMakeLists.txt b/fuzz/CMakeLists.txt index c18fbba5..be40b811 100644 --- a/fuzz/CMakeLists.txt +++ b/fuzz/CMakeLists.txt @@ -1,6 +1,6 @@ # This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details if(${CMAKE_VERSION} VERSION_LESS "3.26") - message(WARNING "Building the Luau fuzzer requires Clang version 3.26 of higher.") + message(WARNING "Building the Luau fuzzer requires CMake version 3.26 or higher.") return() endif() diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index 46a775f4..b9af5247 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -379,7 +379,7 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) if (luau_load(globalState, "=fuzz", bytecode.data(), bytecode.size(), 0) == 0) { Luau::CodeGen::AssemblyOptions options; - options.flags = Luau::CodeGen::CodeGen_ColdFunctions; + options.compilationOptions.flags = Luau::CodeGen::CodeGen_ColdFunctions; options.outputBinary = true; options.target = kFuzzCodegenTarget; Luau::CodeGen::getAssembly(globalState, -1, options); diff --git a/tests/AstJsonEncoder.test.cpp b/tests/AstJsonEncoder.test.cpp index 82e8f139..1c8b2127 100644 --- a/tests/AstJsonEncoder.test.cpp +++ b/tests/AstJsonEncoder.test.cpp @@ -9,6 +9,8 @@ #include #include +LUAU_FASTFLAG(LuauDeclarationExtraPropData) + using namespace Luau; struct JsonEncoderFixture @@ -408,16 +410,32 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatTypeAlias") TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareFunction") { + ScopedFastFlag luauDeclarationExtraPropData{FFlag::LuauDeclarationExtraPropData, true}; + AstStat* statement = expectParseStatement("declare function foo(x: number): string"); std::string_view expected = - R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,39","name":"foo","params":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","nameLocation":"0,24 - 0,30","parameters":[]}]},"retTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,33 - 0,39","name":"string","nameLocation":"0,33 - 0,39","parameters":[]}]},"generics":[],"genericPacks":[]})"; + R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,39","name":"foo","nameLocation":"0,17 - 0,20","params":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","nameLocation":"0,24 - 0,30","parameters":[]}]},"paramNames":[{"type":"AstArgumentName","name":"x","location":"0,21 - 0,22"}],"vararg":false,"varargLocation":"0,0 - 0,0","retTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,33 - 0,39","name":"string","nameLocation":"0,33 - 0,39","parameters":[]}]},"generics":[],"genericPacks":[]})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareFunction2") +{ + ScopedFastFlag luauDeclarationExtraPropData{FFlag::LuauDeclarationExtraPropData, true}; + + AstStat* statement = expectParseStatement("declare function foo(x: number, ...: string): string"); + + std::string_view expected = + R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,52","name":"foo","nameLocation":"0,17 - 0,20","params":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","nameLocation":"0,24 - 0,30","parameters":[]}],"tailType":{"type":"AstTypePackVariadic","location":"0,37 - 0,43","variadicType":{"type":"AstTypeReference","location":"0,37 - 0,43","name":"string","nameLocation":"0,37 - 0,43","parameters":[]}}},"paramNames":[{"type":"AstArgumentName","name":"x","location":"0,21 - 0,22"}],"vararg":true,"varargLocation":"0,32 - 0,35","retTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,46 - 0,52","name":"string","nameLocation":"0,46 - 0,52","parameters":[]}]},"generics":[],"genericPacks":[]})"; CHECK(toJson(statement) == expected); } TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareClass") { + ScopedFastFlag luauDeclarationExtraPropData{FFlag::LuauDeclarationExtraPropData, true}; + AstStatBlock* root = expectParse(R"( declare class Foo prop: number @@ -432,11 +450,11 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareClass") REQUIRE(2 == root->body.size); std::string_view expected1 = - R"({"type":"AstStatDeclareClass","location":"1,22 - 4,11","name":"Foo","props":[{"name":"prop","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"2,18 - 2,24","name":"number","nameLocation":"2,18 - 2,24","parameters":[]}},{"name":"method","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeFunction","location":"3,21 - 4,11","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,39 - 3,45","name":"number","nameLocation":"3,39 - 3,45","parameters":[]}]},"argNames":[{"type":"AstArgumentName","name":"foo","location":"3,34 - 3,37"}],"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,48 - 3,54","name":"string","nameLocation":"3,48 - 3,54","parameters":[]}]}}}],"indexer":null})"; + R"({"type":"AstStatDeclareClass","location":"1,22 - 4,11","name":"Foo","props":[{"name":"prop","nameLocation":"2,12 - 2,16","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"2,18 - 2,24","name":"number","nameLocation":"2,18 - 2,24","parameters":[]},"location":"2,12 - 2,24"},{"name":"method","nameLocation":"3,21 - 3,27","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeFunction","location":"3,12 - 3,54","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,39 - 3,45","name":"number","nameLocation":"3,39 - 3,45","parameters":[]}]},"argNames":[{"type":"AstArgumentName","name":"foo","location":"3,34 - 3,37"}],"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,48 - 3,54","name":"string","nameLocation":"3,48 - 3,54","parameters":[]}]}},"location":"3,12 - 3,54"}],"indexer":null})"; CHECK(toJson(root->body.data[0]) == expected1); std::string_view expected2 = - R"({"type":"AstStatDeclareClass","location":"6,22 - 8,11","name":"Bar","superName":"Foo","props":[{"name":"prop2","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"7,19 - 7,25","name":"string","nameLocation":"7,19 - 7,25","parameters":[]}}],"indexer":null})"; + R"({"type":"AstStatDeclareClass","location":"6,22 - 8,11","name":"Bar","superName":"Foo","props":[{"name":"prop2","nameLocation":"7,12 - 7,17","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"7,19 - 7,25","name":"string","nameLocation":"7,19 - 7,25","parameters":[]},"location":"7,12 - 7,25"}],"indexer":null})"; CHECK(toJson(root->body.data[1]) == expected2); } diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 769637a5..c53fe731 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -6,6 +6,8 @@ #include "doctest.h" #include "Fixture.h" +LUAU_FASTFLAG(LuauFixBindingForGlobalPos); + using namespace Luau; struct DocumentationSymbolFixture : BuiltinsFixture @@ -331,4 +333,16 @@ TEST_CASE_FIXTURE(Fixture, "find_expr_ancestry") CHECK(ancestry.back()->is()); } +TEST_CASE_FIXTURE(BuiltinsFixture, "find_binding_at_position_global_start_of_file") +{ + ScopedFastFlag sff{FFlag::LuauFixBindingForGlobalPos, true}; + check("local x = string.char(1)"); + const Position pos(0, 12); + + std::optional binding = findBindingAtPosition(*getMainModule(), *getMainSourceModule(), pos); + + REQUIRE(binding); + CHECK_EQ(binding->location, Location{Position{0, 0}, Position{0, 0}}); +} + TEST_SUITE_END(); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index d0d4e9be..4e8a0442 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -35,6 +35,7 @@ struct ACFixtureImpl : BaseType { FrontendOptions opts; opts.forAutocomplete = true; + opts.retainFullTypeGraphs = true; this->frontend.check("MainModule", opts); return Luau::autocomplete(this->frontend, "MainModule", Position{row, column}, nullCallback); @@ -44,6 +45,7 @@ struct ACFixtureImpl : BaseType { FrontendOptions opts; opts.forAutocomplete = true; + opts.retainFullTypeGraphs = true; this->frontend.check("MainModule", opts); return Luau::autocomplete(this->frontend, "MainModule", getPosition(marker), callback); @@ -53,6 +55,7 @@ struct ACFixtureImpl : BaseType { FrontendOptions opts; opts.forAutocomplete = true; + opts.retainFullTypeGraphs = true; this->frontend.check(name, opts); return Luau::autocomplete(this->frontend, name, pos, callback); @@ -3272,9 +3275,9 @@ TEST_CASE_FIXTURE(ACFixture, "string_singleton_in_if_statement") // https://github.com/Roblox/luau/issues/858 TEST_CASE_FIXTURE(ACFixture, "string_singleton_in_if_statement2") { - ScopedFastFlag sff[]{ - {FFlag::DebugLuauDeferredConstraintResolution, true}, - }; + // don't run this when the DCR flag isn't set + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; check(R"( --!strict @@ -3681,6 +3684,8 @@ a.@1 auto ac = autocomplete('1'); + CHECK(2 == ac.entryMap.size()); + CHECK(ac.entryMap.count("x")); CHECK(ac.entryMap.count("y")); @@ -3733,11 +3738,13 @@ TEST_CASE_FIXTURE(ACFixture, "string_contents_is_available_to_callback") declare function require(path: string): any )"); - std::optional require = frontend.globalsForAutocomplete.globalScope->linearSearchForBinding("require"); + GlobalTypes& globals = FFlag::DebugLuauDeferredConstraintResolution ? frontend.globals : frontend.globalsForAutocomplete; + + std::optional require = globals.globalScope->linearSearchForBinding("require"); REQUIRE(require); - Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); + Luau::unfreeze(globals.globalTypes); attachTag(require->typeId, "RequireCall"); - Luau::freeze(frontend.globalsForAutocomplete.globalTypes); + Luau::freeze(globals.globalTypes); check(R"( local x = require("testing/@1") @@ -3837,11 +3844,13 @@ TEST_CASE_FIXTURE(ACFixture, "string_completion_outside_quotes") declare function require(path: string): any )"); - std::optional require = frontend.globalsForAutocomplete.globalScope->linearSearchForBinding("require"); + GlobalTypes& globals = FFlag::DebugLuauDeferredConstraintResolution ? frontend.globals : frontend.globalsForAutocomplete; + + std::optional require = globals.globalScope->linearSearchForBinding("require"); REQUIRE(require); - Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); + Luau::unfreeze(globals.globalTypes); attachTag(require->typeId, "RequireCall"); - Luau::freeze(frontend.globalsForAutocomplete.globalTypes); + Luau::freeze(globals.globalTypes); check(R"( local x = require(@1"@2"@3) diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 9d65a5a7..21228d6b 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -191,7 +191,7 @@ TEST_CASE("WindowsUnwindCodesX64") unwind.finishInfo(); std::vector data; - data.resize(unwind.getSize()); + data.resize(unwind.getUnwindInfoSize()); unwind.finalize(data.data(), 0, nullptr, 0); std::vector expected{0x44, 0x33, 0x22, 0x11, 0x22, 0x33, 0x44, 0x55, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x17, 0x0a, 0x05, 0x17, 0x82, 0x13, @@ -215,7 +215,7 @@ TEST_CASE("Dwarf2UnwindCodesX64") unwind.finishInfo(); std::vector data; - data.resize(unwind.getSize()); + data.resize(unwind.getUnwindInfoSize()); unwind.finalize(data.data(), 0, nullptr, 0); std::vector expected{0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x10, 0x0c, 0x07, 0x08, 0x90, 0x01, 0x00, @@ -241,7 +241,7 @@ TEST_CASE("Dwarf2UnwindCodesA64") unwind.finishInfo(); std::vector data; - data.resize(unwind.getSize()); + data.resize(unwind.getUnwindInfoSize()); unwind.finalize(data.data(), 0, nullptr, 0); std::vector expected{0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x1e, 0x0c, 0x1f, 0x00, 0x2c, 0x00, 0x00, @@ -253,7 +253,7 @@ TEST_CASE("Dwarf2UnwindCodesA64") CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); } -#if defined(__x86_64__) || defined(_M_X64) +#if defined(CODEGEN_TARGET_X64) #if defined(_WIN32) // Windows x64 ABI @@ -774,7 +774,7 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") #endif -#if defined(__aarch64__) +#if defined(CODEGEN_TARGET_A64) TEST_CASE("GeneratedCodeExecutionA64") { diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index e8927837..250de6e4 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -22,8 +22,8 @@ LUAU_FASTINT(LuauCompileLoopUnrollThreshold) LUAU_FASTINT(LuauCompileLoopUnrollThresholdMaxBoost) LUAU_FASTINT(LuauRecursionLimit) -LUAU_FASTFLAG(LuauCompileNoJumpLineRetarget) -LUAU_FASTFLAG(LuauCompileRepeatUntilSkippedLocals) +LUAU_FASTFLAG(LuauCompileUserdataInfo) +LUAU_FASTFLAG(LuauCompileFastcall3) using namespace Luau; @@ -2106,8 +2106,6 @@ RETURN R0 0 TEST_CASE("LoopContinueEarlyCleanup") { - ScopedFastFlag luauCompileRepeatUntilSkippedLocals{FFlag::LuauCompileRepeatUntilSkippedLocals, true}; - // locals after a potential 'continue' are not accessible inside the condition and can be closed at the end of a block CHECK_EQ("\n" + compileFunction(R"( local y @@ -2788,8 +2786,6 @@ end TEST_CASE("DebugLineInfoWhile") { - ScopedFastFlag luauCompileNoJumpLineRetarget{FFlag::LuauCompileNoJumpLineRetarget, true}; - Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); Luau::compileOrThrow(bcb, R"( @@ -3136,8 +3132,6 @@ local 8: reg 3, start pc 35 line 21, end pc 35 line 21 TEST_CASE("DebugLocals2") { - ScopedFastFlag luauCompileRepeatUntilSkippedLocals{FFlag::LuauCompileRepeatUntilSkippedLocals, true}; - const char* source = R"( function foo(x) repeat @@ -3167,9 +3161,6 @@ local 2: reg 0, start pc 0 line 4, end pc 2 line 6 TEST_CASE("DebugLocals3") { - ScopedFastFlag luauCompileRepeatUntilSkippedLocals{FFlag::LuauCompileRepeatUntilSkippedLocals, true}; - ScopedFastFlag luauCompileNoJumpLineRetarget{FFlag::LuauCompileNoJumpLineRetarget, true}; - const char* source = R"( function foo(x) repeat @@ -3203,6 +3194,7 @@ local 4: reg 0, start pc 0 line 4, end pc 5 line 8 8: RETURN R0 0 )"); } + TEST_CASE("DebugRemarks") { Luau::BytecodeBuilder bcb; @@ -3230,6 +3222,78 @@ RETURN R0 0 )"); } +TEST_CASE("DebugTypes") +{ + ScopedFastFlag luauCompileUserdataInfo{FFlag::LuauCompileUserdataInfo, true}; + + const char* source = R"( +local up: number = 2 + +function foo(e: vector, f: mat3, g: sequence) + local h = e * e + + for i=1,3 do + print(i) + end + + print(e * f) + print(g) + print(h) + + up += a + return a +end +)"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Types); + bcb.setDumpSource(source); + + Luau::CompileOptions options; + options.vectorCtor = "vector"; + options.vectorType = "vector"; + + options.typeInfoLevel = 1; + + static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr}; + options.userdataTypes = kUserdataCompileTypes; + + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +R0: vector [argument] +R1: mat3 [argument] +R2: userdata [argument] +U0: number +R6: any from 1 to 9 +R3: vector from 0 to 30 +MUL R3 R0 R0 +LOADN R6 1 +LOADN R4 3 +LOADN R5 1 +FORNPREP R4 L1 +L0: GETIMPORT R7 1 [print] +MOVE R8 R6 +CALL R7 1 0 +FORNLOOP R4 L0 +L1: GETIMPORT R4 1 [print] +MUL R5 R0 R1 +CALL R4 1 0 +GETIMPORT R4 1 [print] +MOVE R5 R2 +CALL R4 1 0 +GETIMPORT R4 1 [print] +MOVE R5 R3 +CALL R4 1 0 +GETUPVAL R4 0 +GETIMPORT R5 3 [a] +ADD R4 R4 R5 +SETUPVAL R4 0 +GETIMPORT R4 3 [a] +RETURN R4 1 +)"); +} + TEST_CASE("SourceRemarks") { const char* source = R"( @@ -3419,6 +3483,33 @@ RETURN R1 -1 )"); } +TEST_CASE("Fastcall3") +{ + ScopedFastFlag luauCompileFastcall3{FFlag::LuauCompileFastcall3, true}; + + CHECK_EQ("\n" + compileFunction0(R"( +local a, b, c = ... +return math.min(a, b, c) + math.clamp(a, b, c) +)"), + R"( +GETVARARGS R0 3 +FASTCALL3 19 R0 R1 R2 L0 +MOVE R5 R0 +MOVE R6 R1 +MOVE R7 R2 +GETIMPORT R4 2 [math.min] +CALL R4 3 1 +L0: FASTCALL3 46 R0 R1 R2 L1 +MOVE R6 R0 +MOVE R7 R1 +MOVE R8 R2 +GETIMPORT R5 4 [math.clamp] +CALL R5 3 1 +L1: ADD R3 R4 R5 +RETURN R3 1 +)"); +} + TEST_CASE("FastcallSelect") { // select(_, ...) compiles to a builtin call @@ -4158,8 +4249,6 @@ RETURN R0 0 TEST_CASE("Coverage") { - ScopedFastFlag luauCompileNoJumpLineRetarget{FFlag::LuauCompileNoJumpLineRetarget, true}; - // basic statement coverage CHECK_EQ("\n" + compileFunction0Coverage(R"( print(1) @@ -4603,6 +4692,34 @@ L0: RETURN R0 -1 )"); } +TEST_CASE("VectorFastCall3") +{ + ScopedFastFlag luauCompileFastcall3{FFlag::LuauCompileFastcall3, true}; + + const char* source = R"( +local a, b, c = ... +return Vector3.new(a, b, c) +)"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::CompileOptions options; + options.vectorLib = "Vector3"; + options.vectorCtor = "new"; + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +GETVARARGS R0 3 +FASTCALL3 54 R0 R1 R2 L0 +MOVE R4 R0 +MOVE R5 R1 +MOVE R6 R2 +GETIMPORT R3 2 [Vector3.new] +CALL R3 3 -1 +L0: RETURN R3 -1 +)"); +} + TEST_CASE("VectorLiterals") { CHECK_EQ("\n" + compileFunction("return Vector3.new(1, 2, 3)", 0, 2, /*enableVectors*/ true), R"( diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 9333cb19..65af4e4d 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -16,6 +16,7 @@ #include "doctest.h" #include "ScopedFlags.h" +#include "ConformanceIrHooks.h" #include #include @@ -32,8 +33,8 @@ void luaC_validate(lua_State* L); LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) -LUAU_FASTFLAG(LuauCompileRepeatUntilSkippedLocals) -LUAU_DYNAMIC_FASTFLAG(LuauFastCrossTableMove) +LUAU_FASTFLAG(LuauAttributeSyntax) +LUAU_FASTFLAG(LuauNativeAttribute) static lua_CompileOptions defaultOptions() { @@ -48,6 +49,13 @@ static lua_CompileOptions defaultOptions() return copts; } +static Luau::CodeGen::CompilationOptions defaultCodegenOptions() +{ + Luau::CodeGen::CompilationOptions opts = {}; + opts.flags = Luau::CodeGen::CodeGen_ColdFunctions; + return opts; +} + static int lua_collectgarbage(lua_State* L) { static const char* const opts[] = {"stop", "restart", "collect", "count", "isrunning", "step", "setgoal", "setstepmul", "setstepsize", nullptr}; @@ -118,6 +126,15 @@ static int lua_vector_dot(lua_State* L) return 1; } +static int lua_vector_cross(lua_State* L) +{ + const float* a = luaL_checkvector(L, 1); + const float* b = luaL_checkvector(L, 2); + + lua_pushvector(L, a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0]); + return 1; +} + static int lua_vector_index(lua_State* L) { const float* v = luaL_checkvector(L, 1); @@ -129,6 +146,14 @@ static int lua_vector_index(lua_State* L) return 1; } + if (strcmp(name, "Unit") == 0) + { + float invSqrt = 1.0f / sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]); + + lua_pushvector(L, v[0] * invSqrt, v[1] * invSqrt, v[2] * invSqrt); + return 1; + } + if (strcmp(name, "Dot") == 0) { lua_pushcfunction(L, lua_vector_dot, "Dot"); @@ -144,6 +169,9 @@ static int lua_vector_namecall(lua_State* L) { if (strcmp(str, "Dot") == 0) return lua_vector_dot(L); + + if (strcmp(str, "Cross") == 0) + return lua_vector_cross(L); } luaL_error(L, "%s is not a valid method of vector", luaL_checkstring(L, 1)); @@ -157,7 +185,8 @@ int lua_silence(lua_State* L) using StateRef = std::unique_ptr; static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = nullptr, void (*yield)(lua_State* L) = nullptr, - lua_State* initialLuaState = nullptr, lua_CompileOptions* options = nullptr, bool skipCodegen = false) + lua_State* initialLuaState = nullptr, lua_CompileOptions* options = nullptr, bool skipCodegen = false, + Luau::CodeGen::CompilationOptions* codegenOptions = nullptr) { #ifdef LUAU_CONFORMANCE_SOURCE_DIR std::string path = LUAU_CONFORMANCE_SOURCE_DIR; @@ -238,7 +267,11 @@ static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = n free(bytecode); if (result == 0 && codegen && !skipCodegen && luau_codegen_supported()) - Luau::CodeGen::compile(L, -1, Luau::CodeGen::CodeGen_ColdFunctions); + { + Luau::CodeGen::CompilationOptions nativeOpts = codegenOptions ? *codegenOptions : defaultCodegenOptions(); + + Luau::CodeGen::compile(L, -1, nativeOpts); + } int status = (result == 0) ? lua_resume(L, nullptr, 0) : LUA_ERRSYNTAX; @@ -310,6 +343,209 @@ void setupVectorHelpers(lua_State* L) lua_pop(L, 1); } +Vec2* lua_vec2_push(lua_State* L) +{ + Vec2* data = (Vec2*)lua_newuserdatatagged(L, sizeof(Vec2), kTagVec2); + + lua_getuserdatametatable(L, kTagVec2); + lua_setmetatable(L, -2); + + return data; +} + +Vec2* lua_vec2_get(lua_State* L, int idx) +{ + Vec2* a = (Vec2*)lua_touserdatatagged(L, idx, kTagVec2); + + if (a) + return a; + + luaL_typeerror(L, idx, "vec2"); +} + +static int lua_vec2(lua_State* L) +{ + double x = luaL_checknumber(L, 1); + double y = luaL_checknumber(L, 2); + + Vec2* data = lua_vec2_push(L); + + data->x = float(x); + data->y = float(y); + + return 1; +} + +static int lua_vec2_dot(lua_State* L) +{ + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + + lua_pushnumber(L, a->x * b->x + a->y * b->y); + return 1; +} + +static int lua_vec2_min(lua_State* L) +{ + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + + Vec2* data = lua_vec2_push(L); + + data->x = a->x < b->x ? a->x : b->x; + data->y = a->y < b->y ? a->y : b->y; + + return 1; +} + +static int lua_vec2_index(lua_State* L) +{ + Vec2* v = lua_vec2_get(L, 1); + const char* name = luaL_checkstring(L, 2); + + if (strcmp(name, "X") == 0) + { + lua_pushnumber(L, v->x); + return 1; + } + + if (strcmp(name, "Y") == 0) + { + lua_pushnumber(L, v->y); + return 1; + } + + if (strcmp(name, "Magnitude") == 0) + { + lua_pushnumber(L, sqrtf(v->x * v->x + v->y * v->y)); + return 1; + } + + if (strcmp(name, "Unit") == 0) + { + float invSqrt = 1.0f / sqrtf(v->x * v->x + v->y * v->y); + + Vec2* data = lua_vec2_push(L); + + data->x = v->x * invSqrt; + data->y = v->y * invSqrt; + return 1; + } + + luaL_error(L, "%s is not a valid member of vector", name); +} + +static int lua_vec2_namecall(lua_State* L) +{ + if (const char* str = lua_namecallatom(L, nullptr)) + { + if (strcmp(str, "Dot") == 0) + return lua_vec2_dot(L); + + if (strcmp(str, "Min") == 0) + return lua_vec2_min(L); + } + + luaL_error(L, "%s is not a valid method of vector", luaL_checkstring(L, 1)); +} + +void setupUserdataHelpers(lua_State* L) +{ + // create metatable with all the metamethods + luaL_newmetatable(L, "vec2"); + luaL_getmetatable(L, "vec2"); + lua_pushvalue(L, -1); + lua_setuserdatametatable(L, kTagVec2, -1); + + lua_pushcfunction(L, lua_vec2_index, nullptr); + lua_setfield(L, -2, "__index"); + + lua_pushcfunction(L, lua_vec2_namecall, nullptr); + lua_setfield(L, -2, "__namecall"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x + b->x; + data->y = a->y + b->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__add"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x - b->x; + data->y = a->y - b->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__sub"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x * b->x; + data->y = a->y * b->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__mul"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* b = lua_vec2_get(L, 2); + Vec2* data = lua_vec2_push(L); + + data->x = a->x / b->x; + data->y = a->y / b->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__div"); + + lua_pushcclosurek( + L, + [](lua_State* L) { + Vec2* a = lua_vec2_get(L, 1); + Vec2* data = lua_vec2_push(L); + + data->x = -a->x; + data->y = -a->y; + + return 1; + }, + nullptr, 0, nullptr); + lua_setfield(L, -2, "__unm"); + + lua_setreadonly(L, -1, true); + + // ctor + lua_pushcfunction(L, lua_vec2, "vec2"); + lua_setglobal(L, "vec2"); + + lua_pop(L, 1); +} + static void setupNativeHelpers(lua_State* L) { lua_pushcclosurek( @@ -410,8 +646,6 @@ TEST_CASE("Sort") TEST_CASE("Move") { - ScopedFastFlag luauFastCrossTableMove{DFFlag::LuauFastCrossTableMove, true}; - runConformance("move.lua"); } @@ -533,12 +767,51 @@ TEST_CASE("Pack") TEST_CASE("Vector") { + lua_CompileOptions copts = defaultOptions(); + Luau::CodeGen::CompilationOptions nativeOpts = defaultCodegenOptions(); + + SUBCASE("NoIrHooks") + { + SUBCASE("O0") + { + copts.optimizationLevel = 0; + } + SUBCASE("O1") + { + copts.optimizationLevel = 1; + } + SUBCASE("O2") + { + copts.optimizationLevel = 2; + } + } + SUBCASE("IrHooks") + { + nativeOpts.hooks.vectorAccessBytecodeType = vectorAccessBytecodeType; + nativeOpts.hooks.vectorNamecallBytecodeType = vectorNamecallBytecodeType; + nativeOpts.hooks.vectorAccess = vectorAccess; + nativeOpts.hooks.vectorNamecall = vectorNamecall; + + SUBCASE("O0") + { + copts.optimizationLevel = 0; + } + SUBCASE("O1") + { + copts.optimizationLevel = 1; + } + SUBCASE("O2") + { + copts.optimizationLevel = 2; + } + } + runConformance( "vector.lua", [](lua_State* L) { setupVectorHelpers(L); }, - nullptr, nullptr, nullptr); + nullptr, nullptr, &copts, false, &nativeOpts); } static void populateRTTI(lua_State* L, Luau::TypeId type) @@ -645,8 +918,6 @@ TEST_CASE("Debugger") static bool singlestep = false; static int stephits = 0; - ScopedFastFlag luauCompileRepeatUntilSkippedLocals{FFlag::LuauCompileRepeatUntilSkippedLocals, true}; - SUBCASE("") { singlestep = false; @@ -1761,16 +2032,36 @@ TEST_CASE("UserdataApi") luaL_newmetatable(L, "udata2"); void* ud5 = lua_newuserdata(L, 0); - lua_getfield(L, LUA_REGISTRYINDEX, "udata1"); + luaL_getmetatable(L, "udata1"); lua_setmetatable(L, -2); void* ud6 = lua_newuserdata(L, 0); - lua_getfield(L, LUA_REGISTRYINDEX, "udata2"); + luaL_getmetatable(L, "udata2"); lua_setmetatable(L, -2); CHECK(luaL_checkudata(L, -2, "udata1") == ud5); CHECK(luaL_checkudata(L, -1, "udata2") == ud6); + // tagged user data with fast metatable access + luaL_newmetatable(L, "udata3"); + luaL_getmetatable(L, "udata3"); + lua_setuserdatametatable(L, 50, -1); + + luaL_newmetatable(L, "udata4"); + luaL_getmetatable(L, "udata4"); + lua_setuserdatametatable(L, 51, -1); + + void* ud7 = lua_newuserdatatagged(L, 16, 50); + lua_getuserdatametatable(L, 50); + lua_setmetatable(L, -2); + + void* ud8 = lua_newuserdatatagged(L, 16, 51); + lua_getuserdatametatable(L, 51); + lua_setmetatable(L, -2); + + CHECK(luaL_checkudata(L, -2, "udata3") == ud7); + CHECK(luaL_checkudata(L, -1, "udata4") == ud8); + globalState.reset(); CHECK(dtorhits == 42); @@ -1844,7 +2135,6 @@ TEST_CASE("Iter") } const int kInt64Tag = 1; -static int gInt64MT = -1; static int64_t getInt64(lua_State* L, int idx) { @@ -1861,7 +2151,7 @@ static void pushInt64(lua_State* L, int64_t value) { void* p = lua_newuserdatatagged(L, sizeof(int64_t), kInt64Tag); - lua_getref(L, gInt64MT); + luaL_getmetatable(L, "int64"); lua_setmetatable(L, -2); *static_cast(p) = value; @@ -1871,8 +2161,7 @@ TEST_CASE("Userdata") { runConformance("userdata.lua", [](lua_State* L) { // create metatable with all the metamethods - lua_newtable(L); - gInt64MT = lua_ref(L, -1); + luaL_newmetatable(L, "int64"); // __index lua_pushcfunction( @@ -2095,6 +2384,86 @@ TEST_CASE("NativeTypeAnnotations") }); } +TEST_CASE("NativeUserdata") +{ + lua_CompileOptions copts = defaultOptions(); + Luau::CodeGen::CompilationOptions nativeOpts = defaultCodegenOptions(); + + static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr}; + copts.userdataTypes = kUserdataCompileTypes; + + SUBCASE("NoIrHooks") + { + SUBCASE("O0") + { + copts.optimizationLevel = 0; + } + SUBCASE("O1") + { + copts.optimizationLevel = 1; + } + SUBCASE("O2") + { + copts.optimizationLevel = 2; + } + } + SUBCASE("IrHooks") + { + nativeOpts.hooks.vectorAccessBytecodeType = vectorAccessBytecodeType; + nativeOpts.hooks.vectorNamecallBytecodeType = vectorNamecallBytecodeType; + nativeOpts.hooks.vectorAccess = vectorAccess; + nativeOpts.hooks.vectorNamecall = vectorNamecall; + + nativeOpts.hooks.userdataAccessBytecodeType = userdataAccessBytecodeType; + nativeOpts.hooks.userdataMetamethodBytecodeType = userdataMetamethodBytecodeType; + nativeOpts.hooks.userdataNamecallBytecodeType = userdataNamecallBytecodeType; + nativeOpts.hooks.userdataAccess = userdataAccess; + nativeOpts.hooks.userdataMetamethod = userdataMetamethod; + nativeOpts.hooks.userdataNamecall = userdataNamecall; + + nativeOpts.userdataTypes = kUserdataRunTypes; + + SUBCASE("O0") + { + copts.optimizationLevel = 0; + } + SUBCASE("O1") + { + copts.optimizationLevel = 1; + } + SUBCASE("O2") + { + copts.optimizationLevel = 2; + } + } + + runConformance( + "native_userdata.lua", + [](lua_State* L) { + Luau::CodeGen::setUserdataRemapper(L, kUserdataRunTypes, [](void* context, const char* str, size_t len) -> uint8_t { + const char** types = (const char**)context; + + uint8_t index = 0; + + std::string_view sv{str, len}; + + for (; *types; ++types) + { + if (sv == *types) + return index; + + index++; + } + + return 0xff; + }); + + setupVectorHelpers(L); + setupUserdataHelpers(L); + }, + nullptr, nullptr, &copts, false, &nativeOpts); +} + [[nodiscard]] static std::string makeHugeFunctionSource() { std::string source; @@ -2141,7 +2510,10 @@ TEST_CASE("HugeFunction") REQUIRE(result == 0); if (codegen && luau_codegen_supported()) - Luau::CodeGen::compile(L, -1, Luau::CodeGen::CodeGen_ColdFunctions); + { + Luau::CodeGen::CompilationOptions nativeOptions{Luau::CodeGen::CodeGen_ColdFunctions}; + Luau::CodeGen::compile(L, -1, nativeOptions); + } int status = lua_resume(L, nullptr, 0); REQUIRE(status == 0); @@ -2263,8 +2635,9 @@ TEST_CASE("IrInstructionLimit") REQUIRE(result == 0); + Luau::CodeGen::CompilationOptions nativeOptions{Luau::CodeGen::CodeGen_ColdFunctions}; Luau::CodeGen::CompilationStats nativeStats = {}; - Luau::CodeGen::CompilationResult nativeResult = Luau::CodeGen::compile(L, -1, Luau::CodeGen::CodeGen_ColdFunctions, &nativeStats); + Luau::CodeGen::CompilationResult nativeResult = Luau::CodeGen::compile(L, -1, nativeOptions, &nativeStats); // Limit is not hit immediately, so with some functions compiled it should be a success CHECK(nativeResult.result == Luau::CodeGen::CodeGenCompilationResult::Success); @@ -2333,4 +2706,57 @@ end 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); } +TEST_CASE("NativeAttribute") +{ + if (!codegen || !luau_codegen_supported()) + return; + + ScopedFastFlag sffs[] = {{FFlag::LuauAttributeSyntax, true}, {FFlag::LuauNativeAttribute, true}}; + + std::string source = R"R( + @native + local function sum(x, y) + local function sumHelper(z) + return (x+y+z) + end + return sumHelper + end + + local function sub(x, y) + @native + local function subHelper(z) + return (x+y-z) + end + return subHelper + end)R"; + + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + luau_codegen_create(L); + + luaL_openlibs(L); + luaL_sandbox(L); + luaL_sandboxthread(L); + + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source.data(), source.size(), nullptr, &bytecodeSize); + int result = luau_load(L, "=Code", bytecode, bytecodeSize, 0); + free(bytecode); + + REQUIRE(result == 0); + + Luau::CodeGen::CompilationOptions nativeOptions{Luau::CodeGen::CodeGen_ColdFunctions}; + Luau::CodeGen::CompilationStats nativeStats = {}; + Luau::CodeGen::CompilationResult nativeResult = Luau::CodeGen::compile(L, -1, nativeOptions, &nativeStats); + + CHECK(nativeResult.result == Luau::CodeGen::CodeGenCompilationResult::Success); + + CHECK(!nativeResult.hasErrors()); + REQUIRE(nativeResult.protoFailures.empty()); + + // We should be able to compile at least one of our functions + CHECK_EQ(nativeStats.functionsCompiled, 2); +} + TEST_SUITE_END(); diff --git a/tests/ConformanceIrHooks.h b/tests/ConformanceIrHooks.h new file mode 100644 index 00000000..ab5b86d4 --- /dev/null +++ b/tests/ConformanceIrHooks.h @@ -0,0 +1,542 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/IrBuilder.h" + +static const char* kUserdataRunTypes[] = {"extra", "color", "vec2", "mat3", nullptr}; + +constexpr uint8_t kUserdataExtra = 0; +constexpr uint8_t kUserdataColor = 1; +constexpr uint8_t kUserdataVec2 = 2; +constexpr uint8_t kUserdataMat3 = 3; + +// Userdata tags can be different from userdata bytecode type indices +constexpr uint8_t kTagVec2 = 12; + +struct Vec2 +{ + float x; + float y; +}; + +inline bool compareMemberName(const char* member, size_t memberLength, const char* str) +{ + return memberLength == strlen(str) && strcmp(member, str) == 0; +} + +inline uint8_t typeToUserdataIndex(uint8_t type) +{ + // Underflow will push the type into a value that is not comparable to any kUserdata* constants + return type - LBC_TYPE_TAGGED_USERDATA_BASE; +} + +inline uint8_t userdataIndexToType(uint8_t userdataIndex) +{ + return LBC_TYPE_TAGGED_USERDATA_BASE + userdataIndex; +} + +inline uint8_t vectorAccessBytecodeType(const char* member, size_t memberLength) +{ + using namespace Luau::CodeGen; + + if (compareMemberName(member, memberLength, "Magnitude")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Unit")) + return LBC_TYPE_VECTOR; + + return LBC_TYPE_ANY; +} + +inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos) +{ + using namespace Luau::CodeGen; + + if (compareMemberName(member, memberLength, "Magnitude")) + { + IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); + IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); + IrOp z = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(8)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + + IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), mag); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER)); + + return true; + } + + if (compareMemberName(member, memberLength, "Unit")) + { + IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); + IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); + IrOp z = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(8)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + + IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + + IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv); + IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv); + IrOp zr = build.inst(IrCmd::MUL_NUM, z, inv); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(resultReg), xr, yr, zr); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TVECTOR)); + + return true; + } + + return false; +} + +inline uint8_t vectorNamecallBytecodeType(const char* member, size_t memberLength) +{ + if (compareMemberName(member, memberLength, "Dot")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Cross")) + return LBC_TYPE_VECTOR; + + return LBC_TYPE_ANY; +} + +inline bool vectorNamecall( + Luau::CodeGen::IrBuilder& build, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos) +{ + using namespace Luau::CodeGen; + + if (compareMemberName(member, memberLength, "Dot") && params == 2 && results <= 1) + { + build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); + IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(0)); + IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); + IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(4)); + IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2); + + IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(8)); + IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(8)); + IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2); + + IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(argResReg), sum); + build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TNUMBER)); + + // If the function is called in multi-return context, stack has to be adjusted + if (results == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1)); + + return true; + } + + if (compareMemberName(member, memberLength, "Cross") && params == 2 && results <= 1) + { + build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0)); + IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(0)); + + IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4)); + IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(4)); + + IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(8)); + IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(8)); + + IrOp y1z2 = build.inst(IrCmd::MUL_NUM, y1, z2); + IrOp z1y2 = build.inst(IrCmd::MUL_NUM, z1, y2); + IrOp xr = build.inst(IrCmd::SUB_NUM, y1z2, z1y2); + + IrOp z1x2 = build.inst(IrCmd::MUL_NUM, z1, x2); + IrOp x1z2 = build.inst(IrCmd::MUL_NUM, x1, z2); + IrOp yr = build.inst(IrCmd::SUB_NUM, z1x2, x1z2); + + IrOp x1y2 = build.inst(IrCmd::MUL_NUM, x1, y2); + IrOp y1x2 = build.inst(IrCmd::MUL_NUM, y1, x2); + IrOp zr = build.inst(IrCmd::SUB_NUM, x1y2, y1x2); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(argResReg), xr, yr, zr); + build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TVECTOR)); + + // If the function is called in multi-return context, stack has to be adjusted + if (results == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1)); + + return true; + } + + return false; +} + +inline uint8_t userdataAccessBytecodeType(uint8_t type, const char* member, size_t memberLength) +{ + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + if (compareMemberName(member, memberLength, "R")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "G")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "B")) + return LBC_TYPE_NUMBER; + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "X")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Y")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Magnitude")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Unit")) + return userdataIndexToType(kUserdataVec2); + break; + case kUserdataMat3: + if (compareMemberName(member, memberLength, "Row1")) + return LBC_TYPE_VECTOR; + + if (compareMemberName(member, memberLength, "Row2")) + return LBC_TYPE_VECTOR; + + if (compareMemberName(member, memberLength, "Row3")) + return LBC_TYPE_VECTOR; + break; + } + + return LBC_TYPE_ANY; +} + +inline bool userdataAccess( + Luau::CodeGen::IrBuilder& build, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos) +{ + using namespace Luau::CodeGen; + + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "X")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp value = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), value); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER)); + return true; + } + + if (compareMemberName(member, memberLength, "Y")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp value = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), value); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER)); + return true; + } + + if (compareMemberName(member, memberLength, "Magnitude")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp y = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + + IrOp sum = build.inst(IrCmd::ADD_NUM, x2, y2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), mag); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER)); + return true; + } + + if (compareMemberName(member, memberLength, "Unit")) + { + IrOp udata = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp y = build.inst(IrCmd::BUFFER_READF32, udata, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + + IrOp sum = build.inst(IrCmd::ADD_NUM, x2, y2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + + IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv); + IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), xr, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), yr, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + return true; + } + break; + case kUserdataMat3: + break; + } + + return false; +} + +inline uint8_t userdataMetamethodBytecodeType(uint8_t lhsTy, uint8_t rhsTy, Luau::CodeGen::HostMetamethod method) +{ + switch (method) + { + case Luau::CodeGen::HostMetamethod::Add: + case Luau::CodeGen::HostMetamethod::Sub: + case Luau::CodeGen::HostMetamethod::Mul: + case Luau::CodeGen::HostMetamethod::Div: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2 || typeToUserdataIndex(rhsTy) == kUserdataVec2) + return userdataIndexToType(kUserdataVec2); + break; + case Luau::CodeGen::HostMetamethod::Minus: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2) + return userdataIndexToType(kUserdataVec2); + break; + default: + break; + } + + return LBC_TYPE_ANY; +} + +inline bool userdataMetamethod(Luau::CodeGen::IrBuilder& build, uint8_t lhsTy, uint8_t rhsTy, int resultReg, Luau::CodeGen::IrOp lhs, + Luau::CodeGen::IrOp rhs, Luau::CodeGen::HostMetamethod method, int pcpos) +{ + using namespace Luau::CodeGen; + + switch (method) + { + case Luau::CodeGen::HostMetamethod::Add: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2 && typeToUserdataIndex(rhsTy) == kUserdataVec2) + { + build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos)); + build.loadAndCheckTag(rhs, LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, rhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::ADD_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp my = build.inst(IrCmd::ADD_NUM, y1, y2); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + + return true; + } + break; + case Luau::CodeGen::HostMetamethod::Mul: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2 && typeToUserdataIndex(rhsTy) == kUserdataVec2) + { + build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos)); + build.loadAndCheckTag(rhs, LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, rhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::MUL_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp my = build.inst(IrCmd::MUL_NUM, y1, y2); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + + return true; + } + break; + case Luau::CodeGen::HostMetamethod::Minus: + if (typeToUserdataIndex(lhsTy) == kUserdataVec2) + { + build.loadAndCheckTag(lhs, LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, lhs); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp y = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::UNM_NUM, x); + IrOp my = build.inst(IrCmd::UNM_NUM, y); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(resultReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TUSERDATA)); + + return true; + } + break; + default: + break; + } + + return false; +} + +inline uint8_t userdataNamecallBytecodeType(uint8_t type, const char* member, size_t memberLength) +{ + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "Dot")) + return LBC_TYPE_NUMBER; + + if (compareMemberName(member, memberLength, "Min")) + return userdataIndexToType(kUserdataVec2); + break; + case kUserdataMat3: + break; + } + + return LBC_TYPE_ANY; +} + +inline bool userdataNamecall(Luau::CodeGen::IrBuilder& build, uint8_t type, const char* member, size_t memberLength, int argResReg, int sourceReg, + int params, int results, int pcpos) +{ + using namespace Luau::CodeGen; + + switch (typeToUserdataIndex(type)) + { + case kUserdataColor: + break; + case kUserdataVec2: + if (compareMemberName(member, memberLength, "Dot")) + { + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(argResReg + 2)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2); + + IrOp sum = build.inst(IrCmd::ADD_NUM, xx, yy); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(argResReg), sum); + build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TNUMBER)); + + // If the function is called in multi-return context, stack has to be adjusted + if (results == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1)); + + return true; + } + + if (compareMemberName(member, memberLength, "Min")) + { + IrOp udata1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(sourceReg)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata1, build.constInt(kTagVec2), build.vmExit(pcpos)); + + build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TUSERDATA, build.vmExit(pcpos)); + + IrOp udata2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(argResReg + 2)); + build.inst(IrCmd::CHECK_USERDATA_TAG, udata2, build.constInt(kTagVec2), build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp x2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, x)), build.constTag(LUA_TUSERDATA)); + IrOp mx = build.inst(IrCmd::MIN_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::BUFFER_READF32, udata1, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp y2 = build.inst(IrCmd::BUFFER_READF32, udata2, build.constInt(offsetof(Vec2, y)), build.constTag(LUA_TUSERDATA)); + IrOp my = build.inst(IrCmd::MIN_NUM, y1, y2); + + build.inst(IrCmd::CHECK_GC); + IrOp udatar = build.inst(IrCmd::NEW_USERDATA, build.constInt(sizeof(Vec2)), build.constInt(kTagVec2)); + + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, x)), mx, build.constTag(LUA_TUSERDATA)); + build.inst(IrCmd::BUFFER_WRITEF32, udatar, build.constInt(offsetof(Vec2, y)), my, build.constTag(LUA_TUSERDATA)); + + build.inst(IrCmd::STORE_POINTER, build.vmReg(argResReg), udatar); + build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TUSERDATA)); + + // If the function is called in multi-return context, stack has to be adjusted + if (results == LUA_MULTRET) + build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1)); + + return true; + } + break; + case kUserdataMat3: + break; + } + + return false; +} diff --git a/tests/Error.test.cpp b/tests/Error.test.cpp index 677e3217..00a5a2e7 100644 --- a/tests/Error.test.cpp +++ b/tests/Error.test.cpp @@ -6,6 +6,8 @@ using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + TEST_SUITE_BEGIN("ErrorTests"); TEST_CASE("TypeError_code_should_return_nonzero_code") @@ -34,4 +36,46 @@ local x: Account = 5 CHECK_EQ("Type 'number' could not be converted into 'Account'", toString(result.errors[0])); } +TEST_CASE_FIXTURE(BuiltinsFixture, "binary_op_type_family_errors") +{ + frontend.options.retainFullTypeGraphs = false; + + CheckResult result = check(R"( + --!strict + local x = 1 + "foo" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("Operator '+' could not be applied to operands of types number and string; there is no corresponding overload for __add", + toString(result.errors[0])); + else + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "unary_op_type_family_errors") +{ + frontend.options.retainFullTypeGraphs = false; + + CheckResult result = check(R"( + --!strict + local x = -"foo" + )"); + + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ( + "Operator '-' could not be applied to operand of type string; there is no corresponding overload for __unm", toString(result.errors[0])); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[1])); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); + } +} + TEST_SUITE_END(); diff --git a/tests/Fixture.h b/tests/Fixture.h index 481f79d3..e0c04e8b 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -241,3 +241,21 @@ using DifferFixtureWithBuiltins = DifferFixtureGeneric; } while (false) #define LUAU_REQUIRE_NO_ERRORS(result) LUAU_REQUIRE_ERROR_COUNT(0, result) + +#define LUAU_CHECK_ERRORS(result) \ + do \ + { \ + auto&& r = (result); \ + validateErrors(r.errors); \ + CHECK(!r.errors.empty()); \ + } while (false) + +#define LUAU_CHECK_ERROR_COUNT(count, result) \ + do \ + { \ + auto&& r = (result); \ + validateErrors(r.errors); \ + CHECK_MESSAGE(count == r.errors.size(), getErrors(r)); \ + } while (false) + +#define LUAU_CHECK_NO_ERRORS(result) LUAU_CHECK_ERROR_COUNT(0, result) diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 411d4914..967dea43 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -1333,4 +1333,58 @@ TEST_CASE_FIXTURE(FrontendFixture, "checked_modules_have_the_correct_mode") CHECK(moduleC->mode == Mode::Strict); } +TEST_CASE_FIXTURE(FrontendFixture, "separate_caches_for_autocomplete") +{ + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, false}; + + fileResolver.source["game/A"] = R"( + --!nonstrict + local exports = {} + function exports.hello() end + return exports + )"; + + FrontendOptions opts; + opts.forAutocomplete = true; + + frontend.check("game/A", opts); + + CHECK(nullptr == frontend.moduleResolver.getModule("game/A")); + + ModulePtr acModule = frontend.moduleResolverForAutocomplete.getModule("game/A"); + REQUIRE(acModule != nullptr); + CHECK(acModule->mode == Mode::Strict); + + frontend.check("game/A"); + + ModulePtr module = frontend.moduleResolver.getModule("game/A"); + + REQUIRE(module != nullptr); + CHECK(module->mode == Mode::Nonstrict); +} + +TEST_CASE_FIXTURE(FrontendFixture, "no_separate_caches_with_the_new_solver") +{ + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; + + fileResolver.source["game/A"] = R"( + --!nonstrict + local exports = {} + function exports.hello() end + return exports + )"; + + FrontendOptions opts; + opts.forAutocomplete = true; + + frontend.check("game/A", opts); + + CHECK(nullptr == frontend.moduleResolverForAutocomplete.getModule("game/A")); + + ModulePtr module = frontend.moduleResolver.getModule("game/A"); + + REQUIRE(module != nullptr); + CHECK(module->mode == Mode::Nonstrict); +} + TEST_SUITE_END(); diff --git a/tests/Generalization.test.cpp b/tests/Generalization.test.cpp new file mode 100644 index 00000000..901461ae --- /dev/null +++ b/tests/Generalization.test.cpp @@ -0,0 +1,250 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Generalization.h" +#include "Luau/Scope.h" +#include "Luau/ToString.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/Error.h" + +#include "Fixture.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + +TEST_SUITE_BEGIN("Generalization"); + +struct GeneralizationFixture +{ + TypeArena arena; + BuiltinTypes builtinTypes; + ScopePtr globalScope = std::make_shared(builtinTypes.anyTypePack); + ScopePtr scope = std::make_shared(globalScope); + ToStringOptions opts; + + DenseHashSet generalizedTypes_{nullptr}; + NotNull> generalizedTypes{&generalizedTypes_}; + + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; + + std::pair freshType() + { + FreeType ft{scope.get(), builtinTypes.neverType, builtinTypes.unknownType}; + + TypeId ty = arena.addType(ft); + FreeType* ftv = getMutable(ty); + REQUIRE(ftv != nullptr); + + return {ty, ftv}; + } + + std::string toString(TypeId ty) + { + return ::Luau::toString(ty, opts); + } + + std::string toString(TypePackId ty) + { + return ::Luau::toString(ty, opts); + } + + std::optional generalize(TypeId ty) + { + return ::Luau::generalize(NotNull{&arena}, NotNull{&builtinTypes}, NotNull{scope.get()}, generalizedTypes, ty); + } +}; + +TEST_CASE_FIXTURE(GeneralizationFixture, "generalize_a_type_that_is_bounded_by_another_generalizable_type") +{ + auto [t1, ft1] = freshType(); + auto [t2, ft2] = freshType(); + + // t2 <: t1 <: unknown + // unknown <: t2 <: t1 + + ft1->lowerBound = t2; + ft2->upperBound = t1; + ft2->lowerBound = builtinTypes.unknownType; + + auto t2generalized = generalize(t2); + REQUIRE(t2generalized); + + CHECK(follow(t1) == follow(t2)); + + auto t1generalized = generalize(t1); + REQUIRE(t1generalized); + + CHECK(builtinTypes.unknownType == follow(t1)); + CHECK(builtinTypes.unknownType == follow(t2)); +} + +// Same as generalize_a_type_that_is_bounded_by_another_generalizable_type +// except that we generalize the types in the opposite order +TEST_CASE_FIXTURE(GeneralizationFixture, "generalize_a_type_that_is_bounded_by_another_generalizable_type_in_reverse_order") +{ + auto [t1, ft1] = freshType(); + auto [t2, ft2] = freshType(); + + // t2 <: t1 <: unknown + // unknown <: t2 <: t1 + + ft1->lowerBound = t2; + ft2->upperBound = t1; + ft2->lowerBound = builtinTypes.unknownType; + + auto t1generalized = generalize(t1); + REQUIRE(t1generalized); + + CHECK(follow(t1) == follow(t2)); + + auto t2generalized = generalize(t2); + REQUIRE(t2generalized); + + CHECK(builtinTypes.unknownType == follow(t1)); + CHECK(builtinTypes.unknownType == follow(t2)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "dont_traverse_into_class_types_when_generalizing") +{ + auto [propTy, _] = freshType(); + + TypeId cursedClass = arena.addType(ClassType{"Cursed", {{"oh_no", Property::readonly(propTy)}}, std::nullopt, std::nullopt, {}, {}, ""}); + + auto genClass = generalize(cursedClass); + REQUIRE(genClass); + + auto genPropTy = get(*genClass)->props.at("oh_no").readTy; + CHECK(is(*genPropTy)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "cache_fully_generalized_types") +{ + CHECK(generalizedTypes->empty()); + + TypeId tinyTable = arena.addType(TableType{ + TableType::Props{{"one", builtinTypes.numberType}, {"two", builtinTypes.stringType}}, std::nullopt, TypeLevel{}, TableState::Sealed}); + + generalize(tinyTable); + + CHECK(generalizedTypes->contains(tinyTable)); + CHECK(generalizedTypes->contains(builtinTypes.numberType)); + CHECK(generalizedTypes->contains(builtinTypes.stringType)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "dont_cache_types_that_arent_done_yet") +{ + TypeId freeTy = arena.addType(FreeType{NotNull{globalScope.get()}, builtinTypes.neverType, builtinTypes.stringType}); + + TypeId fnTy = arena.addType(FunctionType{builtinTypes.emptyTypePack, arena.addTypePack(TypePack{{builtinTypes.numberType}})}); + + TypeId tableTy = arena.addType(TableType{ + TableType::Props{{"one", builtinTypes.numberType}, {"two", freeTy}, {"three", fnTy}}, std::nullopt, TypeLevel{}, TableState::Sealed}); + + generalize(tableTy); + + CHECK(generalizedTypes->contains(fnTy)); + CHECK(generalizedTypes->contains(builtinTypes.numberType)); + CHECK(generalizedTypes->contains(builtinTypes.neverType)); + CHECK(generalizedTypes->contains(builtinTypes.stringType)); + CHECK(!generalizedTypes->contains(freeTy)); + CHECK(!generalizedTypes->contains(tableTy)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "functions_containing_cyclic_tables_can_be_cached") +{ + TypeId selfTy = arena.addType(BlockedType{}); + + TypeId methodTy = arena.addType(FunctionType{ + arena.addTypePack({selfTy}), + arena.addTypePack({builtinTypes.numberType}), + }); + + asMutable(selfTy)->ty.emplace( + TableType::Props{{"count", builtinTypes.numberType}, {"method", methodTy}}, std::nullopt, TypeLevel{}, TableState::Sealed); + + generalize(methodTy); + + CHECK(generalizedTypes->contains(methodTy)); + CHECK(generalizedTypes->contains(selfTy)); + CHECK(generalizedTypes->contains(builtinTypes.numberType)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "union_type_traversal_doesnt_crash") +{ + // t1 where t1 = ('h <: (t1 <: 'i)) | ('j <: (t1 <: 'i)) + TypeId i = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId h = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId j = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId unionType = arena.addType(UnionType{{h, j}}); + getMutable(h)->upperBound = i; + getMutable(h)->lowerBound = builtinTypes.neverType; + getMutable(i)->upperBound = builtinTypes.unknownType; + getMutable(i)->lowerBound = unionType; + getMutable(j)->upperBound = i; + getMutable(j)->lowerBound = builtinTypes.neverType; + + generalize(unionType); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "intersection_type_traversal_doesnt_crash") +{ + // t1 where t1 = ('h <: (t1 <: 'i)) & ('j <: (t1 <: 'i)) + TypeId i = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId h = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId j = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId intersectionType = arena.addType(IntersectionType{{h, j}}); + + getMutable(h)->upperBound = i; + getMutable(h)->lowerBound = builtinTypes.neverType; + getMutable(i)->upperBound = builtinTypes.unknownType; + getMutable(i)->lowerBound = intersectionType; + getMutable(j)->upperBound = i; + getMutable(j)->lowerBound = builtinTypes.neverType; + + generalize(intersectionType); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "generalization_traversal_should_re_traverse_unions_if_they_change_type") +{ + // This test case should just not assert + CheckResult result = check(R"( +function byId(p) + return p.id +end + +function foo() + + local productButtonPairs = {} + local func = byId + local dir = -1 + + local function updateSearch() + for product, button in pairs(productButtonPairs) do + button.LayoutOrder = func(product) * dir + end + end + + function(mode) + if mode == 'Name'then + else + if mode == 'New'then + func = function(p) + return p.id + end + elseif mode == 'Price'then + func = function(p) + return p.price + end + end + + end + end +end +)"); +} + +TEST_SUITE_END(); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 2f198e65..611eb7b5 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -12,15 +12,21 @@ #include -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(DebugLuauAbortingChecks) -LUAU_FASTFLAG(LuauCodegenLoadPropCheckRegLinkInTv) +LUAU_FASTFLAG(LuauCodegenInstG) +LUAU_FASTFLAG(LuauCodegenFastcall3) +LUAU_FASTFLAG(LuauCodegenMathSign) using namespace Luau::CodeGen; class IrBuilderFixture { public: + IrBuilderFixture() + : build(hooks) + { + } + void constantFold() { for (IrBlock& block : build.function.blocks) @@ -109,6 +115,7 @@ public: computeCfgDominanceTreeChildren(build.function); } + HostIrHooks hooks; IrBuilder build; // Luau.VM headers are not accessible @@ -328,6 +335,8 @@ TEST_SUITE_BEGIN("ConstantFolding"); TEST_CASE_FIXTURE(IrBuilderFixture, "Numeric") { + ScopedFastFlag luauCodegenMathSign{FFlag::LuauCodegenMathSign, true}; + IrOp block = build.block(IrBlockKind::Internal); build.beginBlock(block); @@ -358,6 +367,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Numeric") build.inst(IrCmd::STORE_INT, build.vmReg(20), build.inst(IrCmd::NOT_ANY, build.constTag(tboolean), build.constInt(0))); build.inst(IrCmd::STORE_INT, build.vmReg(21), build.inst(IrCmd::NOT_ANY, build.constTag(tboolean), build.constInt(1))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(22), build.inst(IrCmd::SIGN_NUM, build.constDouble(-4))); + build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); @@ -386,6 +397,7 @@ bb_0: STORE_INT R19, 0i STORE_INT R20, 1i STORE_INT R21, 0i + STORE_DOUBLE R22, -1 RETURN 0u )"); @@ -1111,6 +1123,8 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinFastcallsMayInvalidateMemory") { + ScopedFastFlag luauCodegenInstG{FFlag::LuauCodegenInstG, true}; + IrOp block = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); @@ -1123,8 +1137,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinFastcallsMayInvalidateMemory") build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); build.inst(IrCmd::CHECK_READONLY, table, fallback); - build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(LBF_SETMETATABLE), build.vmReg(1), build.vmReg(2), build.vmReg(3), build.constInt(3), - build.constInt(1)); + build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(LBF_SETMETATABLE), build.vmReg(1), build.vmReg(2), build.vmReg(3), build.undef(), + build.constInt(3), build.constInt(1)); build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); build.inst(IrCmd::CHECK_READONLY, table, fallback); @@ -1145,7 +1159,7 @@ bb_0: %1 = LOAD_POINTER R0 CHECK_NO_METATABLE %1, bb_fallback_1 CHECK_READONLY %1, bb_fallback_1 - %4 = INVOKE_FASTCALL 61u, R1, R2, R3, 3i, 1i + %4 = INVOKE_FASTCALL 61u, R1, R2, R3, undef, 3i, 1i CHECK_NO_METATABLE %1, bb_fallback_1 CHECK_READONLY %1, bb_fallback_1 STORE_DOUBLE R1, 0.5 @@ -2540,8 +2554,6 @@ bb_0: ; useCount: 0 TEST_CASE_FIXTURE(IrBuilderFixture, "ForgprepInvalidation") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp block = build.block(IrBlockKind::Internal); IrOp followup = build.block(IrBlockKind::Internal); @@ -2581,14 +2593,14 @@ bb_1: TEST_CASE_FIXTURE(IrBuilderFixture, "FastCallEffects1") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + ScopedFastFlag luauCodegenFastcall3{FFlag::LuauCodegenFastcall3, true}; IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); - build.inst(IrCmd::FASTCALL, build.constUint(LBF_MATH_FREXP), build.vmReg(1), build.vmReg(2), build.undef(), build.constInt(1), build.constInt(2)); - build.inst(IrCmd::CHECK_TAG, build.vmReg(1), build.constTag(tnumber), build.vmExit(1)); - build.inst(IrCmd::CHECK_TAG, build.vmReg(2), build.constTag(tnumber), build.vmExit(1)); + build.inst(IrCmd::FASTCALL, build.constUint(LBF_MATH_FREXP), build.vmReg(1), build.vmReg(2), build.constInt(2)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(1)), build.constTag(tnumber), build.vmExit(1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(2)), build.constTag(tnumber), build.vmExit(1)); build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(2)); updateUseCounts(build.function); @@ -2598,7 +2610,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FastCallEffects1") CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: ; in regs: R2 - FASTCALL 14u, R1, R2, undef, 1i, 2i + FASTCALL 14u, R1, R2, 2i RETURN R1, 2i )"); @@ -2606,14 +2618,14 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "FastCallEffects2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + ScopedFastFlag luauCodegenFastcall3{FFlag::LuauCodegenFastcall3, true}; IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); - build.inst(IrCmd::FASTCALL, build.constUint(LBF_MATH_MODF), build.vmReg(1), build.vmReg(2), build.undef(), build.constInt(1), build.constInt(1)); - build.inst(IrCmd::CHECK_TAG, build.vmReg(1), build.constTag(tnumber), build.vmExit(1)); - build.inst(IrCmd::CHECK_TAG, build.vmReg(2), build.constTag(tnumber), build.vmExit(1)); + build.inst(IrCmd::FASTCALL, build.constUint(LBF_MATH_MODF), build.vmReg(1), build.vmReg(2), build.constInt(1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(1)), build.constTag(tnumber), build.vmExit(1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(2)), build.constTag(tnumber), build.vmExit(1)); build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(2)); updateUseCounts(build.function); @@ -2623,8 +2635,9 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FastCallEffects2") CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: ; in regs: R2 - FASTCALL 20u, R1, R2, undef, 1i, 1i - CHECK_TAG R2, tnumber, exit(1) + FASTCALL 20u, R1, R2, 1i + %3 = LOAD_TAG R2 + CHECK_TAG %3, tnumber, exit(1) RETURN R1, 2i )"); @@ -2636,7 +2649,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "InferNumberTagFromLimitedContext") build.beginBlock(entry); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(2.0)); - build.inst(IrCmd::CHECK_TAG, build.vmReg(0), build.constTag(ttable), build.vmExit(1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(ttable), build.vmExit(1)); build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); @@ -2652,6 +2665,58 @@ bb_0: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "DoNotProduceInvalidSplitStore1") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(ttable), build.vmExit(1)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_INT R0, 1i + %1 = LOAD_TAG R0 + CHECK_TAG %1, ttable, exit(1) + %3 = LOAD_TVALUE R0 + STORE_TVALUE R1, %3 + RETURN R1, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "DoNotProduceInvalidSplitStore2") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), build.vmExit(1)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_INT R0, 1i + %1 = LOAD_TAG R0 + CHECK_TAG %1, tnumber, exit(1) + %3 = LOAD_TVALUE R0 + STORE_TVALUE R1, %3 + RETURN R1, 1i + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("Analysis"); @@ -2749,13 +2814,16 @@ bb_1: TEST_CASE_FIXTURE(IrBuilderFixture, "ExplicitUseOfRegisterInVarargSequence") { + ScopedFastFlag luauCodegenInstG{FFlag::LuauCodegenInstG, true}; + ScopedFastFlag luauCodegenFastcall3{FFlag::LuauCodegenFastcall3, true}; + IrOp entry = build.block(IrBlockKind::Internal); IrOp exit = build.block(IrBlockKind::Internal); build.beginBlock(entry); build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(1), build.constInt(-1)); - IrOp results = build.inst( - IrCmd::INVOKE_FASTCALL, build.constUint(0), build.vmReg(0), build.vmReg(1), build.vmReg(2), build.constInt(-1), build.constInt(-1)); + IrOp results = build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(0), build.vmReg(0), build.vmReg(1), build.vmReg(2), build.undef(), + build.constInt(-1), build.constInt(-1)); build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(0), results); build.inst(IrCmd::JUMP, exit); @@ -2770,7 +2838,7 @@ bb_0: ; successors: bb_1 ; out regs: R0... FALLBACK_GETVARARGS 0u, R1, -1i - %1 = INVOKE_FASTCALL 0u, R0, R1, R2, -1i, -1i + %1 = INVOKE_FASTCALL 0u, R0, R1, R2, undef, -1i, -1i ADJUST_STACK_TO_REG R0, %1 JUMP bb_1 @@ -2963,8 +3031,6 @@ bb_1: TEST_CASE_FIXTURE(IrBuilderFixture, "ForgprepImplicitUse") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); IrOp direct = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Internal); @@ -3469,8 +3535,6 @@ bb_1: TEST_CASE_FIXTURE(IrBuilderFixture, "TaggedValuePropagationIntoTvalueChecksRegisterVersion") { - ScopedFastFlag luauCodegenLoadPropCheckRegLinkInTv{FFlag::LuauCodegenLoadPropCheckRegLinkInTv, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3527,8 +3591,6 @@ TEST_SUITE_BEGIN("DeadStoreRemoval"); TEST_CASE_FIXTURE(IrBuilderFixture, "SimpleDoubleStore") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3578,8 +3640,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "UnusedAtReturn") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3611,8 +3671,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "UnusedAtReturnPartial") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3641,8 +3699,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse1") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3671,8 +3727,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3705,8 +3759,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse3") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3735,8 +3787,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse4") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3769,8 +3819,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "PartialVsFullStoresWithRecombination") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3794,8 +3842,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "IgnoreFastcallAdjustment") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3822,8 +3868,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "JumpImplicitLiveOut") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); IrOp next = build.block(IrBlockKind::Internal); @@ -3859,8 +3903,6 @@ bb_1: TEST_CASE_FIXTURE(IrBuilderFixture, "KeepCapturedRegisterStores") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -3898,7 +3940,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "StoreCannotBeReplacedWithCheck") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; ScopedFastFlag debugLuauAbortingChecks{FFlag::DebugLuauAbortingChecks, true}; IrOp block = build.block(IrBlockKind::Internal); @@ -3967,8 +4008,6 @@ bb_2: TEST_CASE_FIXTURE(IrBuilderFixture, "FullStoreHasToBeObservableFromFallbacks") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); IrOp last = build.block(IrBlockKind::Internal); @@ -4025,8 +4064,6 @@ bb_2: TEST_CASE_FIXTURE(IrBuilderFixture, "FullStoreHasToBeObservableFromFallbacks2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); IrOp last = build.block(IrBlockKind::Internal); @@ -4081,8 +4118,6 @@ bb_2: TEST_CASE_FIXTURE(IrBuilderFixture, "FullStoreHasToBeObservableFromFallbacks3") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); IrOp last = build.block(IrBlockKind::Internal); @@ -4140,8 +4175,6 @@ bb_2: TEST_CASE_FIXTURE(IrBuilderFixture, "SafePartialValueStoresWithPreservedTag") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); IrOp last = build.block(IrBlockKind::Internal); @@ -4195,8 +4228,6 @@ bb_2: TEST_CASE_FIXTURE(IrBuilderFixture, "SafePartialValueStoresWithPreservedTag2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); IrOp last = build.block(IrBlockKind::Internal); @@ -4249,8 +4280,6 @@ bb_2: TEST_CASE_FIXTURE(IrBuilderFixture, "DoNotReturnWithPartialStores") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); IrOp success = build.block(IrBlockKind::Internal); IrOp fail = build.block(IrBlockKind::Internal); @@ -4321,8 +4350,6 @@ bb_3: TEST_CASE_FIXTURE(IrBuilderFixture, "PartialOverFullValue") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index 479329b4..0ff8a12c 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -6,23 +6,38 @@ #include "Luau/CodeGen.h" #include "Luau/Compiler.h" #include "Luau/Parser.h" +#include "Luau/IrBuilder.h" #include "doctest.h" #include "ScopedFlags.h" +#include "ConformanceIrHooks.h" #include +#include -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) -LUAU_FASTFLAG(LuauCodegenDirectUserdataFlow) -LUAU_FASTFLAG(LuauCompileTypeInfo) -LUAU_FASTFLAG(LuauLoadTypeInfo) -LUAU_FASTFLAG(LuauCodegenTypeInfo) -LUAU_FASTFLAG(LuauTypeInfoLookupImprovement) +LUAU_FASTFLAG(LuauCompileUserdataInfo) +LUAU_FASTFLAG(LuauLoadUserdataInfo) +LUAU_FASTFLAG(LuauCodegenUserdataOps) +LUAU_FASTFLAG(LuauCodegenUserdataAlloc) +LUAU_FASTFLAG(LuauCompileFastcall3) +LUAU_FASTFLAG(LuauCodegenFastcall3) -static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false) +static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1) { Luau::CodeGen::AssemblyOptions options; + options.compilationOptions.hooks.vectorAccessBytecodeType = vectorAccessBytecodeType; + options.compilationOptions.hooks.vectorNamecallBytecodeType = vectorNamecallBytecodeType; + options.compilationOptions.hooks.vectorAccess = vectorAccess; + options.compilationOptions.hooks.vectorNamecall = vectorNamecall; + + options.compilationOptions.hooks.userdataAccessBytecodeType = userdataAccessBytecodeType; + options.compilationOptions.hooks.userdataMetamethodBytecodeType = userdataMetamethodBytecodeType; + options.compilationOptions.hooks.userdataNamecallBytecodeType = userdataNamecallBytecodeType; + options.compilationOptions.hooks.userdataAccess = userdataAccess; + options.compilationOptions.hooks.userdataMetamethod = userdataMetamethod; + options.compilationOptions.hooks.userdataNamecall = userdataNamecall; + // For IR, we don't care about assembly, but we want a stable target options.target = Luau::CodeGen::AssemblyOptions::Target::X64_SystemV; @@ -47,11 +62,14 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes = Luau::CompileOptions copts = {}; copts.optimizationLevel = 2; - copts.debugLevel = 1; + copts.debugLevel = debugLevel; copts.typeInfoLevel = 1; copts.vectorCtor = "vector"; copts.vectorType = "vector"; + static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr}; + copts.userdataTypes = kUserdataCompileTypes; + Luau::BytecodeBuilder bcb; Luau::compileOrThrow(bcb, result, names, copts); @@ -59,6 +77,33 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes = std::unique_ptr globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); + // Runtime mapping is specifically created to NOT match the compilation mapping + options.compilationOptions.userdataTypes = kUserdataRunTypes; + + if (Luau::CodeGen::isSupported()) + { + // Type remapper requires the codegen runtime + Luau::CodeGen::create(L); + + Luau::CodeGen::setUserdataRemapper(L, kUserdataRunTypes, [](void* context, const char* str, size_t len) -> uint8_t { + const char** types = (const char**)context; + + uint8_t index = 0; + + std::string_view sv{str, len}; + + for (; *types; ++types) + { + if (sv == *types) + return index; + + index++; + } + + return 0xff; + }); + } + if (luau_load(L, "name", bytecode.data(), bytecode.size(), 0) == 0) return Luau::CodeGen::getAssembly(L, -1, options, nullptr); @@ -66,6 +111,20 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes = return ""; } +static std::string getCodegenHeader(const char* source) +{ + std::string assembly = getCodegenAssembly(source, /* includeIrTypes */ true, /* debugLevel */ 2); + + auto bytecodeStart = assembly.find("bb_bytecode_0:"); + + if (bytecodeStart == std::string::npos) + bytecodeStart = assembly.find("bb_0:"); + + REQUIRE(bytecodeStart != std::string::npos); + + return assembly.substr(0, bytecodeStart); +} + TEST_SUITE_BEGIN("IrLowering"); TEST_CASE("VectorReciprocal") @@ -95,8 +154,6 @@ bb_bytecode_1: TEST_CASE("VectorComponentRead") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function compsum(a: vector) return a.X + a.Y + a.Z @@ -174,8 +231,6 @@ bb_bytecode_1: TEST_CASE("VectorSubMulDiv") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3combo(a: vector, b: vector, c: vector, d: vector) return a * b - c / d @@ -208,8 +263,6 @@ bb_bytecode_1: TEST_CASE("VectorSubMulDiv2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3combo(a: vector) local tmp = a * a @@ -238,8 +291,6 @@ bb_bytecode_1: TEST_CASE("VectorMulDivMixed") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3combo(a: vector, b: vector, c: vector, d: vector) return a * 2 + b / 4 + 0.5 * c + 40 / d @@ -280,8 +331,6 @@ bb_bytecode_1: TEST_CASE("ExtraMathMemoryOperands") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: number, b: number, c: number, d: number, e: number) return math.floor(a) + math.ceil(b) + math.round(c) + math.sqrt(d) + math.abs(e) @@ -318,8 +367,6 @@ bb_bytecode_1: TEST_CASE("DseInitialStackState") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo() while {} do @@ -358,7 +405,7 @@ bb_5: TEST_CASE("DseInitialStackState2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + ScopedFastFlag luauCodegenFastcall3{FFlag::LuauCodegenFastcall3, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a) @@ -371,28 +418,7 @@ end bb_bytecode_0: CHECK_SAFE_ENV exit(1) CHECK_TAG R0, tnumber, exit(1) - FASTCALL 14u, R1, R0, undef, 1i, 2i - INTERRUPT 5u - RETURN R0, 1i -)"); -} - -TEST_CASE("DseInitialStackState3") -{ - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - - CHECK_EQ("\n" + getCodegenAssembly(R"( -local function foo(a) - math.sign(a) - return a -end -)"), - R"( -; function foo($arg0) line 2 -bb_bytecode_0: - CHECK_SAFE_ENV exit(1) - CHECK_TAG R0, tnumber, exit(1) - FASTCALL 47u, R1, R0, undef, 1i, 1i + FASTCALL 14u, R1, R0, 2i INTERRUPT 5u RETURN R0, 1i )"); @@ -400,8 +426,6 @@ bb_bytecode_0: TEST_CASE("VectorConstantTag") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function vecrcp(a: vector) return vector(1, 2, 3) + a @@ -427,8 +451,6 @@ bb_bytecode_1: TEST_CASE("VectorNamecall") { - ScopedFastFlag luauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function abs(a: vector) return a:Abs() @@ -451,10 +473,297 @@ bb_bytecode_1: )"); } +TEST_CASE("VectorRandomProp") +{ + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: vector) + return a.XX + a.YY + a.ZZ +end +)"), + R"( +; function foo($arg0) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + FALLBACK_GETTABLEKS 0u, R3, R0, K0 + FALLBACK_GETTABLEKS 2u, R4, R0, K1 + CHECK_TAG R3, tnumber, bb_fallback_3 + CHECK_TAG R4, tnumber, bb_fallback_3 + %14 = LOAD_DOUBLE R3 + %16 = ADD_NUM %14, R4 + STORE_DOUBLE R2, %16 + STORE_TAG R2, tnumber + JUMP bb_4 +bb_4: + CHECK_TAG R0, tvector, exit(5) + FALLBACK_GETTABLEKS 5u, R3, R0, K2 + CHECK_TAG R2, tnumber, bb_fallback_5 + CHECK_TAG R3, tnumber, bb_fallback_5 + %30 = LOAD_DOUBLE R2 + %32 = ADD_NUM %30, R3 + STORE_DOUBLE R1, %32 + STORE_TAG R1, tnumber + JUMP bb_6 +bb_6: + INTERRUPT 8u + RETURN R1, 1i +)"); +} + +TEST_CASE("VectorCustomAccess") +{ + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function vec3magn(a: vector) + return a.Magnitude * 2 +end +)"), + R"( +; function vec3magn($arg0) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_FLOAT R0, 0i + %7 = LOAD_FLOAT R0, 4i + %8 = LOAD_FLOAT R0, 8i + %9 = MUL_NUM %6, %6 + %10 = MUL_NUM %7, %7 + %11 = MUL_NUM %8, %8 + %12 = ADD_NUM %9, %10 + %13 = ADD_NUM %12, %11 + %14 = SQRT_NUM %13 + %20 = MUL_NUM %14, 2 + STORE_DOUBLE R1, %20 + STORE_TAG R1, tnumber + INTERRUPT 3u + RETURN R1, 1i +)"); +} + +TEST_CASE("VectorCustomNamecall") +{ + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function vec3dot(a: vector, b: vector) + return (a:Dot(b)) +end +)"), + R"( +; function vec3dot($arg0, $arg1) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + CHECK_TAG R1, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_TVALUE R1 + STORE_TVALUE R4, %6 + %12 = LOAD_FLOAT R0, 0i + %13 = LOAD_FLOAT R4, 0i + %14 = MUL_NUM %12, %13 + %15 = LOAD_FLOAT R0, 4i + %16 = LOAD_FLOAT R4, 4i + %17 = MUL_NUM %15, %16 + %18 = LOAD_FLOAT R0, 8i + %19 = LOAD_FLOAT R4, 8i + %20 = MUL_NUM %18, %19 + %21 = ADD_NUM %14, %17 + %22 = ADD_NUM %21, %20 + STORE_DOUBLE R2, %22 + STORE_TAG R2, tnumber + INTERRUPT 4u + RETURN R2, 1i +)"); +} + +TEST_CASE("VectorCustomAccessChain") +{ + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: vector, b: vector) + return a.Unit * b.Magnitude +end +)"), + R"( +; function foo($arg0, $arg1) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + CHECK_TAG R1, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %8 = LOAD_FLOAT R0, 0i + %9 = LOAD_FLOAT R0, 4i + %10 = LOAD_FLOAT R0, 8i + %11 = MUL_NUM %8, %8 + %12 = MUL_NUM %9, %9 + %13 = MUL_NUM %10, %10 + %14 = ADD_NUM %11, %12 + %15 = ADD_NUM %14, %13 + %16 = SQRT_NUM %15 + %17 = DIV_NUM 1, %16 + %18 = MUL_NUM %8, %17 + %19 = MUL_NUM %9, %17 + %20 = MUL_NUM %10, %17 + STORE_VECTOR R3, %18, %19, %20 + STORE_TAG R3, tvector + %25 = LOAD_FLOAT R1, 0i + %26 = LOAD_FLOAT R1, 4i + %27 = LOAD_FLOAT R1, 8i + %28 = MUL_NUM %25, %25 + %29 = MUL_NUM %26, %26 + %30 = MUL_NUM %27, %27 + %31 = ADD_NUM %28, %29 + %32 = ADD_NUM %31, %30 + %33 = SQRT_NUM %32 + %40 = LOAD_TVALUE R3 + %42 = NUM_TO_VEC %33 + %43 = MUL_VEC %40, %42 + %44 = TAG_VECTOR %43 + STORE_TVALUE R2, %44 + INTERRUPT 5u + RETURN R2, 1i +)"); +} + +TEST_CASE("VectorCustomNamecallChain") +{ + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(n: vector, b: vector, t: vector) + return n:Cross(t):Dot(b) + 1 +end +)"), + R"( +; function foo($arg0, $arg1, $arg2) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + CHECK_TAG R1, tvector, exit(entry) + CHECK_TAG R2, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %8 = LOAD_TVALUE R2 + STORE_TVALUE R6, %8 + %14 = LOAD_FLOAT R0, 0i + %15 = LOAD_FLOAT R6, 0i + %16 = LOAD_FLOAT R0, 4i + %17 = LOAD_FLOAT R6, 4i + %18 = LOAD_FLOAT R0, 8i + %19 = LOAD_FLOAT R6, 8i + %20 = MUL_NUM %16, %19 + %21 = MUL_NUM %18, %17 + %22 = SUB_NUM %20, %21 + %23 = MUL_NUM %18, %15 + %24 = MUL_NUM %14, %19 + %25 = SUB_NUM %23, %24 + %26 = MUL_NUM %14, %17 + %27 = MUL_NUM %16, %15 + %28 = SUB_NUM %26, %27 + STORE_VECTOR R4, %22, %25, %28 + STORE_TAG R4, tvector + %31 = LOAD_TVALUE R1 + STORE_TVALUE R6, %31 + %37 = LOAD_FLOAT R4, 0i + %38 = LOAD_FLOAT R6, 0i + %39 = MUL_NUM %37, %38 + %40 = LOAD_FLOAT R4, 4i + %41 = LOAD_FLOAT R6, 4i + %42 = MUL_NUM %40, %41 + %43 = LOAD_FLOAT R4, 8i + %44 = LOAD_FLOAT R6, 8i + %45 = MUL_NUM %43, %44 + %46 = ADD_NUM %39, %42 + %47 = ADD_NUM %46, %45 + %53 = ADD_NUM %47, 1 + STORE_DOUBLE R3, %53 + STORE_TAG R3, tnumber + INTERRUPT 9u + RETURN R3, 1i +)"); +} + +TEST_CASE("VectorCustomNamecallChain2") +{ + CHECK_EQ("\n" + getCodegenAssembly(R"( +type Vertex = {n: vector, b: vector} + +local function foo(v: Vertex, t: vector) + return v.n:Cross(t):Dot(v.b) + 1 +end +)"), + R"( +; function foo($arg0, $arg1) line 4 +bb_0: + CHECK_TAG R0, ttable, exit(entry) + CHECK_TAG R1, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %8 = LOAD_POINTER R0 + %9 = GET_SLOT_NODE_ADDR %8, 0u, K1 + CHECK_SLOT_MATCH %9, K1, bb_fallback_3 + %11 = LOAD_TVALUE %9, 0i + STORE_TVALUE R3, %11 + JUMP bb_4 +bb_4: + %16 = LOAD_TVALUE R1 + STORE_TVALUE R5, %16 + CHECK_TAG R3, tvector, exit(3) + CHECK_TAG R5, tvector, exit(3) + %22 = LOAD_FLOAT R3, 0i + %23 = LOAD_FLOAT R5, 0i + %24 = LOAD_FLOAT R3, 4i + %25 = LOAD_FLOAT R5, 4i + %26 = LOAD_FLOAT R3, 8i + %27 = LOAD_FLOAT R5, 8i + %28 = MUL_NUM %24, %27 + %29 = MUL_NUM %26, %25 + %30 = SUB_NUM %28, %29 + %31 = MUL_NUM %26, %23 + %32 = MUL_NUM %22, %27 + %33 = SUB_NUM %31, %32 + %34 = MUL_NUM %22, %25 + %35 = MUL_NUM %24, %23 + %36 = SUB_NUM %34, %35 + STORE_VECTOR R3, %30, %33, %36 + CHECK_TAG R0, ttable, exit(6) + %41 = LOAD_POINTER R0 + %42 = GET_SLOT_NODE_ADDR %41, 6u, K3 + CHECK_SLOT_MATCH %42, K3, bb_fallback_5 + %44 = LOAD_TVALUE %42, 0i + STORE_TVALUE R5, %44 + JUMP bb_6 +bb_6: + CHECK_TAG R3, tvector, exit(8) + CHECK_TAG R5, tvector, exit(8) + %53 = LOAD_FLOAT R3, 0i + %54 = LOAD_FLOAT R5, 0i + %55 = MUL_NUM %53, %54 + %56 = LOAD_FLOAT R3, 4i + %57 = LOAD_FLOAT R5, 4i + %58 = MUL_NUM %56, %57 + %59 = LOAD_FLOAT R3, 8i + %60 = LOAD_FLOAT R5, 8i + %61 = MUL_NUM %59, %60 + %62 = ADD_NUM %55, %58 + %63 = ADD_NUM %62, %61 + %69 = ADD_NUM %63, 1 + STORE_DOUBLE R2, %69 + STORE_TAG R2, tnumber + INTERRUPT 12u + RETURN R2, 1i +)"); +} + TEST_CASE("UserDataGetIndex") { - ScopedFastFlag luauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function getxy(a: Point) return a.x + a.y @@ -485,8 +794,6 @@ bb_4: TEST_CASE("UserDataSetIndex") { - ScopedFastFlag luauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function setxy(a: Point) a.x = 3 @@ -513,8 +820,6 @@ bb_bytecode_1: TEST_CASE("UserDataNamecall") { - ScopedFastFlag luauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function getxy(a: Point) return a:GetX() + a:GetY() @@ -551,9 +856,6 @@ bb_4: TEST_CASE("ExplicitUpvalueAndLocalTypes") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local y: vector = ... @@ -595,8 +897,7 @@ bb_bytecode_0: TEST_CASE("FastcallTypeInferThroughLocal") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileFastcall3, true}, {FFlag::LuauCodegenFastcall3, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function getsum(x, c) @@ -611,43 +912,40 @@ end /* includeIrTypes */ true), R"( ; function getsum($arg0, $arg1) line 2 -; R2: vector from 0 to 17 +; R2: vector from 0 to 18 bb_bytecode_0: - %0 = LOAD_TVALUE R0 - STORE_TVALUE R3, %0 STORE_DOUBLE R4, 2 STORE_TAG R4, tnumber STORE_DOUBLE R5, 3 STORE_TAG R5, tnumber CHECK_SAFE_ENV exit(4) - CHECK_TAG R3, tnumber, exit(4) - %13 = LOAD_DOUBLE R3 - STORE_VECTOR R2, %13, 2, 3 + CHECK_TAG R0, tnumber, exit(4) + %11 = LOAD_DOUBLE R0 + STORE_VECTOR R2, %11, 2, 3 STORE_TAG R2, tvector JUMP_IF_FALSY R1, bb_bytecode_1, bb_3 bb_3: - CHECK_TAG R2, tvector, exit(8) - %21 = LOAD_FLOAT R2, 0i - %26 = LOAD_FLOAT R2, 4i - %35 = ADD_NUM %21, %26 - STORE_DOUBLE R3, %35 + CHECK_TAG R2, tvector, exit(9) + %19 = LOAD_FLOAT R2, 0i + %24 = LOAD_FLOAT R2, 4i + %33 = ADD_NUM %19, %24 + STORE_DOUBLE R3, %33 STORE_TAG R3, tnumber - INTERRUPT 13u + INTERRUPT 14u RETURN R3, 1i bb_bytecode_1: - CHECK_TAG R2, tvector, exit(14) - %42 = LOAD_FLOAT R2, 8i - STORE_DOUBLE R3, %42 + CHECK_TAG R2, tvector, exit(15) + %40 = LOAD_FLOAT R2, 8i + STORE_DOUBLE R3, %40 STORE_TAG R3, tnumber - INTERRUPT 16u + INTERRUPT 17u RETURN R3, 1i )"); } TEST_CASE("FastcallTypeInferThroughUpvalue") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileFastcall3, true}, {FFlag::LuauCodegenFastcall3, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local v = ... @@ -666,49 +964,44 @@ end ; function getsum($arg0, $arg1) line 4 ; U0: vector bb_bytecode_0: - %0 = LOAD_TVALUE R0 - STORE_TVALUE R3, %0 STORE_DOUBLE R4, 2 STORE_TAG R4, tnumber STORE_DOUBLE R5, 3 STORE_TAG R5, tnumber CHECK_SAFE_ENV exit(4) - CHECK_TAG R3, tnumber, exit(4) - %13 = LOAD_DOUBLE R3 - STORE_VECTOR R2, %13, 2, 3 + CHECK_TAG R0, tnumber, exit(4) + %11 = LOAD_DOUBLE R0 + STORE_VECTOR R2, %11, 2, 3 STORE_TAG R2, tvector SET_UPVALUE U0, R2, tvector JUMP_IF_FALSY R1, bb_bytecode_1, bb_3 bb_3: GET_UPVALUE R4, U0 - CHECK_TAG R4, tvector, exit(10) - %23 = LOAD_FLOAT R4, 0i - STORE_DOUBLE R3, %23 + CHECK_TAG R4, tvector, exit(11) + %21 = LOAD_FLOAT R4, 0i + STORE_DOUBLE R3, %21 STORE_TAG R3, tnumber GET_UPVALUE R5, U0 - CHECK_TAG R5, tvector, exit(13) - %29 = LOAD_FLOAT R5, 4i - %38 = ADD_NUM %23, %29 - STORE_DOUBLE R2, %38 + CHECK_TAG R5, tvector, exit(14) + %27 = LOAD_FLOAT R5, 4i + %36 = ADD_NUM %21, %27 + STORE_DOUBLE R2, %36 STORE_TAG R2, tnumber - INTERRUPT 16u + INTERRUPT 17u RETURN R2, 1i bb_bytecode_1: GET_UPVALUE R3, U0 - CHECK_TAG R3, tvector, exit(18) - %46 = LOAD_FLOAT R3, 8i - STORE_DOUBLE R2, %46 + CHECK_TAG R3, tvector, exit(19) + %44 = LOAD_FLOAT R3, 8i + STORE_DOUBLE R2, %44 STORE_TAG R2, tnumber - INTERRUPT 20u + INTERRUPT 21u RETURN R2, 1i )"); } TEST_CASE("LoadAndMoveTypePropagation") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function getsum(n) local seqsum = 0 @@ -774,8 +1067,7 @@ bb_bytecode_4: TEST_CASE("ArgumentTypeRefinement") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}}; + ScopedFastFlag sffs[]{{FFlag::LuauCompileFastcall3, true}, {FFlag::LuauCodegenFastcall3, true}}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function getsum(x, y) @@ -790,32 +1082,27 @@ end bb_bytecode_0: STORE_DOUBLE R3, 1 STORE_TAG R3, tnumber - %2 = LOAD_TVALUE R1 - STORE_TVALUE R4, %2 STORE_DOUBLE R5, 3 STORE_TAG R5, tnumber CHECK_SAFE_ENV exit(4) - CHECK_TAG R4, tnumber, exit(4) - %14 = LOAD_DOUBLE R4 - STORE_VECTOR R2, 1, %14, 3 + CHECK_TAG R1, tnumber, exit(4) + %12 = LOAD_DOUBLE R1 + STORE_VECTOR R2, 1, %12, 3 STORE_TAG R2, tvector - %18 = LOAD_TVALUE R2 - STORE_TVALUE R0, %18 - %22 = LOAD_FLOAT R0, 4i - %27 = LOAD_FLOAT R0, 8i - %36 = ADD_NUM %22, %27 - STORE_DOUBLE R2, %36 + %16 = LOAD_TVALUE R2 + STORE_TVALUE R0, %16 + %20 = LOAD_FLOAT R0, 4i + %25 = LOAD_FLOAT R0, 8i + %34 = ADD_NUM %20, %25 + STORE_DOUBLE R2, %34 STORE_TAG R2, tnumber - INTERRUPT 13u + INTERRUPT 14u RETURN R2, 1i )"); } TEST_CASE("InlineFunctionType") { - ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true}, - {FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function inl(v: vector, s: number) return v.Y * s @@ -860,4 +1147,770 @@ bb_bytecode_0: )"); } +TEST_CASE("ResolveTablePathTypes") +{ + CHECK_EQ("\n" + getCodegenAssembly(R"( +type Vertex = {pos: vector, normal: vector} + +local function foo(arr: {Vertex}, i) + local v = arr[i] + + return v.pos.Y +end +)", + /* includeIrTypes */ true, /* debugLevel */ 2), + R"( +; function foo(arr, i) line 4 +; R0: table [argument 'arr'] +; R2: table from 0 to 6 [local 'v'] +; R4: vector from 3 to 5 +bb_0: + CHECK_TAG R0, ttable, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + CHECK_TAG R1, tnumber, bb_fallback_3 + %8 = LOAD_POINTER R0 + %9 = LOAD_DOUBLE R1 + %10 = TRY_NUM_TO_INDEX %9, bb_fallback_3 + %11 = SUB_INT %10, 1i + CHECK_ARRAY_SIZE %8, %11, bb_fallback_3 + CHECK_NO_METATABLE %8, bb_fallback_3 + %14 = GET_ARR_ADDR %8, %11 + %15 = LOAD_TVALUE %14 + STORE_TVALUE R2, %15 + JUMP bb_4 +bb_4: + CHECK_TAG R2, ttable, exit(1) + %23 = LOAD_POINTER R2 + %24 = GET_SLOT_NODE_ADDR %23, 1u, K0 + CHECK_SLOT_MATCH %24, K0, bb_fallback_5 + %26 = LOAD_TVALUE %24, 0i + STORE_TVALUE R4, %26 + JUMP bb_6 +bb_6: + CHECK_TAG R4, tvector, exit(3) + %33 = LOAD_FLOAT R4, 4i + STORE_DOUBLE R3, %33 + STORE_TAG R3, tnumber + INTERRUPT 5u + RETURN R3, 1i +)"); +} + +TEST_CASE("ResolvableSimpleMath") +{ + CHECK_EQ("\n" + getCodegenHeader(R"( +type Vertex = { p: vector, uv: vector, n: vector, t: vector, b: vector, h: number } +local mesh: { vertices: {Vertex}, indices: {number} } = ... + +local function compute() + for i = 1,#mesh.indices,3 do + local a = mesh.vertices[mesh.indices[i]] + local b = mesh.vertices[mesh.indices[i + 1]] + local c = mesh.vertices[mesh.indices[i + 2]] + + local vba = b.p - a.p + local vca = c.p - a.p + + local uvba = b.uv - a.uv + local uvca = c.uv - a.uv + + local r = 1.0 / (uvba.X * uvca.Y - uvca.X * uvba.Y); + + local sdir = (uvca.Y * vba - uvba.Y * vca) * r + + a.t += sdir + end +end +)"), + R"( +; function compute() line 5 +; U0: table ['mesh'] +; R2: number from 0 to 78 [local 'i'] +; R3: table from 7 to 78 [local 'a'] +; R4: table from 15 to 78 [local 'b'] +; R5: table from 24 to 78 [local 'c'] +; R6: vector from 33 to 78 [local 'vba'] +; R7: vector from 37 to 38 +; R7: vector from 38 to 78 [local 'vca'] +; R8: vector from 37 to 38 +; R8: vector from 42 to 43 +; R8: vector from 43 to 78 [local 'uvba'] +; R9: vector from 42 to 43 +; R9: vector from 47 to 48 +; R9: vector from 48 to 78 [local 'uvca'] +; R10: vector from 47 to 48 +; R10: vector from 52 to 53 +; R10: number from 53 to 78 [local 'r'] +; R11: vector from 52 to 53 +; R11: vector from 65 to 78 [local 'sdir'] +; R12: vector from 72 to 73 +; R12: vector from 75 to 76 +; R13: vector from 71 to 72 +; R14: vector from 71 to 72 +)"); +} + +TEST_CASE("ResolveVectorNamecalls") +{ + CHECK_EQ("\n" + getCodegenAssembly(R"( +type Vertex = {pos: vector, normal: vector} + +local function foo(arr: {Vertex}, i) + return arr[i].normal:Dot(vector(0.707, 0, 0.707)) +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1) line 4 +; R0: table [argument] +; R2: vector from 4 to 6 +bb_0: + CHECK_TAG R0, ttable, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + CHECK_TAG R1, tnumber, bb_fallback_3 + %8 = LOAD_POINTER R0 + %9 = LOAD_DOUBLE R1 + %10 = TRY_NUM_TO_INDEX %9, bb_fallback_3 + %11 = SUB_INT %10, 1i + CHECK_ARRAY_SIZE %8, %11, bb_fallback_3 + CHECK_NO_METATABLE %8, bb_fallback_3 + %14 = GET_ARR_ADDR %8, %11 + %15 = LOAD_TVALUE %14 + STORE_TVALUE R3, %15 + JUMP bb_4 +bb_4: + CHECK_TAG R3, ttable, bb_fallback_5 + %23 = LOAD_POINTER R3 + %24 = GET_SLOT_NODE_ADDR %23, 1u, K0 + CHECK_SLOT_MATCH %24, K0, bb_fallback_5 + %26 = LOAD_TVALUE %24, 0i + STORE_TVALUE R2, %26 + JUMP bb_6 +bb_6: + %31 = LOAD_TVALUE K1, 0i, tvector + STORE_TVALUE R4, %31 + CHECK_TAG R2, tvector, exit(4) + %37 = LOAD_FLOAT R2, 0i + %38 = LOAD_FLOAT R4, 0i + %39 = MUL_NUM %37, %38 + %40 = LOAD_FLOAT R2, 4i + %41 = LOAD_FLOAT R4, 4i + %42 = MUL_NUM %40, %41 + %43 = LOAD_FLOAT R2, 8i + %44 = LOAD_FLOAT R4, 8i + %45 = MUL_NUM %43, %44 + %46 = ADD_NUM %39, %42 + %47 = ADD_NUM %46, %45 + STORE_DOUBLE R2, %47 + STORE_TAG R2, tnumber + ADJUST_STACK_TO_REG R2, 1i + INTERRUPT 7u + RETURN R2, -1i +)"); +} + +TEST_CASE("ImmediateTypeAnnotationHelp") +{ + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(arr, i) + return (arr[i] :: vector) / 5 +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1) line 2 +; R3: vector from 1 to 2 +bb_bytecode_0: + CHECK_TAG R0, ttable, bb_fallback_1 + CHECK_TAG R1, tnumber, bb_fallback_1 + %4 = LOAD_POINTER R0 + %5 = LOAD_DOUBLE R1 + %6 = TRY_NUM_TO_INDEX %5, bb_fallback_1 + %7 = SUB_INT %6, 1i + CHECK_ARRAY_SIZE %4, %7, bb_fallback_1 + CHECK_NO_METATABLE %4, bb_fallback_1 + %10 = GET_ARR_ADDR %4, %7 + %11 = LOAD_TVALUE %10 + STORE_TVALUE R3, %11 + JUMP bb_2 +bb_2: + CHECK_TAG R3, tvector, exit(1) + %19 = LOAD_TVALUE R3 + %20 = NUM_TO_VEC 5 + %21 = DIV_VEC %19, %20 + %22 = TAG_VECTOR %21 + STORE_TVALUE R2, %22 + INTERRUPT 2u + RETURN R2, 1i +)"); +} + +TEST_CASE("UnaryTypeResolve") +{ + ScopedFastFlag sffs[]{{FFlag::LuauCompileFastcall3, true}, {FFlag::LuauCodegenFastcall3, true}}; + + CHECK_EQ("\n" + getCodegenHeader(R"( +local function foo(a, b: vector, c) + local d = not a + local e = -b + local f = #c + return (if d then e else vector(f, 2, 3)).X +end +)"), + R"( +; function foo(a, b, c) line 2 +; R1: vector [argument 'b'] +; R3: boolean from 0 to 17 [local 'd'] +; R4: vector from 1 to 17 [local 'e'] +; R5: number from 2 to 17 [local 'f'] +; R7: vector from 14 to 16 +)"); +} + +TEST_CASE("ForInManualAnnotation") +{ + CHECK_EQ("\n" + getCodegenAssembly(R"( +type Vertex = {pos: vector, normal: vector} + +local function foo(a: {Vertex}) + local sum = 0 + for k, v: Vertex in ipairs(a) do + sum += v.pos.X + end + return sum +end +)", + /* includeIrTypes */ true, /* debugLevel */ 2), + R"( +; function foo(a) line 4 +; R0: table [argument 'a'] +; R1: number from 0 to 14 [local 'sum'] +; R5: number from 5 to 11 [local 'k'] +; R6: table from 5 to 11 [local 'v'] +; R8: vector from 8 to 10 +bb_0: + CHECK_TAG R0, ttable, exit(entry) + JUMP bb_4 +bb_4: + JUMP bb_bytecode_1 +bb_bytecode_1: + STORE_DOUBLE R1, 0 + STORE_TAG R1, tnumber + CHECK_SAFE_ENV exit(1) + JUMP_EQ_TAG K1, tnil, bb_fallback_6, bb_5 +bb_5: + %9 = LOAD_TVALUE K1 + STORE_TVALUE R2, %9 + JUMP bb_7 +bb_7: + %15 = LOAD_TVALUE R0 + STORE_TVALUE R3, %15 + INTERRUPT 4u + SET_SAVEDPC 5u + CALL R2, 1i, 3i + CHECK_SAFE_ENV exit(5) + CHECK_TAG R3, ttable, bb_fallback_8 + CHECK_TAG R4, tnumber, bb_fallback_8 + JUMP_CMP_NUM R4, 0, not_eq, bb_fallback_8, bb_9 +bb_9: + STORE_TAG R2, tnil + STORE_POINTER R4, 0i + STORE_EXTRA R4, 128i + STORE_TAG R4, tlightuserdata + JUMP bb_bytecode_3 +bb_bytecode_2: + CHECK_TAG R6, ttable, exit(6) + %35 = LOAD_POINTER R6 + %36 = GET_SLOT_NODE_ADDR %35, 6u, K2 + CHECK_SLOT_MATCH %36, K2, bb_fallback_10 + %38 = LOAD_TVALUE %36, 0i + STORE_TVALUE R8, %38 + JUMP bb_11 +bb_11: + CHECK_TAG R8, tvector, exit(8) + %45 = LOAD_FLOAT R8, 0i + STORE_DOUBLE R7, %45 + STORE_TAG R7, tnumber + CHECK_TAG R1, tnumber, exit(10) + %52 = LOAD_DOUBLE R1 + %54 = ADD_NUM %52, %45 + STORE_DOUBLE R1, %54 + JUMP bb_bytecode_3 +bb_bytecode_3: + INTERRUPT 11u + CHECK_TAG R2, tnil, bb_fallback_13 + %60 = LOAD_POINTER R3 + %61 = LOAD_INT R4 + %62 = GET_ARR_ADDR %60, %61 + CHECK_ARRAY_SIZE %60, %61, bb_12 + %64 = LOAD_TAG %62 + JUMP_EQ_TAG %64, tnil, bb_12, bb_14 +bb_14: + %66 = ADD_INT %61, 1i + STORE_INT R4, %66 + %68 = INT_TO_NUM %66 + STORE_DOUBLE R5, %68 + STORE_TAG R5, tnumber + %71 = LOAD_TVALUE %62 + STORE_TVALUE R6, %71 + JUMP bb_bytecode_2 +bb_12: + INTERRUPT 13u + RETURN R1, 1i +)"); +} + +TEST_CASE("ForInAutoAnnotationIpairs") +{ + CHECK_EQ("\n" + getCodegenHeader(R"( +type Vertex = {pos: vector, normal: vector} + +local function foo(a: {Vertex}) + local sum = 0 + for k, v in ipairs(a) do + local n = v.pos.X + sum += n + end + return sum +end +)"), + R"( +; function foo(a) line 4 +; R0: table [argument 'a'] +; R1: number from 0 to 14 [local 'sum'] +; R5: number from 5 to 11 [local 'k'] +; R6: table from 5 to 11 [local 'v'] +; R7: number from 6 to 11 [local 'n'] +; R8: vector from 8 to 10 +)"); +} + +TEST_CASE("ForInAutoAnnotationPairs") +{ + CHECK_EQ("\n" + getCodegenHeader(R"( +type Vertex = {pos: vector, normal: vector} + +local function foo(a: {[string]: Vertex}) + local sum = 0 + for k, v in pairs(a) do + local n = v.pos.X + sum += n + end + return sum +end +)"), + R"( +; function foo(a) line 4 +; R0: table [argument 'a'] +; R1: number from 0 to 14 [local 'sum'] +; R5: string from 5 to 11 [local 'k'] +; R6: table from 5 to 11 [local 'v'] +; R7: number from 6 to 11 [local 'n'] +; R8: vector from 8 to 10 +)"); +} + +TEST_CASE("ForInAutoAnnotationGeneric") +{ + CHECK_EQ("\n" + getCodegenHeader(R"( +type Vertex = {pos: vector, normal: vector} + +local function foo(a: {Vertex}) + local sum = 0 + for k, v in a do + local n = v.pos.X + sum += n + end + return sum +end +)"), + R"( +; function foo(a) line 4 +; R0: table [argument 'a'] +; R1: number from 0 to 13 [local 'sum'] +; R5: number from 4 to 10 [local 'k'] +; R6: table from 4 to 10 [local 'v'] +; R7: number from 5 to 10 [local 'n'] +; R8: vector from 7 to 9 +)"); +} + +// Temporary test, when we don't compile new typeinfo, but support loading it +TEST_CASE("CustomUserdataTypesTemp") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, false}, {FFlag::LuauLoadUserdataInfo, true}}; + + CHECK_EQ("\n" + getCodegenHeader(R"( +local function foo(v: vec2, x: mat3) + return v.X * x +end +)"), + R"( +; function foo(v, x) line 2 +; R0: userdata [argument 'v'] +; R1: userdata [argument 'x'] +)"); +} + +TEST_CASE("CustomUserdataTypes") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}}; + + CHECK_EQ("\n" + getCodegenHeader(R"( +local function foo(v: vec2, x: mat3) + return v.X * x +end +)"), + R"( +; function foo(v, x) line 2 +; R0: vec2 [argument 'v'] +; R1: mat3 [argument 'x'] +)"); +} + +TEST_CASE("CustomUserdataPropertyAccess") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(v: vec2) + return v.X + v.Y +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0) line 2 +; R0: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %6, 12i, exit(0) + %8 = BUFFER_READF32 %6, 0i, tuserdata + %15 = BUFFER_READF32 %6, 4i, tuserdata + %24 = ADD_NUM %8, %15 + STORE_DOUBLE R1, %24 + STORE_TAG R1, tnumber + INTERRUPT 5u + RETURN R1, 1i +)"); +} + +TEST_CASE("CustomUserdataPropertyAccess2") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: mat3) + return a.Row1 * a.Row2 +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0) line 2 +; R0: mat3 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + FALLBACK_GETTABLEKS 0u, R2, R0, K0 + FALLBACK_GETTABLEKS 2u, R3, R0, K1 + CHECK_TAG R2, tvector, exit(4) + CHECK_TAG R3, tvector, exit(4) + %14 = LOAD_TVALUE R2 + %15 = LOAD_TVALUE R3 + %16 = MUL_VEC %14, %15 + %17 = TAG_VECTOR %16 + STORE_TVALUE R1, %17 + INTERRUPT 5u + RETURN R1, 1i +)"); +} + +TEST_CASE("CustomUserdataNamecall1") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: vec2, b: vec2) + return a:Dot(b) +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1) line 2 +; R0: vec2 [argument] +; R1: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_TVALUE R1 + STORE_TVALUE R4, %6 + %10 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %10, 12i, exit(1) + %14 = LOAD_POINTER R4 + CHECK_USERDATA_TAG %14, 12i, exit(1) + %16 = BUFFER_READF32 %10, 0i, tuserdata + %17 = BUFFER_READF32 %14, 0i, tuserdata + %18 = MUL_NUM %16, %17 + %19 = BUFFER_READF32 %10, 4i, tuserdata + %20 = BUFFER_READF32 %14, 4i, tuserdata + %21 = MUL_NUM %19, %20 + %22 = ADD_NUM %18, %21 + STORE_DOUBLE R2, %22 + STORE_TAG R2, tnumber + ADJUST_STACK_TO_REG R2, 1i + INTERRUPT 4u + RETURN R2, -1i +)"); +} + +TEST_CASE("CustomUserdataNamecall2") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}, + {FFlag::LuauCodegenUserdataAlloc, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: vec2, b: vec2) + return a:Min(b) +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1) line 2 +; R0: vec2 [argument] +; R1: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %6 = LOAD_TVALUE R1 + STORE_TVALUE R4, %6 + %10 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %10, 12i, exit(1) + %14 = LOAD_POINTER R4 + CHECK_USERDATA_TAG %14, 12i, exit(1) + %16 = BUFFER_READF32 %10, 0i, tuserdata + %17 = BUFFER_READF32 %14, 0i, tuserdata + %18 = MIN_NUM %16, %17 + %19 = BUFFER_READF32 %10, 4i, tuserdata + %20 = BUFFER_READF32 %14, 4i, tuserdata + %21 = MIN_NUM %19, %20 + CHECK_GC + %23 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %23, 0i, %18, tuserdata + BUFFER_WRITEF32 %23, 4i, %21, tuserdata + STORE_POINTER R2, %23 + STORE_TAG R2, tuserdata + ADJUST_STACK_TO_REG R2, 1i + INTERRUPT 4u + RETURN R2, -1i +)"); +} + +TEST_CASE("CustomUserdataMetamethodDirectFlow") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: mat3, b: mat3) + return a * b +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1) line 2 +; R0: mat3 [argument] +; R1: mat3 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + SET_SAVEDPC 1u + DO_ARITH R2, R0, R1, 10i + INTERRUPT 1u + RETURN R2, 1i +)"); +} + +TEST_CASE("CustomUserdataMetamethodDirectFlow2") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: mat3) + return -a +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0) line 2 +; R0: mat3 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + SET_SAVEDPC 1u + DO_ARITH R1, R0, R0, 15i + INTERRUPT 1u + RETURN R1, 1i +)"); +} + +TEST_CASE("CustomUserdataMetamethodDirectFlow3") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: sequence) + return #a +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0) line 2 +; R0: userdata [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + SET_SAVEDPC 1u + DO_LEN R1, R0 + INTERRUPT 1u + RETURN R1, 1i +)"); +} + +TEST_CASE("CustomUserdataMetamethod") +{ + // This test requires runtime component to be present + if (!Luau::CodeGen::isSupported()) + return; + + ScopedFastFlag sffs[]{{FFlag::LuauCompileUserdataInfo, true}, {FFlag::LuauLoadUserdataInfo, true}, {FFlag::LuauCodegenUserdataOps, true}, + {FFlag::LuauCodegenUserdataAlloc, true}}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function foo(a: vec2, b: vec2, c: vec2) + return -c + a * b +end +)", + /* includeIrTypes */ true), + R"( +; function foo($arg0, $arg1, $arg2) line 2 +; R0: vec2 [argument] +; R1: vec2 [argument] +; R2: vec2 [argument] +bb_0: + CHECK_TAG R0, tuserdata, exit(entry) + CHECK_TAG R1, tuserdata, exit(entry) + CHECK_TAG R2, tuserdata, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %10 = LOAD_POINTER R2 + CHECK_USERDATA_TAG %10, 12i, exit(0) + %12 = BUFFER_READF32 %10, 0i, tuserdata + %13 = BUFFER_READF32 %10, 4i, tuserdata + %14 = UNM_NUM %12 + %15 = UNM_NUM %13 + CHECK_GC + %17 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %17, 0i, %14, tuserdata + BUFFER_WRITEF32 %17, 4i, %15, tuserdata + STORE_POINTER R4, %17 + STORE_TAG R4, tuserdata + %26 = LOAD_POINTER R0 + CHECK_USERDATA_TAG %26, 12i, exit(1) + %28 = LOAD_POINTER R1 + CHECK_USERDATA_TAG %28, 12i, exit(1) + %30 = BUFFER_READF32 %26, 0i, tuserdata + %31 = BUFFER_READF32 %28, 0i, tuserdata + %32 = MUL_NUM %30, %31 + %33 = BUFFER_READF32 %26, 4i, tuserdata + %34 = BUFFER_READF32 %28, 4i, tuserdata + %35 = MUL_NUM %33, %34 + %37 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %37, 0i, %32, tuserdata + BUFFER_WRITEF32 %37, 4i, %35, tuserdata + STORE_POINTER R5, %37 + STORE_TAG R5, tuserdata + %50 = BUFFER_READF32 %17, 0i, tuserdata + %51 = BUFFER_READF32 %37, 0i, tuserdata + %52 = ADD_NUM %50, %51 + %53 = BUFFER_READF32 %17, 4i, tuserdata + %54 = BUFFER_READF32 %37, 4i, tuserdata + %55 = ADD_NUM %53, %54 + %57 = NEW_USERDATA 8i, 12i + BUFFER_WRITEF32 %57, 0i, %52, tuserdata + BUFFER_WRITEF32 %57, 4i, %55, tuserdata + STORE_POINTER R3, %57 + STORE_TAG R3, tuserdata + INTERRUPT 3u + RETURN R3, 1i +)"); +} + TEST_SUITE_END(); diff --git a/tests/Lexer.test.cpp b/tests/Lexer.test.cpp index 78d1389a..e0716e4c 100644 --- a/tests/Lexer.test.cpp +++ b/tests/Lexer.test.cpp @@ -192,13 +192,13 @@ TEST_CASE("string_interpolation_double_brace") auto brokenInterpBegin = lexer.next(); CHECK_EQ(brokenInterpBegin.type, Lexeme::BrokenInterpDoubleBrace); - CHECK_EQ(std::string(brokenInterpBegin.data, brokenInterpBegin.length), std::string("foo")); + CHECK_EQ(std::string(brokenInterpBegin.data, brokenInterpBegin.getLength()), std::string("foo")); CHECK_EQ(lexer.next().type, Lexeme::Name); auto interpEnd = lexer.next(); CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd); - CHECK_EQ(std::string(interpEnd.data, interpEnd.length), std::string("}bar")); + CHECK_EQ(std::string(interpEnd.data, interpEnd.getLength()), std::string("}bar")); } TEST_CASE("string_interpolation_double_but_unmatched_brace") diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 5d282392..b758764d 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -8,6 +8,9 @@ #include "doctest.h" LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauAttributeSyntax); +LUAU_FASTFLAG(LuauNativeAttribute); +LUAU_FASTFLAG(LintRedundantNativeAttribute); using namespace Luau; @@ -1955,4 +1958,32 @@ local _ = a <= (b == 0) CHECK_EQ(result.warnings[4].text, "X <= Y <= Z is equivalent to (X <= Y) <= Z; did you mean X <= Y and Y <= Z?"); } +TEST_CASE_FIXTURE(Fixture, "RedundantNativeAttribute") +{ + ScopedFastFlag sff[] = {{FFlag::LuauAttributeSyntax, true}, {FFlag::LuauNativeAttribute, true}, {FFlag::LintRedundantNativeAttribute, true}}; + + LintResult result = lint(R"( +--!native + +@native +local function f(a) + @native + local function g(b) + return (a + b) + end + return g +end + +f(3)(4) +)"); + + REQUIRE(2 == result.warnings.size()); + + CHECK_EQ(result.warnings[0].text, "native attribute on a function is redundant in a native module; consider removing it"); + CHECK_EQ(result.warnings[0].location, Location(Position(3, 0), Position(3, 7))); + + CHECK_EQ(result.warnings[1].text, "native attribute on a function is redundant in a native module; consider removing it"); + CHECK_EQ(result.warnings[1].location, Location(Position(5, 4), Position(5, 11))); +} + TEST_SUITE_END(); diff --git a/tests/NonStrictTypeChecker.test.cpp b/tests/NonStrictTypeChecker.test.cpp index d85e46ee..81a84722 100644 --- a/tests/NonStrictTypeChecker.test.cpp +++ b/tests/NonStrictTypeChecker.test.cpp @@ -15,7 +15,7 @@ using namespace Luau; -LUAU_FASTFLAG(LuauCheckedFunctionSyntax); +LUAU_FASTFLAG(LuauAttributeSyntax); #define NONSTRICT_REQUIRE_ERR_AT_POS(pos, result, idx) \ do \ @@ -69,8 +69,8 @@ struct NonStrictTypeCheckerFixture : Fixture CheckResult checkNonStrict(const std::string& code) { ScopedFastFlag flags[] = { - {FFlag::LuauCheckedFunctionSyntax, true}, {FFlag::DebugLuauDeferredConstraintResolution, true}, + {FFlag::LuauAttributeSyntax, true}, }; LoadDefinitionFileResult res = loadDefinition(definitions); LUAU_ASSERT(res.success); @@ -80,8 +80,8 @@ struct NonStrictTypeCheckerFixture : Fixture CheckResult checkNonStrictModule(const std::string& moduleName) { ScopedFastFlag flags[] = { - {FFlag::LuauCheckedFunctionSyntax, true}, {FFlag::DebugLuauDeferredConstraintResolution, true}, + {FFlag::LuauAttributeSyntax, true}, }; LoadDefinitionFileResult res = loadDefinition(definitions); LUAU_ASSERT(res.success); @@ -89,21 +89,21 @@ struct NonStrictTypeCheckerFixture : Fixture } std::string definitions = R"BUILTIN_SRC( -declare function @checked abs(n: number): number -declare function @checked lower(s: string): string +@checked declare function abs(n: number): number +@checked declare function lower(s: string): string declare function cond() : boolean -declare function @checked contrived(n : Not) : number +@checked declare function contrived(n : Not) : number -- interesting types of things that we would like to mark as checked -declare function @checked onlyNums(...: number) : number -declare function @checked mixedArgs(x: string, ...: number) : number -declare function @checked optionalArg(x: string?) : number +@checked declare function onlyNums(...: number) : number +@checked declare function mixedArgs(x: string, ...: number) : number +@checked declare function optionalArg(x: string?) : number declare foo: { bar: @checked (number) -> number, } -declare function @checked optionalArgsAtTheEnd1(x: string, y: number?, z: number?) : number -declare function @checked optionalArgsAtTheEnd2(x: string, y: number?, z: string) : number +@checked declare function optionalArgsAtTheEnd1(x: string, y: number?, z: number?) : number +@checked declare function optionalArgsAtTheEnd2(x: string, y: number?, z: string) : number type DateTypeArg = { year: number, @@ -119,7 +119,7 @@ declare os : { time: @checked (time: DateTypeArg?) -> number } -declare function @checked require(target : any) : any +@checked declare function require(target : any) : any )BUILTIN_SRC"; }; @@ -560,4 +560,26 @@ local E = require(script.Parent.A) LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "nonstrict_shouldnt_warn_on_valid_buffer_use") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauAttributeSyntax, true}, + }; + + loadDefinition(R"( +declare buffer: { + create: @checked (size: number) -> buffer, + readi8: @checked (b: buffer, offset: number) -> number, + writef64: @checked (b: buffer, offset: number, value: number) -> (), +} +)"); + + CheckResult result = checkNonStrict(R"( +local b = buffer.create(100) +buffer.writef64(b, 0, 5) +buffer.readi8(b, 0) +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 1a9ffd65..e8a10e92 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -11,10 +11,8 @@ #include "Luau/BuiltinDefinitions.h" LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauFixNormalizeCaching) LUAU_FASTFLAG(LuauNormalizeNotUnknownIntersection) -LUAU_FASTFLAG(LuauFixCyclicUnionsOfIntersections); - +LUAU_FASTINT(LuauTypeInferRecursionLimit) using namespace Luau; namespace @@ -428,7 +426,6 @@ struct NormalizeFixture : Fixture UnifierSharedState unifierState{&iceHandler}; Normalizer normalizer{&arena, builtinTypes, NotNull{&unifierState}}; Scope globalScope{builtinTypes->anyTypePack}; - ScopedFastFlag fixNormalizeCaching{FFlag::LuauFixNormalizeCaching, true}; NormalizeFixture() { @@ -801,8 +798,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_union") TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_union_of_intersection") { - ScopedFastFlag sff{FFlag::LuauFixCyclicUnionsOfIntersections, true}; - // t1 where t1 = (string & t1) | string TypeId boundTy = arena.addType(BlockedType{}); TypeId intersectTy = arena.addType(IntersectionType{{builtinTypes->stringType, boundTy}}); @@ -816,8 +811,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_union_of_intersection") TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_intersection_of_unions") { - ScopedFastFlag sff{FFlag::LuauFixCyclicUnionsOfIntersections, true}; - // t1 where t1 = (string & t1) | string TypeId boundTy = arena.addType(BlockedType{}); TypeId unionTy = arena.addType(UnionType{{builtinTypes->stringType, boundTy}}); @@ -962,4 +955,32 @@ TEST_CASE_FIXTURE(NormalizeFixture, "intersect_with_not_unknown") CHECK("never" == toString(normalizer.typeFromNormal(*normalized.get()))); } +TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_stack_overflow_1") +{ + ScopedFastInt sfi{FInt::LuauTypeInferRecursionLimit, 165}; + this->unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + TypeId t1 = arena.addType(TableType{}); + TypeId t2 = arena.addType(TableType{}); + TypeId t3 = arena.addType(IntersectionType{{t1, t2}}); + asMutable(t1)->ty.get_if()->props = {{"foo", Property::readonly(t2)}}; + asMutable(t2)->ty.get_if()->props = {{"foo", Property::readonly(t1)}}; + + std::shared_ptr normalized = normalizer.normalize(t3); + CHECK(normalized); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_stack_overflow_2") +{ + ScopedFastInt sfi{FInt::LuauTypeInferRecursionLimit, 165}; + this->unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + TypeId t1 = arena.addType(TableType{}); + TypeId t2 = arena.addType(TableType{}); + TypeId t3 = arena.addType(IntersectionType{{t1, t2}}); + asMutable(t1)->ty.get_if()->props = {{"foo", Property::readonly(t3)}}; + asMutable(t2)->ty.get_if()->props = {{"foo", Property::readonly(t1)}}; + + std::shared_ptr normalized = normalizer.normalize(t3); + CHECK(normalized); +} + TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 10331408..972d0edd 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -11,13 +11,15 @@ using namespace Luau; -LUAU_FASTFLAG(LuauCheckedFunctionSyntax); LUAU_FASTFLAG(LuauLexerLookaheadRemembersBraceType); LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauTypeLengthLimit); LUAU_FASTINT(LuauParseErrorLimit); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAG(LuauReadWritePropertySyntax); +LUAU_FASTFLAG(LuauAttributeSyntax); +LUAU_FASTFLAG(LuauLeadingBarAndAmpersand2); +LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr); +LUAU_FASTFLAG(LuauDeclarationExtraPropData); namespace { @@ -1857,6 +1859,8 @@ function func():end TEST_CASE_FIXTURE(Fixture, "parse_declarations") { + ScopedFastFlag luauDeclarationExtraPropData{FFlag::LuauDeclarationExtraPropData, true}; + AstStatBlock* stat = parseEx(R"( declare foo: number declare function bar(x: number): string @@ -1870,18 +1874,23 @@ TEST_CASE_FIXTURE(Fixture, "parse_declarations") AstStatDeclareGlobal* global = stat->body.data[0]->as(); REQUIRE(global); CHECK(global->name == "foo"); + CHECK(global->nameLocation == Location({1, 16}, {1, 19})); CHECK(global->type); AstStatDeclareFunction* func = stat->body.data[1]->as(); REQUIRE(func); CHECK(func->name == "bar"); + CHECK(func->nameLocation == Location({2, 25}, {2, 28})); REQUIRE_EQ(func->params.types.size, 1); REQUIRE_EQ(func->retTypes.types.size, 1); AstStatDeclareFunction* varFunc = stat->body.data[2]->as(); REQUIRE(varFunc); CHECK(varFunc->name == "var"); + CHECK(varFunc->nameLocation == Location({3, 25}, {3, 28})); CHECK(varFunc->params.tailType); + CHECK(varFunc->vararg); + CHECK(varFunc->varargLocation == Location({3, 29}, {3, 32})); matchParseError("declare function foo(x)", "All declaration parameters must be annotated"); matchParseError("declare foo", "Expected ':' when parsing global variable declaration, got "); @@ -1889,6 +1898,8 @@ TEST_CASE_FIXTURE(Fixture, "parse_declarations") TEST_CASE_FIXTURE(Fixture, "parse_class_declarations") { + ScopedFastFlag luauDeclarationExtraPropData{FFlag::LuauDeclarationExtraPropData, true}; + AstStatBlock* stat = parseEx(R"( declare class Foo prop: number @@ -1912,11 +1923,16 @@ TEST_CASE_FIXTURE(Fixture, "parse_class_declarations") AstDeclaredClassProp& prop = declaredClass->props.data[0]; CHECK(prop.name == "prop"); + CHECK(prop.nameLocation == Location({2, 12}, {2, 16})); CHECK(prop.ty->is()); + CHECK(prop.location == Location({2, 12}, {2, 24})); AstDeclaredClassProp& method = declaredClass->props.data[1]; CHECK(method.name == "method"); + CHECK(method.nameLocation == Location({3, 21}, {3, 27})); CHECK(method.ty->is()); + CHECK(method.location == Location({3, 12}, {3, 54})); + CHECK(method.isMethod); AstStatDeclareClass* subclass = stat->body.data[1]->as(); REQUIRE(subclass); @@ -1927,7 +1943,9 @@ TEST_CASE_FIXTURE(Fixture, "parse_class_declarations") REQUIRE_EQ(subclass->props.size, 1); AstDeclaredClassProp& prop2 = subclass->props.data[0]; CHECK(prop2.name == "prop2"); + CHECK(prop2.nameLocation == Location({7, 12}, {7, 17})); CHECK(prop2.ty->is()); + CHECK(prop2.location == Location({7, 12}, {7, 25})); } TEST_CASE_FIXTURE(Fixture, "class_method_properties") @@ -3052,10 +3070,10 @@ TEST_CASE_FIXTURE(Fixture, "parse_top_level_checked_fn") { ParseOptions opts; opts.allowDeclarationSyntax = true; - ScopedFastFlag sff{FFlag::LuauCheckedFunctionSyntax, true}; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; std::string src = R"BUILTIN_SRC( -declare function @checked abs(n: number): number +@checked declare function abs(n: number): number )BUILTIN_SRC"; ParseResult pr = tryParse(src, opts); @@ -3065,14 +3083,14 @@ declare function @checked abs(n: number): number AstStat* root = *(pr.root->body.data); auto func = root->as(); LUAU_ASSERT(func); - LUAU_ASSERT(func->checkedFunction); + LUAU_ASSERT(func->isCheckedFunction()); } TEST_CASE_FIXTURE(Fixture, "parse_declared_table_checked_member") { ParseOptions opts; opts.allowDeclarationSyntax = true; - ScopedFastFlag sff{FFlag::LuauCheckedFunctionSyntax, true}; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; const std::string src = R"BUILTIN_SRC( declare math : { @@ -3093,14 +3111,14 @@ TEST_CASE_FIXTURE(Fixture, "parse_declared_table_checked_member") auto prop = *tbl->props.data; auto func = prop.type->as(); LUAU_ASSERT(func); - LUAU_ASSERT(func->checkedFunction); + LUAU_ASSERT(func->isCheckedFunction()); } TEST_CASE_FIXTURE(Fixture, "parse_checked_outside_decl_fails") { ParseOptions opts; opts.allowDeclarationSyntax = true; - ScopedFastFlag sff{FFlag::LuauCheckedFunctionSyntax, true}; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; ParseResult pr = tryParse(R"( local @checked = 3 @@ -3114,11 +3132,11 @@ TEST_CASE_FIXTURE(Fixture, "parse_checked_in_and_out_of_decl_fails") { ParseOptions opts; opts.allowDeclarationSyntax = true; - ScopedFastFlag sff{FFlag::LuauCheckedFunctionSyntax, true}; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; auto pr = tryParse(R"( local @checked = 3 - declare function @checked abs(n: number): number + @checked declare function abs(n: number): number )", opts); LUAU_ASSERT(pr.errors.size() == 2); @@ -3130,10 +3148,10 @@ TEST_CASE_FIXTURE(Fixture, "parse_checked_as_function_name_fails") { ParseOptions opts; opts.allowDeclarationSyntax = true; - ScopedFastFlag sff{FFlag::LuauCheckedFunctionSyntax, true}; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; auto pr = tryParse(R"( - function @checked(x: number) : number + @checked function(x: number) : number end )", opts); @@ -3144,7 +3162,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_use_@_as_variable_name") { ParseOptions opts; opts.allowDeclarationSyntax = true; - ScopedFastFlag sff{FFlag::LuauCheckedFunctionSyntax, true}; + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; auto pr = tryParse(R"( local @blah = 3 @@ -3156,8 +3174,6 @@ TEST_CASE_FIXTURE(Fixture, "cannot_use_@_as_variable_name") TEST_CASE_FIXTURE(Fixture, "read_write_table_properties") { - ScopedFastFlag sff{FFlag::LuauReadWritePropertySyntax, true}; - auto pr = tryParse(R"( type A = {read x: number} type B = {write x: number} @@ -3177,4 +3193,377 @@ TEST_CASE_FIXTURE(Fixture, "read_write_table_properties") LUAU_ASSERT(pr.errors.size() == 0); } +TEST_CASE_FIXTURE(Fixture, "can_parse_leading_bar_unions_successfully") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + + parse(R"(type A = | "Hello" | "World")"); +} + +TEST_CASE_FIXTURE(Fixture, "can_parse_leading_ampersand_intersections_successfully") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + + parse(R"(type A = & { string } & { number })"); +} + +TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + + matchParseError("type A = & number | string | boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); + matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); +} + +void checkAttribute(const AstAttr* attr, const AstAttr::Type type, const Location& location) +{ + CHECK_EQ(attr->type, type); + CHECK_EQ(attr->location, location); +} + +void checkFirstErrorForAttributes(const std::vector& errors, const size_t minSize, const Location& location, const std::string& message) +{ + LUAU_ASSERT(minSize >= 1); + + CHECK_GE(errors.size(), minSize); + CHECK_EQ(errors[0].getLocation(), location); + CHECK_EQ(errors[0].getMessage(), message); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_function_stat") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + AstStatBlock* stat = parse(R"( +@checked +function hello(x, y) + return x + y +end)"); + + LUAU_ASSERT(stat != nullptr); + + AstStatFunction* statFun = stat->body.data[0]->as(); + LUAU_ASSERT(statFun != nullptr); + + AstArray attributes = statFun->func->attributes; + + CHECK_EQ(attributes.size, 1); + + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 0), Position(1, 8))); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attribute_for_function_expression") +{ + ScopedFastFlag sff[] = {{FFlag::LuauAttributeSyntax, true}, {FFlag::LuauAttributeSyntaxFunExpr, true}}; + + AstStatBlock* stat1 = parse(R"( +local function invoker(f) + return f(1) +end + +invoker(@checked function(x) return (x + 2) end) +)"); + + LUAU_ASSERT(stat1 != nullptr); + + AstExprFunction* func1 = stat1->body.data[1]->as()->expr->as()->args.data[0]->as(); + LUAU_ASSERT(func1 != nullptr); + + AstArray attributes1 = func1->attributes; + + CHECK_EQ(attributes1.size, 1); + + checkAttribute(attributes1.data[0], AstAttr::Type::Checked, Location(Position(5, 8), Position(5, 16))); + + AstStatBlock* stat2 = parse(R"( +local f = @checked function(x) return (x + 2) end +)"); + + LUAU_ASSERT(stat2 != nullptr); + + AstExprFunction* func2 = stat2->body.data[0]->as()->values.data[0]->as(); + LUAU_ASSERT(func2 != nullptr); + + AstArray attributes2 = func2->attributes; + + CHECK_EQ(attributes2.size, 1); + + checkAttribute(attributes2.data[0], AstAttr::Type::Checked, Location(Position(1, 10), Position(1, 18))); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_local_function_stat") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + AstStatBlock* stat = parse(R"( + @checked +local function hello(x, y) + return x + y +end)"); + + LUAU_ASSERT(stat != nullptr); + + AstStatLocalFunction* statFun = stat->body.data[0]->as(); + LUAU_ASSERT(statFun != nullptr); + + AstArray attributes = statFun->func->attributes; + + CHECK_EQ(attributes.size, 1); + + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 4), Position(1, 12))); +} + +TEST_CASE_FIXTURE(Fixture, "empty_attribute_name_is_not_allowed") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseResult result = tryParse(R"( +@ +function hello(x, y) + return x + y +end)"); + + checkFirstErrorForAttributes(result.errors, 1, Location(Position(1, 0), Position(1, 1)), "Attribute name is missing"); +} + +TEST_CASE_FIXTURE(Fixture, "dont_parse_attributes_on_non_function_stat") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseResult pr1 = tryParse(R"( +@checked +if a<0 then a = 0 end)"); + checkFirstErrorForAttributes(pr1.errors, 1, Location(Position(2, 0), Position(2, 2)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'if' intead"); + + ParseResult pr2 = tryParse(R"( +local i = 1 +@checked +while a[i] do + print(a[i]) + i = i + 1 +end)"); + checkFirstErrorForAttributes(pr2.errors, 1, Location(Position(3, 0), Position(3, 5)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'while' intead"); + + ParseResult pr3 = tryParse(R"( +@checked +do + local a2 = 2*a + local d = sqrt(b^2 - 4*a*c) + x1 = (-b + d)/a2 + x2 = (-b - d)/a2 +end)"); + checkFirstErrorForAttributes(pr3.errors, 1, Location(Position(2, 0), Position(2, 2)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'do' intead"); + + ParseResult pr4 = tryParse(R"( +@checked +for i=1,10 do print(i) end +)"); + checkFirstErrorForAttributes(pr4.errors, 1, Location(Position(2, 0), Position(2, 3)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'for' intead"); + + ParseResult pr5 = tryParse(R"( +@checked +repeat + line = io.read() +until line ~= "" +)"); + checkFirstErrorForAttributes(pr5.errors, 1, Location(Position(2, 0), Position(2, 6)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'repeat' intead"); + + + ParseResult pr6 = tryParse(R"( +@checked +local x = 10 +)"); + checkFirstErrorForAttributes( + pr6.errors, 1, Location(Position(2, 6), Position(2, 7)), "Expected 'function' after local declaration with attribute, but got 'x' intead"); + + ParseResult pr7 = tryParse(R"( +local i = 1 +while a[i] do + if a[i] == v then @checked break end + i = i + 1 +end +)"); + checkFirstErrorForAttributes(pr7.errors, 1, Location(Position(3, 31), Position(3, 36)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'break' intead"); + + + ParseResult pr8 = tryParse(R"( +function foo1 () @checked return 'a' end +)"); + checkFirstErrorForAttributes(pr8.errors, 1, Location(Position(1, 26), Position(1, 32)), + "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got 'return' intead"); +} + +TEST_CASE_FIXTURE(Fixture, "dont_parse_attribute_on_argument_non_function") +{ + ScopedFastFlag sff[] = {{FFlag::LuauAttributeSyntax, true}, {FFlag::LuauAttributeSyntaxFunExpr, true}}; + + ParseResult pr = tryParse(R"( +local function invoker(f, y) + return f(y) +end + +invoker(function(x) return (x + 2) end, @checked 1) +)"); + + checkFirstErrorForAttributes( + pr.errors, 1, Location(Position(5, 40), Position(5, 48)), "Expected 'function' declaration after attribute, but got '1' intead"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attribute_on_function_type_declaration") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + std::string src = R"( +@checked declare function abs(n: number): number +)"; + + ParseResult pr = tryParse(src, opts); + CHECK_EQ(pr.errors.size(), 0); + + LUAU_ASSERT(pr.root->body.size == 1); + + AstStat* root = *(pr.root->body.data); + + auto func = root->as(); + LUAU_ASSERT(func != nullptr); + + CHECK(func->isCheckedFunction()); + + AstArray attributes = func->attributes; + + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(1, 0), Position(1, 8))); +} + +TEST_CASE_FIXTURE(Fixture, "parse_attributes_on_function_type_declaration_in_table") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + std::string src = R"( +declare bit32: { + band: @checked (...number) -> number +})"; + + ParseResult pr = tryParse(src, opts); + CHECK_EQ(pr.errors.size(), 0); + + LUAU_ASSERT(pr.root->body.size == 1); + + AstStat* root = *(pr.root->body.data); + + AstStatDeclareGlobal* glob = root->as(); + LUAU_ASSERT(glob); + + auto tbl = glob->type->as(); + LUAU_ASSERT(tbl); + + LUAU_ASSERT(tbl->props.size == 1); + AstTableProp prop = tbl->props.data[0]; + + AstTypeFunction* func = prop.type->as(); + LUAU_ASSERT(func); + + AstArray attributes = func->attributes; + + CHECK_EQ(attributes.size, 1); + checkAttribute(attributes.data[0], AstAttr::Type::Checked, Location(Position(2, 10), Position(2, 18))); +} + +TEST_CASE_FIXTURE(Fixture, "dont_parse_attributes_on_non_function_type_declarations") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseOptions opts; + opts.allowDeclarationSyntax = true; + + ParseResult pr1 = tryParse(R"( +@checked declare foo: number + )", + opts); + + checkFirstErrorForAttributes( + pr1.errors, 1, Location(Position(1, 17), Position(1, 20)), "Expected a function type declaration after attribute, but got 'foo' intead"); + + ParseResult pr2 = tryParse(R"( +@checked declare class Foo + prop: number + function method(self, foo: number): string +end)", + opts); + + checkFirstErrorForAttributes( + pr2.errors, 1, Location(Position(1, 17), Position(1, 22)), "Expected a function type declaration after attribute, but got 'class' intead"); + + ParseResult pr3 = tryParse(R"( +declare bit32: { + band: @checked number +})", + opts); + + checkFirstErrorForAttributes( + pr3.errors, 1, Location(Position(2, 19), Position(2, 25)), "Expected '(' when parsing function parameters, got 'number'"); +} + +TEST_CASE_FIXTURE(Fixture, "attributes_cannot_be_duplicated") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseResult result = tryParse(R"( +@checked + @checked +function hello(x, y) + return x + y +end)"); + + checkFirstErrorForAttributes(result.errors, 1, Location(Position(2, 4), Position(2, 12)), "Cannot duplicate attribute '@checked'"); +} + +TEST_CASE_FIXTURE(Fixture, "unsupported_attributes_are_not_allowed") +{ + ScopedFastFlag luauAttributeSyntax{FFlag::LuauAttributeSyntax, true}; + + ParseResult result = tryParse(R"( +@checked + @cool_attribute +function hello(x, y) + return x + y +end)"); + + checkFirstErrorForAttributes(result.errors, 1, Location(Position(2, 4), Position(2, 19)), "Invalid attribute '@cool_attribute'"); +} + +TEST_CASE_FIXTURE(Fixture, "can_parse_leading_bar_unions_successfully") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + + parse(R"(type A = | "Hello" | "World")"); +} + +TEST_CASE_FIXTURE(Fixture, "can_parse_leading_ampersand_intersections_successfully") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + + parse(R"(type A = & { string } & { number })"); +} + +TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + + matchParseError("type A = & number | string | boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); + matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); +} + + TEST_SUITE_END(); diff --git a/tests/Repl.test.cpp b/tests/Repl.test.cpp index c22d464e..3eceea17 100644 --- a/tests/Repl.test.cpp +++ b/tests/Repl.test.cpp @@ -420,4 +420,22 @@ print(NewProxyOne.HelloICauseACrash) )"); } +TEST_CASE_FIXTURE(ReplFixture, "InteractiveStackReserve1") +{ + // Reset stack reservation + lua_resume(L, nullptr, 0); + + runCode(L, R"( +local t = {} +)"); +} + +TEST_CASE_FIXTURE(ReplFixture, "InteractiveStackReserve2") +{ + // Reset stack reservation + lua_resume(L, nullptr, 0); + + getCompletionSet("a"); +} + TEST_SUITE_END(); diff --git a/tests/Set.test.cpp b/tests/Set.test.cpp index 94de4f01..b3824bf1 100644 --- a/tests/Set.test.cpp +++ b/tests/Set.test.cpp @@ -7,8 +7,6 @@ #include #include -LUAU_FASTFLAG(LuauFixSetIter); - TEST_SUITE_BEGIN("SetTests"); TEST_CASE("empty_set_size_0") @@ -107,8 +105,6 @@ TEST_CASE("iterate_over_set_skips_erased_elements") TEST_CASE("iterate_over_set_skips_first_element_if_it_is_erased") { - ScopedFastFlag sff{FFlag::LuauFixSetIter, true}; - /* * As of this writing, in the following set, the key "y" happens to occur * before "x" in the underlying DenseHashSet. This is important because it diff --git a/tests/SharedCodeAllocator.test.cpp b/tests/SharedCodeAllocator.test.cpp index 0b142930..bba8daad 100644 --- a/tests/SharedCodeAllocator.test.cpp +++ b/tests/SharedCodeAllocator.test.cpp @@ -15,8 +15,6 @@ #pragma GCC diagnostic ignored "-Wself-assign-overloaded" #endif -LUAU_FASTFLAG(LuauCodegenContext) - using namespace Luau::CodeGen; @@ -32,8 +30,6 @@ TEST_CASE("NativeModuleRefRefcounting") if (!luau_codegen_supported()) return; - ScopedFastFlag luauCodegenContext{FFlag::LuauCodegenContext, true}; - CodeAllocator codeAllocator{kBlockSize, kMaxTotalSize}; SharedCodeAllocator allocator{&codeAllocator}; @@ -250,8 +246,6 @@ TEST_CASE("NativeProtoRefcounting") if (!luau_codegen_supported()) return; - ScopedFastFlag luauCodegenContext{FFlag::LuauCodegenContext, true}; - CodeAllocator codeAllocator{kBlockSize, kMaxTotalSize}; SharedCodeAllocator allocator{&codeAllocator}; @@ -303,8 +297,6 @@ TEST_CASE("NativeProtoState") if (!luau_codegen_supported()) return; - ScopedFastFlag luauCodegenContext{FFlag::LuauCodegenContext, true}; - CodeAllocator codeAllocator{kBlockSize, kMaxTotalSize}; SharedCodeAllocator allocator{&codeAllocator}; @@ -364,8 +356,6 @@ TEST_CASE("AnonymousModuleLifetime") if (!luau_codegen_supported()) return; - ScopedFastFlag luauCodegenContext{FFlag::LuauCodegenContext, true}; - CodeAllocator codeAllocator{kBlockSize, kMaxTotalSize}; SharedCodeAllocator allocator{&codeAllocator}; @@ -413,8 +403,6 @@ TEST_CASE("SharedAllocation") if (!luau_codegen_supported()) return; - ScopedFastFlag luauCodegenContext{FFlag::LuauCodegenContext, true}; - UniqueSharedCodeGenContext sharedCodeGenContext = createSharedCodeGenContext(); std::unique_ptr L1{luaL_newstate(), lua_close}; @@ -438,10 +426,13 @@ TEST_CASE("SharedAllocation") const ModuleId moduleId = {0x01}; + CompilationOptions options; + options.flags = CodeGen_ColdFunctions; + CompilationStats nativeStats1 = {}; CompilationStats nativeStats2 = {}; - const CompilationResult codeGenResult1 = Luau::CodeGen::compile(moduleId, L1.get(), -1, CodeGen_ColdFunctions, &nativeStats1); - const CompilationResult codeGenResult2 = Luau::CodeGen::compile(moduleId, L2.get(), -1, CodeGen_ColdFunctions, &nativeStats2); + const CompilationResult codeGenResult1 = Luau::CodeGen::compile(moduleId, L1.get(), -1, options, &nativeStats1); + const CompilationResult codeGenResult2 = Luau::CodeGen::compile(moduleId, L2.get(), -1, options, &nativeStats2); REQUIRE(codeGenResult1.result == CodeGenCompilationResult::Success); REQUIRE(codeGenResult2.result == CodeGenCompilationResult::Success); diff --git a/tests/Simplify.test.cpp b/tests/Simplify.test.cpp index ddddbe67..b938b5f8 100644 --- a/tests/Simplify.test.cpp +++ b/tests/Simplify.test.cpp @@ -214,6 +214,14 @@ TEST_CASE_FIXTURE(SimplifyFixture, "any_and_indeterminate_types") CHECK(errorTy == anyLhsPending->options[1]); } +TEST_CASE_FIXTURE(SimplifyFixture, "union_where_lhs_elements_are_a_subset_of_the_rhs") +{ + TypeId lhs = union_(numberTy, stringTy); + TypeId rhs = union_(stringTy, numberTy); + + CHECK("number | string" == toString(union_(lhs, rhs))); +} + TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_indeterminate_types") { CHECK(freeTy == intersect(unknownTy, freeTy)); diff --git a/tests/Subtyping.test.cpp b/tests/Subtyping.test.cpp index 797ef389..44bf26d2 100644 --- a/tests/Subtyping.test.cpp +++ b/tests/Subtyping.test.cpp @@ -915,6 +915,7 @@ TEST_IS_SUBTYPE(numberToNumberType, negate(builtinTypes->classType)); TEST_IS_NOT_SUBTYPE(numberToNumberType, negate(builtinTypes->functionType)); // Negated supertypes: Primitives and singletons +TEST_IS_NOT_SUBTYPE(builtinTypes->stringType, negate(builtinTypes->stringType)); TEST_IS_SUBTYPE(builtinTypes->stringType, negate(builtinTypes->numberType)); TEST_IS_SUBTYPE(str("foo"), meet(builtinTypes->stringType, negate(str("bar")))); TEST_IS_NOT_SUBTYPE(builtinTypes->trueType, negate(builtinTypes->booleanType)); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 4789a810..17faa2e7 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -12,8 +12,8 @@ using namespace Luau; LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAG(LuauCheckedFunctionSyntax); LUAU_FASTFLAG(DebugLuauSharedSelf); +LUAU_FASTFLAG(LuauAttributeSyntax); TEST_SUITE_BEGIN("ToString"); @@ -354,21 +354,24 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_type_when_length_is_exceeded") function f2(f) return f or f1 end function f3(f) return f or f2 end )"); - LUAU_REQUIRE_NO_ERRORS(result); - - ToStringOptions o; - o.exhaustive = false; - if (FFlag::DebugLuauDeferredConstraintResolution) { - o.maxTypeLength = 30; + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions o; + o.exhaustive = false; + o.maxTypeLength = 20; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ())... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ())... *TRUNCATED*"); } else { + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions o; + o.exhaustive = false; o.maxTypeLength = 40; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); @@ -385,20 +388,25 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") function f2(f) return f or f1 end function f3(f) return f or f2 end )"); - LUAU_REQUIRE_NO_ERRORS(result); - ToStringOptions o; - o.exhaustive = true; if (FFlag::DebugLuauDeferredConstraintResolution) { - o.maxTypeLength = 30; + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions o; + o.exhaustive = true; + o.maxTypeLength = 20; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ())... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ())... *TRUNCATED*"); } else { + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions o; + o.exhaustive = true; o.maxTypeLength = 40; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); @@ -741,7 +749,10 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") TypeId ty = requireType("map"); const FunctionType* ftv = get(follow(ty)); - CHECK_EQ("map(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv)); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("map(arr: {a}, fn: (a) -> (b, ...unknown)): {b}", toStringNamedFunction("map", *ftv)); + else + CHECK_EQ("map(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv)); } TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") @@ -965,12 +976,12 @@ Type 'string' could not be converted into 'number' in an invariant context)"; TEST_CASE_FIXTURE(Fixture, "checked_fn_toString") { ScopedFastFlag flags[] = { - {FFlag::LuauCheckedFunctionSyntax, true}, {FFlag::DebugLuauDeferredConstraintResolution, true}, + {FFlag::LuauAttributeSyntax, true}, }; auto _result = loadDefinition(R"( -declare function @checked abs(n: number) : number +@checked declare function abs(n: number) : number )"); auto result = check(Mode::Nonstrict, R"( diff --git a/tests/TypeFamily.test.cpp b/tests/TypeFamily.test.cpp index c5b3e053..068e8684 100644 --- a/tests/TypeFamily.test.cpp +++ b/tests/TypeFamily.test.cpp @@ -3,7 +3,6 @@ #include "Luau/ConstraintSolver.h" #include "Luau/NotNull.h" -#include "Luau/TxnLog.h" #include "Luau/Type.h" #include "ClassFixture.h" @@ -14,6 +13,7 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_DYNAMIC_FASTINT(LuauTypeFamilyApplicationCartesianProductLimit) struct FamilyFixture : Fixture { @@ -24,7 +24,7 @@ struct FamilyFixture : Fixture { swapFamily = TypeFamily{/* name */ "Swap", /* reducer */ - [](TypeId instance, NotNull queue, const std::vector& tys, const std::vector& tps, + [](TypeId instance, const std::vector& tys, const std::vector& tps, NotNull ctx) -> TypeFamilyReductionResult { LUAU_ASSERT(tys.size() == 1); TypeId param = follow(tys.at(0)); @@ -167,15 +167,13 @@ TEST_CASE_FIXTURE(FamilyFixture, "table_internal_families") LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK(toString(requireType("a")) == "{string}"); CHECK(toString(requireType("b")) == "{number}"); - CHECK(toString(requireType("c")) == "{Swap}"); - CHECK(toString(result.errors[0]) == "Type family instance Swap is uninhabited"); + // FIXME: table types are constructing a trivial union here. + CHECK(toString(requireType("c")) == "{Swap}"); + CHECK(toString(result.errors[0]) == "Type family instance Swap is uninhabited"); } TEST_CASE_FIXTURE(FamilyFixture, "function_internal_families") { - // This test is broken right now, but it's not because of type families. See - // CLI-71143. - if (!FFlag::DebugLuauDeferredConstraintResolution) return; @@ -391,8 +389,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "keyof_type_family_errors_if_it_has_nontable_ // FIXME(CLI-95289): we should actually only report the type family being uninhabited error at its first use, I think? LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK(toString(result.errors[0]) == "Type family instance keyof is uninhabited"); - CHECK(toString(result.errors[1]) == "Type family instance keyof is uninhabited"); + CHECK(toString(result.errors[0]) == "Type 'MyObject | boolean' does not have keys, so 'keyof' is invalid"); + CHECK(toString(result.errors[1]) == "Type 'MyObject | boolean' does not have keys, so 'keyof' is invalid"); } TEST_CASE_FIXTURE(BuiltinsFixture, "keyof_type_family_string_indexer") @@ -517,8 +515,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "rawkeyof_type_family_errors_if_it_has_nontab // FIXME(CLI-95289): we should actually only report the type family being uninhabited error at its first use, I think? LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK(toString(result.errors[0]) == "Type family instance rawkeyof is uninhabited"); - CHECK(toString(result.errors[1]) == "Type family instance rawkeyof is uninhabited"); + CHECK(toString(result.errors[0]) == "Type 'MyObject | boolean' does not have keys, so 'rawkeyof' is invalid"); + CHECK(toString(result.errors[1]) == "Type 'MyObject | boolean' does not have keys, so 'rawkeyof' is invalid"); } TEST_CASE_FIXTURE(BuiltinsFixture, "rawkeyof_type_family_common_subset_if_union_of_differing_tables") @@ -590,8 +588,8 @@ TEST_CASE_FIXTURE(ClassFixture, "keyof_type_family_errors_if_it_has_nonclass_par // FIXME(CLI-95289): we should actually only report the type family being uninhabited error at its first use, I think? LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK(toString(result.errors[0]) == "Type family instance keyof is uninhabited"); - CHECK(toString(result.errors[1]) == "Type family instance keyof is uninhabited"); + CHECK(toString(result.errors[0]) == "Type 'BaseClass | boolean' does not have keys, so 'keyof' is invalid"); + CHECK(toString(result.errors[1]) == "Type 'BaseClass | boolean' does not have keys, so 'keyof' is invalid"); } TEST_CASE_FIXTURE(ClassFixture, "keyof_type_family_common_subset_if_union_of_differing_classes") @@ -701,4 +699,468 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "keyof_oss_crash_gh1161") CHECK(get(result.errors[0])); } -TEST_SUITE_END(); +TEST_CASE_FIXTURE(FamilyFixture, "fuzzer_numeric_binop_doesnt_assert_on_generalizeFreeType") +{ + CheckResult result = check(R"( +Module 'l0': +local _ = (67108864)(_ >= _).insert +do end +do end +_(...,_(_,_(_()),_())) +(67108864)()() +_(_ ~= _ // _,l0)(_(_({n0,})),_(_),_) +_(setmetatable(_,{[...]=_,})) + +)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "cyclic_concat_family_at_work") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type T = concat + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireTypeAlias("T")) == "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "exceeded_distributivity_limits") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + ScopedFastInt sfi{DFInt::LuauTypeFamilyApplicationCartesianProductLimit, 10}; + + loadDefinition(R"( + declare class A + function __mul(self, rhs: unknown): A + end + + declare class B + function __mul(self, rhs: unknown): B + end + + declare class C + function __mul(self, rhs: unknown): C + end + + declare class D + function __mul(self, rhs: unknown): D + end + )"); + + CheckResult result = check(R"( + type T = mul + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "didnt_quite_exceed_distributivity_limits") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + // We duplicate the test here because we want to make sure the test failed + // due to exceeding the limits specifically, rather than any possible reasons. + ScopedFastInt sfi{DFInt::LuauTypeFamilyApplicationCartesianProductLimit, 20}; + + loadDefinition(R"( + declare class A + function __mul(self, rhs: unknown): A + end + + declare class B + function __mul(self, rhs: unknown): B + end + + declare class C + function __mul(self, rhs: unknown): C + end + + declare class D + function __mul(self, rhs: unknown): D + end + )"); + + CheckResult result = check(R"( + type T = mul + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "ensure_equivalence_with_distributivity") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + loadDefinition(R"( + declare class A + function __mul(self, rhs: unknown): A + end + + declare class B + function __mul(self, rhs: unknown): B + end + + declare class C + function __mul(self, rhs: unknown): C + end + + declare class D + function __mul(self, rhs: unknown): D + end + )"); + + CheckResult result = check(R"( + type T = mul + type U = mul | mul | mul | mul + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireTypeAlias("T")) == "A | B"); + CHECK(toString(requireTypeAlias("U")) == "A | A | B | B"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "we_shouldnt_warn_that_a_reducible_type_family_is_uninhabited") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + +local Debounce = false +local Active = false + +local function Use(Mode) + + if Mode ~= nil then + + if Mode == false and Active == false then + return + else + Active = not Mode + end + + Debounce = false + end + Active = not Active + +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type IdxAType = index + type IdxBType = index> + + local function ok(idx: IdxAType): string return idx end + local function ok2(idx: IdxBType): string | number | boolean return idx end + local function err(idx: IdxAType): boolean return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK_EQ("boolean", toString(tpm->wantedTp)); + CHECK_EQ("string", toString(tpm->givenTp)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works_w_array") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local MyObject = {"hello", 1, true} + type IdxAType = index + + local function ok(idx: IdxAType): string | number | boolean return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works_w_generic_types") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local function access(tbl: T & {}, key: K): index + return tbl[key] + end + + local subjects = { + english = "boring", + math = "fun" + } + + local key: "english" = "english" + local a: string = access(subjects, key) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_errors_w_bad_indexer") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type errType1 = index + type errType2 = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Property '\"d\"' does not exist on type 'MyObject'"); + CHECK(toString(result.errors[1]) == "Property 'boolean' does not exist on type 'MyObject'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_errors_w_var_indexer") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + local key = "a" + + type errType1 = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Second argument to index is not a valid index type"); + CHECK(toString(result.errors[1]) == "Unknown type 'key'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works_w_union_type_indexer") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + + type idxType = index + local function ok(idx: idxType): string | number return idx end + + type errType = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"a\" | \"d\"' does not exist on type 'MyObject'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works_w_union_type_indexee") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type MyObject2 = {a: number} + + type idxTypeA = index + local function ok(idx: idxTypeA): string | number return idx end + + type errType = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"b\"' does not exist on type 'MyObject | MyObject2'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_rfc_alternative_section") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string} + type MyObject2 = {a: string, b: number} + + local function edgeCase(param: MyObject) + type unknownType = index + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"b\"' does not exist on type 'MyObject'"); +} + +TEST_CASE_FIXTURE(ClassFixture, "index_type_family_works_on_classes") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type KeysOfMyObject = index + + local function ok(idx: KeysOfMyObject): number return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_family_works_w_index_metatables") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local exampleClass = { Foo = "text", Bar = true } + + local exampleClass2 = setmetatable({ Foo = 8 }, { __index = exampleClass }) + type exampleTy2 = index + local function ok(idx: exampleTy2): number return idx end + + local exampleClass3 = setmetatable({ Bar = 5 }, { __index = exampleClass }) + type exampleTy3 = index + local function ok2(idx: exampleTy3): string return idx end + + type exampleTy4 = index + local function ok3(idx: exampleTy4): string | number return idx end + + type errTy = index + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"Car\"' does not exist on type 'exampleClass2'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawget_type_family_works") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type RawAType = rawget + type RawBType = rawget> + local function ok(idx: RawAType): string return idx end + local function ok2(idx: RawBType): string | number | boolean return idx end + local function err(idx: RawAType): boolean return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK_EQ("boolean", toString(tpm->wantedTp)); + CHECK_EQ("string", toString(tpm->givenTp)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawget_type_family_works_w_array") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local MyObject = {"hello", 1, true} + type RawAType = rawget + local function ok(idx: RawAType): string | number | boolean return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawget_type_family_errors_w_var_indexer") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + local key = "a" + type errType1 = rawget + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Second argument to rawget is not a valid index type"); + CHECK(toString(result.errors[1]) == "Unknown type 'key'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawget_type_family_works_w_union_type_indexer") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type rawType = rawget + local function ok(idx: rawType): string | number return idx end + type errType = rawget + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"a\" | \"d\"' does not exist on type 'MyObject'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawget_type_family_works_w_union_type_indexee") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type MyObject = {a: string, b: number, c: boolean} + type MyObject2 = {a: number} + type rawTypeA = rawget + local function ok(idx: rawTypeA): string | number return idx end + type errType = rawget + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"b\"' does not exist on type 'MyObject | MyObject2'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "rawget_type_family_works_w_index_metatables") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local exampleClass = { Foo = "text", Bar = true } + local exampleClass2 = setmetatable({ Foo = 8 }, { __index = exampleClass }) + type exampleTy2 = rawget + local function ok(idx: exampleTy2): number return idx end + local exampleClass3 = setmetatable({ Bar = 5 }, { __index = exampleClass }) + type errType = rawget + type errType2 = rawget + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Property '\"Foo\"' does not exist on type 'exampleClass3'"); + CHECK(toString(result.errors[1]) == "Property '\"Bar\" | \"Foo\"' does not exist on type 'exampleClass3'"); +} + +TEST_CASE_FIXTURE(ClassFixture, "rawget_type_family_errors_w_classes") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type PropsOfMyObject = rawget + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Property '\"BaseField\"' does not exist on type 'BaseClass'"); +} + +TEST_SUITE_END(); \ No newline at end of file diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 06e698a8..54cf1cef 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -9,7 +9,6 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauSharedSelf); -LUAU_FASTFLAG(LuauForbidAliasNamedTypeof); TEST_SUITE_BEGIN("TypeAliases"); @@ -1065,8 +1064,6 @@ TEST_CASE_FIXTURE(Fixture, "table_types_record_the_property_locations") TEST_CASE_FIXTURE(Fixture, "typeof_is_not_a_valid_alias_name") { - ScopedFastFlag sff{FFlag::LuauForbidAliasNamedTypeof, true}; - CheckResult result = check(R"( type typeof = number )"); diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index 8d14f56b..c532c069 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -32,15 +32,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - // Bug: We do not simplify at the right time - CHECK_EQ("any?", toString(requireType("a"))); - } - else - { - CHECK_EQ(builtinTypes->anyType, requireType("a")); - } + CHECK(builtinTypes->anyType == requireType("a")); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") @@ -58,15 +50,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - // Bug: We do not simplify at the right time - CHECK_EQ("any?", toString(requireType("a"))); - } - else - { - CHECK_EQ("any", toString(requireType("a"))); - } + CHECK("any" == toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") @@ -82,15 +66,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - // Bug: We do not simplify at the right time - CHECK_EQ("any?", toString(requireType("a"))); - } - else - { - CHECK_EQ("any", toString(requireType("a"))); - } + CHECK("any" == toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") @@ -104,17 +80,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") end )"); - LUAU_REQUIRE_NO_ERRORS(result); - - if (FFlag::DebugLuauDeferredConstraintResolution) - { - // Bug: We do not simplify at the right time - CHECK_EQ("any?", toString(requireType("a"))); - } - else - { - CHECK_EQ("any", toString(requireType("a"))); - } + CHECK("any" == toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any_pack") @@ -130,15 +96,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any_pack") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - // Bug: We do not simplify at the right time - CHECK_EQ("any?", toString(requireType("a"))); - } - else - { - CHECK_EQ("any", toString(requireType("a"))); - } + CHECK("any" == toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") @@ -443,4 +401,13 @@ end CHECK("(any, any) -> any" == toString(requireType("foo"))); } +TEST_CASE_FIXTURE(Fixture, "cast_to_table_of_any") +{ + CheckResult result = check(R"( + local v = {true} :: {any} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 3c3af65f..90271e29 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -7,6 +7,7 @@ #include "Fixture.h" #include "ClassFixture.h" +#include "ScopedFlags.h" #include "doctest.h" using namespace Luau; @@ -17,6 +18,39 @@ LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls); TEST_SUITE_BEGIN("TypeInferClasses"); +TEST_CASE_FIXTURE(ClassFixture, "Luau.Analyze.CLI_crashes_on_this_test") +{ + CheckResult result = check(R"( + local CircularQueue = {} +CircularQueue.__index = CircularQueue + +function CircularQueue:new() + local newCircularQueue = { + head = nil, + } + setmetatable(newCircularQueue, CircularQueue) + + return newCircularQueue +end + +function CircularQueue:push() + local newListNode + + if self.head then + newListNode = { + prevNode = self.head.prevNode, + nextNode = self.head, + } + newListNode.prevNode.nextNode = newListNode + newListNode.nextNode.prevNode = newListNode + end +end + +return CircularQueue + + )"); +} + TEST_CASE_FIXTURE(ClassFixture, "call_method_of_a_class") { CheckResult result = check(R"( @@ -474,6 +508,31 @@ Type 'ChildClass' could not be converted into 'BaseClass' in an invariant contex } } +TEST_CASE_FIXTURE(ClassFixture, "optional_class_casts_work_in_new_solver") +{ + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; + + CheckResult result = check(R"( + type A = { x: ChildClass } + type B = { x: BaseClass } + + local a = { x = ChildClass.New() } :: A + local opt_a = a :: A? + local b = { x = BaseClass.New() } :: B + local opt_b = b :: B? + local b_from_a = a :: B + local b_from_opt_a = opt_a :: B + local opt_b_from_a = a :: B? + local opt_b_from_opt_a = opt_a :: B? + local a_from_b = b :: A + local a_from_opt_b = opt_b :: A + local opt_a_from_b = b :: A? + local opt_a_from_opt_b = opt_b :: A? + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(ClassFixture, "callable_classes") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index c57eab79..688f27b7 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -7,6 +7,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauDeclarationExtraPropData) + using namespace Luau; TEST_SUITE_BEGIN("DefinitionTests"); @@ -319,6 +321,8 @@ TEST_CASE_FIXTURE(Fixture, "definitions_documentation_symbols") TEST_CASE_FIXTURE(Fixture, "definitions_symbols_are_generated_for_recursively_referenced_types") { + ScopedFastFlag luauDeclarationExtraPropData{FFlag::LuauDeclarationExtraPropData, true}; + loadDefinition(R"( declare class MyClass function myMethod(self) @@ -330,6 +334,22 @@ TEST_CASE_FIXTURE(Fixture, "definitions_symbols_are_generated_for_recursively_re std::optional myClassTy = frontend.globals.globalScope->lookupType("MyClass"); REQUIRE(bool(myClassTy)); CHECK_EQ(myClassTy->type->documentationSymbol, "@test/globaltype/MyClass"); + + ClassType* cls = getMutable(myClassTy->type); + REQUIRE(bool(cls)); + REQUIRE_EQ(cls->props.count("myMethod"), 1); + + const auto& method = cls->props["myMethod"]; + CHECK_EQ(method.documentationSymbol, "@test/globaltype/MyClass.myMethod"); + + FunctionType* function = getMutable(method.type()); + REQUIRE(function); + + REQUIRE(function->definition.has_value()); + CHECK(function->definition->definitionModuleName == "@test"); + CHECK(function->definition->definitionLocation == Location({2, 12}, {2, 35})); + CHECK(!function->definition->varargLocation.has_value()); + CHECK(function->definition->originalNameLocation == Location({2, 21}, {2, 29})); } TEST_CASE_FIXTURE(Fixture, "documentation_symbols_dont_attach_to_persistent_types") diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 4fb3d58b..410a9859 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1582,7 +1582,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_th if (!result.errors.empty()) { for (const auto& e : result.errors) - printf("%s %s: %s\n", e.moduleName.c_str(), toString(e.location).c_str(), toString(e).c_str()); + MESSAGE(e.moduleName << " " << toString(e.location) << ": " << toString(e)); } } @@ -2298,10 +2298,14 @@ end if (FFlag::DebugLuauDeferredConstraintResolution) { LUAU_REQUIRE_ERROR_COUNT(4, result); - CHECK(toString(result.errors[0]) == "Type family instance sub is uninhabited"); - CHECK(toString(result.errors[1]) == "Type family instance sub is uninhabited"); - CHECK(toString(result.errors[2]) == "Type family instance sub is uninhabited"); - CHECK(toString(result.errors[3]) == "Type family instance sub is uninhabited"); + CHECK(toString(result.errors[0]) == + "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub"); + CHECK(toString(result.errors[1]) == + "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub"); + CHECK(toString(result.errors[2]) == + "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub"); + CHECK(toString(result.errors[3]) == + "Operator '-' could not be applied to operands of types unknown and number; there is no corresponding overload for __sub"); } else { @@ -2351,8 +2355,9 @@ end LUAU_REQUIRE_ERRORS(result); auto err = get(result.errors.back()); LUAU_ASSERT(err); - CHECK("false | number" == toString(err->recommendedReturn)); - CHECK(err->recommendedArgs.size() == 0); + CHECK("number" == toString(err->recommendedReturn)); + REQUIRE(1 == err->recommendedArgs.size()); + CHECK("number" == toString(err->recommendedArgs[0].second)); } TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_arg_type") @@ -2374,6 +2379,28 @@ end CHECK("number" == toString(err->recommendedArgs[1].second)); } +TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_arg_type_2") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + // Make sure the error types are cloned to module interface + frontend.options.retainFullTypeGraphs = false; + + CheckResult result = check(R"( +local function escape_fslash(pre) + return (#pre % 2 == 0 and '\\' or '') .. pre .. '.' +end +)"); + + LUAU_REQUIRE_ERRORS(result); + auto err = get(result.errors.back()); + LUAU_ASSERT(err); + CHECK("unknown" == toString(err->recommendedReturn)); + REQUIRE(err->recommendedArgs.size() == 1); + CHECK("a" == toString(err->recommendedArgs[0].second)); +} + TEST_CASE_FIXTURE(Fixture, "local_function_fwd_decl_doesnt_crash") { CheckResult result = check(R"( @@ -2673,4 +2700,52 @@ TEST_CASE_FIXTURE(Fixture, "captured_local_is_assigned_a_function") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "error_suppression_propagates_through_function_calls") +{ + CheckResult result = check(R"( + function first(x: any) + return pairs(x)(x) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("(any) -> (any?, any)" == toString(requireType("first"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzzer_normalizer_out_of_resources") +{ + // This luau code should finish typechecking, not segfault upon dereferencing + // the normalized type + CheckResult result = check(R"( + Module 'l0': +local _ = true,...,_ +if ... then +while _:_(_._G) do +do end +_ = _ and _ +_ = 0 and {# _,} +local _ = "CCCCCCCCCCCCCCCCCCCCCCCCCCC" +local l0 = require(module0) +end +local function l0() +end +elseif _ then +l0 = _ +end +do end +while _ do +_ = if _ then _ elseif _ then _,if _ then _ else _ +_ = _() +do end +do end +if _ then +end +end +_ = _,{} + + )"); +} + + TEST_SUITE_END(); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index d1716f5d..a58fb638 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -433,9 +433,53 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "varlist_declared_by_for_in_loop_should_be_fr end )"); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + auto err = get(result.errors[0]); + CHECK(err != nullptr); + } + else + { + LUAU_REQUIRE_NO_ERRORS(result); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "iter_constraint_before_loop_body") +{ + CheckResult result = check(R"( + local T = { + fields = {}, + } + + function f() + for u, v in pairs(T.fields) do + T.fields[u] = nil + end + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "rbxl_place_file_crash_for_wrong_constraints") +{ + CheckResult result = check(R"( +local VehicleParameters = { + -- These are default values in the case the package structure is broken + StrutSpringStiffnessFront = 28000, +} + +local function updateFromConfiguration() + for property, value in pairs(VehicleParameters) do + VehicleParameters[property] = value + end +end +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + + TEST_CASE_FIXTURE(BuiltinsFixture, "properly_infer_iteratee_is_a_free_table") { // In this case, we cannot know the element type of the table {}. It could be anything. @@ -652,13 +696,7 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") if (FFlag::DebugLuauDeferredConstraintResolution) { TypeId keyTy = requireType("key"); - - const UnionType* ut = get(keyTy); - REQUIRE(ut); - - REQUIRE(ut->options.size() == 2); - CHECK_EQ(builtinTypes->nilType, ut->options[0]); - CHECK_EQ(*builtinTypes->numberType, *ut->options[1]); + CHECK("number?" == toString(keyTy)); } else CHECK_EQ(*builtinTypes->numberType, *requireType("key")); @@ -1010,7 +1048,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties_nonstrict") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(BuiltinsFixture, "pairs_should_not_add_an_indexer") +TEST_CASE_FIXTURE(BuiltinsFixture, "pairs_should_not_retroactively_add_an_indexer") { CheckResult result = check(R"( --!strict @@ -1025,7 +1063,12 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "pairs_should_not_add_an_indexer") )"); if (FFlag::DebugLuauDeferredConstraintResolution) - LUAU_REQUIRE_ERROR_COUNT(2, result); + { + // We regress a little here: The old solver would typecheck the first + // access to prices.wwwww on a table that had no indexer, and the second + // on a table that does. + LUAU_REQUIRE_ERROR_COUNT(0, result); + } else LUAU_REQUIRE_ERROR_COUNT(1, result); } @@ -1114,4 +1157,30 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "forin_metatable_iter_mm") CHECK_EQ("number", toString(requireTypeAtPosition({6, 21}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_preserves_error_suppression") +{ + CheckResult result = check(R"( + function first(x: any) + for k, v in pairs(x) do + print(k, v) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("any" == toString(requireTypeAtPosition({3, 22}))); + CHECK("any" == toString(requireTypeAtPosition({3, 25}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tryDispatchIterableFunction_under_constrained_loop_should_not_assert") +{ + CheckResult result = check(R"( +local function foo(Instance) + for _, Child in next, Instance:GetChildren() do + end +end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 56548608..b8bb9795 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -396,9 +396,17 @@ TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") s += 10 )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); - CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{builtinTypes->stringType, builtinTypes->numberType}})); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); + CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{builtinTypes->stringType, builtinTypes->numberType}})); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable") @@ -423,6 +431,33 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable_with_changing_return_type") +{ + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; + + CheckResult result = check(R"( + --!strict + type T = { x: number } + local MT = {} + + function MT:__add(other): number + return 112 + end + + local t = setmetatable({x = 2}, MT) + local u = t + 3 + t += 3 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + + CHECK("t" == toString(tm->wantedType)); + CHECK("number" == toString(tm->givenType)); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_result_must_be_compatible_with_var") { CheckResult result = check(R"( @@ -732,7 +767,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "and_binexps_dont_unify") end )"); - LUAU_REQUIRE_NO_ERRORS(result); + // This infers a type for `t` of `{unknown}`, and so it makes sense that `t[1].test` would error. + if (FFlag::DebugLuauDeferredConstraintResolution) + LUAU_REQUIRE_ERROR_COUNT(1, result); + else + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators") diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index 640e693b..37f891cb 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -101,4 +101,14 @@ TEST_CASE("singleton_types") CHECK(result.errors.empty()); } +TEST_CASE_FIXTURE(BuiltinsFixture, "property_of_buffers") +{ + CheckResult result = check(R"( + local b = buffer.create(100) + print(b.foo) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 8e81b0cc..3072169c 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -312,7 +312,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "bail_early_if_unification_is_too_complicated } } -// FIXME: Move this test to another source file when removing FFlag::LuauLowerBoundsCalculation TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type_pack") { // In-place quantification causes these types to have the wrong types but only because of nasty interaction with prototyping. @@ -1230,4 +1229,45 @@ TEST_CASE_FIXTURE(Fixture, "table_containing_non_final_type_is_erroneously_cache CHECK(n1 == n2); } +// This is doable with the new solver, but there are some problems we have to work out first. +// CLI-111113 +TEST_CASE_FIXTURE(Fixture, "we_cannot_infer_functions_that_return_inconsistently") +{ + CheckResult result = check(R"( + function find_first(tbl: {T}, el) + for i, e in tbl do + if e == el then + return i + end + end + return nil + end + )"); + +#if 0 + // This #if block describes what should happen. + LUAU_CHECK_NO_ERRORS(result); + + // The second argument has type unknown because the == operator does not + // constrain the type of el. + CHECK("({T}, unknown) -> number?" == toString(requireType("find_first"))); +#else + // This is what actually happens right now. + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_CHECK_ERROR_COUNT(2, result); + + // The second argument should be unknown. CLI-111111 + CHECK("({T}, 'b) -> number" == toString(requireType("find_first"))); + } + else + { + LUAU_CHECK_ERROR_COUNT(1, result); + + CHECK("({T}, b) -> number" == toString(requireType("find_first"))); + } +#endif +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 485a18c6..e089c7be 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -44,6 +44,20 @@ TEST_CASE_FIXTURE(Fixture, "string_singletons") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "string_singleton_function_call") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local x = "a" + function f(x: "a") end + f(x) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "bool_singletons_mismatch") { CheckResult result = check(R"( @@ -562,15 +576,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "singletons_stick_around_under_assignment") local foo = (nil :: any) :: Foo - print(foo.kind == "Bar") -- TypeError: Type "Foo" cannot be compared with "Bar" + print(foo.kind == "Bar") -- type of equality refines to `false` local kind = foo.kind - print(kind == "Bar") -- SHOULD BE: TypeError: Type "Foo" cannot be compared with "Bar" + print(kind == "Bar") -- type of equality refines to `false` )"); - // FIXME: Under the new solver, we get both the errors we expect, but they're - // duplicated because of how we are currently running type family reduction. if (FFlag::DebugLuauDeferredConstraintResolution) - LUAU_REQUIRE_ERROR_COUNT(4, result); + LUAU_REQUIRE_NO_ERRORS(result); else LUAU_REQUIRE_ERROR_COUNT(1, result); } diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 4cc07fba..6f8b4f50 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -15,19 +15,35 @@ using namespace Luau; -LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping); LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls); LUAU_FASTFLAG(LuauFixIndexerSubtypingOrdering); LUAU_FASTFLAG(DebugLuauSharedSelf); -LUAU_FASTFLAG(LuauReadWritePropertySyntax); -LUAU_FASTFLAG(LuauMetatableInstantiationCloneCheck); LUAU_DYNAMIC_FASTFLAG(LuauImproveNonFunctionCallError) TEST_SUITE_BEGIN("TableTests"); +TEST_CASE_FIXTURE(BuiltinsFixture, "generalization_shouldnt_seal_table_in_len_family_fn") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + CheckResult result = check(R"( +local t = {} +for i = #t, 2, -1 do + t[i] = t[i + 1] +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + const TableType* tType = get(requireType("t")); + REQUIRE(tType != nullptr); + REQUIRE(tType->indexer); + CHECK_EQ(tType->indexer->indexType, builtinTypes->numberType); + CHECK_EQ(follow(tType->indexer->indexResultType), builtinTypes->unknownType); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "LUAU_ASSERT_arg_exprs_doesnt_trigger_assert") { CheckResult result = check(R"( @@ -2463,10 +2479,7 @@ local x: {number} | number | string local y = #x )"); - if (FFlag::DebugLuauDeferredConstraintResolution) - LUAU_REQUIRE_ERROR_COUNT(2, result); - else - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); } TEST_CASE_FIXTURE(BuiltinsFixture, "dont_hang_when_trying_to_look_up_in_cyclic_metatable_index") @@ -2729,7 +2742,9 @@ TEST_CASE_FIXTURE(Fixture, "tables_get_names_from_their_locals") TEST_CASE_FIXTURE(Fixture, "should_not_unblock_table_type_twice") { - ScopedFastFlag sff(FFlag::DebugLuauDeferredConstraintResolution, true); + // don't run this when the DCR flag isn't set + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; check(R"( local timer = peek(timerQueue) @@ -2972,7 +2987,7 @@ c = b const TableType* ttv = get(*ty); REQUIRE(ttv); - CHECK(ttv->instantiatedTypeParams.empty()); + CHECK(0 == ttv->instantiatedTypeParams.size()); } TEST_CASE_FIXTURE(Fixture, "table_indexing_error_location") @@ -3155,7 +3170,7 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") LUAU_REQUIRE_ERROR_COUNT(1, result); if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("Value of type '{ x: number? }?' could be nil", toString(result.errors[0])); + CHECK_EQ("Type 'nil' does not have key 'x'", toString(result.errors[0])); else CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); CHECK_EQ("boolean", toString(requireType("u"))); @@ -3241,7 +3256,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_leak_free_table_props") TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table") { ScopedFastFlag sff[] = { - // {FFlag::LuauLowerBoundsCalculation, true}, {FFlag::DebugLuauSharedSelf, true}, }; @@ -4014,7 +4028,6 @@ TEST_CASE_FIXTURE(Fixture, "identify_all_problematic_table_fields") TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported") { ScopedFastFlag sff[] = { - {FFlag::LuauReadWritePropertySyntax, true}, {FFlag::DebugLuauDeferredConstraintResolution, false}, }; @@ -4040,8 +4053,6 @@ TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported TEST_CASE_FIXTURE(Fixture, "read_ond_write_only_indexers_are_unsupported") { - ScopedFastFlag sff{FFlag::LuauReadWritePropertySyntax, true}; - CheckResult result = check(R"( type T = {read [string]: number} type U = {write [string]: boolean} @@ -4155,7 +4166,7 @@ TEST_CASE_FIXTURE(Fixture, "write_annotations_are_unsupported_even_with_the_new_ TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported") { - ScopedFastFlag sff[] = {{FFlag::LuauReadWritePropertySyntax, true}, {FFlag::DebugLuauDeferredConstraintResolution, false}}; + ScopedFastFlag sff[] = {{FFlag::DebugLuauDeferredConstraintResolution, false}}; CheckResult result = check(R"( type W = {read x: number} @@ -4179,7 +4190,7 @@ TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported TEST_CASE_FIXTURE(Fixture, "read_ond_write_only_indexers_are_unsupported") { - ScopedFastFlag sff[] = {{FFlag::LuauReadWritePropertySyntax, true}, {FFlag::DebugLuauDeferredConstraintResolution, false}}; + ScopedFastFlag sff[] = {{FFlag::DebugLuauDeferredConstraintResolution, false}}; CheckResult result = check(R"( type T = {read [string]: number} @@ -4199,7 +4210,7 @@ TEST_CASE_FIXTURE(Fixture, "table_writes_introduce_write_properties") if (!FFlag::DebugLuauDeferredConstraintResolution) return; - ScopedFastFlag sff[] = {{FFlag::LuauReadWritePropertySyntax, true}, {FFlag::DebugLuauDeferredConstraintResolution, true}}; + ScopedFastFlag sff[] = {{FFlag::DebugLuauDeferredConstraintResolution, true}}; CheckResult result = check(R"( function oc(player, speaker) @@ -4351,23 +4362,8 @@ TEST_CASE_FIXTURE(Fixture, "mymovie_read_write_tables_bug_2") LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "setindexer_always_transmute") -{ - ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; - - CheckResult result = check(R"( - function f(x) - (5)[5] = x - end - )"); - - CHECK_EQ("(*error-type*) -> ()", toString(requireType("f"))); -} - TEST_CASE_FIXTURE(BuiltinsFixture, "instantiated_metatable_frozen_table_clone_mutation") { - ScopedFastFlag luauMetatableInstantiationCloneCheck{FFlag::LuauMetatableInstantiationCloneCheck, true}; - fileResolver.source["game/worker"] = R"( type WorkerImpl = { destroy: (self: Worker) -> boolean, @@ -4408,6 +4404,21 @@ TEST_CASE_FIXTURE(Fixture, "setprop_on_a_mutating_local_in_both_loops_and_functi LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "cant_index_this") +{ + CheckResult result = check(R"( + local a: number = 9 + a[18] = "tomfoolery" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + NotATable* notATable = get(result.errors[0]); + REQUIRE(notATable); + + CHECK("number" == toString(notATable->ty)); +} + TEST_CASE_FIXTURE(Fixture, "setindexer_multiple_tables_intersection") { ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; @@ -4419,8 +4430,135 @@ TEST_CASE_FIXTURE(Fixture, "setindexer_multiple_tables_intersection") end )"); + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK("({ [string]: number } & { [thread]: boolean }, never) -> ()" == toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "insert_a_and_f_of_a_into_table_res_in_a_loop") +{ + CheckResult result = check(R"( + local function f(t) + local res = {} + + for k, a in t do + res[k] = f(a) + res[k] = a + end + end + )"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(get(result.errors[0])); + } + else + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_adds_an_unbounded_indexer") +{ + CheckResult result = check(R"( + --!strict + + local a = {} + ipairs(a) + )"); + + // The old solver erroneously leaves a free type dangling here. The new + // solver does better. + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK("{unknown}" == toString(requireType("a"), {true})); + else + CHECK("{a}" == toString(requireType("a"), {true})); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "index_results_compare_to_nil") +{ + CheckResult result = check(R"( + --!strict + + function foo(tbl: {number}) + if tbl[2] == nil then + print("foo") + end + + if tbl[3] ~= nil then + print("bar") + end + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); - CHECK("({ [string]: number } & { [thread]: boolean }, boolean | number) -> ()" == toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzzer_normalization_preserves_tbl_scopes") +{ + CheckResult result = check(R"( +Module 'l0': +do end + +Module 'l1': +local _ = {n0=nil,} +if if nil then _ then +if nil and (_)._ ~= (_)._ then +do end +while _ do +_ = _ +do end +end +end +do end +end +local l0 +while _ do +_ = nil +(_[_])._ %= `{# _}{bit32.extract(# _,1)}` +end + +)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_literal_inference_assert") +{ + CheckResult result = check(R"( + local buttons = { + buttons = {}; + } + + buttons.Button = { + call = nil; + lightParts = nil; + litPropertyOverrides = nil; + model = nil; + pivot = nil; + unlitPropertyOverrides = nil; + } + buttons.Button.__index = buttons.Button + + local lightFuncs: { (self: types.Button, lit: boolean) -> nil } = { + ['\x00'] = function(self: types.Button, lit: boolean) + end; + } + )"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_table_assertion_crash") +{ + CheckResult result = check(R"( + local NexusInstance = {} + function NexusInstance:__InitMetaMethods(): () + local Metatable = {} + local OriginalIndexTable = getmetatable(self).__index + setmetatable(self, Metatable) + + Metatable.__newindex = function(_, Index: string, Value: any): () + --Return if the new and old values are the same. + if self[Index] == Value then + end + end + end + )"); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 140f462a..1d1dd999 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -19,6 +19,7 @@ LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping); +LUAU_FASTFLAG(LuauLeadingBarAndAmpersand2) LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTINT(LuauNormalizeCacheLimit); LUAU_FASTINT(LuauRecursionLimit); @@ -782,7 +783,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "no_heap_use_after_free_error") end )"); - LUAU_REQUIRE_ERRORS(result); + if (FFlag::DebugLuauDeferredConstraintResolution) + LUAU_REQUIRE_NO_ERRORS(result); + else + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "infer_type_assertion_value_type") @@ -1532,7 +1536,7 @@ TEST_CASE_FIXTURE(Fixture, "typeof_cannot_refine_builtin_alias") freeze(arena); - (void) check(R"( + (void)check(R"( function foo(x) if typeof(x) == 'GlobalTable' then end @@ -1540,19 +1544,91 @@ TEST_CASE_FIXTURE(Fixture, "typeof_cannot_refine_builtin_alias") )"); } -/* - * We had an issue where we tripped the canMutate() check when binding one - * blocked type to another. - */ -TEST_CASE_FIXTURE(Fixture, "delay_setIndexer_constraint_if_the_indexers_type_is_blocked") +TEST_CASE_FIXTURE(BuiltinsFixture, "bad_iter_metamethod") { - (void) check(R"( - local SG = GetService(true) - local lines: { [string]: typeof(SG.ScreenGui) } = {} - lines[deadline] = nil -- This line + CheckResult result = check(R"( + function iter(): unknown + return nil + end + + local a = {__iter = iter} + setmetatable(a, a) + + for i in a do + end )"); - // As long as type inference doesn't trip an assert or crash, we're good! + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CannotCallNonFunction* ccnf = get(result.errors[0]); + REQUIRE(ccnf); + + CHECK("unknown" == toString(ccnf->ty)); + } + else + { + LUAU_REQUIRE_NO_ERRORS(result); + } +} + +TEST_CASE_FIXTURE(Fixture, "leading_bar") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + CheckResult result = check(R"( + type Bar = | number + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("number" == toString(requireTypeAlias("Bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "leading_bar_question_mark") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + CheckResult result = check(R"( + type Bar = |? + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Expected type, got '?'" == toString(result.errors[0])); + CHECK("*error-type*?" == toString(requireTypeAlias("Bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "leading_ampersand") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + CheckResult result = check(R"( + type Amp = & string + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("string" == toString(requireTypeAlias("Amp"))); +} + +TEST_CASE_FIXTURE(Fixture, "leading_bar_no_type") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + CheckResult result = check(R"( + type Bar = | + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Expected type, got " == toString(result.errors[0])); + CHECK("*error-type*" == toString(requireTypeAlias("Bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "leading_ampersand_no_type") +{ + ScopedFastFlag sff{FFlag::LuauLeadingBarAndAmpersand2, true}; + CheckResult result = check(R"( + type Amp = & + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Expected type, got " == toString(result.errors[0])); + CHECK("*error-type*" == toString(requireTypeAlias("Amp"))); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 58ccea89..92f07c43 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -13,6 +13,7 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls); +LUAU_FASTFLAG(LuauUnifierRecursionOnRestart); struct TryUnifyFixture : Fixture { @@ -480,4 +481,34 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_two_unions_under_dcr_does_not_creat } } +TEST_CASE_FIXTURE(BuiltinsFixture, "table_unification_full_restart_recursion") +{ + ScopedFastFlag luauUnifierRecursionOnRestart{FFlag::LuauUnifierRecursionOnRestart, true}; + + CheckResult result = check(R"( +local A, B, C, D + +E = function(a, b) + local mt = getmetatable(b) + if mt.tm:bar(A) == nil and mt.tm:bar(B) == nil then end + if mt.foo == true then D(b, 3) end + mt.foo:call(false, b) +end + +A = function(a, b) + local mt = getmetatable(b) + if mt.foo == true then D(b, 3) end + C(mt, 3) +end + +B = function(a, b) + local mt = getmetatable(b) + if mt.foo == true then D(b, 3) end + C(mt, 3) +end + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typestates.test.cpp b/tests/TypeInfer.typestates.test.cpp index dbb9815d..19117447 100644 --- a/tests/TypeInfer.typestates.test.cpp +++ b/tests/TypeInfer.typestates.test.cpp @@ -406,6 +406,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "prototyped_recursive_functions_but_has_futur )"); LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("((() -> ()) | number)?" == toString(requireType("f"))); } @@ -490,5 +491,34 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typestates_do_not_apply_to_the_initial_local CHECK("number" == toString(requireTypeAtPosition({5, 14}), {true})); } +TEST_CASE_FIXTURE(Fixture, "typestate_globals") +{ + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; + + loadDefinition(R"( + declare foo: string | number + declare function f(x: string): () + )"); + + CheckResult result = check(R"( + foo = "a" + f(foo) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "typestate_unknown_global") +{ + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; + + CheckResult result = check(R"( + x = 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(get(result.errors[0])); +} TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 5f4d2a0e..539b8592 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -606,13 +606,7 @@ TEST_CASE_FIXTURE(Fixture, "indexing_into_a_cyclic_union_doesnt_crash") end )"); - // The old solver has a bug: It doesn't consider this goofy thing to be a - // table. It's not really important. What's important is that we don't - // crash, hang, or ICE. - if (FFlag::DebugLuauDeferredConstraintResolution) - LUAU_REQUIRE_NO_ERRORS(result); - else - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_union_write_indirect") diff --git a/tests/TypeInfer.unknownnever.test.cpp b/tests/TypeInfer.unknownnever.test.cpp index 8ec70d11..f1924b1c 100644 --- a/tests/TypeInfer.unknownnever.test.cpp +++ b/tests/TypeInfer.unknownnever.test.cpp @@ -396,4 +396,15 @@ TEST_CASE_FIXTURE(Fixture, "lti_permit_explicit_never_annotation") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "cast_from_never_does_not_error") +{ + CheckResult result = check(R"( + local function f(x: never): number + return x :: number + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 17f4497a..98f8000e 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -726,16 +726,15 @@ assert((function() return sum end)() == 15) --- the reason why this test is interesting is that the table created here has arraysize=0 and a single hash element with key = 1.0 --- ipairs must iterate through that +-- ipairs will not iterate through hash part assert((function() - local arr = { [1] = 42 } + local arr = { [1] = 1, [42] = 42, x = 10 } local sum = 0 for i,v in ipairs(arr) do sum = sum + v end return sum -end)() == 42) +end)() == 1) -- the reason why this test is interesting is it ensures we do correct mutability analysis for locals local function chainTest(n) diff --git a/tests/conformance/bitwise.lua b/tests/conformance/bitwise.lua index f394dc5b..c2536508 100644 --- a/tests/conformance/bitwise.lua +++ b/tests/conformance/bitwise.lua @@ -72,6 +72,7 @@ for _, b in pairs(c) do assert(bit32.bxor(b, b) == 0) assert(bit32.bxor(b, 0) == b) assert(bit32.bxor(b, b, b) == b) + assert(bit32.bxor(b, b, b, b) == 0) assert(bit32.bnot(b) ~= b) assert(bit32.bnot(bit32.bnot(b)) == b) assert(bit32.bnot(b) == 2^32 - 1 - b) diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index 9262f4ea..98d5b317 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -268,10 +268,43 @@ assert(math.min(1) == 1) assert(math.min(1, 2) == 1) assert(math.min(1, 2, -1) == -1) assert(math.min(1, -1, 2) == -1) +assert(math.min(1, -1, 2, -2) == -2) assert(math.max(1) == 1) assert(math.max(1, 2) == 2) assert(math.max(1, 2, -1) == 2) assert(math.max(1, -1, 2) == 2) +assert(math.max(1, -1, 2, -2) == 2) + +local ma, mb, mc, md + +assert(pcall(function() + ma = 1 + mb = -1 + mc = 2 + md = -2 +end) == true) + +-- min/max without contant-folding +assert(math.min(ma) == 1) +assert(math.min(ma, mc) == 1) +assert(math.min(ma, mc, mb) == -1) +assert(math.min(ma, mb, mc) == -1) +assert(math.min(ma, mb, mc, md) == -2) +assert(math.max(ma) == 1) +assert(math.max(ma, mc) == 2) +assert(math.max(ma, mc, mb) == 2) +assert(math.max(ma, mb, mc) == 2) +assert(math.max(ma, mb, mc, md) == 2) + +local inf = math.huge * 2 +local nan = 0 / 0 + +assert(math.min(nan, 2) ~= math.min(nan, 2)) +assert(math.min(1, nan) == 1) +assert(math.max(nan, 2) ~= math.max(nan, 2)) +assert(math.max(1, nan) == 1) + +local function noinline(x, ...) local s, r = pcall(function(y) return y end, x) return r end -- noise assert(math.noise(0.5) == 0) @@ -279,8 +312,10 @@ assert(math.noise(0.5, 0.5) == -0.25) assert(math.noise(0.5, 0.5, -0.5) == 0.125) assert(math.noise(455.7204209769105, 340.80410508750134, 121.80087666537628) == 0.5010709762573242) -local inf = math.huge * 2 -local nan = 0 / 0 +assert(math.noise(noinline(0.5)) == 0) +assert(math.noise(noinline(0.5), 0.5) == -0.25) +assert(math.noise(noinline(0.5), 0.5, -0.5) == 0.125) +assert(math.noise(noinline(455.7204209769105), 340.80410508750134, 121.80087666537628) == 0.5010709762573242) -- sign assert(math.sign(0) == 0) @@ -290,10 +325,12 @@ assert(math.sign(inf) == 1) assert(math.sign(-inf) == -1) assert(math.sign(nan) == 0) -assert(math.min(nan, 2) ~= math.min(nan, 2)) -assert(math.min(1, nan) == 1) -assert(math.max(nan, 2) ~= math.max(nan, 2)) -assert(math.max(1, nan) == 1) +assert(math.sign(noinline(0)) == 0) +assert(math.sign(noinline(42)) == 1) +assert(math.sign(noinline(-42)) == -1) +assert(math.sign(noinline(inf)) == 1) +assert(math.sign(noinline(-inf)) == -1) +assert(math.sign(noinline(nan)) == 0) -- clamp assert(math.clamp(-1, 0, 1) == 0) @@ -301,6 +338,11 @@ assert(math.clamp(0.5, 0, 1) == 0.5) assert(math.clamp(2, 0, 1) == 1) assert(math.clamp(4, 0, 0) == 0) +assert(math.clamp(noinline(-1), 0, 1) == 0) +assert(math.clamp(noinline(0.5), 0, 1) == 0.5) +assert(math.clamp(noinline(2), 0, 1) == 1) +assert(math.clamp(noinline(4), 0, 0) == 0) + -- round assert(math.round(0) == 0) assert(math.round(0.4) == 0) @@ -313,19 +355,58 @@ assert(math.round(math.huge) == math.huge) assert(math.round(0.49999999999999994) == 0) assert(math.round(-0.49999999999999994) == 0) +assert(math.round(noinline(0)) == 0) +assert(math.round(noinline(0.4)) == 0) +assert(math.round(noinline(0.5)) == 1) +assert(math.round(noinline(3.5)) == 4) +assert(math.round(noinline(-0.4)) == 0) +assert(math.round(noinline(-0.5)) == -1) +assert(math.round(noinline(-3.5)) == -4) +assert(math.round(noinline(math.huge)) == math.huge) +assert(math.round(noinline(0.49999999999999994)) == 0) +assert(math.round(noinline(-0.49999999999999994)) == 0) + -- fmod assert(math.fmod(3, 2) == 1) assert(math.fmod(-3, 2) == -1) assert(math.fmod(3, -2) == 1) assert(math.fmod(-3, -2) == -1) +assert(math.fmod(noinline(3), 2) == 1) +assert(math.fmod(noinline(-3), 2) == -1) +assert(math.fmod(noinline(3), -2) == 1) +assert(math.fmod(noinline(-3), -2) == -1) + -- pow assert(math.pow(2, 0) == 1) assert(math.pow(2, 2) == 4) assert(math.pow(4, 0.5) == 2) assert(math.pow(-2, 2) == 4) + +assert(math.pow(noinline(2), 0) == 1) +assert(math.pow(noinline(2), 2) == 4) +assert(math.pow(noinline(4), 0.5) == 2) +assert(math.pow(noinline(-2), 2) == 4) + assert(tostring(math.pow(-2, 0.5)) == "nan") +-- test that fastcalls return correct number of results +assert(select('#', math.floor(1.4)) == 1) +assert(select('#', math.ceil(1.6)) == 1) +assert(select('#', math.sqrt(9)) == 1) +assert(select('#', math.deg(9)) == 1) +assert(select('#', math.rad(9)) == 1) +assert(select('#', math.sin(1.5)) == 1) +assert(select('#', math.atan2(1.5, 0.5)) == 1) +assert(select('#', math.modf(1.5)) == 2) +assert(select('#', math.frexp(1.5)) == 2) + +-- test that fastcalls that return variadic results return them correctly in variadic position +assert(select(1, math.modf(1.5)) == 1) +assert(select(2, math.modf(1.5)) == 0.5) +assert(select(1, math.frexp(1.5)) == 0.75) +assert(select(2, math.frexp(1.5)) == 1) + -- most of the tests above go through fastcall path -- to make sure the basic implementations are also correct we test these functions with string->number coercions assert(math.abs("-4") == 4) @@ -370,21 +451,4 @@ assert(math.sign("-2") == -1) assert(math.sign("0") == 0) assert(math.round("1.8") == 2) --- test that fastcalls return correct number of results -assert(select('#', math.floor(1.4)) == 1) -assert(select('#', math.ceil(1.6)) == 1) -assert(select('#', math.sqrt(9)) == 1) -assert(select('#', math.deg(9)) == 1) -assert(select('#', math.rad(9)) == 1) -assert(select('#', math.sin(1.5)) == 1) -assert(select('#', math.atan2(1.5, 0.5)) == 1) -assert(select('#', math.modf(1.5)) == 2) -assert(select('#', math.frexp(1.5)) == 2) - --- test that fastcalls that return variadic results return them correctly in variadic position -assert(select(1, math.modf(1.5)) == 1) -assert(select(2, math.modf(1.5)) == 0.5) -assert(select(1, math.frexp(1.5)) == 0.75) -assert(select(2, math.frexp(1.5)) == 1) - return('OK') diff --git a/tests/conformance/move.lua b/tests/conformance/move.lua index 9518219f..bb613157 100644 --- a/tests/conformance/move.lua +++ b/tests/conformance/move.lua @@ -65,30 +65,6 @@ do a = table.move({[minI] = 100}, minI, minI, maxI) eqT(a, {[minI] = 100, [maxI] = 100}) - -- moving small amount of elements (array/hash) using a wide range - a = {} - table.move({1, 2, 3, 4, 5}, -100000000, 100000000, -100000000, a) - eqT(a, {1, 2, 3, 4, 5}) - - a = {} - table.move({1, 2}, -100000000, 100000000, 0, a) - eqT(a, {[100000001] = 1, [100000002] = 2}) - - -- hash part copy - a = {} - table.move({[-1000000] = 1, [-100] = 2, [100] = 3, [100000] = 4}, -100000000, 100000000, 0, a) - eqT(a, {[99000000] = 1, [99999900] = 2, [100000100] = 3, [100100000] = 4}) - - -- precise hash part bounds - a = {} - table.move({[-100000000 - 1] = -1, [-100000000] = 1, [-100] = 2, [100] = 3, [100000000] = 4, [100000000 + 1] = -1}, -100000000, 100000000, 0, a) - eqT(a, {[0] = 1, [99999900] = 2, [100000100] = 3, [200000000] = 4}) - - -- no integer undeflow in corner hash part case - a = {} - table.move({[minI] = 100, [-100] = 2}, minI, minI + 100000000, minI, a) - eqT(a, {[minI] = 100}) - -- hash part skips array slice a = {} table.move({[-1] = 1, [0] = 2, [1] = 3, [2] = 4}, -1, 3, 1, a) @@ -97,6 +73,19 @@ do a = {} table.move({[-1] = 1, [0] = 2, [1] = 3, [2] = 4, [10] = 5, [100] = 6, [1000] = 7}, -1, 3, 1, a) eqT(a, {[1] = 1, [2] = 2, [3] = 3, [4] = 4}) + + -- moving ranges containing nil values into tables with values + a = {1, 2, 3, 4, 5} + table.move({10}, 1, 3, 2, a) + eqT(a, {1, 10, nil, nil, 5}) + + a = {1, 2, 3, 4, 5} + table.move({10}, -1, 1, 2, a) + eqT(a, {1, nil, nil, 10, 5}) + + a = {[-1000] = 1, [1000] = 2, [1] = 3} + table.move({10}, -1000, 1000, -1000, a) + eqT(a, {10}) end checkerror("too many", table.move, {}, 0, maxI, 1) diff --git a/tests/conformance/native.lua b/tests/conformance/native.lua index 094e6b83..03845013 100644 --- a/tests/conformance/native.lua +++ b/tests/conformance/native.lua @@ -208,6 +208,35 @@ end assert(pcall(fuzzfail21) == false) +local function fuzzfail22(...) + local _ = {false,},true,...,l0 + while _ do + _ = true,{unpack(0,_),},l0 + _.n126 = nil + _ = {not _,_=not _,n0=_,_,n0=not _,},_ < _ + return _ > _ + end + return `""` +end + +assert(pcall(fuzzfail22) == false) + +local function fuzzfail23(...) + local _ = {false,},_,...,l0 + while _ do + _ = true,{unpack(_),},l0 + _ = {{[_]=nil,_=not _,_,true,_=nil,},not _,not _,_,bxor=- _,} + do end + break + end + do end + local _ = _,true + do end + local _ = _,true +end + +assert(pcall(fuzzfail23) == false) + local function arraySizeInv1() local t = {1, 2, nil, nil, nil, nil, nil, nil, nil, true} diff --git a/tests/conformance/native_userdata.lua b/tests/conformance/native_userdata.lua new file mode 100644 index 00000000..b1b2a103 --- /dev/null +++ b/tests/conformance/native_userdata.lua @@ -0,0 +1,42 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print('testing userdata') + +function ecall(fn, ...) + local ok, err = pcall(fn, ...) + assert(not ok) + return err:sub((err:find(": ") or -1) + 2, #err) +end + +local function realmad(a: vec2, b: vec2, c: vec2): vec2 + return -c + a * b; +end + +local function dm(s: vec2, t: vec2, u: vec2) + local x = s:Dot(t) + assert(x == 13) + + local t = u:Min(s) + assert(t.X == 5) + assert(t.Y == 4) +end + +local s: vec2 = vec2(5, 4) +local t: vec2 = vec2(1, 2) +local u: vec2 = vec2(10, 20) + +local x: vec2 = realmad(s, t, u) + +assert(x.X == -5) +assert(x.Y == -12) + +dm(s, t, u) + +local function mu(v: vec2) + assert(v.Magnitude == 2) + assert(v.Unit.X == 0) + assert(v.Unit.Y == 1) +end + +mu(vec2(0, 2)) + +return 'OK' diff --git a/tests/conformance/tables.lua b/tests/conformance/tables.lua index 03b46396..c739f555 100644 --- a/tests/conformance/tables.lua +++ b/tests/conformance/tables.lua @@ -306,10 +306,14 @@ end assert(table.maxn{} == 0) +assert(table.maxn{[-100] = 1} == 0) assert(table.maxn{["1000"] = true} == 0) assert(table.maxn{["1000"] = true, [24.5] = 3} == 24.5) assert(table.maxn{[1000] = true} == 1000) assert(table.maxn{[10] = true, [100*math.pi] = print} == 100*math.pi) +a = {[10] = 1, [20] = 2} +a[20] = nil +assert(table.maxn(a) == 10) -- int overflow @@ -408,8 +412,36 @@ do assert(table.find({false, true}, true) == 2) - -- make sure table.find checks the hash portion as well by constructing a table literal that forces the value into the hash part - assert(table.find({[(1)] = true}, true) == 1) + -- make sure table.find checks the hash portion as well + assert(table.find({[(2)] = true}, true, 2) == 2) +end + +-- test table.concat +do + -- regular usage + assert(table.concat({}) == "") + assert(table.concat({}, ",") == "") + assert(table.concat({"a", "b", "c"}, ",") == "a,b,c") + assert(table.concat({"a", "b", "c"}, ",", 2) == "b,c") + assert(table.concat({"a", "b", "c"}, ",", 1, 2) == "a,b") + + -- hash elements + local t = {} + t[123] = "a" + t[124] = "b" + + assert(table.concat(t) == "") + assert(table.concat(t, ",", 123, 124) == "a,b") + assert(table.concat(t, ",", 123, 123) == "a") + + -- numeric values + assert(table.concat({1, 2, 3}, ",") == "1,2,3") + assert(table.concat({"a", 2, "c"}, ",") == "a,2,c") + + -- error cases + assert(pcall(table.concat, "") == false) + assert(pcall(table.concat, t, false) == false) + assert(pcall(table.concat, t, ",", 1, 100) == false) end -- test indexing with strings that have zeroes embedded in them diff --git a/tests/conformance/vector.lua b/tests/conformance/vector.lua index 9be88f69..7e4a9a3e 100644 --- a/tests/conformance/vector.lua +++ b/tests/conformance/vector.lua @@ -4,6 +4,12 @@ print('testing vectors') -- detect vector size local vector_size = if pcall(function() return vector(0, 0, 0).w end) then 4 else 3 +function ecall(fn, ...) + local ok, err = pcall(fn, ...) + assert(not ok) + return err:sub((err:find(": ") or -1) + 2, #err) +end + -- equality assert(vector(1, 2, 3) == vector(1, 2, 3)) assert(vector(0, 1, 2) == vector(-0, 1, 2)) @@ -92,9 +98,29 @@ assert(nanv ~= nanv); -- __index assert(vector(1, 2, 2).Magnitude == 3) assert(vector(0, 0, 0)['Dot'](vector(1, 2, 4), vector(5, 6, 7)) == 45) +assert(vector(2, 0, 0).Unit == vector(1, 0, 0)) -- __namecall assert(vector(1, 2, 4):Dot(vector(5, 6, 7)) == 45) +assert(ecall(function() vector(1, 2, 4):Dot() end) == "missing argument #2 (vector expected)") +assert(ecall(function() vector(1, 2, 4):Dot("a") end) == "invalid argument #2 (vector expected, got string)") + +local function doDot1(a: vector, b) + return a:Dot(b) +end + +local function doDot2(a: vector, b) + return (a:Dot(b)) +end + +local v124 = vector(1, 2, 4) + +assert(doDot1(v124, vector(5, 6, 7)) == 45) +assert(doDot2(v124, vector(5, 6, 7)) == 45) +assert(ecall(function() doDot1(v124, "a") end) == "invalid argument #2 (vector expected, got string)") +assert(ecall(function() doDot2(v124, "a") end) == "invalid argument #2 (vector expected, got string)") +assert(select("#", doDot1(v124, vector(5, 6, 7))) == 1) +assert(select("#", doDot2(v124, vector(5, 6, 7))) == 1) -- can't use vector with NaN components as table key assert(pcall(function() local t = {} t[vector(0/0, 2, 3)] = 1 end) == false) @@ -102,6 +128,9 @@ assert(pcall(function() local t = {} t[vector(1, 0/0, 3)] = 1 end) == false) assert(pcall(function() local t = {} t[vector(1, 2, 0/0)] = 1 end) == false) assert(pcall(function() local t = {} rawset(t, vector(0/0, 2, 3), 1) end) == false) +assert(vector(1, 0, 0):Cross(vector(0, 1, 0)) == vector(0, 0, 1)) +assert(vector(0, 1, 0):Cross(vector(1, 0, 0)) == vector(0, 0, -1)) + -- make sure we cover both builtin and C impl assert(vector(1, 2, 4) == vector("1", "2", "4")) diff --git a/tests/main.cpp b/tests/main.cpp index 5d1ee6a6..4de391b6 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -18,7 +18,7 @@ #include // IsDebuggerPresent #endif -#if defined(__x86_64__) || defined(_M_X64) +#if defined(CODEGEN_TARGET_X64) #include #endif @@ -330,7 +330,7 @@ static void setFastFlags(const std::vector& flags) // This function performs system/architecture specific initialization prior to running tests. static void initSystem() { -#if defined(__x86_64__) || defined(_M_X64) +#if defined(CODEGEN_TARGET_X64) // Some unit tests make use of denormalized numbers. So flags to flush to zero or treat denormals as zero // must be disabled for expected behavior. _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF); diff --git a/tools/faillist.txt b/tools/faillist.txt index 469e3a84..834b24c0 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -32,12 +32,10 @@ BuiltinTests.string_format_report_all_type_errors_at_correct_positions BuiltinTests.string_format_use_correct_argument2 BuiltinTests.table_freeze_is_generic BuiltinTests.tonumber_returns_optional_number_type -DefinitionTests.class_definition_overload_metamethods Differ.metatable_metamissing_left Differ.metatable_metamissing_right Differ.metatable_metanormal Differ.negation -FrontendTest.accumulate_cached_errors_in_consistent_order FrontendTest.environments FrontendTest.imported_table_modification_2 FrontendTest.it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded @@ -46,7 +44,6 @@ FrontendTest.trace_requires_in_nonstrict_mode GenericsTests.apply_type_function_nested_generics1 GenericsTests.better_mismatch_error_messages GenericsTests.bound_tables_do_not_clone_original_fields -GenericsTests.correctly_instantiate_polymorphic_member_functions GenericsTests.do_not_always_instantiate_generic_intersection_types GenericsTests.do_not_infer_generic_functions GenericsTests.dont_substitute_bound_types @@ -71,7 +68,6 @@ GenericsTests.no_stack_overflow_from_quantifying GenericsTests.properties_can_be_instantiated_polytypes GenericsTests.quantify_functions_even_if_they_have_an_explicit_generic GenericsTests.self_recursive_instantiated_param -IntersectionTypes.CLI-44817 IntersectionTypes.error_detailed_intersection_all IntersectionTypes.error_detailed_intersection_part IntersectionTypes.intersect_bool_and_false @@ -134,9 +130,9 @@ RefinementTest.call_an_incompatible_function_after_using_typeguard RefinementTest.dataflow_analysis_can_tell_refinements_when_its_appropriate_to_refine_into_nil_or_never RefinementTest.discriminate_from_isa_of_x RefinementTest.discriminate_from_truthiness_of_x -RefinementTest.function_call_with_colon_after_refining_not_to_be_nil RefinementTest.globals_can_be_narrowed_too RefinementTest.isa_type_refinement_must_be_known_ahead_of_time +RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true RefinementTest.not_t_or_some_prop_of_t RefinementTest.refine_a_param_that_got_resolved_during_constraint_solving_stage RefinementTest.refine_a_property_of_some_global @@ -178,19 +174,18 @@ TableTests.generalize_table_argument TableTests.generic_table_instantiation_potential_regression TableTests.indexer_on_sealed_table_must_unify_with_free_table TableTests.indexers_get_quantified_too -TableTests.infer_array TableTests.infer_indexer_from_array_like_table TableTests.infer_indexer_from_its_variable_type_and_unifiable TableTests.inferred_return_type_of_free_table TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound -TableTests.length_operator_union TableTests.less_exponential_blowup_please TableTests.meta_add TableTests.meta_add_inferred TableTests.metatable_mismatch_should_fail TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred TableTests.mixed_tables_with_implicit_numbered_keys +TableTests.nil_assign_doesnt_hit_indexer TableTests.ok_to_provide_a_subtype_during_construction TableTests.ok_to_set_nil_even_on_non_lvalue_base_expr TableTests.okay_to_add_property_to_unsealed_tables_by_assignment @@ -198,7 +193,6 @@ TableTests.okay_to_add_property_to_unsealed_tables_by_function_call TableTests.only_ascribe_synthetic_names_at_module_scope TableTests.open_table_unification_2 TableTests.parameter_was_set_an_indexer_and_bounded_by_another_parameter -TableTests.parameter_was_set_an_indexer_and_bounded_by_string TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table_2 TableTests.persistent_sealed_table_is_immutable @@ -223,12 +217,10 @@ TableTests.table_subtyping_with_extra_props_dont_report_multiple_errors TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors2 TableTests.table_unification_4 TableTests.table_unifies_into_map -TableTests.table_writes_introduce_write_properties TableTests.type_mismatch_on_massive_table_is_cut_short TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon TableTests.when_augmenting_an_unsealed_table_with_an_indexer_apply_the_correct_scope_to_the_indexer_type -TableTests.wrong_assign_does_hit_indexer ToDot.function ToString.exhaustive_toString_of_cyclic_table ToString.free_types @@ -258,7 +250,6 @@ TypeAliases.type_alias_of_an_imported_recursive_generic_type TypeFamilyTests.add_family_at_work TypeFamilyTests.family_as_fn_arg TypeFamilyTests.internal_families_raise_errors -TypeFamilyTests.mul_family_with_union_of_multiplicatives_2 TypeFamilyTests.unsolvable_family TypeInfer.be_sure_to_use_active_txnlog_when_evaluating_a_variadic_overload TypeInfer.check_type_infer_recursion_count @@ -267,9 +258,9 @@ TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error TypeInfer.dont_ice_when_failing_the_occurs_check TypeInfer.dont_report_type_errors_within_an_AstExprError TypeInfer.dont_report_type_errors_within_an_AstStatError +TypeInfer.follow_on_new_types_in_substitution TypeInfer.globals TypeInfer.globals2 -TypeInfer.globals_are_banned_in_strict_mode TypeInfer.infer_through_group_expr TypeInfer.no_stack_overflow_from_isoptional TypeInfer.recursive_function_that_invokes_itself_with_a_refinement_of_its_parameter @@ -279,9 +270,11 @@ TypeInfer.type_infer_recursion_limit_no_ice TypeInfer.type_infer_recursion_limit_normalizer TypeInfer.unify_nearly_identical_recursive_types TypeInferAnyError.can_subscript_any -TypeInferAnyError.for_in_loop_iterator_is_error -TypeInferAnyError.for_in_loop_iterator_is_error2 -TypeInferAnyError.metatable_of_any_can_be_a_table +TypeInferAnyError.for_in_loop_iterator_is_any +TypeInferAnyError.for_in_loop_iterator_is_any2 +TypeInferAnyError.for_in_loop_iterator_is_any_pack +TypeInferAnyError.for_in_loop_iterator_returns_any +TypeInferAnyError.for_in_loop_iterator_returns_any2 TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any TypeInferClasses.callable_classes TypeInferClasses.cannot_unify_class_instance_with_primitive @@ -312,10 +305,7 @@ TypeInferFunctions.function_does_not_return_enough_values TypeInferFunctions.function_exprs_are_generalized_at_signature_scope_not_enclosing TypeInferFunctions.function_is_supertype_of_concrete_functions TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer -TypeInferFunctions.fuzzer_missing_follow_in_ast_stat_fun TypeInferFunctions.generic_packs_are_not_variadic -TypeInferFunctions.higher_order_function_2 -TypeInferFunctions.higher_order_function_3 TypeInferFunctions.higher_order_function_4 TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict TypeInferFunctions.improved_function_arg_mismatch_errors @@ -335,6 +325,7 @@ TypeInferFunctions.param_1_and_2_both_takes_the_same_generic_but_their_arguments TypeInferFunctions.param_1_and_2_both_takes_the_same_generic_but_their_arguments_are_incompatible_2 TypeInferFunctions.report_exiting_without_return_nonstrict TypeInferFunctions.return_type_by_overload +TypeInferFunctions.simple_unannotated_mutual_recursion TypeInferFunctions.too_few_arguments_variadic TypeInferFunctions.too_few_arguments_variadic_generic TypeInferFunctions.too_few_arguments_variadic_generic2 @@ -342,6 +333,7 @@ TypeInferFunctions.too_many_arguments TypeInferFunctions.too_many_arguments_error_location TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function +TypeInferFunctions.unifier_should_not_bind_free_types TypeInferLoops.cli_68448_iterators_need_not_accept_nil TypeInferLoops.dcr_iteration_on_never_gives_never TypeInferLoops.dcr_xpath_candidates @@ -353,7 +345,6 @@ TypeInferLoops.for_in_loop_on_non_function TypeInferLoops.for_in_loop_with_next TypeInferLoops.for_loop TypeInferLoops.ipairs_produces_integral_indices -TypeInferLoops.iterate_over_free_table TypeInferLoops.iterate_over_properties TypeInferLoops.iteration_regression_issue_69967_alt TypeInferLoops.loop_iter_metamethod_nil @@ -364,19 +355,14 @@ TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.loop_typecheck_crash_on_empty_optional TypeInferLoops.properly_infer_iteratee_is_a_free_table TypeInferLoops.repeat_loop -TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free TypeInferLoops.while_loop -TypeInferModules.custom_require_global -TypeInferModules.do_not_modify_imported_types_5 TypeInferModules.require TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory -TypeInferOOP.methods_are_topologically_sorted TypeInferOOP.promise_type_error_too_complex TypeInferOperators.add_type_family_works TypeInferOperators.cli_38355_recursive_union -TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.compound_assign_result_must_be_compatible_with_var TypeInferOperators.concat_op_on_free_lhs_and_string_rhs TypeInferOperators.concat_op_on_string_lhs_and_free_rhs @@ -391,7 +377,6 @@ TypeInferOperators.typecheck_unary_len_error TypeInferOperators.typecheck_unary_minus_error TypeInferOperators.UnknownGlobalCompoundAssign TypeInferPrimitives.CheckMethodsOfNumber -TypeInferPrimitives.string_index TypeInferUnknownNever.assign_to_local_which_is_never TypeInferUnknownNever.index_on_union_of_tables_for_properties_that_is_never TypeInferUnknownNever.index_on_union_of_tables_for_properties_that_is_sorta_never @@ -417,7 +402,6 @@ UnionTypes.generic_function_with_optional_arg UnionTypes.index_on_a_union_type_with_missing_property UnionTypes.less_greedy_unification_with_union_types UnionTypes.optional_arguments_table -UnionTypes.optional_length_error UnionTypes.optional_union_functions UnionTypes.optional_union_members UnionTypes.optional_union_methods diff --git a/tools/stackdbg.py b/tools/stackdbg.py new file mode 100644 index 00000000..de656c60 --- /dev/null +++ b/tools/stackdbg.py @@ -0,0 +1,94 @@ +#!usr/bin/python3 +""" +To use this command, simply run the command: +`command script import /path/to/your/game-engine/Client/Luau/tools/stackdbg.py` +in the `lldb` interpreter. You can also add it to your .lldbinit file to have it be +automatically imported. + +If using vscode, you can add the above command to your launch.json under `preRunCommands` for the appropriate target. For example: +{ + "name": "Luau.UnitTest", + "type": "lldb", + "request": "launch", + "program": "${workspaceFolder}/build/ninja/common-tests/noopt/Luau/Luau.UnitTest", + "preRunCommands": [ + "command script import ${workspaceFolder}/Client/Luau/tools/stackdbg.py" + ], +} + +Once this is loaded, +`(lldb) help stack` +or +`(lldb) stack -h +or +`(lldb) stack --help + +can get you started +""" + +import lldb +import functools +import argparse +import shlex + +# Dumps the collected frame data +def dump(collected): + for (frame_name, size_in_kb, live_size_kb, variables) in collected: + print(f'{frame_name}, locals: {size_in_kb}kb, fp-sp: {live_size_kb}kb') + for (var_name, var_size, variable_obj) in variables: + print(f' {var_name}, {var_size} bytes') + +def dbg_stack_pressure(frame, frames_to_show = 5, sort_frames = False, vars_to_show = 5, sort_vars = True): + totalKb = 0 + collect = [] + for f in frame.thread: + frame_name = f.GetFunctionName() + variables = [ (v.GetName(), v.GetByteSize(), v) for v in f.get_locals() ] + if sort_vars: + variables.sort(key = lambda x: x[1], reverse = True) + size_in_kb = functools.reduce(lambda x,y : x + y[1], variables, 0) / 1024 + + fp = f.GetFP() + sp = f.GetSP() + live_size_kb = round((fp - sp) / 1024, 2) + + size_in_kb = round(size_in_kb, 2) + totalKb += size_in_kb + collect.append((frame_name, size_in_kb, live_size_kb, variables[:vars_to_show])) + if sort_frames: + collect.sort(key = lambda x: x[1], reverse = True) + + print("******************** Report Stack Usage ********************") + totalMb = round(totalKb / 1024, 2) + print(f'{len(frame.thread)} stack frames used {totalMb}MB') + dump(collect[:frames_to_show]) + +def stack(debugger, command, result, internal_dict): + """ + usage: [-h] [-f FRAMES] [-fd] [-v VARS] [-vd] + + optional arguments: + -h, --help show this help message and exit + -f FRAMES, --frames FRAMES + How many stack frames to display + -fd, --sort_frames Sort frames + -v VARS, --vars VARS How many variables per frame to display + -vd, --sort_vars Sort frames + """ + + frame = debugger.GetSelectedTarget().GetProcess().GetSelectedThread().GetSelectedFrame() + args = shlex.split(command) + argparser = argparse.ArgumentParser(allow_abbrev = True) + argparser.add_argument("-f", "--frames", required=False, help="How many stack frames to display", default=5, type=int) + argparser.add_argument("-fd", "--sort_frames", required=False, help="Sort frames in descending order of stack usage", action="store_true", default=False) + argparser.add_argument("-v", "--vars", required=False, help="How many variables per frame to display", default=5, type=int) + argparser.add_argument("-vd", "--sort_vars", required=False, help="Sort locals in descending order of stack usage ", action="store_true", default=False) + + args = argparser.parse_args(args) + dbg_stack_pressure(frame, frames_to_show=args.frames, sort_frames=args.sort_frames, vars_to_show=args.vars, sort_vars=args.sort_vars) + +# Initialization code to add commands +def __lldb_init_module(debugger, internal_dict): + debugger.HandleCommand('command script add -f stackdbg.stack stack') + print("The 'stack' python command has been installed and is ready for use.") +