diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index db2f6712..7dc38835 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -65,10 +65,7 @@ TypeId makeFunction( // Polymorphic bool checked = false ); -void attachMagicFunction(TypeId ty, MagicFunction fn); -void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn); -void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn); -void attachDcrMagicFunctionTypeCheck(TypeId ty, DcrMagicFunctionTypeCheck fn); +void attachMagicFunction(TypeId ty, std::shared_ptr fn); Property makeProperty(TypeId ty, std::optional documentationSymbol = std::nullopt); void assignPropDocumentationSymbols(TableType::Props& props, const std::string& baseName); diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index b0c8fd17..7d5ce892 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -4,6 +4,7 @@ #include #include "Luau/TypeArena.h" #include "Luau/Type.h" +#include "Luau/Scope.h" #include @@ -26,13 +27,17 @@ struct CloneState * while `clone` will make a deep copy of the entire type and its every component. * * Be mindful about which behavior you actually _want_. + * + * Persistent types are not cloned as an optimization. + * If a type is cloned in order to mutate it, 'ignorePersistent' has to be set */ -TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState); -TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState); +TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState, bool ignorePersistent = false); +TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState, bool ignorePersistent = false); TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); +Binding clone(const Binding& binding, TypeArena& dest, CloneState& cloneState); } // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index ceb9cab4..b8eaac56 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -166,7 +166,7 @@ struct ConstraintSolver **/ void finalizeTypeFunctions(); - bool isDone(); + bool isDone() const; private: /** @@ -298,10 +298,10 @@ public: // FIXME: This use of a boolean for the return result is an appalling // interface. bool blockOnPendingTypes(TypeId target, NotNull constraint); - bool blockOnPendingTypes(TypePackId target, NotNull constraint); + bool blockOnPendingTypes(TypePackId targetPack, NotNull constraint); void unblock(NotNull progressed); - void unblock(TypeId progressed, Location location); + void unblock(TypeId ty, Location location); void unblock(TypePackId progressed, Location location); void unblock(const std::vector& types, Location location); void unblock(const std::vector& packs, Location location); @@ -336,7 +336,7 @@ public: * @param location the location where the require is taking place; used for * error locations. **/ - TypeId resolveModule(const ModuleInfo& module, const Location& location); + TypeId resolveModule(const ModuleInfo& info, const Location& location); void reportError(TypeErrorData&& data, const Location& location); void reportError(TypeError e); @@ -420,6 +420,11 @@ public: void throwUserCancelError() const; ToStringOptions opts; + + void fillInDiscriminantTypes( + NotNull constraint, + const std::vector>& discriminantTypes + ); }; void dump(NotNull rootScope, struct ToStringOptions& opts); diff --git a/Analysis/include/Luau/DataFlowGraph.h b/Analysis/include/Luau/DataFlowGraph.h index 83dfa4b7..7c0e81ac 100644 --- a/Analysis/include/Luau/DataFlowGraph.h +++ b/Analysis/include/Luau/DataFlowGraph.h @@ -6,6 +6,7 @@ #include "Luau/ControlFlow.h" #include "Luau/DenseHash.h" #include "Luau/Def.h" +#include "Luau/NotNull.h" #include "Luau/Symbol.h" #include "Luau/TypedAllocator.h" @@ -48,13 +49,13 @@ struct DataFlowGraph const RefinementKey* getRefinementKey(const AstExpr* expr) const; private: - DataFlowGraph() = default; + DataFlowGraph(NotNull defArena, NotNull keyArena); DataFlowGraph(const DataFlowGraph&) = delete; DataFlowGraph& operator=(const DataFlowGraph&) = delete; - DefArena defArena; - RefinementKeyArena keyArena; + NotNull defArena; + NotNull keyArena; DenseHashMap astDefs{nullptr}; @@ -110,30 +111,22 @@ using ScopeStack = std::vector; struct DataFlowGraphBuilder { - static DataFlowGraph build(AstStatBlock* root, NotNull handle); - - /** - * This method is identical to the build method above, but returns a pair of dfg, scopes as the data flow graph - * here is intended to live on the module between runs of typechecking. Before, the DFG only needed to live as - * long as the typecheck, but in a world with incremental typechecking, we need the information on the dfg to incrementally - * typecheck small fragments of code. - * @param block - pointer to the ast to build the dfg for - * @param handle - for raising internal errors while building the dfg - */ - static std::pair, std::vector>> buildShared( + static DataFlowGraph build( AstStatBlock* block, - NotNull handle + NotNull defArena, + NotNull keyArena, + NotNull handle ); private: - DataFlowGraphBuilder() = default; + DataFlowGraphBuilder(NotNull defArena, NotNull keyArena); DataFlowGraphBuilder(const DataFlowGraphBuilder&) = delete; DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete; DataFlowGraph graph; - NotNull defArena{&graph.defArena}; - NotNull keyArena{&graph.keyArena}; + NotNull defArena; + NotNull keyArena; struct InternalErrorReporter* handle = nullptr; diff --git a/Analysis/include/Luau/EqSatSimplificationImpl.h b/Analysis/include/Luau/EqSatSimplificationImpl.h index 24e8777a..e021baa8 100644 --- a/Analysis/include/Luau/EqSatSimplificationImpl.h +++ b/Analysis/include/Luau/EqSatSimplificationImpl.h @@ -53,7 +53,7 @@ LUAU_EQSAT_NODE_SET(Intersection); LUAU_EQSAT_NODE_ARRAY(Negation, 1); -LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(TTypeFun, const TypeFunction*); +LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(TTypeFun, std::shared_ptr); LUAU_EQSAT_UNIT(TNoRefine); LUAU_EQSAT_UNIT(Invalid); @@ -105,6 +105,9 @@ private: std::vector storage; }; +template +using Node = EqSat::Node; + using EType = EqSat::Language< TNil, TBoolean, @@ -146,7 +149,7 @@ using EType = EqSat::Language< struct StringCache { Allocator allocator; - DenseHashMap strings{{}}; + DenseHashMap strings{{}}; std::vector views; StringId add(std::string_view s); @@ -171,6 +174,9 @@ struct Subst Id eclass; Id newClass; + // The node into eclass which is boring, if any + std::optional boringIndex; + std::string desc; Subst(Id eclass, Id newClass, std::string desc = ""); @@ -211,6 +217,7 @@ struct Simplifier void subst(Id from, Id to); void subst(Id from, Id to, const std::string& ruleName); void subst(Id from, Id to, const std::string& ruleName, const std::unordered_map& forceNodes); + void subst(Id from, size_t boringIndex, Id to, const std::string& ruleName, const std::unordered_map& forceNodes); void unionClasses(std::vector& hereParts, Id there); @@ -218,6 +225,7 @@ struct Simplifier void simplifyUnion(Id id); void uninhabitedIntersection(Id id); void intersectWithNegatedClass(Id id); + void intersectWithNegatedAtom(Id id); void intersectWithNoRefine(Id id); void cyclicIntersectionOfUnion(Id id); void cyclicUnionOfIntersection(Id id); @@ -228,6 +236,7 @@ struct Simplifier void unneededTableModification(Id id); void builtinTypeFunctions(Id id); void iffyTypeFunctions(Id id); + void strictMetamethods(Id id); }; template @@ -293,13 +302,13 @@ QueryIterator::QueryIterator(EGraph* egraph_, Id eclass) for (const auto& enode : ecl.nodes) { - if (enode.index() < idx) + if (enode.node.index() < idx) ++index; else break; } - if (index >= ecl.nodes.size() || ecl.nodes[index].index() != idx) + if (index >= ecl.nodes.size() || ecl.nodes[index].node.index() != idx) { egraph = nullptr; index = 0; @@ -329,7 +338,7 @@ std::pair QueryIterator::operator*() const EGraph::EClassT& ecl = (*egraph)[eclass]; LUAU_ASSERT(index < ecl.nodes.size()); - auto& enode = ecl.nodes[index]; + auto& enode = ecl.nodes[index].node; Tag* result = enode.template get(); LUAU_ASSERT(result); return {result, index}; @@ -341,12 +350,16 @@ QueryIterator& QueryIterator::operator++() { const auto& ecl = (*egraph)[eclass]; - ++index; - if (index >= ecl.nodes.size() || ecl.nodes[index].index() != EType::VariantTy::getTypeId()) + do { - egraph = nullptr; - index = 0; - } + ++index; + if (index >= ecl.nodes.size() || ecl.nodes[index].node.index() != EType::VariantTy::getTypeId()) + { + egraph = nullptr; + index = 0; + break; + } + } while (ecl.nodes[index].boring); return *this; } diff --git a/Analysis/include/Luau/FragmentAutocomplete.h b/Analysis/include/Luau/FragmentAutocomplete.h index 2bbba6e6..bf67b8b6 100644 --- a/Analysis/include/Luau/FragmentAutocomplete.h +++ b/Analysis/include/Luau/FragmentAutocomplete.h @@ -15,6 +15,12 @@ namespace Luau { struct FrontendOptions; +enum class FragmentTypeCheckStatus +{ + SkipAutocomplete, + Success, +}; + struct FragmentAutocompleteAncestryResult { DenseHashMap localMap{AstName()}; @@ -29,6 +35,7 @@ struct FragmentParseResult AstStatBlock* root = nullptr; std::vector ancestry; AstStat* nearestStatement = nullptr; + std::vector commentLocations; std::unique_ptr alloc = std::make_unique(); }; @@ -49,14 +56,14 @@ struct FragmentAutocompleteResult FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos); -FragmentParseResult parseFragment( +std::optional parseFragment( const SourceModule& srcModule, std::string_view src, const Position& cursorPos, std::optional fragmentEndPosition ); -FragmentTypeCheckResult typecheckFragment( +std::pair typecheckFragment( Frontend& frontend, const ModuleName& moduleName, const Position& cursorPos, diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 272ee52a..dc443777 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -7,6 +7,7 @@ #include "Luau/ModuleResolver.h" #include "Luau/RequireTracer.h" #include "Luau/Scope.h" +#include "Luau/Set.h" #include "Luau/TypeCheckLimits.h" #include "Luau/Variant.h" #include "Luau/AnyTypeSummary.h" @@ -56,13 +57,32 @@ struct SourceNode return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule; } + bool hasInvalidModuleDependency(bool forAutocomplete) const + { + return forAutocomplete ? invalidModuleDependencyForAutocomplete : invalidModuleDependency; + } + + void setInvalidModuleDependency(bool value, bool forAutocomplete) + { + if (forAutocomplete) + invalidModuleDependencyForAutocomplete = value; + else + invalidModuleDependency = value; + } + ModuleName name; std::string humanReadableName; DenseHashSet requireSet{{}}; std::vector> requireLocations; + Set dependents{{}}; + bool dirtySourceModule = true; bool dirtyModule = true; bool dirtyModuleForAutocomplete = true; + + bool invalidModuleDependency = true; + bool invalidModuleDependencyForAutocomplete = true; + double autocompleteLimitsMult = 1.0; }; @@ -117,7 +137,7 @@ struct FrontendModuleResolver : ModuleResolver std::optional resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override; std::string getHumanReadableModuleName(const ModuleName& moduleName) const override; - void setModule(const ModuleName& moduleName, ModulePtr module); + bool setModule(const ModuleName& moduleName, ModulePtr module); void clearModules(); private: @@ -151,9 +171,13 @@ struct Frontend // Parse and typecheck module graph CheckResult check(const ModuleName& name, std::optional optionOverride = {}); // new shininess + bool allModuleDependenciesValid(const ModuleName& name, bool forAutocomplete = false) const; + bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; void markDirty(const ModuleName& name, std::vector* markedDirty = nullptr); + void traverseDependents(const ModuleName& name, std::function processSubtree); + /** Borrow a pointer into the SourceModule cache. * * Returns nullptr if we don't have it. This could mean that the script diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 49b4ae02..7346a422 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -16,6 +16,8 @@ #include #include +LUAU_FASTFLAG(LuauIncrementalAutocompleteCommentDetection) + namespace Luau { @@ -55,6 +57,7 @@ struct SourceModule } }; +bool isWithinComment(const std::vector& commentLocations, Position pos); bool isWithinComment(const SourceModule& sourceModule, Position pos); bool isWithinComment(const ParseResult& result, Position pos); @@ -136,6 +139,11 @@ struct Module TypePackId returnType = nullptr; std::unordered_map exportedTypeBindings; + // Arenas related to the DFG must persist after the DFG no longer exists, as + // Module objects maintain raw pointers to objects in these arenas. + DefArena defArena; + RefinementKeyArena keyArena; + bool hasModuleScope() const; ScopePtr getModuleScope() const; diff --git a/Analysis/include/Luau/NonStrictTypeChecker.h b/Analysis/include/Luau/NonStrictTypeChecker.h index 6229a932..880d487f 100644 --- a/Analysis/include/Luau/NonStrictTypeChecker.h +++ b/Analysis/include/Luau/NonStrictTypeChecker.h @@ -15,6 +15,7 @@ struct TypeCheckLimits; void checkNonStrict( NotNull builtinTypes, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull ice, NotNull unifierState, diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 97d13a60..f014c433 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/EqSatSimplification.h" #include "Luau/NotNull.h" #include "Luau/Set.h" #include "Luau/TypeFwd.h" @@ -21,8 +22,22 @@ struct Scope; using ModulePtr = std::shared_ptr; -bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); -bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); +bool isSubtype( + TypeId subTy, + TypeId superTy, + NotNull scope, + NotNull builtinTypes, + NotNull simplifier, + InternalErrorReporter& ice +); +bool isSubtype( + TypePackId subPack, + TypePackId superPack, + NotNull scope, + NotNull builtinTypes, + NotNull simplifier, + InternalErrorReporter& ice +); class TypeIds { diff --git a/Analysis/include/Luau/OverloadResolution.h b/Analysis/include/Luau/OverloadResolution.h index 83a33215..d85d769e 100644 --- a/Analysis/include/Luau/OverloadResolution.h +++ b/Analysis/include/Luau/OverloadResolution.h @@ -2,12 +2,13 @@ #pragma once #include "Luau/Ast.h" -#include "Luau/InsertionOrderedMap.h" -#include "Luau/NotNull.h" -#include "Luau/TypeFwd.h" -#include "Luau/Location.h" +#include "Luau/EqSatSimplification.h" #include "Luau/Error.h" +#include "Luau/InsertionOrderedMap.h" +#include "Luau/Location.h" +#include "Luau/NotNull.h" #include "Luau/Subtyping.h" +#include "Luau/TypeFwd.h" namespace Luau { @@ -34,6 +35,7 @@ struct OverloadResolver OverloadResolver( NotNull builtinTypes, NotNull arena, + NotNull simplifier, NotNull normalizer, NotNull typeFunctionRuntime, NotNull scope, @@ -44,6 +46,7 @@ struct OverloadResolver NotNull builtinTypes; NotNull arena; + NotNull simplifier; NotNull normalizer; NotNull typeFunctionRuntime; NotNull scope; @@ -110,6 +113,7 @@ struct SolveResult SolveResult solveFunctionCall( NotNull arena, NotNull builtinTypes, + NotNull simplifier, NotNull normalizer, NotNull typeFunctionRuntime, NotNull iceReporter, diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 0e6eff56..4604a2e1 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -85,12 +85,18 @@ struct Scope void inheritAssignments(const ScopePtr& childScope); void inheritRefinements(const ScopePtr& childScope); + // Track globals that should emit warnings during type checking. + DenseHashSet globalsToWarn{""}; + bool shouldWarnGlobal(std::string name) const; + // For mutually recursive type aliases, it's important that // they use the same types for the same names. // For instance, in `type Tree { data: T, children: Forest } type Forest = {Tree}` // we need that the generic type `T` in both cases is the same, so we use a cache. std::unordered_map typeAliasTypeParameters; std::unordered_map typeAliasTypePackParameters; + + std::optional> interiorFreeTypes; }; // Returns true iff the left scope encloses the right scope. A Scope* equal to diff --git a/Analysis/include/Luau/Simplify.h b/Analysis/include/Luau/Simplify.h index 5b363e96..aab37876 100644 --- a/Analysis/include/Luau/Simplify.h +++ b/Analysis/include/Luau/Simplify.h @@ -19,10 +19,10 @@ struct SimplifyResult DenseHashSet blockedTypes; }; -SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, TypeId ty, TypeId discriminant); +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right); SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, std::set parts); -SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId ty, TypeId discriminant); +SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right); enum class Relation { diff --git a/Analysis/include/Luau/Subtyping.h b/Analysis/include/Luau/Subtyping.h index 1e781056..26c4553e 100644 --- a/Analysis/include/Luau/Subtyping.h +++ b/Analysis/include/Luau/Subtyping.h @@ -1,13 +1,14 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/DenseHash.h" +#include "Luau/EqSatSimplification.h" #include "Luau/Set.h" +#include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFunction.h" #include "Luau/TypeFwd.h" #include "Luau/TypePairHash.h" #include "Luau/TypePath.h" -#include "Luau/TypeFunction.h" -#include "Luau/TypeCheckLimits.h" -#include "Luau/DenseHash.h" #include #include @@ -134,6 +135,7 @@ struct Subtyping { NotNull builtinTypes; NotNull arena; + NotNull simplifier; NotNull normalizer; NotNull typeFunctionRuntime; NotNull iceReporter; @@ -155,6 +157,7 @@ struct Subtyping Subtyping( NotNull builtinTypes, NotNull typeArena, + NotNull simplifier, NotNull normalizer, NotNull typeFunctionRuntime, NotNull iceReporter diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 85957bed..5c268f67 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -69,12 +69,16 @@ using Name = std::string; // A free type is one whose exact shape has yet to be fully determined. struct FreeType { + // New constructors + explicit FreeType(TypeLevel level, TypeId lowerBound, TypeId upperBound); + // This one got promoted to explicit + explicit FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound); + explicit FreeType(Scope* scope, TypeLevel level, TypeId lowerBound, TypeId upperBound); + // Old constructors explicit FreeType(TypeLevel level); explicit FreeType(Scope* scope); FreeType(Scope* scope, TypeLevel level); - FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound); - int index; TypeLevel level; Scope* scope = nullptr; @@ -131,14 +135,14 @@ struct BlockedType BlockedType(); int index; - Constraint* getOwner() const; - void setOwner(Constraint* newOwner); - void replaceOwner(Constraint* newOwner); + const Constraint* getOwner() const; + void setOwner(const Constraint* newOwner); + void replaceOwner(const Constraint* newOwner); private: // The constraint that is intended to unblock this type. Other constraints // should block on this constraint if present. - Constraint* owner = nullptr; + const Constraint* owner = nullptr; }; struct PrimitiveType @@ -279,9 +283,6 @@ struct WithPredicate } }; -using MagicFunction = std::function>(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate)>; - struct MagicFunctionCallContext { NotNull solver; @@ -291,7 +292,6 @@ struct MagicFunctionCallContext TypePackId result; }; -using DcrMagicFunction = std::function; struct MagicRefinementContext { NotNull scope; @@ -308,8 +308,29 @@ struct MagicFunctionTypeCheckContext NotNull checkScope; }; -using DcrMagicRefinement = void (*)(const MagicRefinementContext&); -using DcrMagicFunctionTypeCheck = std::function; +struct MagicFunction +{ + virtual std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) = 0; + + // Callback to allow custom typechecking of builtin function calls whose argument types + // will only be resolved after constraint solving. For example, the arguments to string.format + // have types that can only be decided after parsing the format string and unifying + // with the passed in values, but the correctness of the call can only be decided after + // all the types have been finalized. + virtual bool infer(const MagicFunctionCallContext&) = 0; + virtual void refine(const MagicRefinementContext&) {} + + // If a magic function needs to do its own special typechecking, do it here. + // Returns true if magic typechecking was performed. Return false if the + // default typechecking logic should run. + virtual bool typeCheck(const MagicFunctionTypeCheckContext&) + { + return false; + } + + virtual ~MagicFunction() {} +}; + struct FunctionType { // Global monomorphic function @@ -367,16 +388,7 @@ struct FunctionType Scope* scope = nullptr; TypePackId argTypes; TypePackId retTypes; - MagicFunction magicFunction = nullptr; - DcrMagicFunction dcrMagicFunction = nullptr; - DcrMagicRefinement dcrMagicRefinement = nullptr; - - // Callback to allow custom typechecking of builtin function calls whose argument types - // will only be resolved after constraint solving. For example, the arguments to string.format - // have types that can only be decided after parsing the format string and unifying - // with the passed in values, but the correctness of the call can only be decided after - // all the types have been finalized. - DcrMagicFunctionTypeCheck dcrMagicTypeCheck = nullptr; + std::shared_ptr magic = nullptr; bool hasSelf; // `hasNoFreeOrGenericTypes` should be true if and only if the type does not have any free or generic types present inside it. @@ -608,7 +620,8 @@ struct UserDefinedFunctionData // References to AST elements are owned by the Module allocator which also stores this type AstStatTypeFunction* definition = nullptr; - DenseHashMap environment{""}; + DenseHashMap> environment{""}; + DenseHashMap environment_DEPRECATED{""}; }; /** @@ -625,7 +638,7 @@ struct TypeFunctionInstanceType std::vector typeArguments; std::vector packArguments; - std::optional userFuncName; // Name of the user-defined type function; only available for UDTFs + std::optional userFuncName; // Name of the user-defined type function; only available for UDTFs UserDefinedFunctionData userFuncData; TypeFunctionInstanceType( diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h index 4f8aea87..ebefa41f 100644 --- a/Analysis/include/Luau/TypeArena.h +++ b/Analysis/include/Luau/TypeArena.h @@ -32,9 +32,13 @@ struct TypeArena TypeId addTV(Type&& tv); - TypeId freshType(TypeLevel level); - TypeId freshType(Scope* scope); - TypeId freshType(Scope* scope, TypeLevel level); + TypeId freshType(NotNull builtins, TypeLevel level); + TypeId freshType(NotNull builtins, Scope* scope); + TypeId freshType(NotNull builtins, Scope* scope, TypeLevel level); + + TypeId freshType_DEPRECATED(TypeLevel level); + TypeId freshType_DEPRECATED(Scope* scope); + TypeId freshType_DEPRECATED(Scope* scope, TypeLevel level); TypePackId freshTypePack(Scope* scope); diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h index 3ede5ca7..871471a4 100644 --- a/Analysis/include/Luau/TypeChecker2.h +++ b/Analysis/include/Luau/TypeChecker2.h @@ -2,15 +2,16 @@ #pragma once -#include "Luau/Error.h" -#include "Luau/NotNull.h" #include "Luau/Common.h" -#include "Luau/TypeUtils.h" +#include "Luau/EqSatSimplification.h" +#include "Luau/Error.h" +#include "Luau/Normalize.h" +#include "Luau/NotNull.h" +#include "Luau/Subtyping.h" #include "Luau/Type.h" #include "Luau/TypeFwd.h" #include "Luau/TypeOrPack.h" -#include "Luau/Normalize.h" -#include "Luau/Subtyping.h" +#include "Luau/TypeUtils.h" namespace Luau { @@ -60,8 +61,9 @@ struct Reasonings void check( NotNull builtinTypes, + NotNull simplifier, NotNull typeFunctionRuntime, - NotNull sharedState, + NotNull unifierState, NotNull limits, DcrLogger* logger, const SourceModule& sourceModule, @@ -71,6 +73,7 @@ void check( struct TypeChecker2 { NotNull builtinTypes; + NotNull simplifier; NotNull typeFunctionRuntime; DcrLogger* logger; const NotNull limits; @@ -90,6 +93,7 @@ struct TypeChecker2 TypeChecker2( NotNull builtinTypes, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull unifierState, NotNull limits, @@ -112,14 +116,14 @@ private: std::optional pushStack(AstNode* node); void checkForInternalTypeFunction(TypeId ty, Location location); TypeId checkForTypeFunctionInhabitance(TypeId instance, Location location); - TypePackId lookupPack(AstExpr* expr); + TypePackId lookupPack(AstExpr* expr) const; TypeId lookupType(AstExpr* expr); TypeId lookupAnnotation(AstType* annotation); - std::optional lookupPackAnnotation(AstTypePack* annotation); - TypeId lookupExpectedType(AstExpr* expr); - TypePackId lookupExpectedPack(AstExpr* expr, TypeArena& arena); + std::optional lookupPackAnnotation(AstTypePack* annotation) const; + TypeId lookupExpectedType(AstExpr* expr) const; + TypePackId lookupExpectedPack(AstExpr* expr, TypeArena& arena) const; TypePackId reconstructPack(AstArray exprs, TypeArena& arena); - Scope* findInnermostScope(Location location); + Scope* findInnermostScope(Location location) const; void visit(AstStat* stat); void visit(AstStatIf* ifStatement); void visit(AstStatWhile* whileStatement); @@ -156,7 +160,7 @@ private: void visit(AstExprVarargs* expr); void visitCall(AstExprCall* call); void visit(AstExprCall* call); - std::optional tryStripUnionFromNil(TypeId ty); + std::optional tryStripUnionFromNil(TypeId ty) const; TypeId stripFromNilAndReport(TypeId ty, const Location& location); void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context, TypeId astIndexExprTy); void visit(AstExprIndexName* indexName, ValueContext context); @@ -213,6 +217,9 @@ private: std::vector& errors ); + // Avoid duplicate warnings being emitted for the same global variable. + DenseHashSet warnedGlobals{""}; + void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) const; bool isErrorSuppressing(Location loc, TypeId ty); bool isErrorSuppressing(Location loc1, TypeId ty1, Location loc2, TypeId ty2); diff --git a/Analysis/include/Luau/TypeFunction.h b/Analysis/include/Luau/TypeFunction.h index df696b62..1c97550f 100644 --- a/Analysis/include/Luau/TypeFunction.h +++ b/Analysis/include/Luau/TypeFunction.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Constraint.h" +#include "Luau/EqSatSimplification.h" #include "Luau/Error.h" #include "Luau/NotNull.h" #include "Luau/TypeCheckLimits.h" @@ -41,9 +42,15 @@ struct TypeFunctionRuntime StateRef state; + // Set of functions which have their environment table initialized + DenseHashSet initialized{nullptr}; + // Evaluation of type functions should only be performed in the absence of parse errors in the source module bool allowEvaluation = true; + // Output created by 'print' function + std::vector messages; + private: void prepareState(); }; @@ -53,6 +60,7 @@ struct TypeFunctionContext NotNull arena; NotNull builtins; NotNull scope; + NotNull simplifier; NotNull normalizer; NotNull typeFunctionRuntime; NotNull ice; @@ -63,7 +71,7 @@ struct TypeFunctionContext // The constraint being reduced in this run of the reduction const Constraint* constraint; - std::optional userFuncName; // Name of the user-defined type function; only available for UDTFs + std::optional userFuncName; // Name of the user-defined type function; only available for UDTFs TypeFunctionContext(NotNull cs, NotNull scope, NotNull constraint); @@ -71,6 +79,7 @@ struct TypeFunctionContext NotNull arena, NotNull builtins, NotNull scope, + NotNull simplifier, NotNull normalizer, NotNull typeFunctionRuntime, NotNull ice, @@ -79,6 +88,7 @@ struct TypeFunctionContext : arena(arena) , builtins(builtins) , scope(scope) + , simplifier(simplifier) , normalizer(normalizer) , typeFunctionRuntime(typeFunctionRuntime) , ice(ice) @@ -91,19 +101,31 @@ struct TypeFunctionContext NotNull pushConstraint(ConstraintV&& c) const; }; +enum class Reduction +{ + // The type function is either known to be reducible or the determination is blocked. + MaybeOk, + // The type function is known to be irreducible, but maybe not be erroneous, e.g. when it's over generics or free types. + Irreducible, + // The type function is known to be irreducible, and is definitely erroneous. + Erroneous, +}; + /// Represents a reduction result, which may have successfully reduced the type, /// may have concretely failed to reduce the type, or may simply be stuck /// without more information. template struct TypeFunctionReductionResult { + /// The result of the reduction, if any. If this is nullopt, the type function /// could not be reduced. std::optional result; - /// Whether the result is uninhabited: whether we know, unambiguously and - /// permanently, whether this type function reduction results in an - /// uninhabitable type. This will trigger an error to be reported. - bool uninhabited; + /// Indicates the status of this reduction: is `Reduction::Irreducible` if + /// the this result indicates the type function is irreducible, and + /// `Reduction::Erroneous` if this result indicates the type function is + /// erroneous. `Reduction::MaybeOk` otherwise. + Reduction reductionStatus; /// Any types that need to be progressed or mutated before the reduction may /// proceed. std::vector blockedTypes; @@ -112,6 +134,8 @@ struct TypeFunctionReductionResult std::vector blockedPacks; /// A runtime error message from user-defined type functions std::optional error; + /// Messages printed out from user-defined type functions + std::vector messages; }; template @@ -145,6 +169,7 @@ struct TypePackFunction struct FunctionGraphReductionResult { ErrorVec errors; + ErrorVec messages; DenseHashSet blockedTypes{nullptr}; DenseHashSet blockedPacks{nullptr}; DenseHashSet reducedTypes{nullptr}; @@ -216,6 +241,9 @@ struct BuiltinTypeFunctions TypeFunction indexFunc; TypeFunction rawgetFunc; + TypeFunction setmetatableFunc; + TypeFunction getmetatableFunc; + void addToScope(NotNull arena, NotNull scope) const; }; diff --git a/Analysis/include/Luau/TypeFunctionRuntime.h b/Analysis/include/Luau/TypeFunctionRuntime.h index 356d34a5..d715ccd3 100644 --- a/Analysis/include/Luau/TypeFunctionRuntime.h +++ b/Analysis/include/Luau/TypeFunctionRuntime.h @@ -119,7 +119,14 @@ struct TypeFunctionVariadicTypePack TypeFunctionTypeId type; }; -using TypeFunctionTypePackVariant = Variant; +struct TypeFunctionGenericTypePack +{ + bool isNamed = false; + + std::string name; +}; + +using TypeFunctionTypePackVariant = Variant; struct TypeFunctionTypePackVar { @@ -135,6 +142,9 @@ struct TypeFunctionTypePackVar struct TypeFunctionFunctionType { + std::vector generics; + std::vector genericPacks; + TypeFunctionTypePackId argTypes; TypeFunctionTypePackId retTypes; }; @@ -210,6 +220,14 @@ struct TypeFunctionClassType std::string name; }; +struct TypeFunctionGenericType +{ + bool isNamed = false; + bool isPack = false; + + std::string name; +}; + using TypeFunctionTypeVariant = Luau::Variant< TypeFunctionPrimitiveType, TypeFunctionAnyType, @@ -221,7 +239,8 @@ using TypeFunctionTypeVariant = Luau::Variant< TypeFunctionNegationType, TypeFunctionFunctionType, TypeFunctionTableType, - TypeFunctionClassType>; + TypeFunctionClassType, + TypeFunctionGenericType>; struct TypeFunctionType { diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index de9660ef..c3bed421 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -40,7 +40,7 @@ struct InConditionalContext TypeContext* typeContext; TypeContext oldValue; - InConditionalContext(TypeContext* c) + explicit InConditionalContext(TypeContext* c) : typeContext(c) , oldValue(*c) { @@ -269,8 +269,8 @@ bool isLiteral(const AstExpr* expr); std::vector findBlockedTypesIn(AstExprTable* expr, NotNull> astTypes); /** - * Given a function call and a mapping from expression to type, determine - * whether the type of any argument in said call in depends on a blocked types. + * Given a function call and a mapping from expression to type, determine + * whether the type of any argument in said call in depends on a blocked types. * This is used as a precondition for bidirectional inference: be warned that * the behavior of this algorithm is tightly coupled to that of bidirectional * inference. @@ -280,4 +280,13 @@ std::vector findBlockedTypesIn(AstExprTable* expr, NotNull findBlockedArgTypesIn(AstExprCall* expr, NotNull> astTypes); +/** + * Given a scope and a free type, find the closest parent that has a present + * `interiorFreeTypes` and append the given type to said list. This list will + * be generalized when the requiste `GeneralizationConstraint` is resolved. + * @param scope Initial scope this free type was attached to + * @param ty Free type to track. + */ +void trackInteriorFreeType(Scope* scope, TypeId ty); + } // namespace Luau diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index 0e5475a7..a9685462 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -85,6 +85,8 @@ struct GenericTypeVisitor { } + virtual ~GenericTypeVisitor() {} + virtual void cycle(TypeId) {} virtual void cycle(TypePackId) {} diff --git a/Analysis/src/AnyTypeSummary.cpp b/Analysis/src/AnyTypeSummary.cpp index e82592df..db50e3e9 100644 --- a/Analysis/src/AnyTypeSummary.cpp +++ b/Analysis/src/AnyTypeSummary.cpp @@ -177,7 +177,6 @@ void AnyTypeSummary::visit(const Scope* scope, AstStatReturn* ret, const Module* } } } - } void AnyTypeSummary::visit(const Scope* scope, AstStatLocal* local, const Module* module, NotNull builtinTypes) diff --git a/Analysis/src/AstJsonEncoder.cpp b/Analysis/src/AstJsonEncoder.cpp index b1fd18ac..dbc1b5d8 100644 --- a/Analysis/src/AstJsonEncoder.cpp +++ b/Analysis/src/AstJsonEncoder.cpp @@ -1161,6 +1161,19 @@ struct AstJsonEncoder : public AstVisitor ); } + bool visit(class AstTypeGroup* node) override + { + writeNode( + node, + "AstTypeGroup", + [&]() + { + write("type", node->type); + } + ); + return false; + } + bool visit(class AstTypeSingletonBool* node) override { writeNode( diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 93dabeae..815164d8 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -13,7 +13,7 @@ LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAGVARIABLE(LuauDocumentationAtPosition) +LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon) namespace Luau { @@ -43,11 +43,26 @@ struct AutocompleteNodeFinder : public AstVisitor bool visit(AstStat* stat) override { - if (stat->location.begin < pos && pos <= stat->location.end) + if (FFlag::LuauExtendStatEndPosWithSemicolon) { - ancestry.push_back(stat); - return true; + // Consider 'local myLocal = 4;|' and 'local myLocal = 4', where '|' is the cursor position. In both cases, the cursor position is equal + // to `AstStatLocal.location.end`. However, in the first case (semicolon), we are starting a new statement, whilst in the second case + // (no semicolon) we are still part of the AstStatLocal, hence the different comparison check. + if (stat->location.begin < pos && (stat->hasSemicolon ? pos < stat->location.end : pos <= stat->location.end)) + { + ancestry.push_back(stat); + return true; + } } + else + { + if (stat->location.begin < pos && pos <= stat->location.end) + { + ancestry.push_back(stat); + return true; + } + } + return false; } @@ -518,7 +533,6 @@ static std::optional getMetatableDocumentation( const AstName& index ) { - LUAU_ASSERT(FFlag::LuauDocumentationAtPosition); auto indexIt = mtable->props.find("__index"); if (indexIt == mtable->props.end()) return std::nullopt; @@ -575,26 +589,7 @@ std::optional getDocumentationSymbolAtPosition(const Source } else if (const ClassType* ctv = get(parentTy)) { - if (FFlag::LuauDocumentationAtPosition) - { - while (ctv) - { - if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) - { - if (FFlag::LuauSolverV2) - { - if (auto ty = propIt->second.readTy) - return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol); - } - else - return checkOverloadedDocumentationSymbol( - module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol - ); - } - ctv = ctv->parent ? Luau::get(*ctv->parent) : nullptr; - } - } - else + while (ctv) { if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) { @@ -608,17 +603,15 @@ std::optional getDocumentationSymbolAtPosition(const Source module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol ); } + ctv = ctv->parent ? Luau::get(*ctv->parent) : nullptr; } } - else if (FFlag::LuauDocumentationAtPosition) + else if (const PrimitiveType* ptv = get(parentTy); ptv && ptv->metatable) { - if (const PrimitiveType* ptv = get(parentTy); ptv && ptv->metatable) + if (auto mtable = get(*ptv->metatable)) { - if (auto mtable = get(*ptv->metatable)) - { - if (std::optional docSymbol = getMetatableDocumentation(module, parentExpr, mtable, indexName->index)) - return docSymbol; - } + if (std::optional docSymbol = getMetatableDocumentation(module, parentExpr, mtable, indexName->index)) + return docSymbol; } } } diff --git a/Analysis/src/AutocompleteCore.cpp b/Analysis/src/AutocompleteCore.cpp index 3e231acf..f7f19826 100644 --- a/Analysis/src/AutocompleteCore.cpp +++ b/Analysis/src/AutocompleteCore.cpp @@ -25,6 +25,7 @@ LUAU_FASTINT(LuauTypeInferIterationLimit) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAGVARIABLE(LuauAutocompleteRefactorsForIncrementalAutocomplete) +LUAU_FASTFLAGVARIABLE(LuauAutocompleteUseLimits) static const std::unordered_set kStatementStartingKeywords = {"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -150,6 +151,7 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, T { InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); + SimplifierPtr simplifier = newSimplifier(NotNull{typeArena}, builtinTypes); Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; if (FFlag::LuauSolverV2) @@ -162,7 +164,9 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, T unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; - Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}}; + Subtyping subtyping{ + builtinTypes, NotNull{typeArena}, NotNull{simplifier.get()}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter} + }; return subtyping.isSubtype(subTy, superTy, scope).isSubtype; } @@ -174,6 +178,12 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, T unifier.normalize = false; unifier.checkInhabited = false; + if (FFlag::LuauAutocompleteUseLimits) + { + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; + } + return unifier.canUnify(subTy, superTy).empty(); } } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 6306b5b1..7aee25ce 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -29,50 +29,81 @@ */ LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAGVARIABLE(LuauTypestateBuiltins2) -LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix) +LUAU_FASTFLAGVARIABLE(LuauStringFormatErrorSuppression) LUAU_FASTFLAG(AutocompleteRequirePathSuggestions2) -LUAU_FASTFLAG(LuauVectorDefinitionsExtra) +LUAU_FASTFLAGVARIABLE(LuauTableCloneClonesType3) +LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) +LUAU_FASTFLAGVARIABLE(LuauFreezeIgnorePersistent) namespace Luau { -static std::optional> magicFunctionSelect( - TypeChecker& typechecker, - const ScopePtr& scope, - const AstExprCall& expr, - WithPredicate withPredicate -); -static std::optional> magicFunctionSetMetaTable( - TypeChecker& typechecker, - const ScopePtr& scope, - const AstExprCall& expr, - WithPredicate withPredicate -); -static std::optional> magicFunctionAssert( - TypeChecker& typechecker, - const ScopePtr& scope, - const AstExprCall& expr, - WithPredicate withPredicate -); -static std::optional> magicFunctionPack( - TypeChecker& typechecker, - const ScopePtr& scope, - const AstExprCall& expr, - WithPredicate withPredicate -); -static std::optional> magicFunctionRequire( - TypeChecker& typechecker, - const ScopePtr& scope, - const AstExprCall& expr, - WithPredicate withPredicate -); +struct MagicSelect final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; +struct MagicSetMetatable final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; -static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); -static bool dcrMagicFunctionRequire(MagicFunctionCallContext context); -static bool dcrMagicFunctionPack(MagicFunctionCallContext context); -static bool dcrMagicFunctionFreeze(MagicFunctionCallContext context); +struct MagicAssert final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicPack final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicRequire final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicClone final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicFreeze final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicFormat final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; + bool typeCheck(const MagicFunctionTypeCheckContext& ctx) override; +}; + +struct MagicMatch final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicGmatch final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicFind final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; TypeId makeUnion(TypeArena& arena, std::vector&& types) { @@ -167,34 +198,10 @@ TypeId makeFunction( return arena.addType(std::move(ftv)); } -void attachMagicFunction(TypeId ty, MagicFunction fn) +void attachMagicFunction(TypeId ty, std::shared_ptr magic) { if (auto ftv = getMutable(ty)) - ftv->magicFunction = fn; - else - LUAU_ASSERT(!"Got a non functional type"); -} - -void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn) -{ - if (auto ftv = getMutable(ty)) - ftv->dcrMagicFunction = fn; - else - LUAU_ASSERT(!"Got a non functional type"); -} - -void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn) -{ - if (auto ftv = getMutable(ty)) - ftv->dcrMagicRefinement = fn; - else - LUAU_ASSERT(!"Got a non functional type"); -} - -void attachDcrMagicFunctionTypeCheck(TypeId ty, DcrMagicFunctionTypeCheck fn) -{ - if (auto ftv = getMutable(ty)) - ftv->dcrMagicTypeCheck = fn; + ftv->magic = std::move(magic); else LUAU_ASSERT(!"Got a non functional type"); } @@ -301,28 +308,25 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC addGlobalBinding(globals, "string", it->second.type(), "@luau"); // Setup 'vector' metatable - if (FFlag::LuauVectorDefinitionsExtra) + if (auto it = globals.globalScope->exportedTypeBindings.find("vector"); it != globals.globalScope->exportedTypeBindings.end()) { - if (auto it = globals.globalScope->exportedTypeBindings.find("vector"); it != globals.globalScope->exportedTypeBindings.end()) - { - TypeId vectorTy = it->second.type; - ClassType* vectorCls = getMutable(vectorTy); + TypeId vectorTy = it->second.type; + ClassType* vectorCls = getMutable(vectorTy); - vectorCls->metatable = arena.addType(TableType{{}, std::nullopt, TypeLevel{}, TableState::Sealed}); - TableType* metatableTy = Luau::getMutable(vectorCls->metatable); + vectorCls->metatable = arena.addType(TableType{{}, std::nullopt, TypeLevel{}, TableState::Sealed}); + TableType* metatableTy = Luau::getMutable(vectorCls->metatable); - metatableTy->props["__add"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})}; - metatableTy->props["__sub"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})}; - metatableTy->props["__unm"] = {makeFunction(arena, vectorTy, {}, {vectorTy})}; + metatableTy->props["__add"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})}; + metatableTy->props["__sub"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})}; + metatableTy->props["__unm"] = {makeFunction(arena, vectorTy, {}, {vectorTy})}; - std::initializer_list mulOverloads{ - makeFunction(arena, vectorTy, {vectorTy}, {vectorTy}), - makeFunction(arena, vectorTy, {builtinTypes->numberType}, {vectorTy}), - }; - metatableTy->props["__mul"] = {makeIntersection(arena, mulOverloads)}; - metatableTy->props["__div"] = {makeIntersection(arena, mulOverloads)}; - metatableTy->props["__idiv"] = {makeIntersection(arena, mulOverloads)}; - } + std::initializer_list mulOverloads{ + makeFunction(arena, vectorTy, {vectorTy}, {vectorTy}), + makeFunction(arena, vectorTy, {builtinTypes->numberType}, {vectorTy}), + }; + metatableTy->props["__mul"] = {makeIntersection(arena, mulOverloads)}; + metatableTy->props["__div"] = {makeIntersection(arena, mulOverloads)}; + metatableTy->props["__idiv"] = {makeIntersection(arena, mulOverloads)}; } // next(t: Table, i: K?) -> (K?, V) @@ -395,7 +399,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC } } - attachMagicFunction(getGlobalBinding(globals, "assert"), magicFunctionAssert); + attachMagicFunction(getGlobalBinding(globals, "assert"), std::make_shared()); if (FFlag::LuauSolverV2) { @@ -411,9 +415,8 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC addGlobalBinding(globals, "assert", assertTy, "@luau"); } - attachMagicFunction(getGlobalBinding(globals, "setmetatable"), magicFunctionSetMetaTable); - attachMagicFunction(getGlobalBinding(globals, "select"), magicFunctionSelect); - attachDcrMagicFunction(getGlobalBinding(globals, "select"), dcrMagicFunctionSelect); + attachMagicFunction(getGlobalBinding(globals, "setmetatable"), std::make_shared()); + attachMagicFunction(getGlobalBinding(globals, "select"), std::make_shared()); if (TableType* ttv = getMutable(getGlobalBinding(globals, "table"))) { @@ -444,23 +447,21 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC ttv->props["foreach"].deprecated = true; ttv->props["foreachi"].deprecated = true; - attachMagicFunction(ttv->props["pack"].type(), magicFunctionPack); - attachDcrMagicFunction(ttv->props["pack"].type(), dcrMagicFunctionPack); - if (FFlag::LuauTypestateBuiltins2) - attachDcrMagicFunction(ttv->props["freeze"].type(), dcrMagicFunctionFreeze); + attachMagicFunction(ttv->props["pack"].type(), std::make_shared()); + if (FFlag::LuauTableCloneClonesType3) + attachMagicFunction(ttv->props["clone"].type(), std::make_shared()); + attachMagicFunction(ttv->props["freeze"].type(), std::make_shared()); } if (FFlag::AutocompleteRequirePathSuggestions2) { TypeId requireTy = getGlobalBinding(globals, "require"); attachTag(requireTy, kRequireTagName); - attachMagicFunction(requireTy, magicFunctionRequire); - attachDcrMagicFunction(requireTy, dcrMagicFunctionRequire); + attachMagicFunction(requireTy, std::make_shared()); } else { - attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); - attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); + attachMagicFunction(getGlobalBinding(globals, "require"), std::make_shared()); } } @@ -500,7 +501,7 @@ static std::vector parseFormatString(NotNull builtinTypes, return result; } -std::optional> magicFunctionFormat( +std::optional> MagicFormat::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -550,7 +551,7 @@ std::optional> magicFunctionFormat( return WithPredicate{arena.addTypePack({typechecker.stringType})}; } -static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) +bool MagicFormat::infer(const MagicFunctionCallContext& context) { TypeArena* arena = context.solver->arena; @@ -594,7 +595,7 @@ static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) return true; } -static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext context) +bool MagicFormat::typeCheck(const MagicFunctionTypeCheckContext& context) { AstExprConstantString* fmt = nullptr; if (auto index = context.callSite->func->as(); index && context.callSite->self) @@ -610,9 +611,8 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex if (!fmt) { - if (FFlag::LuauStringFormatArityFix) - context.typechecker->reportError(CountMismatch{1, std::nullopt, 0, CountMismatch::Arg, true, "string.format"}, context.callSite->location); - return; + context.typechecker->reportError(CountMismatch{1, std::nullopt, 0, CountMismatch::Arg, true, "string.format"}, context.callSite->location); + return true; } std::vector expected = parseFormatString(context.builtinTypes, fmt->value.data, fmt->value.size); @@ -629,12 +629,33 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex Location location = context.callSite->args.data[i + (calledWithSelf ? 0 : paramOffset)]->location; // use subtyping instead here SubtypingResult result = context.typechecker->subtyping->isSubtype(actualTy, expectedTy, context.checkScope); + if (!result.isSubtype) { - Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result); - context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location); + if (FFlag::LuauStringFormatErrorSuppression) + { + switch (shouldSuppressErrors(NotNull{&context.typechecker->normalizer}, actualTy)) + { + case ErrorSuppression::Suppress: + break; + case ErrorSuppression::NormalizationFailed: + break; + case ErrorSuppression::DoNotSuppress: + Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result); + + if (!reasonings.suppressed) + context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location); + } + } + else + { + Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result); + context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location); + } } } + + return true; } static std::vector parsePatternString(NotNull builtinTypes, const char* data, size_t size) @@ -697,7 +718,7 @@ static std::vector parsePatternString(NotNull builtinTypes return result; } -static std::optional> magicFunctionGmatch( +std::optional> MagicGmatch::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -733,7 +754,7 @@ static std::optional> magicFunctionGmatch( return WithPredicate{arena.addTypePack({iteratorType})}; } -static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) +bool MagicGmatch::infer(const MagicFunctionCallContext& context) { const auto& [params, tail] = flatten(context.arguments); @@ -766,7 +787,7 @@ static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) return true; } -static std::optional> magicFunctionMatch( +std::optional> MagicMatch::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -806,7 +827,7 @@ static std::optional> magicFunctionMatch( return WithPredicate{returnList}; } -static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) +bool MagicMatch::infer(const MagicFunctionCallContext& context) { const auto& [params, tail] = flatten(context.arguments); @@ -842,7 +863,7 @@ static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) return true; } -static std::optional> magicFunctionFind( +std::optional> MagicFind::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -900,7 +921,7 @@ static std::optional> magicFunctionFind( return WithPredicate{returnList}; } -static bool dcrMagicFunctionFind(MagicFunctionCallContext context) +bool MagicFind::infer(const MagicFunctionCallContext& context) { const auto& [params, tail] = flatten(context.arguments); @@ -977,11 +998,9 @@ TypeId makeStringMetatable(NotNull builtinTypes) FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack}; - formatFTV.magicFunction = &magicFunctionFormat; formatFTV.isCheckedFunction = true; const TypeId formatFn = arena->addType(formatFTV); - attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); - attachDcrMagicFunctionTypeCheck(formatFn, dcrMagicFunctionTypeCheckFormat); + attachMagicFunction(formatFn, std::make_shared()); const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true); @@ -995,16 +1014,14 @@ TypeId makeStringMetatable(NotNull builtinTypes) makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false); const TypeId gmatchFunc = makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true); - attachMagicFunction(gmatchFunc, magicFunctionGmatch); - attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); + attachMagicFunction(gmatchFunc, std::make_shared()); FunctionType matchFuncTy{ arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}) }; matchFuncTy.isCheckedFunction = true; const TypeId matchFunc = arena->addType(matchFuncTy); - attachMagicFunction(matchFunc, magicFunctionMatch); - attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); + attachMagicFunction(matchFunc, std::make_shared()); FunctionType findFuncTy{ arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), @@ -1012,8 +1029,7 @@ TypeId makeStringMetatable(NotNull builtinTypes) }; findFuncTy.isCheckedFunction = true; const TypeId findFunc = arena->addType(findFuncTy); - attachMagicFunction(findFunc, magicFunctionFind); - attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); + attachMagicFunction(findFunc, std::make_shared()); // string.byte : string -> number? -> number? -> ...number FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList}; @@ -1074,7 +1090,7 @@ TypeId makeStringMetatable(NotNull builtinTypes) return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } -static std::optional> magicFunctionSelect( +std::optional> MagicSelect::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -1119,7 +1135,7 @@ static std::optional> magicFunctionSelect( return std::nullopt; } -static bool dcrMagicFunctionSelect(MagicFunctionCallContext context) +bool MagicSelect::infer(const MagicFunctionCallContext& context) { if (context.callSite->args.size <= 0) { @@ -1164,7 +1180,7 @@ static bool dcrMagicFunctionSelect(MagicFunctionCallContext context) return false; } -static std::optional> magicFunctionSetMetaTable( +std::optional> MagicSetMetatable::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -1246,7 +1262,12 @@ static std::optional> magicFunctionSetMetaTable( return WithPredicate{arena.addTypePack({target})}; } -static std::optional> magicFunctionAssert( +bool MagicSetMetatable::infer(const MagicFunctionCallContext&) +{ + return false; +} + +std::optional> MagicAssert::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -1280,7 +1301,12 @@ static std::optional> magicFunctionAssert( return WithPredicate{arena.addTypePack(TypePack{std::move(head), tail})}; } -static std::optional> magicFunctionPack( +bool MagicAssert::infer(const MagicFunctionCallContext&) +{ + return false; +} + +std::optional> MagicPack::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -1323,7 +1349,7 @@ static std::optional> magicFunctionPack( return WithPredicate{arena.addTypePack({packedTable})}; } -static bool dcrMagicFunctionPack(MagicFunctionCallContext context) +bool MagicPack::infer(const MagicFunctionCallContext& context) { TypeArena* arena = context.solver->arena; @@ -1363,7 +1389,74 @@ static bool dcrMagicFunctionPack(MagicFunctionCallContext context) return true; } -static std::optional freezeTable(TypeId inputType, MagicFunctionCallContext& context) +std::optional> MagicClone::handleOldSolver( + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) +{ + LUAU_ASSERT(FFlag::LuauTableCloneClonesType3); + + auto [paramPack, _predicates] = withPredicate; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + const auto& [paramTypes, paramTail] = flatten(paramPack); + if (paramTypes.empty() || expr.args.size == 0) + { + typechecker.reportError(expr.argLocation, CountMismatch{1, std::nullopt, 0}); + return std::nullopt; + } + + TypeId inputType = follow(paramTypes[0]); + + if (!get(inputType)) + return std::nullopt; + + CloneState cloneState{typechecker.builtinTypes}; + TypeId resultType = shallowClone(inputType, arena, cloneState); + + TypePackId clonedTypePack = arena.addTypePack({resultType}); + return WithPredicate{clonedTypePack}; +} + +bool MagicClone::infer(const MagicFunctionCallContext& context) +{ + LUAU_ASSERT(FFlag::LuauTableCloneClonesType3); + + TypeArena* arena = context.solver->arena; + + const auto& [paramTypes, paramTail] = flatten(context.arguments); + if (paramTypes.empty() || context.callSite->args.size == 0) + { + context.solver->reportError(CountMismatch{1, std::nullopt, 0}, context.callSite->argLocation); + return false; + } + + TypeId inputType = follow(paramTypes[0]); + + if (!get(inputType)) + return false; + + CloneState cloneState{context.solver->builtinTypes}; + TypeId resultType = shallowClone(inputType, *arena, cloneState, /* ignorePersistent */ FFlag::LuauFreezeIgnorePersistent); + + if (auto tableType = getMutable(resultType)) + { + tableType->scope = context.constraint->scope.get(); + } + + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + trackInteriorFreeType(context.constraint->scope.get(), resultType); + + TypePackId clonedTypePack = arena->addTypePack({resultType}); + asMutable(context.result)->ty.emplace(clonedTypePack); + + return true; +} + +static std::optional freezeTable(TypeId inputType, const MagicFunctionCallContext& context) { TypeArena* arena = context.solver->arena; @@ -1383,7 +1476,7 @@ static std::optional freezeTable(TypeId inputType, MagicFunctionCallCont { // Clone the input type, this will become our final result type after we mutate it. CloneState cloneState{context.solver->builtinTypes}; - TypeId resultType = shallowClone(inputType, *arena, cloneState); + TypeId resultType = shallowClone(inputType, *arena, cloneState, /* ignorePersistent */ FFlag::LuauFreezeIgnorePersistent); auto tableTy = getMutable(resultType); // `clone` should not break this. LUAU_ASSERT(tableTy); @@ -1408,10 +1501,13 @@ static std::optional freezeTable(TypeId inputType, MagicFunctionCallCont return std::nullopt; } -static bool dcrMagicFunctionFreeze(MagicFunctionCallContext context) +std::optional> MagicFreeze::handleOldSolver(struct TypeChecker &, const std::shared_ptr &, const class AstExprCall &, WithPredicate) { - LUAU_ASSERT(FFlag::LuauTypestateBuiltins2); + return std::nullopt; +} +bool MagicFreeze::infer(const MagicFunctionCallContext& context) +{ TypeArena* arena = context.solver->arena; const DataFlowGraph* dfg = context.solver->dfg.get(); Scope* scope = context.constraint->scope.get(); @@ -1469,7 +1565,7 @@ static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) return good; } -static std::optional> magicFunctionRequire( +std::optional> MagicRequire::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -1515,7 +1611,7 @@ static bool checkRequirePathDcr(NotNull solver, AstExpr* expr) return good; } -static bool dcrMagicFunctionRequire(MagicFunctionCallContext context) +bool MagicRequire::infer(const MagicFunctionCallContext& context) { if (context.callSite->args.size != 1) { diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 98397fa3..6309fa7c 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -7,6 +7,7 @@ #include "Luau/Unifiable.h" LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauFreezeIgnorePersistent) // For each `Luau::clone` call, we will clone only up to N amount of types _and_ packs, as controlled by this limit. LUAU_FASTINTVARIABLE(LuauTypeCloneIterationLimit, 100'000) @@ -38,14 +39,26 @@ class TypeCloner NotNull types; NotNull packs; + TypeId forceTy = nullptr; + TypePackId forceTp = nullptr; + int steps = 0; public: - TypeCloner(NotNull arena, NotNull builtinTypes, NotNull types, NotNull packs) + TypeCloner( + NotNull arena, + NotNull builtinTypes, + NotNull types, + NotNull packs, + TypeId forceTy, + TypePackId forceTp + ) : arena(arena) , builtinTypes(builtinTypes) , types(types) , packs(packs) + , forceTy(forceTy) + , forceTp(forceTp) { } @@ -112,7 +125,7 @@ private: ty = follow(ty, FollowOption::DisableLazyTypeThunks); if (auto it = types->find(ty); it != types->end()) return it->second; - else if (ty->persistent) + else if (ty->persistent && (!FFlag::LuauFreezeIgnorePersistent || ty != forceTy)) return ty; return std::nullopt; } @@ -122,7 +135,7 @@ private: tp = follow(tp); if (auto it = packs->find(tp); it != packs->end()) return it->second; - else if (tp->persistent) + else if (tp->persistent && (!FFlag::LuauFreezeIgnorePersistent || tp != forceTp)) return tp; return std::nullopt; } @@ -148,7 +161,7 @@ public: if (auto clone = find(ty)) return *clone; - else if (ty->persistent) + else if (ty->persistent && (!FFlag::LuauFreezeIgnorePersistent || ty != forceTy)) return ty; TypeId target = arena->addType(ty->ty); @@ -174,7 +187,7 @@ public: if (auto clone = find(tp)) return *clone; - else if (tp->persistent) + else if (tp->persistent && (!FFlag::LuauFreezeIgnorePersistent || tp != forceTp)) return tp; TypePackId target = arena->addTypePack(tp->ty); @@ -458,21 +471,37 @@ private: } // namespace -TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState) +TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState, bool ignorePersistent) { - if (tp->persistent) + if (tp->persistent && (!FFlag::LuauFreezeIgnorePersistent || !ignorePersistent)) return tp; - TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + TypeCloner cloner{ + NotNull{&dest}, + cloneState.builtinTypes, + NotNull{&cloneState.seenTypes}, + NotNull{&cloneState.seenTypePacks}, + nullptr, + FFlag::LuauFreezeIgnorePersistent && ignorePersistent ? tp : nullptr + }; + return cloner.shallowClone(tp); } -TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState) +TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState, bool ignorePersistent) { - if (typeId->persistent) + if (typeId->persistent && (!FFlag::LuauFreezeIgnorePersistent || !ignorePersistent)) return typeId; - TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + TypeCloner cloner{ + NotNull{&dest}, + cloneState.builtinTypes, + NotNull{&cloneState.seenTypes}, + NotNull{&cloneState.seenTypePacks}, + FFlag::LuauFreezeIgnorePersistent && ignorePersistent ? typeId : nullptr, + nullptr + }; + return cloner.shallowClone(typeId); } @@ -481,7 +510,7 @@ TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) if (tp->persistent) return tp; - TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr}; return cloner.clone(tp); } @@ -490,13 +519,13 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) if (typeId->persistent) return typeId; - TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr}; return cloner.clone(typeId); } TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) { - TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr}; TypeFun copy = typeFun; @@ -521,4 +550,18 @@ TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) return copy; } +Binding clone(const Binding& binding, TypeArena& dest, CloneState& cloneState) +{ + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr}; + + Binding b; + b.deprecated = binding.deprecated; + b.deprecatedSuggestion = binding.deprecatedSuggestion; + b.documentationSymbol = binding.documentationSymbol; + b.location = binding.location; + b.typeId = cloner.clone(binding.typeId); + + return b; +} + } // namespace Luau diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index a0b5fcf4..b0f7c432 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -3,8 +3,6 @@ #include "Luau/Constraint.h" #include "Luau/VisitType.h" -LUAU_FASTFLAGVARIABLE(LuauDontRefCountTypesInTypeFunctions) - namespace Luau { @@ -60,9 +58,8 @@ struct ReferenceCountInitializer : TypeOnceVisitor // // The default behavior here is `true` for "visit the child types" // of this type, hence: - return !FFlag::LuauDontRefCountTypesInTypeFunctions; + return false; } - }; bool isReferenceCountedType(const TypeId typ) diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index ed3d8a6d..f77d7944 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -31,13 +31,14 @@ LUAU_FASTINT(LuauCheckRecursionLimit) LUAU_FASTFLAG(DebugLuauLogSolverToJson) LUAU_FASTFLAG(DebugLuauMagicTypes) -LUAU_FASTFLAG(DebugLuauEqSatSimplification) -LUAU_FASTFLAG(LuauTypestateBuiltins2) +LUAU_FASTFLAG(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) -LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues) LUAU_FASTFLAGVARIABLE(LuauNewSolverPrePopulateClasses) -LUAU_FASTFLAGVARIABLE(LuauUserTypeFunExportedAndLocal) LUAU_FASTFLAGVARIABLE(LuauNewSolverPopulateTableLocations) +LUAU_FASTFLAGVARIABLE(LuauTrackInteriorFreeTypesOnScope) + +LUAU_FASTFLAGVARIABLE(InferGlobalTypes) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -229,8 +230,17 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) Checkpoint end = checkpoint(this); TypeId result = arena->addType(BlockedType{}); - NotNull genConstraint = - addConstraint(scope, block->location, GeneralizationConstraint{result, moduleFnTy, std::move(interiorTypes.back())}); + NotNull genConstraint = addConstraint( + scope, + block->location, + GeneralizationConstraint{ + result, moduleFnTy, FFlag::LuauTrackInteriorFreeTypesOnScope ? std::vector{} : std::move(interiorTypes.back()) + } + ); + + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + scope->interiorFreeTypes = std::move(interiorTypes.back()); + getMutable(result)->setOwner(genConstraint); forEachConstraint( start, @@ -299,9 +309,19 @@ void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStat } } + TypeId ConstraintGenerator::freshType(const ScopePtr& scope) { - return Luau::freshType(arena, builtinTypes, scope.get()); + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + { + auto ft = Luau::freshType(arena, builtinTypes, scope.get()); + interiorTypes.back().push_back(ft); + return ft; + } + else + { + return Luau::freshType(arena, builtinTypes, scope.get()); + } } TypePackId ConstraintGenerator::freshTypePack(const ScopePtr& scope) @@ -720,12 +740,6 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc continue; } - if (!FFlag::LuauUserTypeFunExportedAndLocal && scope->parent != globalScope) - { - reportError(function->location, GenericError{"Local user-defined functions are not supported yet"}); - continue; - } - ScopePtr defnScope = childScope(function, scope); // Create TypeFunctionInstanceType @@ -751,11 +765,8 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc UserDefinedFunctionData udtfData; - if (FFlag::LuauUserTypeFunExportedAndLocal) - { - udtfData.owner = module; - udtfData.definition = function; - } + udtfData.owner = module; + udtfData.definition = function; TypeId typeFunctionTy = arena->addType( TypeFunctionInstanceType{NotNull{&builtinTypeFunctions().userFunc}, std::move(typeParams), {}, function->name, udtfData} @@ -764,7 +775,7 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy}; // Set type bindings and definition locations for this user-defined type function - if (FFlag::LuauUserTypeFunExportedAndLocal && function->exported) + if (function->exported) scope->exportedTypeBindings[function->name.value] = std::move(typeFunction); else scope->privateTypeBindings[function->name.value] = std::move(typeFunction); @@ -799,49 +810,74 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc } } - if (FFlag::LuauUserTypeFunExportedAndLocal) + // Additional pass for user-defined type functions to fill in their environments completely + for (AstStat* stat : block->body) { - // Additional pass for user-defined type functions to fill in their environments completely - for (AstStat* stat : block->body) + if (auto function = stat->as()) { - if (auto function = stat->as()) + // Find the type function we have already created + TypeFunctionInstanceType* mainTypeFun = nullptr; + + if (auto it = scope->privateTypeBindings.find(function->name.value); it != scope->privateTypeBindings.end()) + mainTypeFun = getMutable(it->second.type); + + if (!mainTypeFun) { - // Find the type function we have already created - TypeFunctionInstanceType* mainTypeFun = nullptr; - - if (auto it = scope->privateTypeBindings.find(function->name.value); it != scope->privateTypeBindings.end()) + if (auto it = scope->exportedTypeBindings.find(function->name.value); it != scope->exportedTypeBindings.end()) mainTypeFun = getMutable(it->second.type); + } - if (!mainTypeFun) + // Fill it with all visible type functions + if (mainTypeFun) + { + UserDefinedFunctionData& userFuncData = mainTypeFun->userFuncData; + size_t level = 0; + + for (Scope* curr = scope.get(); curr; curr = curr->parent.get()) { - if (auto it = scope->exportedTypeBindings.find(function->name.value); it != scope->exportedTypeBindings.end()) - mainTypeFun = getMutable(it->second.type); - } - - // Fill it with all visible type functions - if (mainTypeFun) - { - UserDefinedFunctionData& userFuncData = mainTypeFun->userFuncData; - - for (Scope* curr = scope.get(); curr; curr = curr->parent.get()) + for (auto& [name, tf] : curr->privateTypeBindings) { - for (auto& [name, tf] : curr->privateTypeBindings) - { - if (userFuncData.environment.find(name)) - continue; + if (userFuncData.environment.find(name)) + continue; - if (auto ty = get(tf.type); ty && ty->userFuncData.definition) - userFuncData.environment[name] = ty->userFuncData.definition; - } + if (auto ty = get(tf.type); ty && ty->userFuncData.definition) + userFuncData.environment[name] = std::make_pair(ty->userFuncData.definition, level); + } - for (auto& [name, tf] : curr->exportedTypeBindings) - { - if (userFuncData.environment.find(name)) - continue; + for (auto& [name, tf] : curr->exportedTypeBindings) + { + if (userFuncData.environment.find(name)) + continue; - if (auto ty = get(tf.type); ty && ty->userFuncData.definition) - userFuncData.environment[name] = ty->userFuncData.definition; - } + if (auto ty = get(tf.type); ty && ty->userFuncData.definition) + userFuncData.environment[name] = std::make_pair(ty->userFuncData.definition, level); + } + + level++; + } + } + else if (mainTypeFun) + { + UserDefinedFunctionData& userFuncData = mainTypeFun->userFuncData; + + for (Scope* curr = scope.get(); curr; curr = curr->parent.get()) + { + for (auto& [name, tf] : curr->privateTypeBindings) + { + if (userFuncData.environment_DEPRECATED.find(name)) + continue; + + if (auto ty = get(tf.type); ty && ty->userFuncData.definition) + userFuncData.environment_DEPRECATED[name] = ty->userFuncData.definition; + } + + for (auto& [name, tf] : curr->exportedTypeBindings) + { + if (userFuncData.environment_DEPRECATED.find(name)) + continue; + + if (auto ty = get(tf.type); ty && ty->userFuncData.definition) + userFuncData.environment_DEPRECATED[name] = ty->userFuncData.definition; } } } @@ -1053,18 +1089,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); else if (const AstExprCall* call = value->as()) { - if (FFlag::LuauTypestateBuiltins2) - { - if (matchSetMetatable(*call)) - addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); - } - else - { - if (const AstExprGlobal* global = call->func->as(); global && global->name == "setmetatable") - { - addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); - } - } + if (matchSetMetatable(*call)) + addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); } } @@ -1571,20 +1597,6 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeAlias* ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeFunction* function) { - // If a type function with the same name was already defined, we skip over - auto bindingIt = scope->privateTypeBindings.find(function->name.value); - if (bindingIt == scope->privateTypeBindings.end()) - return ControlFlow::None; - - TypeFun typeFunction = bindingIt->second; - - // Adding typeAliasExpansionConstraint on user-defined type function for the constraint solver - if (auto typeFunctionTy = get(follow(typeFunction.type))) - { - TypeId expansionTy = arena->addType(PendingExpansionType{{}, function->name, typeFunctionTy->typeArguments, typeFunctionTy->packArguments}); - addConstraint(scope, function->location, TypeAliasExpansionConstraint{/* target */ expansionTy}); - } - return ControlFlow::None; } @@ -2047,7 +2059,7 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* return InferencePack{arena->addTypePack({resultTy}), {refinementArena.variadic(returnRefinements)}}; } - if (FFlag::LuauTypestateBuiltins2 && shouldTypestateForFirstArgument(*call) && call->args.size > 0 && isLValue(call->args.data[0])) + if (shouldTypestateForFirstArgument(*call) && call->args.size > 0 && isLValue(call->args.data[0])) { AstExpr* targetExpr = call->args.data[0]; auto resultTy = arena->addType(BlockedType{}); @@ -2196,7 +2208,8 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantStrin if (forceSingleton) return Inference{arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}})}; - FreeType ft = FreeType{scope.get()}; + FreeType ft = + FFlag::LuauFreeTypesMustHaveBounds ? FreeType{scope.get(), builtinTypes->neverType, builtinTypes->unknownType} : FreeType{scope.get()}; ft.lowerBound = arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}}); ft.upperBound = builtinTypes->stringType; const TypeId freeTy = arena->addType(ft); @@ -2210,7 +2223,8 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantBool* if (forceSingleton) return Inference{singletonType}; - FreeType ft = FreeType{scope.get()}; + FreeType ft = + FFlag::LuauFreeTypesMustHaveBounds ? FreeType{scope.get(), builtinTypes->neverType, builtinTypes->unknownType} : FreeType{scope.get()}; ft.lowerBound = singletonType; ft.upperBound = builtinTypes->booleanType; const TypeId freeTy = arena->addType(ft); @@ -2372,8 +2386,17 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprFunction* fun Checkpoint endCheckpoint = checkpoint(this); TypeId generalizedTy = arena->addType(BlockedType{}); - NotNull gc = - addConstraint(sig.signatureScope, func->location, GeneralizationConstraint{generalizedTy, sig.signature, std::move(interiorTypes.back())}); + NotNull gc = addConstraint( + sig.signatureScope, + func->location, + GeneralizationConstraint{ + generalizedTy, sig.signature, FFlag::LuauTrackInteriorFreeTypesOnScope ? std::vector{} : std::move(interiorTypes.back()) + } + ); + + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + sig.signatureScope->interiorFreeTypes = std::move(interiorTypes.back()); + getMutable(generalizedTy)->setOwner(gc); interiorTypes.pop_back(); @@ -2721,15 +2744,12 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExpr* expr, Type visitLValue(scope, e, rhsType); else if (auto e = expr->as()) { - if (FFlag::LuauNewSolverVisitErrorExprLvalues) + // If we end up with some sort of error expression in an lvalue + // position, at least go and check the expressions so that when + // we visit them later, there aren't any invalid assumptions. + for (auto subExpr : e->expressions) { - // If we end up with some sort of error expression in an lvalue - // position, at least go and check the expressions so that when - // we visit them later, there aren't any invalid assumptions. - for (auto subExpr : e->expressions) - { - check(scope, subExpr); - } + check(scope, subExpr); } } else @@ -2790,6 +2810,14 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* glob DefId def = dfg->getDef(global); rootScope->lvalueTypes[def] = rhsType; + if (FFlag::InferGlobalTypes) + { + // Sketchy: We're specifically looking for BlockedTypes that were + // initially created by ConstraintGenerator::prepopulateGlobalScope. + if (auto bt = get(follow(*annotatedTy)); bt && !bt->getOwner()) + emplaceType(asMutable(*annotatedTy), rhsType); + } + addConstraint(scope, global->location, SubtypeConstraint{rhsType, *annotatedTy}); } } @@ -2931,11 +2959,11 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, ty, expr, toBlock - ); - // The visitor we ran prior should ensure that there are no - // blocked types that we would encounter while matching on - // this expression. - LUAU_ASSERT(toBlock.empty()); + ); + // The visitor we ran prior should ensure that there are no + // blocked types that we would encounter while matching on + // this expression. + LUAU_ASSERT(toBlock.empty()); } } @@ -3182,9 +3210,8 @@ TypeId ConstraintGenerator::resolveReferenceType( if (alias.has_value()) { - // If the alias is not generic, we don't need to set up a blocked - // type and an instantiation constraint. - if (alias.has_value() && alias->typeParams.empty() && alias->typePackParams.empty()) + // If the alias is not generic, we don't need to set up a blocked type and an instantiation constraint + if (alias.has_value() && alias->typeParams.empty() && alias->typePackParams.empty() && !ref->hasParameterList) { result = alias->type; } @@ -3393,6 +3420,12 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool } else if (auto unionAnnotation = ty->as()) { + if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) + { + if (unionAnnotation->types.size == 1) + return resolveType(scope, unionAnnotation->types.data[0], inTypeArguments); + } + std::vector parts; for (AstType* part : unionAnnotation->types) { @@ -3403,6 +3436,12 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool } else if (auto intersectionAnnotation = ty->as()) { + if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) + { + if (intersectionAnnotation->types.size == 1) + return resolveType(scope, intersectionAnnotation->types.data[0], inTypeArguments); + } + std::vector parts; for (AstType* part : intersectionAnnotation->types) { @@ -3411,6 +3450,10 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool result = arena->addType(IntersectionType{parts}); } + else if (auto typeGroupAnnotation = ty->as()) + { + result = resolveType(scope, typeGroupAnnotation->type, inTypeArguments); + } else if (auto boolAnnotation = ty->as()) { if (boolAnnotation->value) @@ -3694,6 +3737,26 @@ struct GlobalPrepopulator : AstVisitor return true; } + bool visit(AstStatAssign* assign) override + { + if (FFlag::InferGlobalTypes) + { + for (const Luau::AstExpr* expr : assign->vars) + { + if (const AstExprGlobal* g = expr->as()) + { + if (!globalScope->lookup(g->name)) + globalScope->globalsToWarn.insert(g->name.value); + + TypeId bt = arena->addType(BlockedType{}); + globalScope->bindings[g->name] = Binding{bt, g->location}; + } + } + } + + return true; + } + bool visit(AstStatFunction* function) override { if (AstExprGlobal* g = function->name->as()) @@ -3877,20 +3940,7 @@ TypeId ConstraintGenerator::createTypeFunctionInstance( TypeId ConstraintGenerator::simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right) { - if (FFlag::DebugLuauEqSatSimplification) - { - TypeId ty = arena->addType(UnionType{{left, right}}); - std::optional res = eqSatSimplify(simplifier, ty); - if (!res) - return ty; - - for (TypeId tyFun : res->newTypeFunctions) - addConstraint(scope, location, ReduceConstraint{tyFun}); - - return res->result; - } - else - return ::Luau::simplifyUnion(builtinTypes, arena, left, right).result; + return ::Luau::simplifyUnion(builtinTypes, arena, left, right).result; } std::vector> borrowConstraints(const std::vector& constraints) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index d18c61cb..cb2f6bbf 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -31,10 +31,12 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverIncludeDependencies) LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings) LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500) -LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack) LUAU_FASTFLAGVARIABLE(DebugLuauEqSatSimplification) LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations) LUAU_FASTFLAGVARIABLE(LuauAllowNilAssignmentToIndexer) +LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) +LUAU_FASTFLAGVARIABLE(LuauAlwaysFillInFunctionCallDiscriminantTypes) +LUAU_FASTFLAGVARIABLE(LuauTrackInteriorFreeTablesOnScope) namespace Luau { @@ -72,7 +74,7 @@ size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const { if (auto blocked = get(ty)) { - Constraint* owner = blocked->getOwner(); + const Constraint* owner = blocked->getOwner(); LUAU_ASSERT(owner); return owner == constraint; } @@ -443,7 +445,7 @@ void ConstraintSolver::run() if (success) { unblock(c); - unsolvedConstraints.erase(unsolvedConstraints.begin() + i); + unsolvedConstraints.erase(unsolvedConstraints.begin() + ptrdiff_t(i)); // decrement the referenced free types for this constraint if we dispatched successfully! for (auto ty : c->getMaybeMutatedFreeTypes()) @@ -550,7 +552,7 @@ void ConstraintSolver::finalizeTypeFunctions() } } -bool ConstraintSolver::isDone() +bool ConstraintSolver::isDone() const { return unsolvedConstraints.empty(); } @@ -723,8 +725,20 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNullerrorRecoveryType()); } - for (TypeId ty : c.interiorTypes) - generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty, /* avoidSealingTables */ false); + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + { + // We check if this member is initialized and then access it, but + // clang-tidy doesn't understand this is safe. + if (constraint->scope->interiorFreeTypes) + for (TypeId ty : *constraint->scope->interiorFreeTypes) // NOLINT(bugprone-unchecked-optional-access) + generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty, /* avoidSealingTables */ false); + } + else + { + for (TypeId ty : c.interiorTypes) + generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty, /* avoidSealingTables */ false); + } + return true; } @@ -800,9 +814,17 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullscope); TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + { + trackInteriorFreeType(constraint->scope, keyTy); + trackInteriorFreeType(constraint->scope, valueTy); + } TypeId tableTy = arena->addType(TableType{TableType::Props{}, TableIndexer{keyTy, valueTy}, TypeLevel{}, constraint->scope, TableState::Free}); + if (FFlag::LuauTrackInteriorFreeTypesOnScope && FFlag::LuauTrackInteriorFreeTablesOnScope) + trackInteriorFreeType(constraint->scope, tableTy); + unify(constraint, nextTy, tableTy); auto it = begin(c.variables); @@ -939,14 +961,6 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul if (auto typeFn = get(follow(tf->type))) pushConstraint(NotNull(constraint->scope.get()), constraint->location, ReduceConstraint{tf->type}); - // If there are no parameters to the type function we can just use the type - // directly. - if (tf->typeParams.empty() && tf->typePackParams.empty()) - { - bindResult(tf->type); - return true; - } - // Due to how pending expansion types and TypeFun's are created // If this check passes, we have created a cyclic / corecursive type alias // of size 0 @@ -959,6 +973,13 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul return true; } + // If there are no parameters to the type function we can just use the type directly + if (tf->typeParams.empty() && tf->typePackParams.empty()) + { + bindResult(tf->type); + return true; + } + auto [typeArguments, packArguments] = saturateArguments(arena, builtinTypes, *tf, petv->typeArguments, petv->packArguments); bool sameTypes = std::equal( @@ -1122,6 +1143,28 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul return true; } +void ConstraintSolver::fillInDiscriminantTypes( + NotNull constraint, + const std::vector>& discriminantTypes +) +{ + for (std::optional ty : discriminantTypes) + { + if (!ty) + continue; + + // If the discriminant type has been transmuted, we need to unblock them. + if (!isBlocked(*ty)) + { + unblock(*ty, constraint->location); + continue; + } + + // We bind any unused discriminants to the `*no-refine*` type indicating that it can be safely ignored. + emplaceType(asMutable(follow(*ty)), builtinTypes->noRefineType); + } +} + bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull constraint) { TypeId fn = follow(c.fn); @@ -1137,6 +1180,8 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(asMutable(c.result), builtinTypes->anyTypePack); unblock(c.result, constraint->location); + if (FFlag::LuauAlwaysFillInFunctionCallDiscriminantTypes) + fillInDiscriminantTypes(constraint, c.discriminantTypes); return true; } @@ -1144,12 +1189,16 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(fn)) { bind(constraint, c.result, builtinTypes->errorRecoveryTypePack()); + if (FFlag::LuauAlwaysFillInFunctionCallDiscriminantTypes) + fillInDiscriminantTypes(constraint, c.discriminantTypes); return true; } if (get(fn)) { bind(constraint, c.result, builtinTypes->neverTypePack); + if (FFlag::LuauAlwaysFillInFunctionCallDiscriminantTypes) + fillInDiscriminantTypes(constraint, c.discriminantTypes); return true; } @@ -1219,50 +1268,46 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulldcrMagicFunction) - usedMagic = ftv->dcrMagicFunction(MagicFunctionCallContext{NotNull{this}, constraint, c.callSite, c.argsPack, result}); - - if (ftv->dcrMagicRefinement) - ftv->dcrMagicRefinement(MagicRefinementContext{constraint->scope, c.callSite, c.discriminantTypes}); + if (ftv->magic) + { + usedMagic = ftv->magic->infer(MagicFunctionCallContext{NotNull{this}, constraint, c.callSite, c.argsPack, result}); + ftv->magic->refine(MagicRefinementContext{constraint->scope, c.callSite, c.discriminantTypes}); + } } if (!usedMagic) emplace(constraint, c.result, constraint->scope); } - for (std::optional ty : c.discriminantTypes) + if (FFlag::LuauAlwaysFillInFunctionCallDiscriminantTypes) { - if (!ty) - continue; - - // If the discriminant type has been transmuted, we need to unblock them. - if (!isBlocked(*ty)) + fillInDiscriminantTypes(constraint, c.discriminantTypes); + } + else + { + // NOTE: This is the body of the `fillInDiscriminantTypes` helper. + for (std::optional ty : c.discriminantTypes) { - unblock(*ty, constraint->location); - continue; - } + if (!ty) + continue; + + // If the discriminant type has been transmuted, we need to unblock them. + if (!isBlocked(*ty)) + { + unblock(*ty, constraint->location); + continue; + } - if (FFlag::LuauRemoveNotAnyHack) - { // We bind any unused discriminants to the `*no-refine*` type indicating that it can be safely ignored. emplaceType(asMutable(follow(*ty)), builtinTypes->noRefineType); } - else - { - // We use `any` here because the discriminant type may be pointed at by both branches, - // where the discriminant type is not negated, and the other where it is negated, i.e. - // `unknown ~ unknown` and `~unknown ~ never`, so `T & unknown ~ T` and `T & ~unknown ~ never` - // v.s. - // `any ~ any` and `~any ~ any`, so `T & any ~ T` and `T & ~any ~ T` - // - // In practice, users cannot negate `any`, so this is an implementation detail we can always change. - emplaceType(asMutable(follow(*ty)), builtinTypes->anyType); - } } + OverloadResolver resolver{ builtinTypes, NotNull{arena}, + simplifier, normalizer, typeFunctionRuntime, constraint->scope, @@ -1618,7 +1663,7 @@ bool ConstraintSolver::tryDispatchHasIndexer( for (TypeId part : parts) { TypeId r = arena->addType(BlockedType{}); - getMutable(r)->setOwner(const_cast(constraint.get())); + getMutable(r)->setOwner(constraint.get()); bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen); // If we've cut a recursive loop short, skip it. @@ -1650,7 +1695,7 @@ bool ConstraintSolver::tryDispatchHasIndexer( for (TypeId part : parts) { TypeId r = arena->addType(BlockedType{}); - getMutable(r)->setOwner(const_cast(constraint.get())); + getMutable(r)->setOwner(constraint.get()); bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen); // If we've cut a recursive loop short, skip it. @@ -1770,6 +1815,10 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNulladdType(TableType{TableState::Free, TypeLevel{}, constraint->scope}); + + if (FFlag::LuauTrackInteriorFreeTypesOnScope && FFlag::LuauTrackInteriorFreeTablesOnScope) + trackInteriorFreeType(constraint->scope, newUpperBound); + TableType* upperTable = getMutable(newUpperBound); LUAU_ASSERT(upperTable); @@ -2048,6 +2097,8 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNullscope); + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + trackInteriorFreeType(constraint->scope, f); shiftReferences(resultTy, f); emplaceType(asMutable(resultTy), f); } @@ -2103,6 +2154,11 @@ bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNullscope); TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + { + trackInteriorFreeType(constraint->scope, keyTy); + trackInteriorFreeType(constraint->scope, valueTy); + } TypeId tableTy = arena->addType(TableType{TableState::Sealed, {}, constraint->scope}); getMutable(tableTy)->indexer = TableIndexer{keyTy, valueTy}; @@ -2434,6 +2495,8 @@ TablePropLookupResult ConstraintSolver::lookupTableProp( if (ttv->state == TableState::Free) { TypeId result = freshType(arena, builtinTypes, ttv->scope); + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + trackInteriorFreeType(ttv->scope, result); switch (context) { case ValueContext::RValue: @@ -2539,10 +2602,17 @@ TablePropLookupResult ConstraintSolver::lookupTableProp( NotNull scope{ft->scope}; const TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, scope}); + + if (FFlag::LuauTrackInteriorFreeTypesOnScope && FFlag::LuauTrackInteriorFreeTablesOnScope) + trackInteriorFreeType(constraint->scope, newUpperBound); + TableType* tt = getMutable(newUpperBound); LUAU_ASSERT(tt); TypeId propType = freshType(arena, builtinTypes, scope); + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + trackInteriorFreeType(scope, propType); + switch (context) { case ValueContext::RValue: @@ -2773,10 +2843,10 @@ bool ConstraintSolver::blockOnPendingTypes(TypeId target, NotNull constraint) +bool ConstraintSolver::blockOnPendingTypes(TypePackId targetPack, NotNull constraint) { Blocker blocker{NotNull{this}, constraint}; - blocker.traverse(pack); + blocker.traverse(targetPack); return !blocker.blocked; } diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp index 9925f29c..cff87858 100644 --- a/Analysis/src/DataFlowGraph.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -13,7 +13,6 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauTypestateBuiltins2) namespace Luau { @@ -62,6 +61,12 @@ const RefinementKey* RefinementKeyArena::node(const RefinementKey* parent, DefId return allocator.allocate(RefinementKey{parent, def, propName}); } +DataFlowGraph::DataFlowGraph(NotNull defArena, NotNull keyArena) + : defArena{defArena} + , keyArena{keyArena} +{ +} + DefId DataFlowGraph::getDef(const AstExpr* expr) const { auto def = astDefs.find(expr); @@ -178,11 +183,23 @@ bool DfgScope::canUpdateDefinition(DefId def, const std::string& key) const return true; } -DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull handle) +DataFlowGraphBuilder::DataFlowGraphBuilder(NotNull defArena, NotNull keyArena) + : graph{defArena, keyArena} + , defArena{defArena} + , keyArena{keyArena} +{ +} + +DataFlowGraph DataFlowGraphBuilder::build( + AstStatBlock* block, + NotNull defArena, + NotNull keyArena, + NotNull handle +) { LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking"); - DataFlowGraphBuilder builder; + DataFlowGraphBuilder builder(defArena, keyArena); builder.handle = handle; DfgScope* moduleScope = builder.makeChildScope(); PushScope ps{builder.scopeStack, moduleScope}; @@ -198,30 +215,6 @@ DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull, std::vector>> DataFlowGraphBuilder::buildShared( - AstStatBlock* block, - NotNull handle -) -{ - - LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking"); - - DataFlowGraphBuilder builder; - builder.handle = handle; - DfgScope* moduleScope = builder.makeChildScope(); - PushScope ps{builder.scopeStack, moduleScope}; - builder.visitBlockWithoutChildScope(block); - builder.resolveCaptures(); - - if (FFlag::DebugLuauFreezeArena) - { - builder.defArena->allocator.freeze(); - builder.keyArena->allocator.freeze(); - } - - return {std::make_shared(std::move(builder.graph)), std::move(builder.scopes)}; -} - void DataFlowGraphBuilder::resolveCaptures() { for (const auto& [_, capture] : captures) @@ -885,7 +878,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprCall* c) { visitExpr(c->func); - if (FFlag::LuauTypestateBuiltins2 && shouldTypestateForFirstArgument(*c) && c->args.size > 1 && isLValue(*c->args.begin())) + if (shouldTypestateForFirstArgument(*c) && c->args.size > 1 && isLValue(*c->args.begin())) { AstExpr* firstArg = *c->args.begin(); @@ -1176,6 +1169,8 @@ void DataFlowGraphBuilder::visitType(AstType* t) return; // ok else if (auto s = t->as()) return; // ok + else if (auto g = t->as()) + return visitType(g->type); else handle->ice("Unknown AstType in DataFlowGraphBuilder::visitType"); } diff --git a/Analysis/src/Differ.cpp b/Analysis/src/Differ.cpp index 7debae9c..e6222067 100644 --- a/Analysis/src/Differ.cpp +++ b/Analysis/src/Differ.cpp @@ -13,7 +13,6 @@ namespace Luau { - std::string DiffPathNode::toString() const { switch (kind) @@ -945,14 +944,12 @@ std::vector>::const_reverse_iterator DifferEnvironment return visitingStack.crend(); } - DifferResult diff(TypeId ty1, TypeId ty2) { DifferEnvironment differEnv{ty1, ty2, std::nullopt, std::nullopt}; return diffUsingEnv(differEnv, ty1, ty2); } - DifferResult diffWithSymbols(TypeId ty1, TypeId ty2, std::optional symbol1, std::optional symbol2) { DifferEnvironment differEnv{ty1, ty2, symbol1, symbol2}; diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 828fc7ed..0042d6fb 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,15 +1,13 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -LUAU_FASTFLAG(LuauMathMap) - -LUAU_FASTFLAGVARIABLE(LuauVectorDefinitions) -LUAU_FASTFLAGVARIABLE(LuauVectorDefinitionsExtra) +LUAU_FASTFLAG(LuauBufferBitMethods2) +LUAU_FASTFLAGVARIABLE(LuauMathMapDefinition) +LUAU_FASTFLAG(LuauVector2Constructor) namespace Luau { -// TODO: there has to be a better way, like splitting up per library static const std::string kBuiltinDefinitionLuaSrcChecked_DEPRECATED = R"BUILTIN_SRC( declare bit32: { @@ -30,227 +28,6 @@ declare bit32: { byteswap: @checked (n: number) -> number, } -declare math: { - frexp: @checked (n: number) -> (number, number), - ldexp: @checked (s: number, e: number) -> number, - fmod: @checked (x: number, y: number) -> number, - modf: @checked (n: number) -> (number, number), - pow: @checked (x: number, y: number) -> number, - exp: @checked (n: number) -> number, - - ceil: @checked (n: number) -> number, - floor: @checked (n: number) -> number, - abs: @checked (n: number) -> number, - sqrt: @checked (n: number) -> number, - - log: @checked (n: number, base: number?) -> number, - log10: @checked (n: number) -> number, - - rad: @checked (n: number) -> number, - deg: @checked (n: number) -> number, - - sin: @checked (n: number) -> number, - cos: @checked (n: number) -> number, - tan: @checked (n: number) -> number, - sinh: @checked (n: number) -> number, - cosh: @checked (n: number) -> number, - tanh: @checked (n: number) -> number, - atan: @checked (n: number) -> number, - acos: @checked (n: number) -> number, - asin: @checked (n: number) -> number, - atan2: @checked (y: number, x: number) -> number, - - min: @checked (number, ...number) -> number, - max: @checked (number, ...number) -> number, - - pi: number, - huge: number, - - randomseed: @checked (seed: number) -> (), - random: @checked (number?, number?) -> number, - - sign: @checked (n: number) -> number, - clamp: @checked (n: number, min: number, max: number) -> number, - noise: @checked (x: number, y: number?, z: number?) -> number, - round: @checked (n: number) -> number, -} - -type DateTypeArg = { - year: number, - month: number, - day: number, - hour: number?, - min: number?, - sec: number?, - isdst: boolean?, -} - -type DateTypeResult = { - year: number, - month: number, - wday: number, - yday: number, - day: number, - hour: number, - min: number, - sec: number, - isdst: boolean, -} - -declare os: { - time: (time: DateTypeArg?) -> number, - date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string), - difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number, - clock: () -> number, -} - -@checked declare function require(target: any): any - -@checked declare function getfenv(target: any): { [string]: any } - -declare _G: any -declare _VERSION: string - -declare function gcinfo(): number - -declare function print(...: T...) - -declare function type(value: T): string -declare function typeof(value: T): string - --- `assert` has a magic function attached that will give more detailed type information -declare function assert(value: T, errorMessage: string?): T -declare function error(message: T, level: number?): never - -declare function tostring(value: T): string -declare function tonumber(value: T, radix: number?): number? - -declare function rawequal(a: T1, b: T2): boolean -declare function rawget(tab: {[K]: V}, k: K): V -declare function rawset(tab: {[K]: V}, k: K, v: V): {[K]: V} -declare function rawlen(obj: {[K]: V} | string): number - -declare function setfenv(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)? - -declare function ipairs(tab: {V}): (({V}, number) -> (number?, V), {V}, number) - -declare function pcall(f: (A...) -> R..., ...: A...): (boolean, R...) - --- FIXME: The actual type of `xpcall` is: --- (f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, R2...) --- Since we can't represent the return value, we use (boolean, R1...). -declare function xpcall(f: (A...) -> R1..., err: (E) -> R2..., ...: A...): (boolean, R1...) - --- `select` has a magic function attached to provide more detailed type information -declare function select(i: string | number, ...: A...): ...any - --- FIXME: This type is not entirely correct - `loadstring` returns a function or --- (nil, string). -declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) - -@checked declare function newproxy(mt: boolean?): any - -declare coroutine: { - create: (f: (A...) -> R...) -> thread, - resume: (co: thread, A...) -> (boolean, R...), - running: () -> thread, - status: @checked (co: thread) -> "dead" | "running" | "normal" | "suspended", - wrap: (f: (A...) -> R...) -> ((A...) -> R...), - yield: (A...) -> R..., - isyieldable: () -> boolean, - close: @checked (co: thread) -> (boolean, any) -} - -declare table: { - concat: (t: {V}, sep: string?, i: number?, j: number?) -> string, - insert: ((t: {V}, value: V) -> ()) & ((t: {V}, pos: number, value: V) -> ()), - maxn: (t: {V}) -> number, - remove: (t: {V}, number?) -> V?, - sort: (t: {V}, comp: ((V, V) -> boolean)?) -> (), - create: (count: number, value: V?) -> {V}, - find: (haystack: {V}, needle: V, init: number?) -> number?, - - unpack: (list: {V}, i: number?, j: number?) -> ...V, - pack: (...V) -> { n: number, [number]: V }, - - getn: (t: {V}) -> number, - foreach: (t: {[K]: V}, f: (K, V) -> ()) -> (), - foreachi: ({V}, (number, V) -> ()) -> (), - - move: (src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V}, - clear: (table: {[K]: V}) -> (), - - isfrozen: (t: {[K]: V}) -> boolean, -} - -declare debug: { - info: ((thread: thread, level: number, options: string) -> R...) & ((level: number, options: string) -> R...) & ((func: (A...) -> R1..., options: string) -> R2...), - traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string), -} - -declare utf8: { - char: @checked (...number) -> string, - charpattern: string, - codes: @checked (str: string) -> ((string, number) -> (number, number), string, number), - codepoint: @checked (str: string, i: number?, j: number?) -> ...number, - len: @checked (s: string, i: number?, j: number?) -> (number?, number?), - offset: @checked (s: string, n: number?, i: number?) -> number, -} - --- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. -declare function unpack(tab: {V}, i: number?, j: number?): ...V - - ---- Buffer API -declare buffer: { - create: @checked (size: number) -> buffer, - fromstring: @checked (str: string) -> buffer, - tostring: @checked (b: buffer) -> string, - len: @checked (b: buffer) -> number, - copy: @checked (target: buffer, targetOffset: number, source: buffer, sourceOffset: number?, count: number?) -> (), - fill: @checked (b: buffer, offset: number, value: number, count: number?) -> (), - readi8: @checked (b: buffer, offset: number) -> number, - readu8: @checked (b: buffer, offset: number) -> number, - readi16: @checked (b: buffer, offset: number) -> number, - readu16: @checked (b: buffer, offset: number) -> number, - readi32: @checked (b: buffer, offset: number) -> number, - readu32: @checked (b: buffer, offset: number) -> number, - readf32: @checked (b: buffer, offset: number) -> number, - readf64: @checked (b: buffer, offset: number) -> number, - writei8: @checked (b: buffer, offset: number, value: number) -> (), - writeu8: @checked (b: buffer, offset: number, value: number) -> (), - writei16: @checked (b: buffer, offset: number, value: number) -> (), - writeu16: @checked (b: buffer, offset: number, value: number) -> (), - writei32: @checked (b: buffer, offset: number, value: number) -> (), - writeu32: @checked (b: buffer, offset: number, value: number) -> (), - writef32: @checked (b: buffer, offset: number, value: number) -> (), - writef64: @checked (b: buffer, offset: number, value: number) -> (), - readstring: @checked (b: buffer, offset: number, count: number) -> string, - writestring: @checked (b: buffer, offset: number, value: string, count: number?) -> (), -} - -)BUILTIN_SRC"; - -static const std::string kBuiltinDefinitionLuaSrcChecked = R"BUILTIN_SRC( - -declare bit32: { - band: @checked (...number) -> number, - bor: @checked (...number) -> number, - bxor: @checked (...number) -> number, - btest: @checked (number, ...number) -> boolean, - rrotate: @checked (x: number, disp: number) -> number, - lrotate: @checked (x: number, disp: number) -> number, - lshift: @checked (x: number, disp: number) -> number, - arshift: @checked (x: number, disp: number) -> number, - rshift: @checked (x: number, disp: number) -> number, - bnot: @checked (x: number) -> number, - extract: @checked (n: number, field: number, width: number?) -> number, - replace: @checked (n: number, v: number, field: number, width: number?) -> number, - countlz: @checked (n: number) -> number, - countrz: @checked (n: number) -> number, - byteswap: @checked (n: number) -> number, -} - declare math: { frexp: @checked (n: number) -> (number, number), ldexp: @checked (s: number, e: number) -> number, @@ -422,7 +199,231 @@ declare utf8: { -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. declare function unpack(tab: {V}, i: number?, j: number?): ...V +)BUILTIN_SRC"; +static const std::string kBuiltinDefinitionBaseSrc = R"BUILTIN_SRC( + +@checked declare function require(target: any): any + +@checked declare function getfenv(target: any): { [string]: any } + +declare _G: any +declare _VERSION: string + +declare function gcinfo(): number + +declare function print(...: T...) + +declare function type(value: T): string +declare function typeof(value: T): string + +-- `assert` has a magic function attached that will give more detailed type information +declare function assert(value: T, errorMessage: string?): T +declare function error(message: T, level: number?): never + +declare function tostring(value: T): string +declare function tonumber(value: T, radix: number?): number? + +declare function rawequal(a: T1, b: T2): boolean +declare function rawget(tab: {[K]: V}, k: K): V +declare function rawset(tab: {[K]: V}, k: K, v: V): {[K]: V} +declare function rawlen(obj: {[K]: V} | string): number + +declare function setfenv(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)? + +declare function ipairs(tab: {V}): (({V}, number) -> (number?, V), {V}, number) + +declare function pcall(f: (A...) -> R..., ...: A...): (boolean, R...) + +-- FIXME: The actual type of `xpcall` is: +-- (f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, R2...) +-- Since we can't represent the return value, we use (boolean, R1...). +declare function xpcall(f: (A...) -> R1..., err: (E) -> R2..., ...: A...): (boolean, R1...) + +-- `select` has a magic function attached to provide more detailed type information +declare function select(i: string | number, ...: A...): ...any + +-- FIXME: This type is not entirely correct - `loadstring` returns a function or +-- (nil, string). +declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) + +@checked declare function newproxy(mt: boolean?): any + +-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. +declare function unpack(tab: {V}, i: number?, j: number?): ...V + +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionBit32Src = R"BUILTIN_SRC( + +declare bit32: { + band: @checked (...number) -> number, + bor: @checked (...number) -> number, + bxor: @checked (...number) -> number, + btest: @checked (number, ...number) -> boolean, + rrotate: @checked (x: number, disp: number) -> number, + lrotate: @checked (x: number, disp: number) -> number, + lshift: @checked (x: number, disp: number) -> number, + arshift: @checked (x: number, disp: number) -> number, + rshift: @checked (x: number, disp: number) -> number, + bnot: @checked (x: number) -> number, + extract: @checked (n: number, field: number, width: number?) -> number, + replace: @checked (n: number, v: number, field: number, width: number?) -> number, + countlz: @checked (n: number) -> number, + countrz: @checked (n: number) -> number, + byteswap: @checked (n: number) -> number, +} + +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionMathSrc = R"BUILTIN_SRC( + +declare math: { + frexp: @checked (n: number) -> (number, number), + ldexp: @checked (s: number, e: number) -> number, + fmod: @checked (x: number, y: number) -> number, + modf: @checked (n: number) -> (number, number), + pow: @checked (x: number, y: number) -> number, + exp: @checked (n: number) -> number, + + ceil: @checked (n: number) -> number, + floor: @checked (n: number) -> number, + abs: @checked (n: number) -> number, + sqrt: @checked (n: number) -> number, + + log: @checked (n: number, base: number?) -> number, + log10: @checked (n: number) -> number, + + rad: @checked (n: number) -> number, + deg: @checked (n: number) -> number, + + sin: @checked (n: number) -> number, + cos: @checked (n: number) -> number, + tan: @checked (n: number) -> number, + sinh: @checked (n: number) -> number, + cosh: @checked (n: number) -> number, + tanh: @checked (n: number) -> number, + atan: @checked (n: number) -> number, + acos: @checked (n: number) -> number, + asin: @checked (n: number) -> number, + atan2: @checked (y: number, x: number) -> number, + + min: @checked (number, ...number) -> number, + max: @checked (number, ...number) -> number, + + pi: number, + huge: number, + + randomseed: @checked (seed: number) -> (), + random: @checked (number?, number?) -> number, + + sign: @checked (n: number) -> number, + clamp: @checked (n: number, min: number, max: number) -> number, + noise: @checked (x: number, y: number?, z: number?) -> number, + round: @checked (n: number) -> number, + map: @checked (x: number, inmin: number, inmax: number, outmin: number, outmax: number) -> number, + lerp: @checked (a: number, b: number, t: number) -> number, +} + +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionOsSrc = R"BUILTIN_SRC( + +type DateTypeArg = { + year: number, + month: number, + day: number, + hour: number?, + min: number?, + sec: number?, + isdst: boolean?, +} + +type DateTypeResult = { + year: number, + month: number, + wday: number, + yday: number, + day: number, + hour: number, + min: number, + sec: number, + isdst: boolean, +} + +declare os: { + time: (time: DateTypeArg?) -> number, + date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string), + difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number, + clock: () -> number, +} + +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionCoroutineSrc = R"BUILTIN_SRC( + +declare coroutine: { + create: (f: (A...) -> R...) -> thread, + resume: (co: thread, A...) -> (boolean, R...), + running: () -> thread, + status: @checked (co: thread) -> "dead" | "running" | "normal" | "suspended", + wrap: (f: (A...) -> R...) -> ((A...) -> R...), + yield: (A...) -> R..., + isyieldable: () -> boolean, + close: @checked (co: thread) -> (boolean, any) +} + +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionTableSrc = R"BUILTIN_SRC( + +declare table: { + concat: (t: {V}, sep: string?, i: number?, j: number?) -> string, + insert: ((t: {V}, value: V) -> ()) & ((t: {V}, pos: number, value: V) -> ()), + maxn: (t: {V}) -> number, + remove: (t: {V}, number?) -> V?, + sort: (t: {V}, comp: ((V, V) -> boolean)?) -> (), + create: (count: number, value: V?) -> {V}, + find: (haystack: {V}, needle: V, init: number?) -> number?, + + unpack: (list: {V}, i: number?, j: number?) -> ...V, + pack: (...V) -> { n: number, [number]: V }, + + getn: (t: {V}) -> number, + foreach: (t: {[K]: V}, f: (K, V) -> ()) -> (), + foreachi: ({V}, (number, V) -> ()) -> (), + + move: (src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V}, + clear: (table: {[K]: V}) -> (), + + isfrozen: (t: {[K]: V}) -> boolean, +} + +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionDebugSrc = R"BUILTIN_SRC( + +declare debug: { + info: ((thread: thread, level: number, options: string) -> R...) & ((level: number, options: string) -> R...) & ((func: (A...) -> R1..., options: string) -> R2...), + traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string), +} + +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionUtf8Src = R"BUILTIN_SRC( + +declare utf8: { + char: @checked (...number) -> string, + charpattern: string, + codes: @checked (str: string) -> ((string, number) -> (number, number), string, number), + codepoint: @checked (str: string, i: number?, j: number?) -> ...number, + len: @checked (s: string, i: number?, j: number?) -> (number?, number?), + offset: @checked (s: string, n: number?, i: number?) -> number, +} + +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionBufferSrc_DEPRECATED = R"BUILTIN_SRC( --- Buffer API declare buffer: { create: @checked (size: number) -> buffer, @@ -453,10 +454,47 @@ declare buffer: { )BUILTIN_SRC"; -static const std::string kBuiltinDefinitionVectorSrc_DEPRECATED = R"BUILTIN_SRC( +static const std::string kBuiltinDefinitionBufferSrc = R"BUILTIN_SRC( +--- Buffer API +declare buffer: { + create: @checked (size: number) -> buffer, + fromstring: @checked (str: string) -> buffer, + tostring: @checked (b: buffer) -> string, + len: @checked (b: buffer) -> number, + copy: @checked (target: buffer, targetOffset: number, source: buffer, sourceOffset: number?, count: number?) -> (), + fill: @checked (b: buffer, offset: number, value: number, count: number?) -> (), + readi8: @checked (b: buffer, offset: number) -> number, + readu8: @checked (b: buffer, offset: number) -> number, + readi16: @checked (b: buffer, offset: number) -> number, + readu16: @checked (b: buffer, offset: number) -> number, + readi32: @checked (b: buffer, offset: number) -> number, + readu32: @checked (b: buffer, offset: number) -> number, + readf32: @checked (b: buffer, offset: number) -> number, + readf64: @checked (b: buffer, offset: number) -> number, + writei8: @checked (b: buffer, offset: number, value: number) -> (), + writeu8: @checked (b: buffer, offset: number, value: number) -> (), + writei16: @checked (b: buffer, offset: number, value: number) -> (), + writeu16: @checked (b: buffer, offset: number, value: number) -> (), + writei32: @checked (b: buffer, offset: number, value: number) -> (), + writeu32: @checked (b: buffer, offset: number, value: number) -> (), + writef32: @checked (b: buffer, offset: number, value: number) -> (), + writef64: @checked (b: buffer, offset: number, value: number) -> (), + readstring: @checked (b: buffer, offset: number, count: number) -> string, + writestring: @checked (b: buffer, offset: number, value: string, count: number?) -> (), + readbits: @checked (b: buffer, bitOffset: number, bitCount: number) -> number, + writebits: @checked (b: buffer, bitOffset: number, bitCount: number, value: number) -> (), +} --- TODO: this will be replaced with a built-in primitive type -declare class vector end +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionVectorSrc_NoVector2Ctor_DEPRECATED = R"BUILTIN_SRC( + +-- While vector would have been better represented as a built-in primitive type, type solver class handling covers most of the properties +declare class vector + x: number + y: number + z: number +end declare vector: { create: @checked (x: number, y: number, z: number) -> vector, @@ -489,7 +527,7 @@ declare class vector end declare vector: { - create: @checked (x: number, y: number, z: number) -> vector, + create: @checked (x: number, y: number, z: number?) -> vector, magnitude: @checked (vec: vector) -> number, normalize: @checked (vec: vector) -> vector, cross: @checked (vec1: vector, vec2: vector) -> vector, @@ -511,12 +549,25 @@ declare vector: { std::string getBuiltinDefinitionSource() { - std::string result = FFlag::LuauMathMap ? kBuiltinDefinitionLuaSrcChecked : kBuiltinDefinitionLuaSrcChecked_DEPRECATED; + std::string result = FFlag::LuauMathMapDefinition ? kBuiltinDefinitionBaseSrc : kBuiltinDefinitionLuaSrcChecked_DEPRECATED; - if (FFlag::LuauVectorDefinitionsExtra) + if (FFlag::LuauMathMapDefinition) + { + result += kBuiltinDefinitionBit32Src; + result += kBuiltinDefinitionMathSrc; + result += kBuiltinDefinitionOsSrc; + result += kBuiltinDefinitionCoroutineSrc; + result += kBuiltinDefinitionTableSrc; + result += kBuiltinDefinitionDebugSrc; + result += kBuiltinDefinitionUtf8Src; + } + + result += FFlag::LuauBufferBitMethods2 ? kBuiltinDefinitionBufferSrc : kBuiltinDefinitionBufferSrc_DEPRECATED; + + if (FFlag::LuauVector2Constructor) result += kBuiltinDefinitionVectorSrc; - else if (FFlag::LuauVectorDefinitions) - result += kBuiltinDefinitionVectorSrc_DEPRECATED; + else + result += kBuiltinDefinitionVectorSrc_NoVector2Ctor_DEPRECATED; return result; } diff --git a/Analysis/src/EqSatSimplification.cpp b/Analysis/src/EqSatSimplification.cpp index 9e69baf5..5927c773 100644 --- a/Analysis/src/EqSatSimplification.cpp +++ b/Analysis/src/EqSatSimplification.cpp @@ -92,18 +92,24 @@ size_t TTable::Hash::operator()(const TTable& value) const return hash; } -uint32_t StringCache::add(std::string_view s) +StringId StringCache::add(std::string_view s) { - size_t hash = std::hash()(s); - if (uint32_t* it = strings.find(hash)) + /* Important subtlety: This use of DenseHashMap + * is okay because std::hash works solely on the bytes + * referred by the string_view. + * + * In other words, two string views which contain the same bytes will have + * the same hash whether or not their addresses are the same. + */ + if (StringId* it = strings.find(s)) return *it; char* storage = static_cast(allocator.allocate(s.size())); memcpy(storage, s.data(), s.size()); - uint32_t result = uint32_t(views.size()); + StringId result = StringId(views.size()); views.emplace_back(storage, s.size()); - strings[hash] = result; + strings[s] = result; return result; } @@ -143,6 +149,61 @@ static bool isTerminal(const EType& node) node.get() || node.get(); } +static bool areTerminalAndDefinitelyDisjoint(const EType& lhs, const EType& rhs) +{ + // If either node is non-terminal, then we early exit: we're not going to + // do a state space search for whether something like: + // (A | B | C | D) & (E | F | G | H) + // ... is a disjoint intersection. + if (!isTerminal(lhs) || !isTerminal(rhs)) + return false; + + // Special case some types that aren't strict, disjoint subsets. + if (lhs.get() || lhs.get()) + return !(rhs.get() || rhs.get()); + + // Handling strings / booleans: these are the types for which we + // expect something like: + // + // "foo" & ~"bar" + // + // ... to simplify to "foo". + if (lhs.get()) + return !(rhs.get() || rhs.get()); + + if (lhs.get()) + return !(rhs.get() || rhs.get()); + + if (auto lhsSString = lhs.get()) + { + auto rhsSString = rhs.get(); + if (!rhsSString) + return !rhs.get(); + return lhsSString->value() != rhsSString->value(); + } + + if (auto lhsSBoolean = lhs.get()) + { + auto rhsSBoolean = rhs.get(); + if (!rhsSBoolean) + return !rhs.get(); + return lhsSBoolean->value() != rhsSBoolean->value(); + } + + // At this point: + // - We know both nodes are terminal + // - We know that the LHS is not any boolean, string, or class + // At this point, we have two classes of checks left: + // - Whether the two enodes are exactly the same set (now that the static + // sets have been covered). + // - Whether one of the enodes is a large semantic set such as TAny, + // TUnknown, or TError. + return !( + lhs.index() == rhs.index() || lhs.get() || rhs.get() || lhs.get() || rhs.get() || lhs.get() || + rhs.get() || lhs.get() || rhs.get() || lhs.get() || rhs.get() + ); +} + static bool isTerminal(const EGraph& egraph, Id eclass) { const auto& nodes = egraph[eclass].nodes; @@ -151,7 +212,7 @@ static bool isTerminal(const EGraph& egraph, Id eclass) nodes.end(), [](auto& a) { - return isTerminal(a); + return isTerminal(a.node); } ); } @@ -335,11 +396,31 @@ Id toId( { LUAU_ASSERT(tfun->packArguments.empty()); + if (tfun->userFuncName) { + // TODO: User defined type functions are pseudo-effectful: error + // reporting is done via the `print` statement, so running a + // UDTF multiple times may end up double erroring. egraphs + // currently may induce type functions to be reduced multiple + // times. We should probably opt _not_ to process user defined + // type functions at all. + return egraph.add(TOpaque{ty}); + } + std::vector parts; + parts.reserve(tfun->typeArguments.size()); for (TypeId part : tfun->typeArguments) parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part)); - return cache(egraph.add(TTypeFun{tfun->function.get(), std::move(parts)})); + // This looks sily, but we're making a copy of the specific + // `TypeFunctionInstanceType` outside of the provided arena so that + // we can access the members without fear of the specific TFIT being + // overwritten with a bound type. + return cache(egraph.add(TTypeFun{ + std::make_shared( + tfun->function, tfun->typeArguments, tfun->packArguments, tfun->userFuncName, tfun->userFuncData + ), + std::move(parts) + })); } else if (get(ty)) return egraph.add(TNoRefine{}); @@ -399,7 +480,7 @@ static size_t computeCost(std::unordered_map& bestNodes, const EGrap if (auto it = costs.find(id); it != costs.end()) return it->second; - const std::vector& nodes = egraph[id].nodes; + const std::vector>& nodes = egraph[id].nodes; size_t minCost = std::numeric_limits::max(); size_t bestNode = std::numeric_limits::max(); @@ -416,7 +497,7 @@ static size_t computeCost(std::unordered_map& bestNodes, const EGrap // First, quickly scan for a terminal type. If we can find one, it is obviously the best. for (size_t index = 0; index < nodes.size(); ++index) { - if (isTerminal(nodes[index])) + if (isTerminal(nodes[index].node)) { minCost = 1; bestNode = index; @@ -468,44 +549,44 @@ static size_t computeCost(std::unordered_map& bestNodes, const EGrap { const auto& node = nodes[index]; - if (node.get()) + if (node.node.get()) updateCost(BOUND_PENALTY, index); // TODO: This could probably be an assert now that we don't need rewrite rules to handle TBound. - else if (node.get()) + else if (node.node.get()) { minCost = 1; bestNode = index; } - else if (auto tbl = node.get()) + else if (auto tbl = node.node.get()) { // TODO: We could make the penalty a parameter to computeChildren. std::optional maybeCost = computeChildren(tbl->operands(), minCost); if (maybeCost) updateCost(TABLE_TYPE_PENALTY + *maybeCost, index); } - else if (node.get()) + else if (node.node.get()) { minCost = IMPORTED_TABLE_PENALTY; bestNode = index; } - else if (auto u = node.get()) + else if (auto u = node.node.get()) { std::optional maybeCost = computeChildren(u->operands(), minCost); if (maybeCost) updateCost(SET_TYPE_PENALTY + *maybeCost, index); } - else if (auto i = node.get()) + else if (auto i = node.node.get()) { std::optional maybeCost = computeChildren(i->operands(), minCost); if (maybeCost) updateCost(SET_TYPE_PENALTY + *maybeCost, index); } - else if (auto negation = node.get()) + else if (auto negation = node.node.get()) { std::optional maybeCost = computeChildren(negation->operands(), minCost); if (maybeCost) updateCost(NEGATION_PENALTY + *maybeCost, index); } - else if (auto tfun = node.get()) + else if (auto tfun = node.node.get()) { std::optional maybeCost = computeChildren(tfun->operands(), minCost); if (maybeCost) @@ -574,28 +655,34 @@ TypeId flattenTableNode( // If a TTable is its own basis, it must be the case that some other // node on this eclass is a TImportedTable. Let's use that. + bool found = false; + for (size_t i = 0; i < eclass.nodes.size(); ++i) { - if (eclass.nodes[i].get()) + if (eclass.nodes[i].node.get()) { + found = true; index = i; break; } } - // If we couldn't find one, we don't know what to do. Use ErrorType. - LUAU_ASSERT(0); - return builtinTypes->errorType; + if (!found) + { + // If we couldn't find one, we don't know what to do. Use ErrorType. + LUAU_ASSERT(0); + return builtinTypes->errorType; + } } const auto& node = eclass.nodes[index]; - if (const TTable* ttable = node.get()) + if (const TTable* ttable = node.node.get()) { stack.push_back(ttable); id = ttable->getBasis(); continue; } - else if (const TImportedTable* ti = node.get()) + else if (const TImportedTable* ti = node.node.get()) { importedTable = ti; break; @@ -622,7 +709,8 @@ TypeId flattenTableNode( StringId propName = t->propNames[i]; const Id propType = t->propTypes()[i]; - resultTable.props[strings.asString(propName)] = Property{fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, propType)}; + resultTable.props[strings.asString(propName)] = + Property{fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, propType)}; } } @@ -646,7 +734,7 @@ TypeId fromId( size_t index = bestNodes.at(rootId); LUAU_ASSERT(index <= egraph[rootId].nodes.size()); - const EType& node = egraph[rootId].nodes[index]; + const EType& node = egraph[rootId].nodes[index].node; if (node.get()) return builtinTypes->nilType; @@ -703,7 +791,20 @@ TypeId fromId( if (parts.empty()) return builtinTypes->neverType; else if (parts.size() == 1) - return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, parts[0]); + { + TypeId placeholder = arena->addType(BlockedType{}); + seen[rootId] = placeholder; + auto result = fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, parts[0]); + if (follow(result) == placeholder) + { + emplaceType(asMutable(placeholder), "EGRAPH-SINGLETON-CYCLE"); + } + else + { + emplaceType(asMutable(placeholder), result); + } + return result; + } else { TypeId res = arena->addType(BlockedType{}); @@ -768,7 +869,11 @@ TypeId fromId( for (Id part : tfun->operands()) args.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part)); - asMutable(res)->ty.emplace(*tfun->value(), std::move(args)); + auto oldInstance = tfun->value(); + + asMutable(res)->ty.emplace( + oldInstance->function, std::move(args), std::vector(), oldInstance->userFuncName, oldInstance->userFuncData + ); newTypeFunctions.push_back(res); @@ -848,12 +953,20 @@ std::string mkDesc( const int RULE_PADDING = 35; const std::string rulePadding(std::max(0, RULE_PADDING - rule.size()), ' '); const std::string fromIdStr = ""; // "(" + std::to_string(uint32_t(from)) + ") "; - const std::string toIdStr = ""; // "(" + std::to_string(uint32_t(to)) + ") "; + const std::string toIdStr = ""; // "(" + std::to_string(uint32_t(to)) + ") "; return rule + ":" + rulePadding + fromIdStr + toString(fromTy, opts) + " <=> " + toIdStr + toString(toTy, opts); } -std::string mkDesc(EGraph& egraph, const StringCache& strings, NotNull arena, NotNull builtinTypes, Id from, Id to, const std::string& rule) +std::string mkDesc( + EGraph& egraph, + const StringCache& strings, + NotNull arena, + NotNull builtinTypes, + Id from, + Id to, + const std::string& rule +) { if (!FFlag::DebugLuauLogSimplification) return ""; @@ -906,7 +1019,7 @@ static std::string getNodeName(const StringCache& strings, const EType& node) else if (node.get()) return "never"; else if (auto tfun = node.get()) - return "tfun " + tfun->value()->name; + return "tfun " + tfun->value()->function->name; else if (node.get()) return "~"; else if (node.get()) @@ -928,8 +1041,9 @@ std::string toDot(const StringCache& strings, const EGraph& egraph) for (const auto& [id, eclass] : egraph.getAllClasses()) { - for (const auto& node : eclass.nodes) + for (const auto& n : eclass.nodes) { + const EType& node = n.node; if (!node.operands().empty()) populated.insert(id); for (Id op : node.operands()) @@ -950,7 +1064,7 @@ std::string toDot(const StringCache& strings, const EGraph& egraph) for (size_t index = 0; index < eclass.nodes.size(); ++index) { - const auto& node = eclass.nodes[index]; + const auto& node = eclass.nodes[index].node; const std::string label = getNodeName(strings, node); const std::string nodeName = "n" + std::to_string(uint32_t(id)) + "_" + std::to_string(index); @@ -965,7 +1079,7 @@ std::string toDot(const StringCache& strings, const EGraph& egraph) { for (size_t index = 0; index < eclass.nodes.size(); ++index) { - const auto& node = eclass.nodes[index]; + const auto& node = eclass.nodes[index].node; const std::string label = getNodeName(strings, node); const std::string nodeName = "n" + std::to_string(uint32_t(egraph.find(id))) + "_" + std::to_string(index); @@ -1001,7 +1115,7 @@ static Tag const* isTag(const EGraph& egraph, Id id) { for (const auto& node : egraph[id].nodes) { - if (auto n = isTag(node)) + if (auto n = isTag(node.node)) return n; } return nullptr; @@ -1037,7 +1151,7 @@ protected: { for (const auto& node : (*egraph)[id].nodes) { - if (auto n = node.get()) + if (auto n = node.node.get()) return n; } return nullptr; @@ -1225,8 +1339,10 @@ const EType* findSubtractableClass(const EGraph& egraph, std::unordered_set& const EType* bestUnion = nullptr; std::optional unionSize; - for (const auto& node : egraph[id].nodes) + for (const auto& n : egraph[id].nodes) { + const EType& node = n.node; + if (isTerminal(node)) return &node; @@ -1342,14 +1458,14 @@ bool subtract(EGraph& egraph, CanonicalizedType& ct, Id part) return true; } -Id fromCanonicalized(EGraph& egraph, CanonicalizedType& ct) +static std::pair fromCanonicalized(EGraph& egraph, CanonicalizedType& ct) { if (ct.isUnknown()) { if (ct.errorPart) - return egraph.add(TAny{}); + return {egraph.add(TAny{}), 1}; else - return egraph.add(TUnknown{}); + return {egraph.add(TUnknown{}), 1}; } std::vector parts; @@ -1387,7 +1503,12 @@ Id fromCanonicalized(EGraph& egraph, CanonicalizedType& ct) parts.insert(parts.end(), ct.functionParts.begin(), ct.functionParts.end()); parts.insert(parts.end(), ct.otherParts.begin(), ct.otherParts.end()); - return mkUnion(egraph, std::move(parts)); + std::sort(parts.begin(), parts.end()); + auto it = std::unique(parts.begin(), parts.end()); + parts.erase(it, parts.end()); + + const size_t size = parts.size(); + return {mkUnion(egraph, std::move(parts)), size}; } void addChildren(const EGraph& egraph, const EType* enode, VecDeque& worklist) @@ -1433,7 +1554,7 @@ const Tag* Simplifier::isTag(Id id) const { for (const auto& node : get(id).nodes) { - if (const Tag* ty = node.get()) + if (const Tag* ty = node.node.get()) return ty; } @@ -1467,6 +1588,16 @@ void Simplifier::subst(Id from, Id to, const std::string& ruleName, const std::u substs.emplace_back(from, to, desc); } +void Simplifier::subst(Id from, size_t boringIndex, Id to, const std::string& ruleName, const std::unordered_map& forceNodes) +{ + std::string desc; + if (FFlag::DebugLuauLogSimplification) + desc = mkDesc(egraph, stringCache, arena, builtinTypes, from, to, forceNodes, ruleName); + + egraph.markBoring(from, boringIndex); + substs.emplace_back(from, to, desc); +} + void Simplifier::unionClasses(std::vector& hereParts, Id there) { if (1 == hereParts.size() && isTag(hereParts[0])) @@ -1517,9 +1648,12 @@ void Simplifier::simplifyUnion(Id id) for (Id part : u->operands()) unionWithType(egraph, canonicalized, find(part)); - Id resultId = fromCanonicalized(egraph, canonicalized); + const auto [resultId, newSize] = fromCanonicalized(egraph, canonicalized); - subst(id, resultId, "simplifyUnion", {{id, unionIndex}}); + if (newSize < u->operands().size()) + subst(id, unionIndex, resultId, "simplifyUnion", {{id, unionIndex}}); + else + subst(id, resultId, "simplifyUnion", {{id, unionIndex}}); } } @@ -1552,11 +1686,6 @@ std::optional intersectOne(EGraph& egraph, Id hereId, const EType* hereNo thereNode->get() || thereNode->get() || hereNode->get() || thereNode->get()) return std::nullopt; - if (hereNode->get()) - return *thereNode; - if (thereNode->get()) - return *hereNode; - if (hereNode->get()) return *thereNode; if (thereNode->get()) @@ -1732,7 +1861,7 @@ void Simplifier::uninhabitedIntersection(Id id) const auto& partNodes = egraph[partId].nodes; for (size_t partIndex = 0; partIndex < partNodes.size(); ++partIndex) { - const EType& N = partNodes[partIndex]; + const EType& N = partNodes[partIndex].node; if (std::optional intersection = intersectOne(egraph, accumulator, &accumulatorNode, partId, &N)) { if (isTag(*intersection)) @@ -1755,9 +1884,14 @@ void Simplifier::uninhabitedIntersection(Id id) if ((unsimplified.empty() || !isTag(accumulator)) && find(accumulator) != id) unsimplified.push_back(accumulator); + const bool isSmaller = unsimplified.size() < parts.size(); + const Id result = mkIntersection(egraph, std::move(unsimplified)); - subst(id, result, "uninhabitedIntersection", {{id, index}}); + if (isSmaller) + subst(id, index, result, "uninhabitedIntersection", {{id, index}}); + else + subst(id, result, "uninhabitedIntersection", {{id, index}}); } } @@ -1788,14 +1922,19 @@ void Simplifier::intersectWithNegatedClass(Id id) const auto& iNodes = egraph[iId].nodes; for (size_t iIndex = 0; iIndex < iNodes.size(); ++iIndex) { - const EType& iNode = iNodes[iIndex]; + const EType& iNode = iNodes[iIndex].node; if (isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode) || // isTag(iNode) || // I'm not sure about this one. isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode)) { // eg string & ~SomeClass - subst(id, iId, "intersectClassWithNegatedClass", {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}}); + subst( + id, + iId, + "intersectClassWithNegatedClass", + {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}} + ); return; } @@ -1803,27 +1942,37 @@ void Simplifier::intersectWithNegatedClass(Id id) { switch (relateClasses(class_, negatedClass)) { - case LeftSuper: - // eg Instance & ~Part - // This cannot be meaningfully reduced. - continue; - case RightSuper: - subst(id, egraph.add(TNever{}), "intersectClassWithNegatedClass", {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}}); - return; - case Unrelated: - // Part & ~Folder == Part + case LeftSuper: + // eg Instance & ~Part + // This cannot be meaningfully reduced. + continue; + case RightSuper: + subst( + id, + egraph.add(TNever{}), + "intersectClassWithNegatedClass", + {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}} + ); + return; + case Unrelated: + // Part & ~Folder == Part + { + std::vector newParts; + newParts.reserve(intersection->operands().size() - 1); + for (Id part : intersection->operands()) { - std::vector newParts; - newParts.reserve(intersection->operands().size() - 1); - for (Id part : intersection->operands()) - { - if (part != jId) - newParts.push_back(part); - } - - Id substId = egraph.add(Intersection{newParts.begin(), newParts.end()}); - subst(id, substId, "intersectClassWithNegatedClass", {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}}); + if (part != jId) + newParts.push_back(part); } + + Id substId = mkIntersection(egraph, newParts); + subst( + id, + substId, + "intersectClassWithNegatedClass", + {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}} + ); + } } } } @@ -1839,6 +1988,74 @@ void Simplifier::intersectWithNegatedClass(Id id) } } +void Simplifier::intersectWithNegatedAtom(Id id) +{ + // Let I and ~J be two arbitrary distinct operands of an intersection where + // I and J are terminal but are not type variables. (free, generic, or + // otherwise opaque) + // + // If I and J are equal, then the whole intersection is equivalent to never. + // + // If I and J are inequal, then J & ~I == J + + for (const auto [intersection, intersectionIndex] : Query(&egraph, id)) + { + const Slice& intersectionOperands = intersection->operands(); + for (size_t i = 0; i < intersectionOperands.size(); ++i) + { + for (const auto [negation, negationIndex] : Query(&egraph, intersectionOperands[i])) + { + for (size_t negationOperandIndex = 0; negationOperandIndex < egraph[negation->operands()[0]].nodes.size(); ++negationOperandIndex) + { + const EType* negationOperand = &egraph[negation->operands()[0]].nodes[negationOperandIndex].node; + if (!isTerminal(*negationOperand) || negationOperand->get()) + continue; + + for (size_t j = 0; j < intersectionOperands.size(); ++j) + { + if (j == i) + continue; + + for (size_t jNodeIndex = 0; jNodeIndex < egraph[intersectionOperands[j]].nodes.size(); ++jNodeIndex) + { + const EType* jNode = &egraph[intersectionOperands[j]].nodes[jNodeIndex].node; + if (!isTerminal(*jNode) || jNode->get()) + continue; + + if (*negationOperand == *jNode) + { + // eg "Hello" & ~"Hello" + // or boolean & ~boolean + subst( + id, + egraph.add(TNever{}), + "intersectWithNegatedAtom", + {{id, intersectionIndex}, {intersectionOperands[i], negationIndex}, {intersectionOperands[j], jNodeIndex}} + ); + return; + } + else if (areTerminalAndDefinitelyDisjoint(*jNode, *negationOperand)) + { + // eg "Hello" & ~"World" + // or boolean & ~string + std::vector newOperands(intersectionOperands.begin(), intersectionOperands.end()); + newOperands.erase(newOperands.begin() + std::vector::difference_type(i)); + + subst( + id, + mkIntersection(egraph, std::move(newOperands)), + "intersectWithNegatedAtom", + {{id, intersectionIndex}, {intersectionOperands[i], negationIndex}, {intersectionOperands[j], jNodeIndex}} + ); + } + } + } + } + } + } + } +} + void Simplifier::intersectWithNoRefine(Id id) { for (const auto pair : Query(&egraph, id)) @@ -2003,7 +2220,7 @@ void Simplifier::expandNegation(Id id) if (!ok) continue; - subst(id, fromCanonicalized(egraph, canonicalized), "expandNegation", {{id, index}}); + subst(id, fromCanonicalized(egraph, canonicalized).first, "expandNegation", {{id, index}}); } } @@ -2160,7 +2377,7 @@ void Simplifier::intersectTableProperty(Id id) subst( id, - egraph.add(Intersection{std::move(newIntersectionParts)}), + mkIntersection(egraph, std::move(newIntersectionParts)), "intersectTableProperty", {{id, intersectionIndex}, {iId, table1Index}, {jId, table2Index}} ); @@ -2250,7 +2467,7 @@ void Simplifier::builtinTypeFunctions(Id id) if (args.size() != 2) continue; - const std::string& name = tfun->value()->name; + const std::string& name = tfun->value()->function->name; if (name == "add" || name == "sub" || name == "mul" || name == "div" || name == "idiv" || name == "pow" || name == "mod") { if (isTag(args[0]) && isTag(args[1])) @@ -2272,15 +2489,43 @@ void Simplifier::iffyTypeFunctions(Id id) { const Slice& args = tfun->operands(); - const std::string& name = tfun->value()->name; + const std::string& name = tfun->value()->function->name; if (name == "union") subst(id, add(Union{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}}); - else if (name == "intersect" || name == "refine") + else if (name == "intersect") subst(id, add(Intersection{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}}); } } +// Replace instances of `lt` and `le` when either X or Y is `number` +// or `string` with `boolean`. Lua semantics are that if we see the expression: +// +// x < y +// +// ... we error if `x` and `y` don't have the same type. We know that for +// `string` and `number`, comparisons will always return a boolean. So if either +// of the arguments to `lt<>` are equivalent to `number` or `string`, then the +// type is effectively `boolean`: either the other type is equivalent, in which +// case we eval to `boolean`, or we diverge (raise an error). +void Simplifier::strictMetamethods(Id id) +{ + for (const auto [tfun, index] : Query(&egraph, id)) + { + const Slice& args = tfun->operands(); + + const std::string& name = tfun->value()->function->name; + + if (!(name == "lt" || name == "le") || args.size() != 2) + continue; + + if (isTag(args[0]) || isTag(args[0]) || isTag(args[1]) || isTag(args[1])) + { + subst(id, add(TBoolean{}), __FUNCTION__, {{id, index}}); + } + } +} + static void deleteSimplifier(Simplifier* s) { delete s; @@ -2308,6 +2553,7 @@ std::optional eqSatSimplify(NotNull simpl &Simplifier::simplifyUnion, &Simplifier::uninhabitedIntersection, &Simplifier::intersectWithNegatedClass, + &Simplifier::intersectWithNegatedAtom, &Simplifier::intersectWithNoRefine, &Simplifier::cyclicIntersectionOfUnion, &Simplifier::cyclicUnionOfIntersection, @@ -2318,6 +2564,7 @@ std::optional eqSatSimplify(NotNull simpl &Simplifier::unneededTableModification, &Simplifier::builtinTypeFunctions, &Simplifier::iffyTypeFunctions, + &Simplifier::strictMetamethods, }; std::unordered_set seen; @@ -2371,9 +2618,9 @@ std::optional eqSatSimplify(NotNull simpl // try to run any rules on it. bool shouldAbort = false; - for (const EType& enode : egraph[id].nodes) + for (const auto& enode : egraph[id].nodes) { - if (isTerminal(enode)) + if (isTerminal(enode.node)) { shouldAbort = true; break; @@ -2383,8 +2630,8 @@ std::optional eqSatSimplify(NotNull simpl if (shouldAbort) continue; - for (const EType& enode : egraph[id].nodes) - addChildren(egraph, &enode, worklist); + for (const auto& enode : egraph[id].nodes) + addChildren(egraph, &enode.node, worklist); for (Simplifier::RewriteRuleFn rule : rules) (simplifier.get()->*rule)(id); diff --git a/Analysis/src/FragmentAutocomplete.cpp b/Analysis/src/FragmentAutocomplete.cpp index 5819d309..bc82d750 100644 --- a/Analysis/src/FragmentAutocomplete.cpp +++ b/Analysis/src/FragmentAutocomplete.cpp @@ -6,6 +6,7 @@ #include "Luau/Autocomplete.h" #include "Luau/Common.h" #include "Luau/EqSatSimplification.h" +#include "Luau/ModuleResolver.h" #include "Luau/Parser.h" #include "Luau/ParseOptions.h" #include "Luau/Module.h" @@ -19,16 +20,21 @@ #include "Luau/Parser.h" #include "Luau/ParseOptions.h" #include "Luau/Module.h" - +#include "Luau/Clone.h" #include "AutocompleteCore.h" - LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferIterationLimit); LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauAllowFragmentParsing); LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete) +LUAU_FASTFLAGVARIABLE(LuauIncrementalAutocompleteBugfixes) +LUAU_FASTFLAG(LuauReferenceAllocatorInNewSolver) +LUAU_FASTFLAGVARIABLE(LuauMixedModeDefFinderTraversesTypeOf) +LUAU_FASTFLAG(LuauBetterReverseDependencyTracking) +LUAU_FASTFLAGVARIABLE(LuauCloneIncrementalModule) + namespace { template @@ -49,6 +55,96 @@ void copyModuleMap(Luau::DenseHashMap& result, const Luau::DenseHashMap +void cloneModuleMap(TypeArena& destArena, CloneState& cloneState, const Luau::DenseHashMap& source, Luau::DenseHashMap& dest) +{ + for (auto [k, v] : source) + { + dest[k] = Luau::clone(v, destArena, cloneState); + } +} + +struct MixedModeIncrementalTCDefFinder : public AstVisitor +{ + bool visit(AstExprLocal* local) override + { + referencedLocalDefs.emplace_back(local->local, local); + return true; + } + + bool visit(AstTypeTypeof* node) override + { + // We need to traverse typeof expressions because they may refer to locals that we need + // to populate the local environment for fragment typechecking. For example, `typeof(m)` + // requires that we find the local/global `m` and place it in the environment. + // The default behaviour here is to return false, and have individual visitors override + // the specific behaviour they need. + return FFlag::LuauMixedModeDefFinderTraversesTypeOf; + } + + // ast defs is just a mapping from expr -> def in general + // will get built up by the dfg builder + + // localDefs, we need to copy over + std::vector> referencedLocalDefs; +}; + +void cloneAndSquashScopes( + CloneState& cloneState, + const Scope* staleScope, + const ModulePtr& staleModule, + NotNull destArena, + NotNull dfg, + AstStatBlock* program, + Scope* destScope +) +{ + std::vector scopes; + for (const Scope* current = staleScope; current; current = current->parent.get()) + { + scopes.emplace_back(current); + } + + // in reverse order (we need to clone the parents and override defs as we go down the list) + for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) + { + const Scope* curr = *it; + // Clone the lvalue types + for (const auto& [def, ty] : curr->lvalueTypes) + destScope->lvalueTypes[def] = Luau::clone(ty, *destArena, cloneState); + // Clone the rvalueRefinements + for (const auto& [def, ty] : curr->rvalueRefinements) + destScope->rvalueRefinements[def] = Luau::clone(ty, *destArena, cloneState); + for (const auto& [n, m] : curr->importedTypeBindings) + { + std::unordered_map importedBindingTypes; + for (const auto& [v, tf] : m) + importedBindingTypes[v] = Luau::clone(tf, *destArena, cloneState); + destScope->importedTypeBindings[n] = m; + } + + // Finally, clone up the bindings + for (const auto& [s, b] : curr->bindings) + { + destScope->bindings[s] = Luau::clone(b, *destArena, cloneState); + } + } + + // The above code associates defs with TypeId's in the scope + // so that lookup to locals will succeed. + MixedModeIncrementalTCDefFinder finder; + program->visit(&finder); + std::vector> locals = std::move(finder.referencedLocalDefs); + for (auto [loc, expr] : locals) + { + if (std::optional binding = staleScope->linearSearchForBinding(loc->name.value, true)) + { + destScope->lvalueTypes[dfg->getDef(expr)] = Luau::clone(binding->typeId, *destArena, cloneState); + } + } + return; +} + static FrontendModuleResolver& getModuleResolver(Frontend& frontend, std::optional options) { if (FFlag::LuauSolverV2 || !options) @@ -200,7 +296,7 @@ ScopePtr findClosestScope(const ModulePtr& module, const AstStat* nearestStateme return closest; } -FragmentParseResult parseFragment( +std::optional parseFragment( const SourceModule& srcModule, std::string_view src, const Position& cursorPos, @@ -245,6 +341,9 @@ FragmentParseResult parseFragment( opts.captureComments = true; opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack), startPos}; ParseResult p = Luau::Parser::parse(srcStart, parseLength, *nameTbl, *fragmentResult.alloc.get(), opts); + // This means we threw a ParseError and we should decline to offer autocomplete here. + if (p.root == nullptr) + return std::nullopt; std::vector fabricatedAncestry = std::move(result.ancestry); @@ -258,16 +357,39 @@ FragmentParseResult parseFragment( fragmentResult.root = std::move(p.root); fragmentResult.ancestry = std::move(fabricatedAncestry); fragmentResult.nearestStatement = nearestStatement; + fragmentResult.commentLocations = std::move(p.commentLocations); return fragmentResult; } +ModulePtr cloneModule(CloneState& cloneState, const ModulePtr& source, std::unique_ptr alloc) +{ + freeze(source->internalTypes); + freeze(source->interfaceTypes); + ModulePtr incremental = std::make_shared(); + incremental->name = source->name; + incremental->humanReadableName = source->humanReadableName; + incremental->allocator = std::move(alloc); + // Clone types + cloneModuleMap(incremental->internalTypes, cloneState, source->astTypes, incremental->astTypes); + cloneModuleMap(incremental->internalTypes, cloneState, source->astTypePacks, incremental->astTypePacks); + cloneModuleMap(incremental->internalTypes, cloneState, source->astExpectedTypes, incremental->astExpectedTypes); + + cloneModuleMap(incremental->internalTypes, cloneState, source->astOverloadResolvedTypes, incremental->astOverloadResolvedTypes); + + cloneModuleMap(incremental->internalTypes, cloneState, source->astForInNextTypes, incremental->astForInNextTypes); + + copyModuleMap(incremental->astScopes, source->astScopes); + + return incremental; +} + ModulePtr copyModule(const ModulePtr& result, std::unique_ptr alloc) { - freeze(result->internalTypes); - freeze(result->interfaceTypes); ModulePtr incrementalModule = std::make_shared(); incrementalModule->name = result->name; - incrementalModule->humanReadableName = result->humanReadableName; + incrementalModule->humanReadableName = "Incremental$" + result->humanReadableName; + incrementalModule->internalTypes.owningModule = incrementalModule.get(); + incrementalModule->interfaceTypes.owningModule = incrementalModule.get(); incrementalModule->allocator = std::move(alloc); // Don't need to keep this alive (it's already on the source module) copyModuleVec(incrementalModule->scopes, result->scopes); @@ -286,21 +408,6 @@ ModulePtr copyModule(const ModulePtr& result, std::unique_ptr alloc) return incrementalModule; } -struct MixedModeIncrementalTCDefFinder : public AstVisitor -{ - bool visit(AstExprLocal* local) override - { - referencedLocalDefs.push_back({local->local, local}); - return true; - } - - // ast defs is just a mapping from expr -> def in general - // will get built up by the dfg builder - - // localDefs, we need to copy over - std::vector> referencedLocalDefs; -}; - void mixedModeCompatibility( const ScopePtr& bottomScopeStale, const ScopePtr& myFakeScope, @@ -339,7 +446,9 @@ FragmentTypeCheckResult typecheckFragment_( { freeze(stale->internalTypes); freeze(stale->interfaceTypes); - ModulePtr incrementalModule = copyModule(stale, std::move(astAllocator)); + CloneState cloneState{frontend.builtinTypes}; + ModulePtr incrementalModule = + FFlag::LuauCloneIncrementalModule ? cloneModule(cloneState, stale, std::move(astAllocator)) : copyModule(stale, std::move(astAllocator)); incrementalModule->checkedInNewSolver = true; unfreeze(incrementalModule->internalTypes); unfreeze(incrementalModule->interfaceTypes); @@ -366,7 +475,8 @@ FragmentTypeCheckResult typecheckFragment_( TypeFunctionRuntime typeFunctionRuntime(iceHandler, NotNull{&limits}); /// Create a DataFlowGraph just for the surrounding context - auto dfg = DataFlowGraphBuilder::build(root, iceHandler); + DataFlowGraph dfg = DataFlowGraphBuilder::build(root, NotNull{&incrementalModule->defArena}, NotNull{&incrementalModule->keyArena}, iceHandler); + SimplifierPtr simplifier = newSimplifier(NotNull{&incrementalModule->internalTypes}, frontend.builtinTypes); FrontendModuleResolver& resolver = getModuleResolver(frontend, opts); @@ -386,25 +496,34 @@ FragmentTypeCheckResult typecheckFragment_( NotNull{&dfg}, {} }; + std::shared_ptr freshChildOfNearestScope = nullptr; + if (FFlag::LuauCloneIncrementalModule) + { + freshChildOfNearestScope = std::make_shared(closestScope); + incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope); + cg.rootScope = freshChildOfNearestScope.get(); - cg.rootScope = stale->getModuleScope().get(); - // Any additions to the scope must occur in a fresh scope - auto freshChildOfNearestScope = std::make_shared(closestScope); - incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope); - - // Update freshChildOfNearestScope with the appropriate lvalueTypes - mixedModeCompatibility(closestScope, freshChildOfNearestScope, stale, NotNull{&dfg}, root); - - // closest Scope -> children = { ...., freshChildOfNearestScope} - // We need to trim nearestChild from the scope hierarcy - closestScope->children.push_back(NotNull{freshChildOfNearestScope.get()}); - // Visit just the root - we know the scope it should be in - cg.visitFragmentRoot(freshChildOfNearestScope, root); - // Trim nearestChild from the closestScope - Scope* back = closestScope->children.back().get(); - LUAU_ASSERT(back == freshChildOfNearestScope.get()); - closestScope->children.pop_back(); - + cloneAndSquashScopes( + cloneState, closestScope.get(), stale, NotNull{&incrementalModule->internalTypes}, NotNull{&dfg}, root, freshChildOfNearestScope.get() + ); + cg.visitFragmentRoot(freshChildOfNearestScope, root); + } + else + { + // Any additions to the scope must occur in a fresh scope + cg.rootScope = stale->getModuleScope().get(); + freshChildOfNearestScope = std::make_shared(closestScope); + incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope); + mixedModeCompatibility(closestScope, freshChildOfNearestScope, stale, NotNull{&dfg}, root); + // closest Scope -> children = { ...., freshChildOfNearestScope} + // We need to trim nearestChild from the scope hierarcy + closestScope->children.emplace_back(freshChildOfNearestScope.get()); + cg.visitFragmentRoot(freshChildOfNearestScope, root); + // Trim nearestChild from the closestScope + Scope* back = closestScope->children.back().get(); + LUAU_ASSERT(back == freshChildOfNearestScope.get()); + closestScope->children.pop_back(); + } /// Initialize the constraint solver and run it ConstraintSolver cs{ @@ -444,7 +563,7 @@ FragmentTypeCheckResult typecheckFragment_( } -FragmentTypeCheckResult typecheckFragment( +std::pair typecheckFragment( Frontend& frontend, const ModuleName& moduleName, const Position& cursorPos, @@ -453,6 +572,13 @@ FragmentTypeCheckResult typecheckFragment( std::optional fragmentEndPosition ) { + + if (FFlag::LuauBetterReverseDependencyTracking) + { + if (!frontend.allModuleDependenciesValid(moduleName, opts && opts->forAutocomplete)) + return {FragmentTypeCheckStatus::SkipAutocomplete, {}}; + } + const SourceModule* sourceModule = frontend.getSourceModule(moduleName); if (!sourceModule) { @@ -468,13 +594,30 @@ FragmentTypeCheckResult typecheckFragment( return {}; } - FragmentParseResult parseResult = parseFragment(*sourceModule, src, cursorPos, fragmentEndPosition); + if (FFlag::LuauIncrementalAutocompleteBugfixes && FFlag::LuauReferenceAllocatorInNewSolver) + { + if (sourceModule->allocator.get() != module->allocator.get()) + { + return {FragmentTypeCheckStatus::SkipAutocomplete, {}}; + } + } + + auto tryParse = parseFragment(*sourceModule, src, cursorPos, fragmentEndPosition); + + if (!tryParse) + return {FragmentTypeCheckStatus::SkipAutocomplete, {}}; + + FragmentParseResult& parseResult = *tryParse; + + if (isWithinComment(parseResult.commentLocations, fragmentEndPosition.value_or(cursorPos))) + return {FragmentTypeCheckStatus::SkipAutocomplete, {}}; + FrontendOptions frontendOptions = opts.value_or(frontend.options); const ScopePtr& closestScope = findClosestScope(module, parseResult.nearestStatement); FragmentTypeCheckResult result = typecheckFragment_(frontend, parseResult.root, module, closestScope, cursorPos, std::move(parseResult.alloc), frontendOptions); result.ancestry = std::move(parseResult.ancestry); - return result; + return {FragmentTypeCheckStatus::Success, result}; } @@ -498,7 +641,14 @@ FragmentAutocompleteResult fragmentAutocomplete( return {}; } - auto tcResult = typecheckFragment(frontend, moduleName, cursorPosition, opts, src, fragmentEndPosition); + // If the cursor is within a comment in the stale source module we should avoid providing a recommendation + if (isWithinComment(*sourceModule, fragmentEndPosition.value_or(cursorPosition))) + return {}; + + auto [tcStatus, tcResult] = typecheckFragment(frontend, moduleName, cursorPosition, opts, src, fragmentEndPosition); + if (tcStatus == FragmentTypeCheckStatus::SkipAutocomplete) + return {}; + auto globalScope = (opts && opts->forAutocomplete) ? frontend.globalsForAutocomplete.globalScope.get() : frontend.globals.globalScope.get(); TypeArena arenaForFragmentAutocomplete; diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 053e99c2..0292726b 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -13,6 +13,7 @@ #include "Luau/EqSatSimplification.h" #include "Luau/FileResolver.h" #include "Luau/NonStrictTypeChecker.h" +#include "Luau/NotNull.h" #include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/StringUtils.h" @@ -38,7 +39,6 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3) -LUAU_FASTFLAGVARIABLE(LuauStoreCommentsForDefinitionFiles) LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJsonFile) @@ -47,9 +47,14 @@ LUAU_FASTFLAGVARIABLE(DebugLuauForceStrictMode) LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false) +LUAU_FASTFLAGVARIABLE(LuauBetterReverseDependencyTracking) + LUAU_FASTFLAG(StudioReportLuauAny2) LUAU_FASTFLAGVARIABLE(LuauStoreSolverTypeOnModule) +LUAU_FASTFLAGVARIABLE(LuauReferenceAllocatorInNewSolver) +LUAU_FASTFLAGVARIABLE(LuauSelectivelyRetainDFGArena) + namespace Luau { @@ -135,7 +140,7 @@ static ParseResult parseSourceForModule(std::string_view source, Luau::SourceMod sourceModule.root = parseResult.root; sourceModule.mode = Mode::Definition; - if (FFlag::LuauStoreCommentsForDefinitionFiles && options.captureComments) + if (options.captureComments) { sourceModule.hotcomments = parseResult.hotcomments; sourceModule.commentLocations = parseResult.commentLocations; @@ -817,6 +822,16 @@ bool Frontend::parseGraph( topseen = Permanent; buildQueue.push_back(top->name); + + if (FFlag::LuauBetterReverseDependencyTracking) + { + // at this point we know all valid dependencies are processed into SourceNodes + for (const ModuleName& dep : top->requireSet) + { + if (auto it = sourceNodes.find(dep); it != sourceNodes.end()) + it->second->dependents.insert(top->name); + } + } } else { @@ -1046,6 +1061,11 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) freeze(module->interfaceTypes); module->internalTypes.clear(); + if (FFlag::LuauSelectivelyRetainDFGArena) + { + module->defArena.allocator.clear(); + module->keyArena.allocator.clear(); + } module->astTypes.clear(); module->astTypePacks.clear(); @@ -1099,15 +1119,49 @@ void Frontend::recordItemResult(const BuildQueueItem& item) if (item.exception) std::rethrow_exception(item.exception); - if (item.options.forAutocomplete) + if (FFlag::LuauBetterReverseDependencyTracking) { - moduleResolverForAutocomplete.setModule(item.name, item.module); - item.sourceNode->dirtyModuleForAutocomplete = false; + bool replacedModule = false; + if (item.options.forAutocomplete) + { + replacedModule = moduleResolverForAutocomplete.setModule(item.name, item.module); + item.sourceNode->dirtyModuleForAutocomplete = false; + } + else + { + replacedModule = moduleResolver.setModule(item.name, item.module); + item.sourceNode->dirtyModule = false; + } + + if (replacedModule) + { + LUAU_TIMETRACE_SCOPE("Frontend::invalidateDependentModules", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", item.name.c_str()); + traverseDependents( + item.name, + [forAutocomplete = item.options.forAutocomplete](SourceNode& sourceNode) + { + bool traverseSubtree = !sourceNode.hasInvalidModuleDependency(forAutocomplete); + sourceNode.setInvalidModuleDependency(true, forAutocomplete); + return traverseSubtree; + } + ); + } + + item.sourceNode->setInvalidModuleDependency(false, item.options.forAutocomplete); } else { - moduleResolver.setModule(item.name, item.module); - item.sourceNode->dirtyModule = false; + if (item.options.forAutocomplete) + { + moduleResolverForAutocomplete.setModule(item.name, item.module); + item.sourceNode->dirtyModuleForAutocomplete = false; + } + else + { + moduleResolver.setModule(item.name, item.module); + item.sourceNode->dirtyModule = false; + } } stats.timeCheck += item.stats.timeCheck; @@ -1144,6 +1198,13 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config return result; } +bool Frontend::allModuleDependenciesValid(const ModuleName& name, bool forAutocomplete) const +{ + LUAU_ASSERT(FFlag::LuauBetterReverseDependencyTracking); + auto it = sourceNodes.find(name); + return it != sourceNodes.end() && !it->second->hasInvalidModuleDependency(forAutocomplete); +} + bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const { auto it = sourceNodes.find(name); @@ -1158,16 +1219,80 @@ bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const */ void Frontend::markDirty(const ModuleName& name, std::vector* markedDirty) { + LUAU_TIMETRACE_SCOPE("Frontend::markDirty", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + + if (FFlag::LuauBetterReverseDependencyTracking) + { + traverseDependents( + name, + [markedDirty](SourceNode& sourceNode) + { + if (markedDirty) + markedDirty->push_back(sourceNode.name); + + if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) + return false; + + sourceNode.dirtySourceModule = true; + sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; + + return true; + } + ); + } + else + { + if (sourceNodes.count(name) == 0) + return; + + std::unordered_map> reverseDeps; + for (const auto& module : sourceNodes) + { + for (const auto& dep : module.second->requireSet) + reverseDeps[dep].push_back(module.first); + } + + std::vector queue{name}; + + while (!queue.empty()) + { + ModuleName next = std::move(queue.back()); + queue.pop_back(); + + LUAU_ASSERT(sourceNodes.count(next) > 0); + SourceNode& sourceNode = *sourceNodes[next]; + + if (markedDirty) + markedDirty->push_back(next); + + if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) + continue; + + sourceNode.dirtySourceModule = true; + sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; + + if (0 == reverseDeps.count(next)) + continue; + + sourceModules.erase(next); + + const std::vector& dependents = reverseDeps[next]; + queue.insert(queue.end(), dependents.begin(), dependents.end()); + } + } +} + +void Frontend::traverseDependents(const ModuleName& name, std::function processSubtree) +{ + LUAU_ASSERT(FFlag::LuauBetterReverseDependencyTracking); + LUAU_TIMETRACE_SCOPE("Frontend::traverseDependents", "Frontend"); + if (sourceNodes.count(name) == 0) return; - std::unordered_map> reverseDeps; - for (const auto& module : sourceNodes) - { - for (const auto& dep : module.second->requireSet) - reverseDeps[dep].push_back(module.first); - } - std::vector queue{name}; while (!queue.empty()) @@ -1178,22 +1303,10 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked LUAU_ASSERT(sourceNodes.count(next) > 0); SourceNode& sourceNode = *sourceNodes[next]; - if (markedDirty) - markedDirty->push_back(next); - - if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) + if (!processSubtree(sourceNode)) continue; - sourceNode.dirtySourceModule = true; - sourceNode.dirtyModule = true; - sourceNode.dirtyModuleForAutocomplete = true; - - if (0 == reverseDeps.count(next)) - continue; - - sourceModules.erase(next); - - const std::vector& dependents = reverseDeps[next]; + const Set& dependents = sourceNode.dependents; queue.insert(queue.end(), dependents.begin(), dependents.end()); } } @@ -1317,6 +1430,11 @@ ModulePtr check( result->mode = mode; result->internalTypes.owningModule = result.get(); result->interfaceTypes.owningModule = result.get(); + if (FFlag::LuauReferenceAllocatorInNewSolver) + { + result->allocator = sourceModule.allocator; + result->names = sourceModule.names; + } iceHandler->moduleName = sourceModule.name; @@ -1331,7 +1449,7 @@ ModulePtr check( } } - DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler); + DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, NotNull{&result->defArena}, NotNull{&result->keyArena}, iceHandler); UnifierSharedState unifierState{iceHandler}; unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; @@ -1427,6 +1545,7 @@ ModulePtr check( case Mode::Nonstrict: Luau::checkNonStrict( builtinTypes, + NotNull{simplifier.get()}, NotNull{&typeFunctionRuntime}, iceHandler, NotNull{&unifierState}, @@ -1440,7 +1559,14 @@ ModulePtr check( // fallthrough intentional case Mode::Strict: Luau::check( - builtinTypes, NotNull{&typeFunctionRuntime}, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get() + builtinTypes, + NotNull{simplifier.get()}, + NotNull{&typeFunctionRuntime}, + NotNull{&unifierState}, + NotNull{&limits}, + logger.get(), + sourceModule, + result.get() ); break; case Mode::NoCheck: @@ -1622,6 +1748,17 @@ std::pair Frontend::getSourceNode(const ModuleName& sourceNode->name = sourceModule->name; sourceNode->humanReadableName = sourceModule->humanReadableName; + + if (FFlag::LuauBetterReverseDependencyTracking) + { + // clear all prior dependents. we will re-add them after parsing the rest of the graph + for (const auto& [moduleName, _] : sourceNode->requireLocations) + { + if (auto depIt = sourceNodes.find(moduleName); depIt != sourceNodes.end()) + depIt->second->dependents.erase(sourceNode->name); + } + } + sourceNode->requireSet.clear(); sourceNode->requireLocations.clear(); sourceNode->dirtySourceModule = false; @@ -1743,11 +1880,21 @@ std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& return frontend->fileResolver->getHumanReadableModuleName(moduleName); } -void FrontendModuleResolver::setModule(const ModuleName& moduleName, ModulePtr module) +bool FrontendModuleResolver::setModule(const ModuleName& moduleName, ModulePtr module) { std::scoped_lock lock(moduleMutex); - modules[moduleName] = std::move(module); + if (FFlag::LuauBetterReverseDependencyTracking) + { + bool replaced = modules.count(moduleName) > 0; + modules[moduleName] = std::move(module); + return replaced; + } + else + { + modules[moduleName] = std::move(module); + return false; + } } void FrontendModuleResolver::clearModules() diff --git a/Analysis/src/Generalization.cpp b/Analysis/src/Generalization.cpp index 3eb14fda..ceffc307 100644 --- a/Analysis/src/Generalization.cpp +++ b/Analysis/src/Generalization.cpp @@ -10,12 +10,14 @@ #include "Luau/VisitType.h" LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete) +LUAU_FASTFLAGVARIABLE(LuauGeneralizationRemoveRecursiveUpperBound) namespace Luau { struct MutatingGeneralizer : TypeOnceVisitor { + NotNull arena; NotNull builtinTypes; NotNull scope; @@ -29,6 +31,7 @@ struct MutatingGeneralizer : TypeOnceVisitor bool avoidSealingTables = false; MutatingGeneralizer( + NotNull arena, NotNull builtinTypes, NotNull scope, NotNull> cachedTypes, @@ -37,6 +40,7 @@ struct MutatingGeneralizer : TypeOnceVisitor bool avoidSealingTables ) : TypeOnceVisitor(/* skipBoundTypes */ true) + , arena(arena) , builtinTypes(builtinTypes) , scope(scope) , cachedTypes(cachedTypes) @@ -229,6 +233,53 @@ struct MutatingGeneralizer : TypeOnceVisitor else { TypeId ub = follow(ft->upperBound); + if (FFlag::LuauGeneralizationRemoveRecursiveUpperBound) + { + + // If the upper bound is a union type or an intersection type, + // and one of it's members is the free type we're + // generalizing, don't include it in the upper bound. For a + // free type such as: + // + // t1 where t1 = D <: 'a <: (A | B | C | t1) + // + // Naively replacing it with it's upper bound creates: + // + // t1 where t1 = A | B | C | t1 + // + // It makes sense to just optimize this and exclude the + // recursive component by semantic subtyping rules. + + if (auto itv = get(ub)) + { + std::vector newIds; + newIds.reserve(itv->parts.size()); + for (auto part : itv) + { + if (part != ty) + newIds.push_back(part); + } + if (newIds.size() == 1) + ub = newIds[0]; + else if (newIds.size() > 0) + ub = arena->addType(IntersectionType{std::move(newIds)}); + } + else if (auto utv = get(ub)) + { + std::vector newIds; + newIds.reserve(utv->options.size()); + for (auto part : utv) + { + if (part != ty) + newIds.push_back(part); + } + if (newIds.size() == 1) + ub = newIds[0]; + else if (newIds.size() > 0) + ub = arena->addType(UnionType{std::move(newIds)}); + } + } + if (FreeType* upperFree = getMutable(ub); upperFree && upperFree->lowerBound == ty) upperFree->lowerBound = builtinTypes->neverType; else @@ -926,7 +977,8 @@ struct TypeCacher : TypeOnceVisitor return false; } - bool visit(TypePackId tp, const BoundTypePack& btp) override { + bool visit(TypePackId tp, const BoundTypePack& btp) override + { traverse(btp.boundTo); if (isUncacheable(btp.boundTo)) markUncacheable(tp); @@ -969,7 +1021,7 @@ std::optional generalize( FreeTypeSearcher fts{scope, cachedTypes}; fts.traverse(ty); - MutatingGeneralizer gen{builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables}; + MutatingGeneralizer gen{arena, builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables}; gen.traverse(ty); diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 4b6d1115..79b7f03e 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -61,9 +62,7 @@ TypeId Instantiation::clean(TypeId ty) LUAU_ASSERT(ftv); FunctionType clone = FunctionType{level, scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; - clone.magicFunction = ftv->magicFunction; - clone.dcrMagicFunction = ftv->dcrMagicFunction; - clone.dcrMagicRefinement = ftv->dcrMagicRefinement; + clone.magic = ftv->magic; clone.tags = ftv->tags; clone.argNames = ftv->argNames; TypeId result = addType(std::move(clone)); @@ -165,7 +164,7 @@ TypeId ReplaceGenerics::clean(TypeId ty) } else { - return addType(FreeType{scope, level}); + return FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtinTypes, scope, level) : addType(FreeType{scope, level}); } } diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index cd133ba0..3209fd08 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -15,11 +15,12 @@ #include LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAGVARIABLE(LuauIncrementalAutocompleteCommentDetection) namespace Luau { -static bool contains(Position pos, Comment comment) +static bool contains_DEPRECATED(Position pos, Comment comment) { if (comment.location.contains(pos)) return true; @@ -32,7 +33,22 @@ static bool contains(Position pos, Comment comment) return false; } -static bool isWithinComment(const std::vector& commentLocations, Position pos) +static bool contains(Position pos, Comment comment) +{ + if (comment.location.contains(pos)) + return true; + else if (comment.type == Lexeme::BrokenComment && comment.location.begin <= pos) // Broken comments are broken specifically because they don't + // have an end + return true; + // comments actually span the whole line - in incremental mode, we could pass a cursor outside of the current parsed comment range span, but it + // would still be 'within' the comment So, the cursor must be on the same line and the comment itself must come strictly after the `begin` + else if (comment.type == Lexeme::Comment && comment.location.end.line == pos.line && comment.location.begin <= pos) + return true; + else + return false; +} + +bool isWithinComment(const std::vector& commentLocations, Position pos) { auto iter = std::lower_bound( commentLocations.begin(), @@ -40,6 +56,11 @@ static bool isWithinComment(const std::vector& commentLocations, Positi Comment{Lexeme::Comment, Location{pos, pos}}, [](const Comment& a, const Comment& b) { + if (FFlag::LuauIncrementalAutocompleteCommentDetection) + { + if (a.type == Lexeme::Comment) + return a.location.end.line < b.location.end.line; + } return a.location.end < b.location.end; } ); @@ -47,7 +68,7 @@ static bool isWithinComment(const std::vector& commentLocations, Positi if (iter == commentLocations.end()) return false; - if (contains(pos, *iter)) + if (FFlag::LuauIncrementalAutocompleteCommentDetection ? contains(pos, *iter) : contains_DEPRECATED(pos, *iter)) return true; // Due to the nature of std::lower_bound, it is possible that iter points at a comment that ends @@ -149,9 +170,9 @@ struct ClonePublicInterface : Substitution freety->scope->location, module->name, InternalError{"Free type is escaping its module; please report this bug at " - "https://github.com/luau-lang/luau/issues"} - ); - result = builtinTypes->errorRecoveryType(); + "https://github.com/luau-lang/luau/issues"} + ); + result = builtinTypes->errorRecoveryType(); } else if (auto genericty = getMutable(result)) { @@ -173,8 +194,8 @@ struct ClonePublicInterface : Substitution ftp->scope->location, module->name, InternalError{"Free type pack is escaping its module; please report this bug at " - "https://github.com/luau-lang/luau/issues"} - ); + "https://github.com/luau-lang/luau/issues"} + ); clonedTp = builtinTypes->errorRecoveryTypePack(); } else if (auto gtp = getMutable(clonedTp)) diff --git a/Analysis/src/NonStrictTypeChecker.cpp b/Analysis/src/NonStrictTypeChecker.cpp index b4a5eaf6..0645e4e2 100644 --- a/Analysis/src/NonStrictTypeChecker.cpp +++ b/Analysis/src/NonStrictTypeChecker.cpp @@ -19,8 +19,8 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauUserTypeFunNonstrict) LUAU_FASTFLAGVARIABLE(LuauCountSelfCallsNonstrict) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -158,6 +158,7 @@ private: struct NonStrictTypeChecker { NotNull builtinTypes; + NotNull simplifier; NotNull typeFunctionRuntime; const NotNull ice; NotNull arena; @@ -174,6 +175,7 @@ struct NonStrictTypeChecker NonStrictTypeChecker( NotNull arena, NotNull builtinTypes, + NotNull simplifier, NotNull typeFunctionRuntime, const NotNull ice, NotNull unifierState, @@ -182,12 +184,13 @@ struct NonStrictTypeChecker Module* module ) : builtinTypes(builtinTypes) + , simplifier(simplifier) , typeFunctionRuntime(typeFunctionRuntime) , ice(ice) , arena(arena) , module(module) , normalizer{arena, builtinTypes, unifierState, /* cache inhabitance */ true} - , subtyping{builtinTypes, arena, NotNull(&normalizer), typeFunctionRuntime, ice} + , subtyping{builtinTypes, arena, simplifier, NotNull(&normalizer), typeFunctionRuntime, ice} , dfg(dfg) , limits(limits) { @@ -209,7 +212,7 @@ struct NonStrictTypeChecker return *fst; else if (auto ftp = get(pack)) { - TypeId result = arena->addType(FreeType{ftp->scope}); + TypeId result = FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtinTypes, ftp->scope) : arena->addType(FreeType{ftp->scope}); TypePackId freeTail = arena->addTypePack(FreeTypePack{ftp->scope}); TypePack* resultPack = emplaceTypePack(asMutable(pack)); @@ -232,13 +235,14 @@ struct NonStrictTypeChecker if (noTypeFunctionErrors.find(instance)) return instance; - ErrorVec errors = reduceTypeFunctions( - instance, - location, - TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, typeFunctionRuntime, ice, limits}, - true - ) - .errors; + ErrorVec errors = + reduceTypeFunctions( + instance, + location, + TypeFunctionContext{arena, builtinTypes, stack.back(), simplifier, NotNull{&normalizer}, typeFunctionRuntime, ice, limits}, + true + ) + .errors; if (errors.empty()) noTypeFunctionErrors.insert(instance); @@ -424,9 +428,6 @@ struct NonStrictTypeChecker NonStrictContext visit(AstStatTypeFunction* typeFunc) { - if (!FFlag::LuauUserTypeFunNonstrict) - reportError(GenericError{"This syntax is not supported"}, typeFunc->location); - return {}; } @@ -888,6 +889,7 @@ private: void checkNonStrict( NotNull builtinTypes, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull ice, NotNull unifierState, @@ -899,7 +901,9 @@ void checkNonStrict( { LUAU_TIMETRACE_SCOPE("checkNonStrict", "Typechecking"); - NonStrictTypeChecker typeChecker{NotNull{&module->internalTypes}, builtinTypes, typeFunctionRuntime, ice, unifierState, dfg, limits, module}; + NonStrictTypeChecker typeChecker{ + NotNull{&module->internalTypes}, builtinTypes, simplifier, typeFunctionRuntime, ice, unifierState, dfg, limits, module + }; typeChecker.visit(sourceModule.root); unfreeze(module->interfaceTypes); copyErrors(module->errors, module->interfaceTypes, builtinTypes); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index a2b440b9..9aa6fb97 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -17,12 +17,11 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant) -LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); -LUAU_FASTFLAG(LuauSolverV2); - +LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000) LUAU_FASTINTVARIABLE(LuauNormalizeIntersectionLimit, 200) -LUAU_FASTFLAGVARIABLE(LuauNormalizationTracksCyclicPairsThroughInhabitance); -LUAU_FASTFLAGVARIABLE(LuauIntersectNormalsNeedsToTrackResourceLimits); +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAGVARIABLE(LuauFixInfiniteRecursionInNormalization) +LUAU_FASTFLAGVARIABLE(LuauFixNormalizedIntersectionOfNegatedClass) namespace Luau { @@ -1809,7 +1808,8 @@ NormalizationResult Normalizer::unionNormalWithTy( } else if (get(here.tops)) return NormalizationResult::True; - else if (get(there) || get(there) || get(there) || get(there) || get(there)) + else if (get(there) || get(there) || get(there) || get(there) || + get(there)) { if (tyvarIndex(there) <= ignoreSmallerTyvars) return NormalizationResult::True; @@ -2284,9 +2284,24 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th else if (isSubclass(there, hereTy)) { TypeIds negations = std::move(hereNegations); + bool emptyIntersectWithNegation = false; for (auto nIt = negations.begin(); nIt != negations.end();) { + if (FFlag::LuauFixNormalizedIntersectionOfNegatedClass && isSubclass(there, *nIt)) + { + // Hitting this block means that the incoming class is a + // subclass of this type, _and_ one of its negations is a + // superclass of this type, e.g.: + // + // Dog & ~Animal + // + // Clearly this intersects to never, so we mark this class as + // being removed from the normalized class type. + emptyIntersectWithNegation = true; + break; + } + if (!isSubclass(*nIt, there)) { nIt = negations.erase(nIt); @@ -2299,7 +2314,8 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th it = heres.ordering.erase(it); heres.classes.erase(hereTy); - heres.pushPair(there, std::move(negations)); + if (!emptyIntersectWithNegation) + heres.pushPair(there, std::move(negations)); break; } // If the incoming class is a superclass of the current class, we don't @@ -2584,11 +2600,31 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there { if (tprop.readTy.has_value()) { - // if the intersection of the read types of a property is uninhabited, the whole table is `never`. - // We've seen these table prop elements before and we're about to ask if their intersection - // is inhabited - if (FFlag::LuauNormalizationTracksCyclicPairsThroughInhabitance) + if (FFlag::LuauFixInfiniteRecursionInNormalization) { + TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result; + + // If any property is going to get mapped to `never`, we can just call the entire table `never`. + // Since this check is syntactic, we may sometimes miss simplifying tables with complex uninhabited properties. + // Prior versions of this code attempted to do this semantically using the normalization machinery, but this + // mistakenly causes infinite loops when giving more complex recursive table types. As it stands, this approach + // will continue to scale as simplification is improved, but we may wish to reintroduce the semantic approach + // once we have revisited the usage of seen sets systematically (and possibly with some additional guarding to recognize + // when types are infinitely-recursive with non-pointer identical instances of them, or some guard to prevent that + // construction altogether). See also: `gh1632_no_infinite_recursion_in_normalization` + if (get(ty)) + return {builtinTypes->neverType}; + + prop.readTy = ty; + hereSubThere &= (ty == hprop.readTy); + thereSubHere &= (ty == tprop.readTy); + } + else + { + // if the intersection of the read types of a property is uninhabited, the whole table is `never`. + // We've seen these table prop elements before and we're about to ask if their intersection + // is inhabited + auto pair1 = std::pair{*hprop.readTy, *tprop.readTy}; auto pair2 = std::pair{*tprop.readTy, *hprop.readTy}; if (seenTablePropPairs.contains(pair1) || seenTablePropPairs.contains(pair2)) @@ -2603,6 +2639,8 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there seenTablePropPairs.insert(pair2); } + // FIXME(ariel): this is being added in a flag removal, so not changing the semantics here, but worth noting that this + // fresh `seenSet` is definitely a bug. we already have `seenSet` from the parameter that _should_ have been used here. Set seenSet{nullptr}; NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy, seenTablePropPairs, seenSet); @@ -2616,34 +2654,6 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there hereSubThere &= (ty == hprop.readTy); thereSubHere &= (ty == tprop.readTy); } - else - { - - if (seenSet.contains(*hprop.readTy) && seenSet.contains(*tprop.readTy)) - { - seenSet.erase(*hprop.readTy); - seenSet.erase(*tprop.readTy); - return {builtinTypes->neverType}; - } - else - { - seenSet.insert(*hprop.readTy); - seenSet.insert(*tprop.readTy); - } - - NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy); - - seenSet.erase(*hprop.readTy); - seenSet.erase(*tprop.readTy); - - if (NormalizationResult::True != res) - return {builtinTypes->neverType}; - - TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result; - prop.readTy = ty; - hereSubThere &= (ty == hprop.readTy); - thereSubHere &= (ty == tprop.readTy); - } } else { @@ -3042,12 +3052,9 @@ NormalizationResult Normalizer::intersectTyvarsWithTy( // See above for an explaination of `ignoreSmallerTyvars`. NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) { - if (FFlag::LuauIntersectNormalsNeedsToTrackResourceLimits) - { - RecursionCounter _rc(&sharedState->counters.recursionCount); - if (!withinResourceLimits()) - return NormalizationResult::HitLimits; - } + RecursionCounter _rc(&sharedState->counters.recursionCount); + if (!withinResourceLimits()) + return NormalizationResult::HitLimits; if (!get(there.tops)) { @@ -3162,7 +3169,8 @@ NormalizationResult Normalizer::intersectNormalWithTy( } return NormalizationResult::True; } - else if (get(there) || get(there) || get(there) || get(there) || get(there)) + else if (get(there) || get(there) || get(there) || get(there) || + get(there)) { NormalizedType thereNorm{builtinTypes}; NormalizedType topNorm{builtinTypes}; @@ -3465,7 +3473,14 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) return arena->addType(UnionType{std::move(result)}); } -bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) +bool isSubtype( + TypeId subTy, + TypeId superTy, + NotNull scope, + NotNull builtinTypes, + NotNull simplifier, + InternalErrorReporter& ice +) { UnifierSharedState sharedState{&ice}; TypeArena arena; @@ -3478,7 +3493,7 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull scope, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) +bool isSubtype( + TypePackId subPack, + TypePackId superPack, + NotNull scope, + NotNull builtinTypes, + NotNull simplifier, + InternalErrorReporter& ice +) { UnifierSharedState sharedState{&ice}; TypeArena arena; @@ -3504,7 +3526,7 @@ bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, N // Subtyping under DCR is not implemented using unification! if (FFlag::LuauSolverV2) { - Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}}; + Subtyping subtyping{builtinTypes, NotNull{&arena}, simplifier, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}}; return subtyping.isSubtype(subPack, superPack, scope).isSubtype; } diff --git a/Analysis/src/OverloadResolution.cpp b/Analysis/src/OverloadResolution.cpp index f5557f2d..32858cd1 100644 --- a/Analysis/src/OverloadResolution.cpp +++ b/Analysis/src/OverloadResolution.cpp @@ -16,6 +16,7 @@ namespace Luau OverloadResolver::OverloadResolver( NotNull builtinTypes, NotNull arena, + NotNull simplifier, NotNull normalizer, NotNull typeFunctionRuntime, NotNull scope, @@ -25,12 +26,13 @@ OverloadResolver::OverloadResolver( ) : builtinTypes(builtinTypes) , arena(arena) + , simplifier(simplifier) , normalizer(normalizer) , typeFunctionRuntime(typeFunctionRuntime) , scope(scope) , ice(reporter) , limits(limits) - , subtyping({builtinTypes, arena, normalizer, typeFunctionRuntime, ice}) + , subtyping({builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, ice}) , callLoc(callLocation) { } @@ -202,7 +204,7 @@ std::pair OverloadResolver::checkOverload_ ) { FunctionGraphReductionResult result = reduceTypeFunctions( - fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, ice, limits}, /*force=*/true + fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, simplifier, normalizer, typeFunctionRuntime, ice, limits}, /*force=*/true ); if (!result.errors.empty()) return {OverloadIsNonviable, result.errors}; @@ -404,9 +406,10 @@ void OverloadResolver::add(Analysis analysis, TypeId ty, ErrorVec&& errors) // we wrap calling the overload resolver in a separate function to reduce overall stack pressure in `solveFunctionCall`. // this limits the lifetime of `OverloadResolver`, a large type, to only as long as it is actually needed. -std::optional selectOverload( +static std::optional selectOverload( NotNull builtinTypes, NotNull arena, + NotNull simplifier, NotNull normalizer, NotNull typeFunctionRuntime, NotNull scope, @@ -417,7 +420,8 @@ std::optional selectOverload( TypePackId argsPack ) { - auto resolver = std::make_unique(builtinTypes, arena, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location); + auto resolver = + std::make_unique(builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location); auto [status, overload] = resolver->selectOverload(fn, argsPack); if (status == OverloadResolver::Analysis::Ok) @@ -432,6 +436,7 @@ std::optional selectOverload( SolveResult solveFunctionCall( NotNull arena, NotNull builtinTypes, + NotNull simplifier, NotNull normalizer, NotNull typeFunctionRuntime, NotNull iceReporter, @@ -443,7 +448,7 @@ SolveResult solveFunctionCall( ) { std::optional overloadToUse = - selectOverload(builtinTypes, arena, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location, fn, argsPack); + selectOverload(builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location, fn, argsPack); if (!overloadToUse) return {SolveResult::NoMatchingOverload}; diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 27894505..db99d827 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -211,6 +211,16 @@ void Scope::inheritRefinements(const ScopePtr& childScope) } } +bool Scope::shouldWarnGlobal(std::string name) const +{ + for (const Scope* current = this; current; current = current->parent.get()) + { + if (current->globalsToWarn.contains(name)) + return true; + } + return false; +} + bool subsumesStrict(Scope* left, Scope* right) { while (right) diff --git a/Analysis/src/Simplify.cpp b/Analysis/src/Simplify.cpp index 6cb511eb..8a0483e6 100644 --- a/Analysis/src/Simplify.cpp +++ b/Analysis/src/Simplify.cpp @@ -31,16 +31,16 @@ struct TypeSimplifier int recursionDepth = 0; - TypeId mkNegation(TypeId ty); + TypeId mkNegation(TypeId ty) const; TypeId intersectFromParts(std::set parts); - TypeId intersectUnionWithType(TypeId unionTy, TypeId right); + TypeId intersectUnionWithType(TypeId left, TypeId right); TypeId intersectUnions(TypeId left, TypeId right); - TypeId intersectNegatedUnion(TypeId unionTy, TypeId right); + TypeId intersectNegatedUnion(TypeId left, TypeId right); - TypeId intersectTypeWithNegation(TypeId a, TypeId b); - TypeId intersectNegations(TypeId a, TypeId b); + TypeId intersectTypeWithNegation(TypeId left, TypeId right); + TypeId intersectNegations(TypeId left, TypeId right); TypeId intersectIntersectionWithType(TypeId left, TypeId right); @@ -48,8 +48,8 @@ struct TypeSimplifier // unions, intersections, or negations. std::optional basicIntersect(TypeId left, TypeId right); - TypeId intersect(TypeId ty, TypeId discriminant); - TypeId union_(TypeId ty, TypeId discriminant); + TypeId intersect(TypeId left, TypeId right); + TypeId union_(TypeId left, TypeId right); TypeId simplify(TypeId ty); TypeId simplify(TypeId ty, DenseHashSet& seen); @@ -573,7 +573,7 @@ Relation relate(TypeId left, TypeId right) return relate(left, right, seen); } -TypeId TypeSimplifier::mkNegation(TypeId ty) +TypeId TypeSimplifier::mkNegation(TypeId ty) const { TypeId result = nullptr; diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index e8357f48..e00f0d3d 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -98,9 +98,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a FunctionType clone = FunctionType{a.level, a.scope, a.argTypes, a.retTypes, a.definition, a.hasSelf}; clone.generics = a.generics; clone.genericPacks = a.genericPacks; - clone.magicFunction = a.magicFunction; - clone.dcrMagicFunction = a.dcrMagicFunction; - clone.dcrMagicRefinement = a.dcrMagicRefinement; + clone.magic = a.magic; clone.tags = a.tags; clone.argNames = a.argNames; clone.isCheckedFunction = a.isCheckedFunction; diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index a93d910b..e4985a02 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -22,7 +22,6 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity) -LUAU_FASTFLAGVARIABLE(LuauRetrySubtypingWithoutHiddenPack) namespace Luau { @@ -396,12 +395,14 @@ TypePackId* SubtypingEnvironment::getMappedPackBounds(TypePackId tp) Subtyping::Subtyping( NotNull builtinTypes, NotNull typeArena, + NotNull simplifier, NotNull normalizer, NotNull typeFunctionRuntime, NotNull iceReporter ) : builtinTypes(builtinTypes) , arena(typeArena) + , simplifier(simplifier) , normalizer(normalizer) , typeFunctionRuntime(typeFunctionRuntime) , iceReporter(iceReporter) @@ -1472,15 +1473,14 @@ SubtypingResult Subtyping::isCovariantWith( // If subtyping failed in the argument packs, we should check if there's a hidden variadic tail and try ignoring it. // This might cause subtyping correctly because the sub type here may not have a hidden variadic tail or equivalent. - if (FFlag::LuauRetrySubtypingWithoutHiddenPack && !result.isSubtype) + if (!result.isSubtype) { auto [arguments, tail] = flatten(superFunction->argTypes); if (auto variadic = get(tail); variadic && variadic->hidden) { - result.orElse( - isContravariantWith(env, subFunction->argTypes, arena->addTypePack(TypePack{arguments}), scope).withBothComponent(TypePath::PackField::Arguments) - ); + result.orElse(isContravariantWith(env, subFunction->argTypes, arena->addTypePack(TypePack{arguments}), scope) + .withBothComponent(TypePath::PackField::Arguments)); } } } @@ -1861,7 +1861,7 @@ TypeId Subtyping::makeAggregateType(const Container& container, TypeId orElse) std::pair Subtyping::handleTypeFunctionReductionResult(const TypeFunctionInstanceType* functionInstance, NotNull scope) { - TypeFunctionContext context{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, iceReporter, NotNull{&limits}}; + TypeFunctionContext context{arena, builtinTypes, scope, simplifier, normalizer, typeFunctionRuntime, iceReporter, NotNull{&limits}}; TypeId function = arena->addType(*functionInstance); FunctionGraphReductionResult result = reduceTypeFunctions(function, {}, context, true); ErrorVec errors; diff --git a/Analysis/src/TableLiteralInference.cpp b/Analysis/src/TableLiteralInference.cpp index c50cd16e..5c7ea4d9 100644 --- a/Analysis/src/TableLiteralInference.cpp +++ b/Analysis/src/TableLiteralInference.cpp @@ -9,6 +9,8 @@ #include "Luau/TypeUtils.h" #include "Luau/Unifier2.h" +LUAU_FASTFLAGVARIABLE(LuauDontInPlaceMutateTableType) + namespace Luau { @@ -236,6 +238,8 @@ TypeId matchLiteralType( return exprType; } + DenseHashSet keysToDelete{nullptr}; + for (const AstExprTable::Item& item : exprTable->items) { if (isRecord(item)) @@ -280,7 +284,10 @@ TypeId matchLiteralType( else tableTy->indexer = TableIndexer{expectedTableTy->indexer->indexType, matchedType}; - tableTy->props.erase(keyStr); + if (FFlag::LuauDontInPlaceMutateTableType) + keysToDelete.insert(item.key->as()); + else + tableTy->props.erase(keyStr); } // If it's just an extra property and the expected type @@ -387,6 +394,16 @@ TypeId matchLiteralType( LUAU_ASSERT(!"Unexpected"); } + if (FFlag::LuauDontInPlaceMutateTableType) + { + for (const auto& key: keysToDelete) + { + const AstArray& s = key->value; + std::string keyStr{s.data, s.data + s.size}; + tableTy->props.erase(keyStr); + } + } + // Keys that the expectedType says we should have, but that aren't // specified by the AST fragment. // diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index a42882ed..eee41f24 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -10,6 +10,9 @@ #include #include +LUAU_FASTFLAG(LuauStoreCSTData) +LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon) +LUAU_FASTFLAG(LuauAstTypeGroup) namespace { @@ -45,11 +48,13 @@ struct Writer virtual void space() = 0; virtual void maybeSpace(const Position& newPos, int reserve) = 0; virtual void write(std::string_view) = 0; + virtual void writeMultiline(std::string_view) = 0; virtual void identifier(std::string_view name) = 0; virtual void keyword(std::string_view) = 0; virtual void symbol(std::string_view) = 0; virtual void literal(std::string_view) = 0; virtual void string(std::string_view) = 0; + virtual void sourceString(std::string_view, CstExprConstantString::QuoteStyle quoteStyle, unsigned int blockDepth) = 0; }; struct StringWriter : Writer @@ -93,6 +98,32 @@ struct StringWriter : Writer lastChar = ' '; } + void writeMultiline(std::string_view s) override + { + if (s.empty()) + return; + + ss.append(s.data(), s.size()); + lastChar = s[s.size() - 1]; + + size_t index = 0; + size_t numLines = 0; + while (true) + { + auto newlinePos = s.find('\n', index); + if (newlinePos == std::string::npos) + break; + numLines++; + index = newlinePos + 1; + } + + pos.line += unsigned(numLines); + if (numLines > 0) + pos.column = unsigned(s.size()) - unsigned(index); + else + pos.column += unsigned(s.size()); + } + void write(std::string_view s) override { if (s.empty()) @@ -134,10 +165,17 @@ struct StringWriter : Writer void symbol(std::string_view s) override { - if (isDigit(lastChar) && s[0] == '.') - space(); + if (FFlag::LuauStoreCSTData) + { + write(s); + } + else + { + if (isDigit(lastChar) && s[0] == '.') + space(); - write(s); + write(s); + } } void literal(std::string_view s) override @@ -161,14 +199,54 @@ struct StringWriter : Writer write(escape(s)); write(quote); } + + void sourceString(std::string_view s, CstExprConstantString::QuoteStyle quoteStyle, unsigned int blockDepth) override + { + if (quoteStyle == CstExprConstantString::QuotedRaw) + { + auto blocks = std::string(blockDepth, '='); + write('['); + write(blocks); + write('['); + writeMultiline(s); + write(']'); + write(blocks); + write(']'); + } + else + { + LUAU_ASSERT(blockDepth == 0); + + char quote = '"'; + switch (quoteStyle) + { + case CstExprConstantString::QuotedDouble: + quote = '"'; + break; + case CstExprConstantString::QuotedSingle: + quote = '\''; + break; + case CstExprConstantString::QuotedInterp: + quote = '`'; + break; + default: + LUAU_ASSERT(!"Unhandled quote type"); + } + + write(quote); + writeMultiline(s); + write(quote); + } + } }; class CommaSeparatorInserter { public: - CommaSeparatorInserter(Writer& w) + explicit CommaSeparatorInserter(Writer& w, const Position* commaPosition = nullptr) : first(true) , writer(w) + , commaPosition(commaPosition) { } void operator()() @@ -176,17 +254,25 @@ public: if (first) first = !first; else + { + if (FFlag::LuauStoreCSTData && commaPosition) + { + writer.advance(*commaPosition); + commaPosition++; + } writer.symbol(","); + } } private: bool first; Writer& writer; + const Position* commaPosition; }; -struct Printer +struct Printer_DEPRECATED { - explicit Printer(Writer& writer) + explicit Printer_DEPRECATED(Writer& writer) : writer(writer) { } @@ -242,7 +328,8 @@ struct Printer } else if (typeCount == 1) { - if (unconditionallyParenthesize) + bool shouldParenthesize = unconditionallyParenthesize && (list.types.size == 0 || !list.types.data[0]->is()); + if (FFlag::LuauAstTypeGroup ? shouldParenthesize : unconditionallyParenthesize) writer.symbol("("); // Only variadic tail @@ -255,7 +342,7 @@ struct Printer visualizeTypeAnnotation(*list.types.data[0]); } - if (unconditionallyParenthesize) + if (FFlag::LuauAstTypeGroup ? shouldParenthesize : unconditionallyParenthesize) writer.symbol(")"); } else @@ -433,6 +520,7 @@ struct Printer visualize(*item.value); } + // Decrement endPos column so that we advance to before the closing `}` brace before writing, rather than after it Position endPos = expr.location.end; if (endPos.column > 0) --endPos.column; @@ -1164,6 +1252,12 @@ struct Printer writer.symbol(")"); } } + else if (const auto& a = typeAnnotation.as()) + { + writer.symbol("("); + visualizeTypeAnnotation(*a->type); + writer.symbol(")"); + } else if (const auto& a = typeAnnotation.as()) { writer.keyword(a->value ? "true" : "false"); @@ -1183,20 +1277,1349 @@ struct Printer } }; +struct Printer +{ + explicit Printer(Writer& writer, CstNodeMap cstNodeMap) + : writer(writer) + , cstNodeMap(std::move(cstNodeMap)) + { + } + + bool writeTypes = false; + Writer& writer; + CstNodeMap cstNodeMap; + + template + T* lookupCstNode(AstNode* astNode) + { + if (const auto cstNode = cstNodeMap[astNode]) + return cstNode->as(); + return nullptr; + } + + void visualize(const AstLocal& local) + { + advance(local.location.begin); + + writer.identifier(local.name.value); + if (writeTypes && local.annotation) + { + // TODO: handle spacing for type annotation + writer.symbol(":"); + visualizeTypeAnnotation(*local.annotation); + } + } + + void visualizeTypePackAnnotation(const AstTypePack& annotation, bool forVarArg) + { + advance(annotation.location.begin); + if (const AstTypePackVariadic* variadicTp = annotation.as()) + { + if (!forVarArg) + writer.symbol("..."); + + visualizeTypeAnnotation(*variadicTp->variadicType); + } + else if (const AstTypePackGeneric* genericTp = annotation.as()) + { + writer.symbol(genericTp->genericName.value); + writer.symbol("..."); + } + else if (const AstTypePackExplicit* explicitTp = annotation.as()) + { + LUAU_ASSERT(!forVarArg); + visualizeTypeList(explicitTp->typeList, true); + } + else + { + LUAU_ASSERT(!"Unknown TypePackAnnotation kind"); + } + } + + void visualizeTypeList(const AstTypeList& list, bool unconditionallyParenthesize) + { + size_t typeCount = list.types.size + (list.tailType != nullptr ? 1 : 0); + if (typeCount == 0) + { + writer.symbol("("); + writer.symbol(")"); + } + else if (typeCount == 1) + { + bool shouldParenthesize = unconditionallyParenthesize && (list.types.size == 0 || !list.types.data[0]->is()); + if (FFlag::LuauAstTypeGroup ? shouldParenthesize : unconditionallyParenthesize) + writer.symbol("("); + + // Only variadic tail + if (list.types.size == 0) + { + visualizeTypePackAnnotation(*list.tailType, false); + } + else + { + visualizeTypeAnnotation(*list.types.data[0]); + } + + if (FFlag::LuauAstTypeGroup ? shouldParenthesize : unconditionallyParenthesize) + writer.symbol(")"); + } + else + { + writer.symbol("("); + + bool first = true; + for (const auto& el : list.types) + { + if (first) + first = false; + else + writer.symbol(","); + + visualizeTypeAnnotation(*el); + } + + if (list.tailType) + { + writer.symbol(","); + visualizeTypePackAnnotation(*list.tailType, false); + } + + writer.symbol(")"); + } + } + + bool isIntegerish(double d) + { + if (d <= std::numeric_limits::max() && d >= std::numeric_limits::min()) + return double(int(d)) == d && !(d == 0.0 && signbit(d)); + else + return false; + } + + void visualize(AstExpr& expr) + { + advance(expr.location.begin); + + if (const auto& a = expr.as()) + { + writer.symbol("("); + visualize(*a->expr); + advance(Position{a->location.end.line, a->location.end.column - 1}); + writer.symbol(")"); + } + else if (expr.is()) + { + writer.keyword("nil"); + } + else if (const auto& a = expr.as()) + { + if (a->value) + writer.keyword("true"); + else + writer.keyword("false"); + } + else if (const auto& a = expr.as()) + { + if (const auto cstNode = lookupCstNode(a)) + { + writer.literal(std::string_view(cstNode->value.data, cstNode->value.size)); + } + else + { + if (isinf(a->value)) + { + if (a->value > 0) + writer.literal("1e500"); + else + writer.literal("-1e500"); + } + else if (isnan(a->value)) + writer.literal("0/0"); + else + { + if (isIntegerish(a->value)) + writer.literal(std::to_string(int(a->value))); + else + { + char buffer[100]; + size_t len = snprintf(buffer, sizeof(buffer), "%.17g", a->value); + writer.literal(std::string_view{buffer, len}); + } + } + } + } + else if (const auto& a = expr.as()) + { + if (const auto cstNode = lookupCstNode(a)) + { + writer.sourceString( + std::string_view(cstNode->sourceString.data, cstNode->sourceString.size), cstNode->quoteStyle, cstNode->blockDepth + ); + } + else + writer.string(std::string_view(a->value.data, a->value.size)); + } + else if (const auto& a = expr.as()) + { + writer.identifier(a->local->name.value); + } + else if (const auto& a = expr.as()) + { + writer.identifier(a->name.value); + } + else if (expr.is()) + { + writer.symbol("..."); + } + else if (const auto& a = expr.as()) + { + visualize(*a->func); + + const auto cstNode = lookupCstNode(a); + + if (cstNode) + { + if (cstNode->openParens) + { + advance(*cstNode->openParens); + writer.symbol("("); + } + } + else + { + writer.symbol("("); + } + + CommaSeparatorInserter comma(writer, cstNode ? cstNode->commaPositions.begin() : nullptr); + for (const auto& arg : a->args) + { + comma(); + visualize(*arg); + } + + if (cstNode) + { + if (cstNode->closeParens) + { + advance(*cstNode->closeParens); + writer.symbol(")"); + } + } + else + { + writer.symbol(")"); + } + } + else if (const auto& a = expr.as()) + { + visualize(*a->expr); + advance(a->opPosition); + writer.symbol(std::string(1, a->op)); + advance(a->indexLocation.begin); + writer.write(a->index.value); + } + else if (const auto& a = expr.as()) + { + const auto cstNode = lookupCstNode(a); + visualize(*a->expr); + if (cstNode) + advance(cstNode->openBracketPosition); + writer.symbol("["); + visualize(*a->index); + if (cstNode) + advance(cstNode->closeBracketPosition); + writer.symbol("]"); + } + else if (const auto& a = expr.as()) + { + writer.keyword("function"); + visualizeFunctionBody(*a); + } + else if (const auto& a = expr.as()) + { + writer.symbol("{"); + + const CstExprTable::Item* cstItem = nullptr; + if (const auto cstNode = lookupCstNode(a)) + { + LUAU_ASSERT(cstNode->items.size == a->items.size); + cstItem = cstNode->items.begin(); + } + + bool first = true; + + for (const auto& item : a->items) + { + if (!cstItem) + { + if (first) + first = false; + else + writer.symbol(","); + } + + switch (item.kind) + { + case AstExprTable::Item::List: + break; + + case AstExprTable::Item::Record: + { + const auto& value = item.key->as()->value; + advance(item.key->location.begin); + writer.identifier(std::string_view(value.data, value.size)); + if (cstItem) + advance(*cstItem->equalsPosition); + else + writer.maybeSpace(item.value->location.begin, 1); + writer.symbol("="); + } + break; + + case AstExprTable::Item::General: + { + if (cstItem) + advance(*cstItem->indexerOpenPosition); + writer.symbol("["); + visualize(*item.key); + if (cstItem) + advance(*cstItem->indexerClosePosition); + writer.symbol("]"); + if (cstItem) + advance(*cstItem->equalsPosition); + else + writer.maybeSpace(item.value->location.begin, 1); + writer.symbol("="); + } + break; + + default: + LUAU_ASSERT(!"Unknown table item kind"); + } + + advance(item.value->location.begin); + visualize(*item.value); + + if (cstItem) + { + if (cstItem->separator) + { + LUAU_ASSERT(cstItem->separatorPosition); + advance(*cstItem->separatorPosition); + if (cstItem->separator == CstExprTable::Comma) + writer.symbol(","); + else if (cstItem->separator == CstExprTable::Semicolon) + writer.symbol(";"); + } + cstItem++; + } + } + + Position endPos = expr.location.end; + if (endPos.column > 0) + --endPos.column; + + advance(endPos); + + writer.symbol("}"); + advance(expr.location.end); + } + else if (const auto& a = expr.as()) + { + if (const auto cstNode = lookupCstNode(a)) + advance(cstNode->opPosition); + + switch (a->op) + { + case AstExprUnary::Not: + writer.keyword("not"); + break; + case AstExprUnary::Minus: + writer.symbol("-"); + break; + case AstExprUnary::Len: + writer.symbol("#"); + break; + } + visualize(*a->expr); + } + else if (const auto& a = expr.as()) + { + visualize(*a->left); + + if (const auto cstNode = lookupCstNode(a)) + advance(cstNode->opPosition); + else + { + switch (a->op) + { + case AstExprBinary::Add: + case AstExprBinary::Sub: + case AstExprBinary::Mul: + case AstExprBinary::Div: + case AstExprBinary::FloorDiv: + case AstExprBinary::Mod: + case AstExprBinary::Pow: + case AstExprBinary::CompareLt: + case AstExprBinary::CompareGt: + writer.maybeSpace(a->right->location.begin, 2); + break; + case AstExprBinary::Concat: + case AstExprBinary::CompareNe: + case AstExprBinary::CompareEq: + case AstExprBinary::CompareLe: + case AstExprBinary::CompareGe: + case AstExprBinary::Or: + writer.maybeSpace(a->right->location.begin, 3); + break; + case AstExprBinary::And: + writer.maybeSpace(a->right->location.begin, 4); + break; + default: + LUAU_ASSERT(!"Unknown Op"); + } + } + + writer.symbol(toString(a->op)); + + visualize(*a->right); + } + else if (const auto& a = expr.as()) + { + visualize(*a->expr); + + if (writeTypes) + { + writer.maybeSpace(a->annotation->location.begin, 2); + writer.symbol("::"); + visualizeTypeAnnotation(*a->annotation); + } + } + else if (const auto& a = expr.as()) + { + writer.keyword("if"); + visualizeElseIfExpr(*a); + } + else if (const auto& a = expr.as()) + { + const auto* cstNode = lookupCstNode(a); + + writer.symbol("`"); + + size_t index = 0; + + for (const auto& string : a->strings) + { + if (cstNode) + { + if (index > 0) + { + advance(cstNode->stringPositions.data[index]); + writer.symbol("}"); + } + const AstArray sourceString = cstNode->sourceStrings.data[index]; + writer.writeMultiline(std::string_view(sourceString.data, sourceString.size)); + } + else + { + writer.write(escape(std::string_view(string.data, string.size), /* escapeForInterpString = */ true)); + } + + if (index < a->expressions.size) + { + writer.symbol("{"); + visualize(*a->expressions.data[index]); + if (!cstNode) + writer.symbol("}"); + } + + index++; + } + + writer.symbol("`"); + } + else if (const auto& a = expr.as()) + { + writer.symbol("(error-expr"); + + for (size_t i = 0; i < a->expressions.size; i++) + { + writer.symbol(i == 0 ? ": " : ", "); + visualize(*a->expressions.data[i]); + } + + writer.symbol(")"); + } + else + { + LUAU_ASSERT(!"Unknown AstExpr"); + } + } + + void writeEnd(const Location& loc) + { + Position endPos = loc.end; + if (endPos.column >= 3) + endPos.column -= 3; + advance(endPos); + writer.keyword("end"); + } + + void advance(const Position& newPos) + { + writer.advance(newPos); + } + + void visualize(AstStat& program) + { + advance(program.location.begin); + + if (const auto& block = program.as()) + { + writer.keyword("do"); + for (const auto& s : block->body) + visualize(*s); + if (const auto cstNode = lookupCstNode(block)) + { + advance(cstNode->endPosition); + writer.keyword("end"); + } + else + { + writer.advance(block->location.end); + writeEnd(program.location); + } + } + else if (const auto& a = program.as()) + { + writer.keyword("if"); + visualizeElseIf(*a); + } + else if (const auto& a = program.as()) + { + writer.keyword("while"); + visualize(*a->condition); + // TODO: what if 'hasDo = false'? + advance(a->doLocation.begin); + writer.keyword("do"); + visualizeBlock(*a->body); + advance(a->body->location.end); + writer.keyword("end"); + } + else if (const auto& a = program.as()) + { + writer.keyword("repeat"); + visualizeBlock(*a->body); + if (const auto cstNode = lookupCstNode(a)) + writer.advance(cstNode->untilPosition); + else if (a->condition->location.begin.column > 5) + writer.advance(Position{a->condition->location.begin.line, a->condition->location.begin.column - 6}); + writer.keyword("until"); + visualize(*a->condition); + } + else if (program.is()) + writer.keyword("break"); + else if (program.is()) + writer.keyword("continue"); + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + writer.keyword("return"); + + CommaSeparatorInserter comma(writer, cstNode ? cstNode->commaPositions.begin() : nullptr); + for (const auto& expr : a->list) + { + comma(); + visualize(*expr); + } + } + else if (const auto& a = program.as()) + { + visualize(*a->expr); + } + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + writer.keyword("local"); + + CommaSeparatorInserter varComma(writer, cstNode ? cstNode->varsCommaPositions.begin() : nullptr); + for (const auto& local : a->vars) + { + varComma(); + visualize(*local); + } + + if (a->equalsSignLocation) + { + advance(a->equalsSignLocation->begin); + writer.symbol("="); + } + + + CommaSeparatorInserter valueComma(writer, cstNode ? cstNode->valuesCommaPositions.begin() : nullptr); + for (const auto& value : a->values) + { + valueComma(); + visualize(*value); + } + } + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + writer.keyword("for"); + + visualize(*a->var); + if (cstNode) + advance(cstNode->equalsPosition); + writer.symbol("="); + visualize(*a->from); + if (cstNode) + advance(cstNode->endCommaPosition); + writer.symbol(","); + visualize(*a->to); + if (a->step) + { + if (cstNode && cstNode->stepCommaPosition) + advance(*cstNode->stepCommaPosition); + writer.symbol(","); + visualize(*a->step); + } + advance(a->doLocation.begin); + writer.keyword("do"); + visualizeBlock(*a->body); + + advance(a->body->location.end); + writer.keyword("end"); + } + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + writer.keyword("for"); + + CommaSeparatorInserter varComma(writer, cstNode ? cstNode->varsCommaPositions.begin() : nullptr); + for (const auto& var : a->vars) + { + varComma(); + visualize(*var); + } + + advance(a->inLocation.begin); + writer.keyword("in"); + + CommaSeparatorInserter valComma(writer, cstNode ? cstNode->valuesCommaPositions.begin() : nullptr); + + for (const auto& val : a->values) + { + valComma(); + visualize(*val); + } + + advance(a->doLocation.begin); + writer.keyword("do"); + + visualizeBlock(*a->body); + + advance(a->body->location.end); + writer.keyword("end"); + } + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + CommaSeparatorInserter varComma(writer, cstNode ? cstNode->varsCommaPositions.begin() : nullptr); + for (const auto& var : a->vars) + { + varComma(); + visualize(*var); + } + + if (cstNode) + advance(cstNode->equalsPosition); + else + writer.space(); + writer.symbol("="); + + CommaSeparatorInserter valueComma(writer, cstNode ? cstNode->valuesCommaPositions.begin() : nullptr); + for (const auto& value : a->values) + { + valueComma(); + visualize(*value); + } + } + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + visualize(*a->var); + + if (cstNode) + advance(cstNode->opPosition); + + switch (a->op) + { + case AstExprBinary::Add: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("+="); + break; + case AstExprBinary::Sub: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("-="); + break; + case AstExprBinary::Mul: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("*="); + break; + case AstExprBinary::Div: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("/="); + break; + case AstExprBinary::FloorDiv: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 3); + writer.symbol("//="); + break; + case AstExprBinary::Mod: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("%="); + break; + case AstExprBinary::Pow: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("^="); + break; + case AstExprBinary::Concat: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 3); + writer.symbol("..="); + break; + default: + LUAU_ASSERT(!"Unexpected compound assignment op"); + } + + visualize(*a->value); + } + else if (const auto& a = program.as()) + { + writer.keyword("function"); + visualize(*a->name); + visualizeFunctionBody(*a->func); + } + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + writer.keyword("local"); + + if (cstNode) + advance(cstNode->functionKeywordPosition); + else + writer.space(); + + writer.keyword("function"); + advance(a->name->location.begin); + writer.identifier(a->name->name.value); + visualizeFunctionBody(*a->func); + } + else if (const auto& a = program.as()) + { + if (writeTypes) + { + if (a->exported) + writer.keyword("export"); + + writer.keyword("type"); + writer.identifier(a->name.value); + if (a->generics.size > 0 || a->genericPacks.size > 0) + { + writer.symbol("<"); + CommaSeparatorInserter comma(writer); + + for (auto o : a->generics) + { + comma(); + + writer.advance(o.location.begin); + writer.identifier(o.name.value); + + if (o.defaultValue) + { + writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypeAnnotation(*o.defaultValue); + } + } + + for (auto o : a->genericPacks) + { + comma(); + + writer.advance(o.location.begin); + writer.identifier(o.name.value); + writer.symbol("..."); + + if (o.defaultValue) + { + writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypePackAnnotation(*o.defaultValue, false); + } + } + + writer.symbol(">"); + } + writer.maybeSpace(a->type->location.begin, 2); + writer.symbol("="); + visualizeTypeAnnotation(*a->type); + } + } + else if (const auto& t = program.as()) + { + if (writeTypes) + { + writer.keyword("type function"); + writer.identifier(t->name.value); + visualizeFunctionBody(*t->body); + } + } + else if (const auto& a = program.as()) + { + writer.symbol("(error-stat"); + + for (size_t i = 0; i < a->expressions.size; i++) + { + writer.symbol(i == 0 ? ": " : ", "); + visualize(*a->expressions.data[i]); + } + + for (size_t i = 0; i < a->statements.size; i++) + { + writer.symbol(i == 0 && a->expressions.size == 0 ? ": " : ", "); + visualize(*a->statements.data[i]); + } + + writer.symbol(")"); + } + else + { + LUAU_ASSERT(!"Unknown AstStat"); + } + + if (program.hasSemicolon) + { + if (FFlag::LuauStoreCSTData) + advance(Position{program.location.end.line, program.location.end.column - 1}); + writer.symbol(";"); + } + } + + void visualizeFunctionBody(AstExprFunction& func) + { + if (func.generics.size > 0 || func.genericPacks.size > 0) + { + CommaSeparatorInserter comma(writer); + writer.symbol("<"); + for (const auto& o : func.generics) + { + comma(); + + writer.advance(o.location.begin); + writer.identifier(o.name.value); + } + for (const auto& o : func.genericPacks) + { + comma(); + + writer.advance(o.location.begin); + writer.identifier(o.name.value); + writer.symbol("..."); + } + writer.symbol(">"); + } + + writer.symbol("("); + CommaSeparatorInserter comma(writer); + + for (size_t i = 0; i < func.args.size; ++i) + { + AstLocal* local = func.args.data[i]; + + comma(); + + advance(local->location.begin); + writer.identifier(local->name.value); + if (writeTypes && local->annotation) + { + writer.symbol(":"); + visualizeTypeAnnotation(*local->annotation); + } + } + + if (func.vararg) + { + comma(); + advance(func.varargLocation.begin); + writer.symbol("..."); + + if (func.varargAnnotation) + { + writer.symbol(":"); + visualizeTypePackAnnotation(*func.varargAnnotation, true); + } + } + + writer.symbol(")"); + + if (writeTypes && func.returnAnnotation) + { + writer.symbol(":"); + writer.space(); + + visualizeTypeList(*func.returnAnnotation, false); + } + + visualizeBlock(*func.body); + advance(func.body->location.end); + writer.keyword("end"); + } + + void visualizeBlock(AstStatBlock& block) + { + for (const auto& s : block.body) + visualize(*s); + writer.advance(block.location.end); + } + + void visualizeBlock(AstStat& stat) + { + if (AstStatBlock* block = stat.as()) + visualizeBlock(*block); + else + LUAU_ASSERT(!"visualizeBlock was expecting an AstStatBlock"); + } + + void visualizeElseIf(AstStatIf& elseif) + { + visualize(*elseif.condition); + if (elseif.thenLocation) + advance(elseif.thenLocation->begin); + writer.keyword("then"); + visualizeBlock(*elseif.thenbody); + + if (elseif.elsebody == nullptr) + { + advance(elseif.thenbody->location.end); + writer.keyword("end"); + } + else if (auto elseifelseif = elseif.elsebody->as()) + { + if (elseif.elseLocation) + advance(elseif.elseLocation->begin); + writer.keyword("elseif"); + visualizeElseIf(*elseifelseif); + } + else + { + if (elseif.elseLocation) + advance(elseif.elseLocation->begin); + writer.keyword("else"); + + visualizeBlock(*elseif.elsebody); + advance(elseif.elsebody->location.end); + writer.keyword("end"); + } + } + + void visualizeElseIfExpr(AstExprIfElse& elseif) + { + const auto cstNode = lookupCstNode(&elseif); + + visualize(*elseif.condition); + if (cstNode) + advance(cstNode->thenPosition); + writer.keyword("then"); + visualize(*elseif.trueExpr); + + if (elseif.falseExpr) + { + if (cstNode) + advance(cstNode->elsePosition); + if (auto elseifelseif = elseif.falseExpr->as(); elseifelseif && (!cstNode || cstNode->isElseIf)) + { + writer.keyword("elseif"); + visualizeElseIfExpr(*elseifelseif); + } + else + { + writer.keyword("else"); + visualize(*elseif.falseExpr); + } + } + } + + void visualizeTypeAnnotation(AstType& typeAnnotation) + { + advance(typeAnnotation.location.begin); + if (const auto& a = typeAnnotation.as()) + { + const auto cstNode = lookupCstNode(a); + + if (a->prefix) + { + writer.write(a->prefix->value); + if (cstNode) + advance(*cstNode->prefixPointPosition); + writer.symbol("."); + } + + advance(a->nameLocation.begin); + writer.write(a->name.value); + if (a->parameters.size > 0 || a->hasParameterList) + { + CommaSeparatorInserter comma(writer, cstNode ? cstNode->parametersCommaPositions.begin() : nullptr); + if (cstNode) + advance(cstNode->openParametersPosition); + writer.symbol("<"); + for (auto o : a->parameters) + { + comma(); + + if (o.type) + visualizeTypeAnnotation(*o.type); + else + visualizeTypePackAnnotation(*o.typePack, false); + } + if (cstNode) + advance(cstNode->closeParametersPosition); + writer.symbol(">"); + } + } + else if (const auto& a = typeAnnotation.as()) + { + if (a->generics.size > 0 || a->genericPacks.size > 0) + { + CommaSeparatorInserter comma(writer); + writer.symbol("<"); + for (const auto& o : a->generics) + { + comma(); + + writer.advance(o.location.begin); + writer.identifier(o.name.value); + } + for (const auto& o : a->genericPacks) + { + comma(); + + writer.advance(o.location.begin); + writer.identifier(o.name.value); + writer.symbol("..."); + } + writer.symbol(">"); + } + + { + visualizeTypeList(a->argTypes, true); + } + + writer.symbol("->"); + visualizeTypeList(a->returnTypes, true); + } + else if (const auto& a = typeAnnotation.as()) + { + AstTypeReference* indexType = a->indexer ? a->indexer->indexType->as() : nullptr; + + writer.symbol("{"); + + const auto cstNode = lookupCstNode(a); + if (cstNode) + { + if (cstNode->isArray) + { + LUAU_ASSERT(a->props.size == 0 && indexType && indexType->name == "number"); + if (a->indexer->accessLocation) + { + LUAU_ASSERT(a->indexer->access != AstTableAccess::ReadWrite); + advance(a->indexer->accessLocation->begin); + writer.keyword(a->indexer->access == AstTableAccess::Read ? "read" : "write"); + } + visualizeTypeAnnotation(*a->indexer->resultType); + } + else + { + const AstTableProp* prop = a->props.begin(); + + for (size_t i = 0; i < cstNode->items.size; ++i) + { + CstTypeTable::Item item = cstNode->items.data[i]; + // we store indexer as part of items to preserve property ordering + if (item.kind == CstTypeTable::Item::Kind::Indexer) + { + LUAU_ASSERT(a->indexer); + + if (a->indexer->accessLocation) + { + LUAU_ASSERT(a->indexer->access != AstTableAccess::ReadWrite); + advance(a->indexer->accessLocation->begin); + writer.keyword(a->indexer->access == AstTableAccess::Read ? "read" : "write"); + } + + advance(item.indexerOpenPosition); + writer.symbol("["); + visualizeTypeAnnotation(*a->indexer->indexType); + advance(item.indexerClosePosition); + writer.symbol("]"); + advance(item.colonPosition); + writer.symbol(":"); + visualizeTypeAnnotation(*a->indexer->resultType); + + if (item.separator) + { + LUAU_ASSERT(item.separatorPosition); + advance(*item.separatorPosition); + if (item.separator == CstExprTable::Comma) + writer.symbol(","); + else if (item.separator == CstExprTable::Semicolon) + writer.symbol(";"); + } + } + else + { + if (prop->accessLocation) + { + LUAU_ASSERT(prop->access != AstTableAccess::ReadWrite); + advance(prop->accessLocation->begin); + writer.keyword(prop->access == AstTableAccess::Read ? "read" : "write"); + } + + if (item.kind == CstTypeTable::Item::Kind::StringProperty) + { + advance(item.indexerOpenPosition); + writer.symbol("["); + writer.sourceString( + std::string_view(item.stringInfo->sourceString.data, item.stringInfo->sourceString.size), + item.stringInfo->quoteStyle, + item.stringInfo->blockDepth + ); + advance(item.indexerClosePosition); + writer.symbol("]"); + } + else + { + advance(prop->location.begin); + writer.identifier(prop->name.value); + } + + advance(item.colonPosition); + writer.symbol(":"); + visualizeTypeAnnotation(*prop->type); + + if (item.separator) + { + LUAU_ASSERT(item.separatorPosition); + advance(*item.separatorPosition); + if (item.separator == CstExprTable::Comma) + writer.symbol(","); + else if (item.separator == CstExprTable::Semicolon) + writer.symbol(";"); + } + + ++prop; + } + } + } + } + else + { + if (a->props.size == 0 && indexType && indexType->name == "number") + { + visualizeTypeAnnotation(*a->indexer->resultType); + } + else + { + CommaSeparatorInserter comma(writer); + + for (size_t i = 0; i < a->props.size; ++i) + { + comma(); + advance(a->props.data[i].location.begin); + writer.identifier(a->props.data[i].name.value); + if (a->props.data[i].type) + { + writer.symbol(":"); + visualizeTypeAnnotation(*a->props.data[i].type); + } + } + if (a->indexer) + { + comma(); + writer.symbol("["); + visualizeTypeAnnotation(*a->indexer->indexType); + writer.symbol("]"); + writer.symbol(":"); + visualizeTypeAnnotation(*a->indexer->resultType); + } + } + } + + Position endPos = a->location.end; + if (endPos.column > 0) + --endPos.column; + advance(endPos); + + writer.symbol("}"); + } + else if (auto a = typeAnnotation.as()) + { + const auto cstNode = lookupCstNode(a); + writer.keyword("typeof"); + if (cstNode) + advance(cstNode->openPosition); + writer.symbol("("); + visualize(*a->expr); + if (cstNode) + advance(cstNode->closePosition); + writer.symbol(")"); + } + else if (const auto& a = typeAnnotation.as()) + { + if (a->types.size == 2) + { + AstType* l = a->types.data[0]; + AstType* r = a->types.data[1]; + + auto lta = l->as(); + if (lta && lta->name == "nil") + std::swap(l, r); + + // it's still possible that we had a (T | U) or (T | nil) and not (nil | T) + auto rta = r->as(); + if (rta && rta->name == "nil") + { + bool wrap = l->as() || l->as(); + + if (wrap) + writer.symbol("("); + + visualizeTypeAnnotation(*l); + + if (wrap) + writer.symbol(")"); + + writer.symbol("?"); + return; + } + } + + for (size_t i = 0; i < a->types.size; ++i) + { + if (i > 0) + { + writer.maybeSpace(a->types.data[i]->location.begin, 2); + writer.symbol("|"); + } + + bool wrap = a->types.data[i]->as() || a->types.data[i]->as(); + + if (wrap) + writer.symbol("("); + + visualizeTypeAnnotation(*a->types.data[i]); + + if (wrap) + writer.symbol(")"); + } + } + else if (const auto& a = typeAnnotation.as()) + { + for (size_t i = 0; i < a->types.size; ++i) + { + if (i > 0) + { + writer.maybeSpace(a->types.data[i]->location.begin, 2); + writer.symbol("&"); + } + + bool wrap = a->types.data[i]->as() || a->types.data[i]->as(); + + if (wrap) + writer.symbol("("); + + visualizeTypeAnnotation(*a->types.data[i]); + + if (wrap) + writer.symbol(")"); + } + } + else if (const auto& a = typeAnnotation.as()) + { + writer.symbol("("); + visualizeTypeAnnotation(*a->type); + writer.symbol(")"); + } + else if (const auto& a = typeAnnotation.as()) + { + writer.keyword(a->value ? "true" : "false"); + } + else if (const auto& a = typeAnnotation.as()) + { + if (const auto cstNode = lookupCstNode(a)) + { + writer.sourceString( + std::string_view(cstNode->sourceString.data, cstNode->sourceString.size), cstNode->quoteStyle, cstNode->blockDepth + ); + } + else + writer.string(std::string_view(a->value.data, a->value.size)); + } + else if (typeAnnotation.is()) + { + writer.symbol("%error-type%"); + } + else + { + LUAU_ASSERT(!"Unknown AstType"); + } + } +}; + std::string toString(AstNode* node) { StringWriter writer; writer.pos = node->location.begin; - Printer printer(writer); - printer.writeTypes = true; + if (FFlag::LuauStoreCSTData) + { + Printer printer(writer, CstNodeMap{nullptr}); + printer.writeTypes = true; - if (auto statNode = node->asStat()) - printer.visualize(*statNode); - else if (auto exprNode = node->asExpr()) - printer.visualize(*exprNode); - else if (auto typeNode = node->asType()) - printer.visualizeTypeAnnotation(*typeNode); + if (auto statNode = node->asStat()) + printer.visualize(*statNode); + else if (auto exprNode = node->asExpr()) + printer.visualize(*exprNode); + else if (auto typeNode = node->asType()) + printer.visualizeTypeAnnotation(*typeNode); + } + else + { + Printer_DEPRECATED printer(writer); + printer.writeTypes = true; + + if (auto statNode = node->asStat()) + printer.visualize(*statNode); + else if (auto exprNode = node->asExpr()) + printer.visualize(*exprNode); + else if (auto typeNode = node->asType()) + printer.visualizeTypeAnnotation(*typeNode); + } return writer.str(); } @@ -1206,24 +2629,48 @@ void dump(AstNode* node) printf("%s\n", toString(node).c_str()); } -std::string transpile(AstStatBlock& block) +std::string transpile(AstStatBlock& block, const CstNodeMap& cstNodeMap) { StringWriter writer; - Printer(writer).visualizeBlock(block); + if (FFlag::LuauStoreCSTData) + { + Printer(writer, cstNodeMap).visualizeBlock(block); + } + else + { + Printer_DEPRECATED(writer).visualizeBlock(block); + } + return writer.str(); +} + +std::string transpileWithTypes(AstStatBlock& block, const CstNodeMap& cstNodeMap) +{ + StringWriter writer; + if (FFlag::LuauStoreCSTData) + { + Printer printer(writer, cstNodeMap); + printer.writeTypes = true; + printer.visualizeBlock(block); + } + else + { + Printer_DEPRECATED printer(writer); + printer.writeTypes = true; + printer.visualizeBlock(block); + } return writer.str(); } std::string transpileWithTypes(AstStatBlock& block) { - StringWriter writer; - Printer printer(writer); - printer.writeTypes = true; - printer.visualizeBlock(block); - return writer.str(); + // TODO: remove this interface? + return transpileWithTypes(block, CstNodeMap{nullptr}); } TranspileResult transpile(std::string_view source, ParseOptions options, bool withTypes) { + options.storeCstData = true; + auto allocator = Allocator{}; auto names = AstNameTable{allocator}; ParseResult parseResult = Parser::parse(source.data(), source.size(), names, allocator, options); @@ -1241,9 +2688,9 @@ TranspileResult transpile(std::string_view source, ParseOptions options, bool wi return TranspileResult{"", {}, "Internal error: Parser yielded empty parse tree"}; if (withTypes) - return TranspileResult{transpileWithTypes(*parseResult.root)}; + return TranspileResult{transpileWithTypes(*parseResult.root, parseResult.cstNodeMap)}; - return TranspileResult{transpile(*parseResult.root)}; + return TranspileResult{transpile(*parseResult.root, parseResult.cstNodeMap)}; } } // namespace Luau diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index a5298ee5..bb08856c 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -27,6 +27,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAGVARIABLE(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -478,24 +479,12 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) return false; } -FreeType::FreeType(TypeLevel level) +// New constructors +FreeType::FreeType(TypeLevel level, TypeId lowerBound, TypeId upperBound) : index(Unifiable::freshIndex()) , level(level) - , scope(nullptr) -{ -} - -FreeType::FreeType(Scope* scope) - : index(Unifiable::freshIndex()) - , level{} - , scope(scope) -{ -} - -FreeType::FreeType(Scope* scope, TypeLevel level) - : index(Unifiable::freshIndex()) - , level(level) - , scope(scope) + , lowerBound(lowerBound) + , upperBound(upperBound) { } @@ -507,6 +496,40 @@ FreeType::FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound) { } +FreeType::FreeType(Scope* scope, TypeLevel level, TypeId lowerBound, TypeId upperBound) + : index(Unifiable::freshIndex()) + , level(level) + , scope(scope) + , lowerBound(lowerBound) + , upperBound(upperBound) +{ +} + +// Old constructors +FreeType::FreeType(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(nullptr) +{ + LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds); +} + +FreeType::FreeType(Scope* scope) + : index(Unifiable::freshIndex()) + , level{} + , scope(scope) +{ + LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds); +} + +FreeType::FreeType(Scope* scope, TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(scope) +{ + LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds); +} + GenericType::GenericType() : index(Unifiable::freshIndex()) , name("g" + std::to_string(index)) @@ -554,12 +577,12 @@ BlockedType::BlockedType() { } -Constraint* BlockedType::getOwner() const +const Constraint* BlockedType::getOwner() const { return owner; } -void BlockedType::setOwner(Constraint* newOwner) +void BlockedType::setOwner(const Constraint* newOwner) { LUAU_ASSERT(owner == nullptr); @@ -569,7 +592,7 @@ void BlockedType::setOwner(Constraint* newOwner) owner = newOwner; } -void BlockedType::replaceOwner(Constraint* newOwner) +void BlockedType::replaceOwner(const Constraint* newOwner) { owner = newOwner; } diff --git a/Analysis/src/TypeArena.cpp b/Analysis/src/TypeArena.cpp index 617bd305..e4e9e293 100644 --- a/Analysis/src/TypeArena.cpp +++ b/Analysis/src/TypeArena.cpp @@ -3,6 +3,7 @@ #include "Luau/TypeArena.h" LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena); +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -22,7 +23,34 @@ TypeId TypeArena::addTV(Type&& tv) return allocated; } -TypeId TypeArena::freshType(TypeLevel level) +TypeId TypeArena::freshType(NotNull builtins, TypeLevel level) +{ + TypeId allocated = types.allocate(FreeType{level, builtins->neverType, builtins->unknownType}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypeId TypeArena::freshType(NotNull builtins, Scope* scope) +{ + TypeId allocated = types.allocate(FreeType{scope, builtins->neverType, builtins->unknownType}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypeId TypeArena::freshType(NotNull builtins, Scope* scope, TypeLevel level) +{ + TypeId allocated = types.allocate(FreeType{scope, level, builtins->neverType, builtins->unknownType}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypeId TypeArena::freshType_DEPRECATED(TypeLevel level) { TypeId allocated = types.allocate(FreeType{level}); @@ -31,7 +59,7 @@ TypeId TypeArena::freshType(TypeLevel level) return allocated; } -TypeId TypeArena::freshType(Scope* scope) +TypeId TypeArena::freshType_DEPRECATED(Scope* scope) { TypeId allocated = types.allocate(FreeType{scope}); @@ -40,7 +68,7 @@ TypeId TypeArena::freshType(Scope* scope) return allocated; } -TypeId TypeArena::freshType(Scope* scope, TypeLevel level) +TypeId TypeArena::freshType_DEPRECATED(Scope* scope, TypeLevel level) { TypeId allocated = types.allocate(FreeType{scope, level}); diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 6f28b11c..6fc60b2f 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -386,8 +386,12 @@ public: } AstType* operator()(const NegationType& ntv) { - // FIXME: do the same thing we do with ErrorType - throw InternalCompilerError("Cannot convert NegationType into AstNode"); + AstArray params; + params.size = 1; + params.data = static_cast(allocator->allocate(sizeof(AstType*))); + params.data[0] = AstTypeOrPack{Luau::visit(*this, ntv.ty->ty), nullptr}; + + return allocator->alloc(Location(), std::nullopt, AstName("negate"), std::nullopt, Location(), true, params); } AstType* operator()(const TypeFunctionInstanceType& tfit) { diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 5397a0e8..32c0f4db 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -7,7 +7,6 @@ #include "Luau/DcrLogger.h" #include "Luau/DenseHash.h" #include "Luau/Error.h" -#include "Luau/InsertionOrderedMap.h" #include "Luau/Instantiation.h" #include "Luau/Metamethods.h" #include "Luau/Normalize.h" @@ -27,12 +26,12 @@ #include "Luau/VisitType.h" #include -#include -#include LUAU_FASTFLAG(DebugLuauMagicTypes) +LUAU_FASTFLAG(InferGlobalTypes) LUAU_FASTFLAGVARIABLE(LuauTableKeysAreRValues) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -175,7 +174,7 @@ struct InternalTypeFunctionFinder : TypeOnceVisitor DenseHashSet mentionedFunctions{nullptr}; DenseHashSet mentionedFunctionPacks{nullptr}; - InternalTypeFunctionFinder(std::vector& declStack) + explicit InternalTypeFunctionFinder(std::vector& declStack) { TypeFunctionFinder f; for (TypeId fn : declStack) @@ -268,6 +267,7 @@ struct InternalTypeFunctionFinder : TypeOnceVisitor void check( NotNull builtinTypes, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull unifierState, NotNull limits, @@ -278,7 +278,7 @@ void check( { LUAU_TIMETRACE_SCOPE("check", "Typechecking"); - TypeChecker2 typeChecker{builtinTypes, typeFunctionRuntime, unifierState, limits, logger, &sourceModule, module}; + TypeChecker2 typeChecker{builtinTypes, simplifier, typeFunctionRuntime, unifierState, limits, logger, &sourceModule, module}; typeChecker.visit(sourceModule.root); @@ -295,6 +295,7 @@ void check( TypeChecker2::TypeChecker2( NotNull builtinTypes, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull unifierState, NotNull limits, @@ -303,6 +304,7 @@ TypeChecker2::TypeChecker2( Module* module ) : builtinTypes(builtinTypes) + , simplifier(simplifier) , typeFunctionRuntime(typeFunctionRuntime) , logger(logger) , limits(limits) @@ -310,7 +312,7 @@ TypeChecker2::TypeChecker2( , sourceModule(sourceModule) , module(module) , normalizer{&module->internalTypes, builtinTypes, unifierState, /* cacheInhabitance */ true} - , _subtyping{builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, typeFunctionRuntime, NotNull{unifierState->iceHandler}} + , _subtyping{builtinTypes, NotNull{&module->internalTypes}, simplifier, NotNull{&normalizer}, typeFunctionRuntime, NotNull{unifierState->iceHandler}} , subtyping(&_subtyping) { } @@ -492,7 +494,9 @@ TypeId TypeChecker2::checkForTypeFunctionInhabitance(TypeId instance, Location l reduceTypeFunctions( instance, location, - TypeFunctionContext{NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, typeFunctionRuntime, ice, limits}, + TypeFunctionContext{ + NotNull{&module->internalTypes}, builtinTypes, stack.back(), simplifier, NotNull{&normalizer}, typeFunctionRuntime, ice, limits + }, true ) .errors; @@ -501,7 +505,7 @@ TypeId TypeChecker2::checkForTypeFunctionInhabitance(TypeId instance, Location l return instance; } -TypePackId TypeChecker2::lookupPack(AstExpr* expr) +TypePackId TypeChecker2::lookupPack(AstExpr* expr) const { // If a type isn't in the type graph, it probably means that a recursion limit was exceeded. // We'll just return anyType in these cases. Typechecking against any is very fast and this @@ -551,7 +555,7 @@ TypeId TypeChecker2::lookupAnnotation(AstType* annotation) return checkForTypeFunctionInhabitance(follow(*ty), annotation->location); } -std::optional TypeChecker2::lookupPackAnnotation(AstTypePack* annotation) +std::optional TypeChecker2::lookupPackAnnotation(AstTypePack* annotation) const { TypePackId* tp = module->astResolvedTypePacks.find(annotation); if (tp != nullptr) @@ -559,7 +563,7 @@ std::optional TypeChecker2::lookupPackAnnotation(AstTypePack* annota return {}; } -TypeId TypeChecker2::lookupExpectedType(AstExpr* expr) +TypeId TypeChecker2::lookupExpectedType(AstExpr* expr) const { if (TypeId* ty = module->astExpectedTypes.find(expr)) return follow(*ty); @@ -567,7 +571,7 @@ TypeId TypeChecker2::lookupExpectedType(AstExpr* expr) return builtinTypes->anyType; } -TypePackId TypeChecker2::lookupExpectedPack(AstExpr* expr, TypeArena& arena) +TypePackId TypeChecker2::lookupExpectedPack(AstExpr* expr, TypeArena& arena) const { if (TypeId* ty = module->astExpectedTypes.find(expr)) return arena.addTypePack(TypePack{{follow(*ty)}, std::nullopt}); @@ -591,7 +595,7 @@ TypePackId TypeChecker2::reconstructPack(AstArray exprs, TypeArena& ar return arena.addTypePack(TypePack{head, tail}); } -Scope* TypeChecker2::findInnermostScope(Location location) +Scope* TypeChecker2::findInnermostScope(Location location) const { Scope* bestScope = module->getModuleScope().get(); @@ -1014,7 +1018,8 @@ void TypeChecker2::visit(AstStatForIn* forInStatement) { reportError(OptionalValueAccess{iteratorTy}, forInStatement->values.data[0]->location); } - else if (std::optional iterMmTy = findMetatableEntry(builtinTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location)) + else if (std::optional iterMmTy = + findMetatableEntry(builtinTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location)) { Instantiation instantiation{TxnLog::empty(), &arena, builtinTypes, TypeLevel{}, scope}; @@ -1349,7 +1354,17 @@ void TypeChecker2::visit(AstExprGlobal* expr) { NotNull scope = stack.back(); if (!scope->lookup(expr->name)) + { reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location); + } + else if (FFlag::InferGlobalTypes) + { + if (scope->shouldWarnGlobal(expr->name.value) && !warnedGlobals.contains(expr->name.value)) + { + reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location); + warnedGlobals.insert(expr->name.value); + } + } } void TypeChecker2::visit(AstExprVarargs* expr) @@ -1437,10 +1452,11 @@ void TypeChecker2::visitCall(AstExprCall* call) TypePackId argsTp = module->internalTypes.addTypePack(args); if (auto ftv = get(follow(*originalCallTy))) { - if (ftv->dcrMagicTypeCheck) + if (ftv->magic) { - ftv->dcrMagicTypeCheck(MagicFunctionTypeCheckContext{NotNull{this}, builtinTypes, call, argsTp, scope}); - return; + bool usedMagic = ftv->magic->typeCheck(MagicFunctionTypeCheckContext{NotNull{this}, builtinTypes, call, argsTp, scope}); + if (usedMagic) + return; } } @@ -1448,6 +1464,7 @@ void TypeChecker2::visitCall(AstExprCall* call) OverloadResolver resolver{ builtinTypes, NotNull{&module->internalTypes}, + simplifier, NotNull{&normalizer}, typeFunctionRuntime, NotNull{stack.back()}, @@ -1545,7 +1562,7 @@ void TypeChecker2::visit(AstExprCall* call) visitCall(call); } -std::optional TypeChecker2::tryStripUnionFromNil(TypeId ty) +std::optional TypeChecker2::tryStripUnionFromNil(TypeId ty) const { if (const UnionType* utv = get(ty)) { @@ -2089,7 +2106,10 @@ TypeId TypeChecker2::visit(AstExprBinary* expr, AstNode* overrideKey) } else { - expectedRets = module->internalTypes.addTypePack({module->internalTypes.freshType(scope, TypeLevel{})}); + expectedRets = module->internalTypes.addTypePack( + {FFlag::LuauFreeTypesMustHaveBounds ? module->internalTypes.freshType(builtinTypes, scope, TypeLevel{}) + : module->internalTypes.freshType_DEPRECATED(scope, TypeLevel{})} + ); } TypeId expectedTy = module->internalTypes.addType(FunctionType(expectedArgs, expectedRets)); @@ -2341,7 +2361,8 @@ TypeId TypeChecker2::flattenPack(TypePackId pack) return *fst; else if (auto ftp = get(pack)) { - TypeId result = module->internalTypes.addType(FreeType{ftp->scope}); + TypeId result = FFlag::LuauFreeTypesMustHaveBounds ? module->internalTypes.freshType(builtinTypes, ftp->scope) + : module->internalTypes.addType(FreeType{ftp->scope}); TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope}); TypePack* resultPack = emplaceTypePack(asMutable(pack)); @@ -2403,6 +2424,8 @@ void TypeChecker2::visit(AstType* ty) return visit(t); else if (auto t = ty->as()) return visit(t); + else if (auto t = ty->as()) + return visit(t->type); } void TypeChecker2::visit(AstTypeReference* ty) @@ -3024,7 +3047,7 @@ PropertyType TypeChecker2::hasIndexTypeFromType( { TypeId indexType = follow(tt->indexer->indexType); TypeId givenType = module->internalTypes.addType(SingletonType{StringSingleton{prop}}); - if (isSubtype(givenType, indexType, NotNull{module->getModuleScope().get()}, builtinTypes, *ice)) + if (isSubtype(givenType, indexType, NotNull{module->getModuleScope().get()}, builtinTypes, simplifier, *ice)) return {NormalizationResult::True, {tt->indexer->indexResultType}}; } diff --git a/Analysis/src/TypeFunction.cpp b/Analysis/src/TypeFunction.cpp index e7620cc4..a5f69460 100644 --- a/Analysis/src/TypeFunction.cpp +++ b/Analysis/src/TypeFunction.cpp @@ -46,10 +46,10 @@ LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyApplicationCartesianProductLimit, 5'0 LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyUseGuesserDepth, -1); LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies) -LUAU_FASTFLAG(LuauUserTypeFunFixRegister) -LUAU_FASTFLAG(LuauRemoveNotAnyHack) -LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionResetState) -LUAU_FASTFLAG(LuauUserTypeFunExportedAndLocal) +LUAU_FASTFLAG(DebugLuauEqSatSimplification) +LUAU_FASTFLAGVARIABLE(LuauMetatableTypeFunctions) +LUAU_FASTFLAGVARIABLE(LuauClipNestedAndRecursiveUnion) +LUAU_FASTFLAGVARIABLE(LuauDoNotGeneralizeInTypeFunctions) namespace Luau { @@ -220,6 +220,9 @@ struct TypeFunctionReducer template void handleTypeFunctionReduction(T subject, TypeFunctionReductionResult reduction) { + for (auto& message : reduction.messages) + result.messages.emplace_back(location, UserDefinedTypeFunctionError{std::move(message)}); + if (reduction.result) replace(subject, *reduction.result); else @@ -229,7 +232,7 @@ struct TypeFunctionReducer if (reduction.error.has_value()) result.errors.emplace_back(location, UserDefinedTypeFunctionError{*reduction.error}); - if (reduction.uninhabited || force) + if (reduction.reductionStatus != Reduction::MaybeOk || force) { if (FFlag::DebugLuauLogTypeFamilies) printf("%s is uninhabited\n", toString(subject, {true}).c_str()); @@ -239,7 +242,7 @@ struct TypeFunctionReducer else if constexpr (std::is_same_v) result.errors.emplace_back(location, UninhabitedTypePackFunction{subject}); } - else if (!reduction.uninhabited && !force) + else if (reduction.reductionStatus == Reduction::MaybeOk && !force) { if (FFlag::DebugLuauLogTypeFamilies) printf( @@ -528,7 +531,7 @@ static std::optional> tryDistributeTypeFunct ) { // op (a | b) (c | d) ~ (op a (c | d)) | (op b (c | d)) ~ (op a c) | (op a d) | (op b c) | (op b d) - bool uninhabited = false; + Reduction reductionStatus = Reduction::MaybeOk; std::vector blockedTypes; std::vector results; size_t cartesianProductSize = 1; @@ -557,7 +560,7 @@ static std::optional> tryDistributeTypeFunct // TODO: We'd like to report that the type function application is too complex here. if (size_t(DFInt::LuauTypeFamilyApplicationCartesianProductLimit) <= cartesianProductSize) - return {{std::nullopt, true, {}, {}}}; + return {{std::nullopt, Reduction::Erroneous, {}, {}}}; } if (!firstUnion) @@ -572,21 +575,22 @@ static std::optional> tryDistributeTypeFunct TypeFunctionReductionResult result = f(instance, arguments, packParams, ctx, args...); blockedTypes.insert(blockedTypes.end(), result.blockedTypes.begin(), result.blockedTypes.end()); - uninhabited |= result.uninhabited; + if (result.reductionStatus != Reduction::MaybeOk) + reductionStatus = result.reductionStatus; - if (result.uninhabited || !result.result) + if (reductionStatus != Reduction::MaybeOk || !result.result) break; else results.push_back(*result.result); } - if (uninhabited || !blockedTypes.empty()) - return {{std::nullopt, uninhabited, blockedTypes, {}}}; + if (reductionStatus != Reduction::MaybeOk || !blockedTypes.empty()) + return {{std::nullopt, reductionStatus, blockedTypes, {}}}; if (!results.empty()) { if (results.size() == 1) - return {{results[0], false, {}, {}}}; + return {{results[0], Reduction::MaybeOk, {}, {}}}; TypeId resultTy = ctx->arena->addType(TypeFunctionInstanceType{ NotNull{&builtinTypeFunctions().unionFunc}, @@ -594,7 +598,7 @@ static std::optional> tryDistributeTypeFunct {}, }); - return {{resultTy, false, {}, {}}}; + return {{resultTy, Reduction::MaybeOk, {}, {}}}; } return std::nullopt; @@ -609,32 +613,21 @@ TypeFunctionReductionResult userDefinedTypeFunction( { auto typeFunction = getMutable(instance); - if (FFlag::LuauUserTypeFunExportedAndLocal) + if (typeFunction->userFuncData.owner.expired()) { - if (typeFunction->userFuncData.owner.expired()) - { - ctx->ice->ice("user-defined type function module has expired"); - return {std::nullopt, true, {}, {}}; - } - - if (!typeFunction->userFuncName || !typeFunction->userFuncData.definition) - { - ctx->ice->ice("all user-defined type functions must have an associated function definition"); - return {std::nullopt, true, {}, {}}; - } + ctx->ice->ice("user-defined type function module has expired"); + return {std::nullopt, Reduction::Erroneous, {}, {}}; } - else + + if (!typeFunction->userFuncName || !typeFunction->userFuncData.definition) { - if (!ctx->userFuncName) - { - ctx->ice->ice("all user-defined type functions must have an associated function definition"); - return {std::nullopt, true, {}, {}}; - } + ctx->ice->ice("all user-defined type functions must have an associated function definition"); + return {std::nullopt, Reduction::Erroneous, {}, {}}; } // If type functions cannot be evaluated because of errors in the code, we do not generate any additional ones if (!ctx->typeFunctionRuntime->allowEvaluation) - return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + return {ctx->builtins->errorRecoveryType(), Reduction::MaybeOk, {}, {}}; for (auto typeParam : typeParams) { @@ -642,76 +635,83 @@ TypeFunctionReductionResult userDefinedTypeFunction( // block if we need to if (isPending(ty, ctx->solver)) - return {std::nullopt, false, {ty}, {}}; + return {std::nullopt, Reduction::MaybeOk, {ty}, {}}; } - if (FFlag::LuauUserTypeFunExportedAndLocal) + // Ensure that whole type function environment is registered + for (auto& [name, definition] : typeFunction->userFuncData.environment) { - // Ensure that whole type function environment is registered - for (auto& [name, definition] : typeFunction->userFuncData.environment) + if (std::optional error = ctx->typeFunctionRuntime->registerFunction(definition.first)) { - if (std::optional error = ctx->typeFunctionRuntime->registerFunction(definition)) - { - // Failure to register at this point means that original definition had to error out and should not have been present in the - // environment - ctx->ice->ice("user-defined type function reference cannot be registered"); - return {std::nullopt, true, {}, {}}; - } + // Failure to register at this point means that original definition had to error out and should not have been present in the + // environment + ctx->ice->ice("user-defined type function reference cannot be registered"); + return {std::nullopt, Reduction::Erroneous, {}, {}}; } } - AstName name = FFlag::LuauUserTypeFunExportedAndLocal ? typeFunction->userFuncData.definition->name : *ctx->userFuncName; + AstName name = typeFunction->userFuncData.definition->name; lua_State* global = ctx->typeFunctionRuntime->state.get(); if (global == nullptr) - return {std::nullopt, true, {}, {}, format("'%s' type function: cannot be evaluated in this context", name.value)}; + return {std::nullopt, Reduction::Erroneous, {}, {}, format("'%s' type function: cannot be evaluated in this context", name.value)}; // Separate sandboxed thread for individual execution and private globals lua_State* L = lua_newthread(global); LuauTempThreadPopper popper(global); - if (FFlag::LuauUserTypeFunExportedAndLocal) + // Build up the environment table of each function we have visible + for (auto& [_, curr] : typeFunction->userFuncData.environment) { - // Fetch the function we want to evaluate - lua_pushlightuserdata(L, typeFunction->userFuncData.definition); + // Environment table has to be filled only once in the current execution context + if (ctx->typeFunctionRuntime->initialized.find(curr.first)) + continue; + ctx->typeFunctionRuntime->initialized.insert(curr.first); + + lua_pushlightuserdata(L, curr.first); lua_gettable(L, LUA_REGISTRYINDEX); if (!lua_isfunction(L, -1)) { ctx->ice->ice("user-defined type function reference cannot be found in the registry"); - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } - // Build up the environment + // Build up the environment of the current function, where some might not be visible lua_getfenv(L, -1); lua_setreadonly(L, -1, false); for (auto& [name, definition] : typeFunction->userFuncData.environment) { - lua_pushlightuserdata(L, definition); - lua_gettable(L, LUA_REGISTRYINDEX); - - if (!lua_isfunction(L, -1)) + // Filter visibility based on original scope depth + if (definition.second >= curr.second) { - ctx->ice->ice("user-defined type function reference cannot be found in the registry"); - return {std::nullopt, true, {}, {}}; - } + lua_pushlightuserdata(L, definition.first); + lua_gettable(L, LUA_REGISTRYINDEX); - lua_setfield(L, -2, name.c_str()); + if (!lua_isfunction(L, -1)) + break; // Don't have to report an error here, we will visit each function in outer loop + + lua_setfield(L, -2, name.c_str()); + } } lua_setreadonly(L, -1, true); - lua_pop(L, 1); - } - else - { - lua_getglobal(global, name.value); - lua_xmove(global, L, 1); + lua_pop(L, 2); } - if (FFlag::LuauUserDefinedTypeFunctionResetState) - resetTypeFunctionState(L); + // Fetch the function we want to evaluate + lua_pushlightuserdata(L, typeFunction->userFuncData.definition); + lua_gettable(L, LUA_REGISTRYINDEX); + + if (!lua_isfunction(L, -1)) + { + ctx->ice->ice("user-defined type function reference cannot be found in the registry"); + return {std::nullopt, Reduction::Erroneous, {}, {}}; + } + + resetTypeFunctionState(L); // Push serialized arguments onto the stack @@ -727,7 +727,7 @@ TypeFunctionReductionResult userDefinedTypeFunction( TypeFunctionTypeId serializedTy = serialize(ty, runtimeBuilder.get()); // Check if there were any errors while serializing if (runtimeBuilder->errors.size() != 0) - return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()}; + return {std::nullopt, Reduction::Erroneous, {}, {}, runtimeBuilder->errors.front()}; allocTypeUserData(L, serializedTy->type); } @@ -743,12 +743,23 @@ TypeFunctionReductionResult userDefinedTypeFunction( throw UserCancelError(ctx->ice->moduleName); }; + ctx->typeFunctionRuntime->messages.clear(); + if (auto error = checkResultForError(L, name.value, lua_pcall(L, int(typeParams.size()), 1, 0))) - return {std::nullopt, true, {}, {}, error}; + return {std::nullopt, Reduction::Erroneous, {}, {}, error, ctx->typeFunctionRuntime->messages}; // If the return value is not a type userdata, return with error message if (!isTypeUserData(L, 1)) - return {std::nullopt, true, {}, {}, format("'%s' type function: returned a non-type value", name.value)}; + { + return { + std::nullopt, + Reduction::Erroneous, + {}, + {}, + format("'%s' type function: returned a non-type value", name.value), + ctx->typeFunctionRuntime->messages + }; + } TypeFunctionTypeId retTypeFunctionTypeId = getTypeUserData(L, 1); @@ -759,9 +770,9 @@ TypeFunctionReductionResult userDefinedTypeFunction( // At least 1 error occurred while deserializing if (runtimeBuilder->errors.size() > 0) - return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()}; + return {std::nullopt, Reduction::Erroneous, {}, {}, runtimeBuilder->errors.front(), ctx->typeFunctionRuntime->messages}; - return {retTypeId, false, {}, {}}; + return {retTypeId, Reduction::MaybeOk, {}, {}, std::nullopt, ctx->typeFunctionRuntime->messages}; } TypeFunctionReductionResult notTypeFunction( @@ -780,16 +791,16 @@ TypeFunctionReductionResult notTypeFunction( TypeId ty = follow(typeParams.at(0)); if (ty == instance) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; if (isPending(ty, ctx->solver)) - return {std::nullopt, false, {ty}, {}}; + return {std::nullopt, Reduction::MaybeOk, {ty}, {}}; if (auto result = tryDistributeTypeFunctionApp(notTypeFunction, instance, typeParams, packParams, ctx)) return *result; // `not` operates on anything and returns a `boolean` always. - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult lenTypeFunction( @@ -808,19 +819,19 @@ TypeFunctionReductionResult lenTypeFunction( TypeId operandTy = follow(typeParams.at(0)); if (operandTy == instance) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // check to see if the operand type is resolved enough, and wait to reduce if not // the use of `typeFromNormal` later necessitates blocking on local types. if (isPending(operandTy, ctx->solver)) - return {std::nullopt, false, {operandTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {operandTy}, {}}; // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy, /* avoidSealingTables */ true); if (!maybeGeneralized) - return {std::nullopt, false, {operandTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {operandTy}, {}}; operandTy = *maybeGeneralized; } @@ -829,21 +840,21 @@ TypeFunctionReductionResult lenTypeFunction( // if the type failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normTy || inhabited == NormalizationResult::HitLimits) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if the operand type is error suppressing, we can immediately reduce to `number`. if (normTy->shouldSuppressErrors()) - return {ctx->builtins->numberType, false, {}, {}}; + return {ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; // # always returns a number, even if its operand is never. // if we're checking the length of a string, that works! if (inhabited == NormalizationResult::False || normTy->isSubtypeOfString()) - return {ctx->builtins->numberType, false, {}, {}}; + return {ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; // we use the normalized operand here in case there was an intersection or union. TypeId normalizedOperand = follow(ctx->normalizer->typeFromNormal(*normTy)); if (normTy->hasTopTable() || get(normalizedOperand)) - return {ctx->builtins->numberType, false, {}, {}}; + return {ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; if (auto result = tryDistributeTypeFunctionApp(lenTypeFunction, instance, typeParams, packParams, ctx)) return *result; @@ -854,35 +865,35 @@ TypeFunctionReductionResult lenTypeFunction( std::optional mmType = findMetatableEntry(ctx->builtins, dummy, operandTy, "__len", Location{}); if (!mmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) - return {std::nullopt, false, {*mmType}, {}}; + return {std::nullopt, Reduction::MaybeOk, {*mmType}, {}}; const FunctionType* mmFtv = get(*mmType); if (!mmFtv) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); if (!instantiatedMmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); if (!instantiatedMmFtv) - return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + return {ctx->builtins->errorRecoveryType(), Reduction::MaybeOk, {}, {}}; TypePackId inferredArgPack = ctx->arena->addTypePack({operandTy}); Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) - return {std::nullopt, true, {}, {}}; // occurs check failed + return {std::nullopt, Reduction::Erroneous, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->simplifier, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; // `len` must return a `number`. - return {ctx->builtins->numberType, false, {}, {}}; + return {ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult unmTypeFunction( @@ -901,18 +912,18 @@ TypeFunctionReductionResult unmTypeFunction( TypeId operandTy = follow(typeParams.at(0)); if (operandTy == instance) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // check to see if the operand type is resolved enough, and wait to reduce if not if (isPending(operandTy, ctx->solver)) - return {std::nullopt, false, {operandTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {operandTy}, {}}; // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy); if (!maybeGeneralized) - return {std::nullopt, false, {operandTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {operandTy}, {}}; operandTy = *maybeGeneralized; } @@ -920,19 +931,19 @@ TypeFunctionReductionResult unmTypeFunction( // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normTy) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if the operand is error suppressing, we can just go ahead and reduce. if (normTy->shouldSuppressErrors()) - return {operandTy, false, {}, {}}; + return {operandTy, Reduction::MaybeOk, {}, {}}; // if we have a `never`, we can never observe that the operation didn't work. if (is(operandTy)) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // If the type is exactly `number`, we can reduce now. if (normTy->isExactlyNumber()) - return {ctx->builtins->numberType, false, {}, {}}; + return {ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; if (auto result = tryDistributeTypeFunctionApp(unmTypeFunction, instance, typeParams, packParams, ctx)) return *result; @@ -943,37 +954,37 @@ TypeFunctionReductionResult unmTypeFunction( std::optional mmType = findMetatableEntry(ctx->builtins, dummy, operandTy, "__unm", Location{}); if (!mmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) - return {std::nullopt, false, {*mmType}, {}}; + return {std::nullopt, Reduction::MaybeOk, {*mmType}, {}}; const FunctionType* mmFtv = get(*mmType); if (!mmFtv) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); if (!instantiatedMmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); if (!instantiatedMmFtv) - return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + return {ctx->builtins->errorRecoveryType(), Reduction::MaybeOk, {}, {}}; TypePackId inferredArgPack = ctx->arena->addTypePack({operandTy}); Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) - return {std::nullopt, true, {}, {}}; // occurs check failed + return {std::nullopt, Reduction::Erroneous, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->simplifier, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; if (std::optional ret = first(instantiatedMmFtv->retTypes)) - return {*ret, false, {}, {}}; + return {ret, Reduction::MaybeOk, {}, {}}; else - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } void dummyStateClose(lua_State*) {} @@ -997,21 +1008,18 @@ std::optional TypeFunctionRuntime::registerFunction(AstStatTypeFunc lua_State* global = state.get(); - if (FFlag::LuauUserTypeFunExportedAndLocal) + // Fetch to check if function is already registered + lua_pushlightuserdata(global, function); + lua_gettable(global, LUA_REGISTRYINDEX); + + if (!lua_isnil(global, -1)) { - // Fetch to check if function is already registered - lua_pushlightuserdata(global, function); - lua_gettable(global, LUA_REGISTRYINDEX); - - if (!lua_isnil(global, -1)) - { - lua_pop(global, 1); - return std::nullopt; - } - lua_pop(global, 1); + return std::nullopt; } + lua_pop(global, 1); + AstName name = function->name; // Construct ParseResult containing the type function @@ -1024,7 +1032,7 @@ std::optional TypeFunctionRuntime::registerFunction(AstStatTypeFunc AstStat* stmtArray[] = {&stmtReturn}; AstArray stmts{stmtArray, 1}; AstStatBlock exec{Location{}, stmts}; - ParseResult parseResult{&exec, 1}; + ParseResult parseResult{&exec, 1, {}, {}, {}, CstNodeMap{nullptr}}; BytecodeBuilder builder; try @@ -1065,19 +1073,10 @@ std::optional TypeFunctionRuntime::registerFunction(AstStatTypeFunc return format("Could not find '%s' type function in the global scope", name.value); } - if (FFlag::LuauUserTypeFunExportedAndLocal) - { - // Store resulting function in the registry - lua_pushlightuserdata(global, function); - lua_xmove(L, global, 1); - lua_settable(global, LUA_REGISTRYINDEX); - } - else - { - // Store resulting function in the global environment - lua_xmove(L, global, 1); - lua_setglobal(global, name.value); - } + // Store resulting function in the registry + lua_pushlightuserdata(global, function); + lua_xmove(L, global, 1); + lua_settable(global, LUA_REGISTRYINDEX); return std::nullopt; } @@ -1096,8 +1095,7 @@ void TypeFunctionRuntime::prepareState() registerTypeUserData(L); - if (FFlag::LuauUserTypeFunFixRegister) - registerTypesLibrary(L); + registerTypesLibrary(L); luaL_sandbox(L); luaL_sandboxthread(L); @@ -1107,6 +1105,7 @@ TypeFunctionContext::TypeFunctionContext(NotNull cs, NotNullarena) , builtins(cs->builtinTypes) , scope(scope) + , simplifier(cs->simplifier) , normalizer(cs->normalizer) , typeFunctionRuntime(cs->typeFunctionRuntime) , ice(NotNull{&cs->iceReporter}) @@ -1148,30 +1147,30 @@ TypeFunctionReductionResult numericBinopTypeFunction( // isPending of `lhsTy` or `rhsTy` would return true, even if it cycles. We want a different answer for that. if (lhsTy == instance || rhsTy == instance) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // if we have a `never`, we can never observe that the math operator is unreachable. if (is(lhsTy) || is(rhsTy)) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; const Location location = ctx->constraint ? ctx->constraint->location : Location{}; // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); if (!lhsMaybeGeneralized) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (!rhsMaybeGeneralized) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; lhsTy = *lhsMaybeGeneralized; rhsTy = *rhsMaybeGeneralized; @@ -1183,15 +1182,15 @@ TypeFunctionReductionResult numericBinopTypeFunction( // if either failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normLhsTy || !normRhsTy) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if one of the types is error suppressing, we can reduce to `any` since we should suppress errors in the result of the usage. if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) - return {ctx->builtins->anyType, false, {}, {}}; + return {ctx->builtins->anyType, Reduction::MaybeOk, {}, {}}; // if we're adding two `number` types, the result is `number`. if (normLhsTy->isExactlyNumber() && normRhsTy->isExactlyNumber()) - return {ctx->builtins->numberType, false, {}, {}}; + return {ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; if (auto result = tryDistributeTypeFunctionApp(numericBinopTypeFunction, instance, typeParams, packParams, ctx, metamethod)) return *result; @@ -1209,36 +1208,56 @@ TypeFunctionReductionResult numericBinopTypeFunction( } if (!mmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) - return {std::nullopt, false, {*mmType}, {}}; + return {std::nullopt, Reduction::MaybeOk, {*mmType}, {}}; TypePackId argPack = ctx->arena->addTypePack({lhsTy, rhsTy}); SolveResult solveResult; if (!reversed) solveResult = solveFunctionCall( - ctx->arena, ctx->builtins, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack + ctx->arena, + ctx->builtins, + ctx->simplifier, + ctx->normalizer, + ctx->typeFunctionRuntime, + ctx->ice, + ctx->limits, + ctx->scope, + location, + *mmType, + argPack ); else { TypePack* p = getMutable(argPack); std::swap(p->head.front(), p->head.back()); solveResult = solveFunctionCall( - ctx->arena, ctx->builtins, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack + ctx->arena, + ctx->builtins, + ctx->simplifier, + ctx->normalizer, + ctx->typeFunctionRuntime, + ctx->ice, + ctx->limits, + ctx->scope, + location, + *mmType, + argPack ); } if (!solveResult.typePackId.has_value()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; TypePack extracted = extendTypePack(*ctx->arena, ctx->builtins, *solveResult.typePackId, 1); if (extracted.head.empty()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; - return {extracted.head.front(), false, {}, {}}; + return {extracted.head.front(), Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult addTypeFunction( @@ -1371,24 +1390,24 @@ TypeFunctionReductionResult concatTypeFunction( // isPending of `lhsTy` or `rhsTy` would return true, even if it cycles. We want a different answer for that. if (lhsTy == instance || rhsTy == instance) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); if (!lhsMaybeGeneralized) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (!rhsMaybeGeneralized) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; lhsTy = *lhsMaybeGeneralized; rhsTy = *rhsMaybeGeneralized; @@ -1399,19 +1418,19 @@ TypeFunctionReductionResult concatTypeFunction( // if either failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normLhsTy || !normRhsTy) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if one of the types is error suppressing, we can reduce to `any` since we should suppress errors in the result of the usage. if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) - return {ctx->builtins->anyType, false, {}, {}}; + return {ctx->builtins->anyType, Reduction::MaybeOk, {}, {}}; - // if we have a `never`, we can never observe that the numeric operator didn't work. + // if we have a `never`, we can never observe that the operator didn't work. if (is(lhsTy) || is(rhsTy)) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // if we're concatenating two elements that are either strings or numbers, the result is `string`. if ((normLhsTy->isSubtypeOfString() || normLhsTy->isExactlyNumber()) && (normRhsTy->isSubtypeOfString() || normRhsTy->isExactlyNumber())) - return {ctx->builtins->stringType, false, {}, {}}; + return {ctx->builtins->stringType, Reduction::MaybeOk, {}, {}}; if (auto result = tryDistributeTypeFunctionApp(concatTypeFunction, instance, typeParams, packParams, ctx)) return *result; @@ -1429,23 +1448,23 @@ TypeFunctionReductionResult concatTypeFunction( } if (!mmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) - return {std::nullopt, false, {*mmType}, {}}; + return {std::nullopt, Reduction::MaybeOk, {*mmType}, {}}; const FunctionType* mmFtv = get(*mmType); if (!mmFtv) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); if (!instantiatedMmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); if (!instantiatedMmFtv) - return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + return {ctx->builtins->errorRecoveryType(), Reduction::MaybeOk, {}, {}}; std::vector inferredArgs; if (!reversed) @@ -1456,13 +1475,13 @@ TypeFunctionReductionResult concatTypeFunction( TypePackId inferredArgPack = ctx->arena->addTypePack(std::move(inferredArgs)); Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) - return {std::nullopt, true, {}, {}}; // occurs check failed + return {std::nullopt, Reduction::Erroneous, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->simplifier, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; - return {ctx->builtins->stringType, false, {}, {}}; + return {ctx->builtins->stringType, Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult andTypeFunction( @@ -1483,27 +1502,27 @@ TypeFunctionReductionResult andTypeFunction( // t1 = and ~> lhs if (follow(rhsTy) == instance && lhsTy != rhsTy) - return {lhsTy, false, {}, {}}; + return {lhsTy, Reduction::MaybeOk, {}, {}}; // t1 = and ~> rhs if (follow(lhsTy) == instance && lhsTy != rhsTy) - return {rhsTy, false, {}, {}}; + return {rhsTy, Reduction::MaybeOk, {}, {}}; // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); if (!lhsMaybeGeneralized) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (!rhsMaybeGeneralized) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; lhsTy = *lhsMaybeGeneralized; rhsTy = *rhsMaybeGeneralized; @@ -1517,7 +1536,7 @@ TypeFunctionReductionResult andTypeFunction( blockedTypes.push_back(ty); for (auto ty : overallResult.blockedTypes) blockedTypes.push_back(ty); - return {overallResult.result, false, std::move(blockedTypes), {}}; + return {overallResult.result, Reduction::MaybeOk, std::move(blockedTypes), {}}; } TypeFunctionReductionResult orTypeFunction( @@ -1538,27 +1557,27 @@ TypeFunctionReductionResult orTypeFunction( // t1 = or ~> lhs if (follow(rhsTy) == instance && lhsTy != rhsTy) - return {lhsTy, false, {}, {}}; + return {lhsTy, Reduction::MaybeOk, {}, {}}; // t1 = or ~> rhs if (follow(lhsTy) == instance && lhsTy != rhsTy) - return {rhsTy, false, {}, {}}; + return {rhsTy, Reduction::MaybeOk, {}, {}}; // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); if (!lhsMaybeGeneralized) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (!rhsMaybeGeneralized) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; lhsTy = *lhsMaybeGeneralized; rhsTy = *rhsMaybeGeneralized; @@ -1572,7 +1591,7 @@ TypeFunctionReductionResult orTypeFunction( blockedTypes.push_back(ty); for (auto ty : overallResult.blockedTypes) blockedTypes.push_back(ty); - return {overallResult.result, false, std::move(blockedTypes), {}}; + return {overallResult.result, Reduction::MaybeOk, std::move(blockedTypes), {}}; } static TypeFunctionReductionResult comparisonTypeFunction( @@ -1594,12 +1613,12 @@ static TypeFunctionReductionResult comparisonTypeFunction( TypeId rhsTy = follow(typeParams.at(1)); if (lhsTy == instance || rhsTy == instance) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // Algebra Reduction Rules for comparison type functions // Note that comparing to never tells you nothing about the other operand @@ -1636,15 +1655,15 @@ static TypeFunctionReductionResult comparisonTypeFunction( rhsTy = follow(rhsTy); // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); if (!lhsMaybeGeneralized) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (!rhsMaybeGeneralized) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; lhsTy = *lhsMaybeGeneralized; rhsTy = *rhsMaybeGeneralized; @@ -1659,23 +1678,23 @@ static TypeFunctionReductionResult comparisonTypeFunction( // if either failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normLhsTy || !normRhsTy || lhsInhabited == NormalizationResult::HitLimits || rhsInhabited == NormalizationResult::HitLimits) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if one of the types is error suppressing, we can just go ahead and reduce. if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; // if we have an uninhabited type (e.g. `never`), we can never observe that the comparison didn't work. if (lhsInhabited == NormalizationResult::False || rhsInhabited == NormalizationResult::False) - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; // If both types are some strict subset of `string`, we can reduce now. if (normLhsTy->isSubtypeOfString() && normRhsTy->isSubtypeOfString()) - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; // If both types are exactly `number`, we can reduce now. if (normLhsTy->isExactlyNumber() && normRhsTy->isExactlyNumber()) - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; if (auto result = tryDistributeTypeFunctionApp(comparisonTypeFunction, instance, typeParams, packParams, ctx, metamethod)) return *result; @@ -1689,34 +1708,34 @@ static TypeFunctionReductionResult comparisonTypeFunction( mmType = findMetatableEntry(ctx->builtins, dummy, rhsTy, metamethod, Location{}); if (!mmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) - return {std::nullopt, false, {*mmType}, {}}; + return {std::nullopt, Reduction::MaybeOk, {*mmType}, {}}; const FunctionType* mmFtv = get(*mmType); if (!mmFtv) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); if (!instantiatedMmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); if (!instantiatedMmFtv) - return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + return {ctx->builtins->errorRecoveryType(), Reduction::MaybeOk, {}, {}}; TypePackId inferredArgPack = ctx->arena->addTypePack({lhsTy, rhsTy}); Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) - return {std::nullopt, true, {}, {}}; // occurs check failed + return {std::nullopt, Reduction::Erroneous, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->simplifier, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult ltTypeFunction( @@ -1769,20 +1788,20 @@ TypeFunctionReductionResult eqTypeFunction( // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); if (!lhsMaybeGeneralized) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (!rhsMaybeGeneralized) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; lhsTy = *lhsMaybeGeneralized; rhsTy = *rhsMaybeGeneralized; @@ -1795,15 +1814,15 @@ TypeFunctionReductionResult eqTypeFunction( // if either failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normLhsTy || !normRhsTy || lhsInhabited == NormalizationResult::HitLimits || rhsInhabited == NormalizationResult::HitLimits) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if one of the types is error suppressing, we can just go ahead and reduce. if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; // if we have a `never`, we can never observe that the comparison didn't work. if (lhsInhabited == NormalizationResult::False || rhsInhabited == NormalizationResult::False) - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. @@ -1818,49 +1837,49 @@ TypeFunctionReductionResult eqTypeFunction( if (!mmType) { if (intersectInhabited == NormalizationResult::True) - return {ctx->builtins->booleanType, false, {}, {}}; // if it's inhabited, everything is okay! + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; // if it's inhabited, everything is okay! // we might be in a case where we still want to accept the comparison... if (intersectInhabited == NormalizationResult::False) { // if they're both subtypes of `string` but have no common intersection, the comparison is allowed but always `false`. if (normLhsTy->isSubtypeOfString() && normRhsTy->isSubtypeOfString()) - return {ctx->builtins->falseType, false, {}, {}}; + return {ctx->builtins->falseType, Reduction::MaybeOk, {}, {}}; // if they're both subtypes of `boolean` but have no common intersection, the comparison is allowed but always `false`. if (normLhsTy->isSubtypeOfBooleans() && normRhsTy->isSubtypeOfBooleans()) - return {ctx->builtins->falseType, false, {}, {}}; + return {ctx->builtins->falseType, Reduction::MaybeOk, {}, {}}; } - return {std::nullopt, true, {}, {}}; // if it's not, then this type function is irreducible! + return {std::nullopt, Reduction::Erroneous, {}, {}}; // if it's not, then this type function is irreducible! } mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) - return {std::nullopt, false, {*mmType}, {}}; + return {std::nullopt, Reduction::MaybeOk, {*mmType}, {}}; const FunctionType* mmFtv = get(*mmType); if (!mmFtv) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); if (!instantiatedMmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); if (!instantiatedMmFtv) - return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + return {ctx->builtins->errorRecoveryType(), Reduction::MaybeOk, {}, {}}; TypePackId inferredArgPack = ctx->arena->addTypePack({lhsTy, rhsTy}); Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) - return {std::nullopt, true, {}, {}}; // occurs check failed + return {std::nullopt, Reduction::Erroneous, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->simplifier, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; } // Collect types that prevent us from reducing a particular refinement. @@ -1905,13 +1924,13 @@ TypeFunctionReductionResult refineTypeFunction( // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(targetTy, ctx->solver)) - return {std::nullopt, false, {targetTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {targetTy}, {}}; else { for (auto t : discriminantTypes) { if (isPending(t, ctx->solver)) - return {std::nullopt, false, {t}, {}}; + return {std::nullopt, Reduction::MaybeOk, {t}, {}}; } } // Refine a target type and a discriminant one at a time. @@ -1919,7 +1938,7 @@ TypeFunctionReductionResult refineTypeFunction( auto stepRefine = [&ctx](TypeId target, TypeId discriminant) -> std::pair> { std::vector toBlock; - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional targetMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, target); std::optional discriminantMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, discriminant); @@ -1940,57 +1959,68 @@ TypeFunctionReductionResult refineTypeFunction( if (!frb.found.empty()) return {nullptr, {frb.found.begin(), frb.found.end()}}; - /* HACK: Refinements sometimes produce a type T & ~any under the assumption - * that ~any is the same as any. This is so so weird, but refinements needs - * some way to say "I may refine this, but I'm not sure." - * - * It does this by refining on a blocked type and deferring the decision - * until it is unblocked. - * - * Refinements also get negated, so we wind up with types like T & ~*blocked* - * - * We need to treat T & ~any as T in this case. - */ - if (auto nt = get(discriminant)) + if (FFlag::DebugLuauEqSatSimplification) { - if (FFlag::LuauRemoveNotAnyHack) + auto simplifyResult = eqSatSimplify(ctx->simplifier, ctx->arena->addType(IntersectionType{{target, discriminant}})); + if (simplifyResult) + { + if (ctx->solver) + { + for (TypeId newTf : simplifyResult->newTypeFunctions) + ctx->pushConstraint(ReduceConstraint{newTf}); + } + + return {simplifyResult->result, {}}; + } + else + return {nullptr, {}}; + } + else + { + /* HACK: Refinements sometimes produce a type T & ~any under the assumption + * that ~any is the same as any. This is so so weird, but refinements needs + * some way to say "I may refine this, but I'm not sure." + * + * It does this by refining on a blocked type and deferring the decision + * until it is unblocked. + * + * Refinements also get negated, so we wind up with types like T & ~*blocked* + * + * We need to treat T & ~any as T in this case. + */ + if (auto nt = get(discriminant)) { if (get(follow(nt->ty))) return {target, {}}; } - else + + // If the target type is a table, then simplification already implements the logic to deal with refinements properly since the + // type of the discriminant is guaranteed to only ever be an (arbitrarily-nested) table of a single property type. + if (get(target)) { - if (get(follow(nt->ty))) - return {target, {}}; + SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, target, discriminant); + if (!result.blockedTypes.empty()) + return {nullptr, {result.blockedTypes.begin(), result.blockedTypes.end()}}; + + return {result.result, {}}; } + + // In the general case, we'll still use normalization though. + TypeId intersection = ctx->arena->addType(IntersectionType{{target, discriminant}}); + std::shared_ptr normIntersection = ctx->normalizer->normalize(intersection); + std::shared_ptr normType = ctx->normalizer->normalize(target); + + // if the intersection failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normIntersection || !normType) + return {nullptr, {}}; + + TypeId resultTy = ctx->normalizer->typeFromNormal(*normIntersection); + // include the error type if the target type is error-suppressing and the intersection we computed is not + if (normType->shouldSuppressErrors() && !normIntersection->shouldSuppressErrors()) + resultTy = ctx->arena->addType(UnionType{{resultTy, ctx->builtins->errorType}}); + + return {resultTy, {}}; } - - // If the target type is a table, then simplification already implements the logic to deal with refinements properly since the - // type of the discriminant is guaranteed to only ever be an (arbitrarily-nested) table of a single property type. - if (get(target)) - { - SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, target, discriminant); - if (!result.blockedTypes.empty()) - return {nullptr, {result.blockedTypes.begin(), result.blockedTypes.end()}}; - - return {result.result, {}}; - } - - // In the general case, we'll still use normalization though. - TypeId intersection = ctx->arena->addType(IntersectionType{{target, discriminant}}); - std::shared_ptr normIntersection = ctx->normalizer->normalize(intersection); - std::shared_ptr normType = ctx->normalizer->normalize(target); - - // if the intersection failed to normalize, we can't reduce, but know nothing about inhabitance. - if (!normIntersection || !normType) - return {nullptr, {}}; - - TypeId resultTy = ctx->normalizer->typeFromNormal(*normIntersection); - // include the error type if the target type is error-suppressing and the intersection we computed is not - if (normType->shouldSuppressErrors() && !normIntersection->shouldSuppressErrors()) - resultTy = ctx->arena->addType(UnionType{{resultTy, ctx->builtins->errorType}}); - - return {resultTy, {}}; }; // refine target with each discriminant type in sequence (reverse of insertion order) @@ -2003,15 +2033,15 @@ TypeFunctionReductionResult refineTypeFunction( auto [refined, blocked] = stepRefine(target, discriminant); if (blocked.empty() && refined == nullptr) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; if (!blocked.empty()) - return {std::nullopt, false, blocked, {}}; + return {std::nullopt, Reduction::MaybeOk, blocked, {}}; target = refined; discriminantTypes.pop_back(); } - return {target, false, {}, {}}; + return {target, Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult singletonTypeFunction( @@ -2031,14 +2061,14 @@ TypeFunctionReductionResult singletonTypeFunction( // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(type, ctx->solver)) - return {std::nullopt, false, {type}, {}}; + return {std::nullopt, Reduction::MaybeOk, {type}, {}}; // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, type); if (!maybeGeneralized) - return {std::nullopt, false, {type}, {}}; + return {std::nullopt, Reduction::MaybeOk, {type}, {}}; type = *maybeGeneralized; } @@ -2049,12 +2079,49 @@ TypeFunctionReductionResult singletonTypeFunction( // if we have a singleton type or `nil`, which is its own singleton type... if (get(followed) || isNil(followed)) - return {type, false, {}, {}}; + return {type, Reduction::MaybeOk, {}, {}}; // otherwise, we'll return the top type, `unknown`. - return {ctx->builtins->unknownType, false, {}, {}}; + return {ctx->builtins->unknownType, Reduction::MaybeOk, {}, {}}; } +struct CollectUnionTypeOptions : TypeOnceVisitor +{ + NotNull ctx; + DenseHashSet options{nullptr}; + DenseHashSet blockingTypes{nullptr}; + + explicit CollectUnionTypeOptions(NotNull ctx) + : TypeOnceVisitor(/* skipBoundTypes */ true) + , ctx(ctx) + { + } + + bool visit(TypeId ty) override + { + options.insert(ty); + if (isPending(ty, ctx->solver)) + blockingTypes.insert(ty); + return false; + } + + bool visit(TypePackId tp) override + { + return false; + } + + bool visit(TypeId ty, const TypeFunctionInstanceType& tfit) override + { + if (tfit.function->name != builtinTypeFunctions().unionFunc.name) + { + options.insert(ty); + blockingTypes.insert(ty); + return false; + } + return true; + } +}; + TypeFunctionReductionResult unionTypeFunction( TypeId instance, const std::vector& typeParams, @@ -2070,7 +2137,36 @@ TypeFunctionReductionResult unionTypeFunction( // if we only have one parameter, there's nothing to do. if (typeParams.size() == 1) - return {follow(typeParams[0]), false, {}, {}}; + return {follow(typeParams[0]), Reduction::MaybeOk, {}, {}}; + + if (FFlag::LuauClipNestedAndRecursiveUnion) + { + + CollectUnionTypeOptions collector{ctx}; + collector.traverse(instance); + + if (!collector.blockingTypes.empty()) + { + std::vector blockingTypes{collector.blockingTypes.begin(), collector.blockingTypes.end()}; + return {std::nullopt, Reduction::MaybeOk, std::move(blockingTypes), {}}; + } + + TypeId resultTy = ctx->builtins->neverType; + for (auto ty : collector.options) + { + SimplifyResult result = simplifyUnion(ctx->builtins, ctx->arena, resultTy, ty); + // This condition might fire if one of the arguments to this type + // function is a free type somewhere deep in a nested union or + // intersection type, even though we ran a pass above to capture + // some blocked types. + if (!result.blockedTypes.empty()) + return {std::nullopt, Reduction::MaybeOk, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + + resultTy = result.result; + } + + return {resultTy, Reduction::MaybeOk, {}, {}}; + } // we need to follow all of the type parameters. std::vector types; @@ -2098,12 +2194,12 @@ TypeFunctionReductionResult unionTypeFunction( // if we still have a `lastType` at the end, we're taking the short-circuit and reducing early. if (lastType) - return {lastType, false, {}, {}}; + return {lastType, Reduction::MaybeOk, {}, {}}; // check to see if the operand types are resolved enough, and wait to reduce if not for (auto ty : types) if (isPending(ty, ctx->solver)) - return {std::nullopt, false, {ty}, {}}; + return {std::nullopt, Reduction::MaybeOk, {ty}, {}}; // fold over the types with `simplifyUnion` TypeId resultTy = ctx->builtins->neverType; @@ -2111,12 +2207,12 @@ TypeFunctionReductionResult unionTypeFunction( { SimplifyResult result = simplifyUnion(ctx->builtins, ctx->arena, resultTy, ty); if (!result.blockedTypes.empty()) - return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + return {std::nullopt, Reduction::MaybeOk, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; resultTy = result.result; } - return {resultTy, false, {}, {}}; + return {resultTy, Reduction::MaybeOk, {}, {}}; } @@ -2135,7 +2231,7 @@ TypeFunctionReductionResult intersectTypeFunction( // if we only have one parameter, there's nothing to do. if (typeParams.size() == 1) - return {follow(typeParams[0]), false, {}, {}}; + return {follow(typeParams[0]), Reduction::MaybeOk, {}, {}}; // we need to follow all of the type parameters. std::vector types; @@ -2143,23 +2239,20 @@ TypeFunctionReductionResult intersectTypeFunction( for (auto ty : typeParams) types.emplace_back(follow(ty)); - if (FFlag::LuauRemoveNotAnyHack) - { - // if we only have two parameters and one is `*no-refine*`, we're all done. - if (types.size() == 2 && get(types[1])) - return {types[0], false, {}, {}}; - else if (types.size() == 2 && get(types[0])) - return {types[1], false, {}, {}}; - } + // if we only have two parameters and one is `*no-refine*`, we're all done. + if (types.size() == 2 && get(types[1])) + return {types[0], Reduction::MaybeOk, {}, {}}; + else if (types.size() == 2 && get(types[0])) + return {types[1], Reduction::MaybeOk, {}, {}}; // check to see if the operand types are resolved enough, and wait to reduce if not // if any of them are `never`, the intersection will always be `never`, so we can reduce directly. for (auto ty : types) { if (isPending(ty, ctx->solver)) - return {std::nullopt, false, {ty}, {}}; + return {std::nullopt, Reduction::MaybeOk, {ty}, {}}; else if (get(ty)) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; } // fold over the types with `simplifyIntersection` @@ -2167,12 +2260,12 @@ TypeFunctionReductionResult intersectTypeFunction( for (auto ty : types) { // skip any `*no-refine*` types. - if (FFlag::LuauRemoveNotAnyHack && get(ty)) + if (get(ty)) continue; SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, resultTy, ty); if (!result.blockedTypes.empty()) - return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + return {std::nullopt, Reduction::MaybeOk, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; resultTy = result.result; } @@ -2183,10 +2276,10 @@ TypeFunctionReductionResult intersectTypeFunction( if (get(resultTy)) { TypeId intersection = ctx->arena->addType(IntersectionType{typeParams}); - return {intersection, false, {}, {}}; + return {intersection, Reduction::MaybeOk, {}, {}}; } - return {resultTy, false, {}, {}}; + return {resultTy, Reduction::MaybeOk, {}, {}}; } // computes the keys of `ty` into `result` @@ -2286,17 +2379,17 @@ TypeFunctionReductionResult keyofFunctionImpl( // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normTy) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if we don't have either just tables or just classes, we've got nothing to get keys of (at least until a future version perhaps adds classes // as well) if (normTy->hasTables() == normTy->hasClasses()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; // this is sort of atrocious, but we're trying to reject any type that has not normalized to a table or a union of tables. if (normTy->hasTops() || normTy->hasBooleans() || normTy->hasErrors() || normTy->hasNils() || normTy->hasNumbers() || normTy->hasStrings() || normTy->hasThreads() || normTy->hasBuffers() || normTy->hasFunctions() || normTy->hasTyvars()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; // we're going to collect the keys in here Set keys{{}}; @@ -2315,7 +2408,7 @@ TypeFunctionReductionResult keyofFunctionImpl( // collect all the properties from the first class type if (!computeKeysOf(*classesIter, keys, seen, isRaw, ctx)) - return {ctx->builtins->stringType, false, {}, {}}; // if it failed, we have a top type! + return {ctx->builtins->stringType, Reduction::MaybeOk, {}, {}}; // if it failed, we have a top type! // we need to look at each class to remove any keys that are not common amongst them all while (++classesIter != classesIterEnd) @@ -2350,7 +2443,7 @@ TypeFunctionReductionResult keyofFunctionImpl( // collect all the properties from the first table type if (!computeKeysOf(*tablesIter, keys, seen, isRaw, ctx)) - return {ctx->builtins->stringType, false, {}, {}}; // if it failed, we have the top table type! + return {ctx->builtins->stringType, Reduction::MaybeOk, {}, {}}; // if it failed, we have the top table type! // we need to look at each tables to remove any keys that are not common amongst them all while (++tablesIter != normTy->tables.end()) @@ -2374,7 +2467,7 @@ TypeFunctionReductionResult keyofFunctionImpl( // if the set of keys is empty, `keyof` is `never` if (keys.empty()) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // everything is validated, we need only construct our big union of singletons now! std::vector singletons; @@ -2387,9 +2480,9 @@ TypeFunctionReductionResult keyofFunctionImpl( // We can take straight take it from the first entry // because it was added into the type arena already. if (singletons.size() == 1) - return {singletons.front(), false, {}, {}}; + return {singletons.front(), Reduction::MaybeOk, {}, {}}; - return {ctx->arena->addType(UnionType{singletons}), false, {}, {}}; + return {ctx->arena->addType(UnionType{singletons}), Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult keyofTypeFunction( @@ -2460,7 +2553,7 @@ bool searchPropsAndIndexer( // index into tbl's indexer if (tblIndexer) { - if (isSubtype(ty, tblIndexer->indexType, ctx->scope, ctx->builtins, *ctx->ice)) + if (isSubtype(ty, tblIndexer->indexType, ctx->scope, ctx->builtins, ctx->simplifier, *ctx->ice)) { TypeId idxResultTy = follow(tblIndexer->indexResultType); @@ -2535,32 +2628,32 @@ TypeFunctionReductionResult indexFunctionImpl( // if the indexee failed to normalize, we can't reduce, but know nothing about inhabitance. if (!indexeeNormTy) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if we don't have either just tables or just classes, we've got nothing to index into if (indexeeNormTy->hasTables() == indexeeNormTy->hasClasses()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; // we're trying to reject any type that has not normalized to a table/class or a union of tables/classes. if (indexeeNormTy->hasTops() || indexeeNormTy->hasBooleans() || indexeeNormTy->hasErrors() || indexeeNormTy->hasNils() || indexeeNormTy->hasNumbers() || indexeeNormTy->hasStrings() || indexeeNormTy->hasThreads() || indexeeNormTy->hasBuffers() || indexeeNormTy->hasFunctions() || indexeeNormTy->hasTyvars()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; TypeId indexerTy = follow(typeParams.at(1)); if (isPending(indexerTy, ctx->solver)) - return {std::nullopt, false, {indexerTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {indexerTy}, {}}; std::shared_ptr indexerNormTy = ctx->normalizer->normalize(indexerTy); // if the indexer failed to normalize, we can't reduce, but know nothing about inhabitance. if (!indexerNormTy) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // we're trying to reject any type that is not a string singleton or primitive (string, number, boolean, thread, nil, function, table, or buffer) if (indexerNormTy->hasTops() || indexerNormTy->hasErrors()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; // indexer can be a union —> break them down into a vector const std::vector* typesToFind = nullptr; @@ -2577,7 +2670,7 @@ TypeFunctionReductionResult indexFunctionImpl( LUAU_ASSERT(!indexeeNormTy->hasTables()); if (isRaw) // rawget should never reduce for classes (to match the behavior of the rawget global function) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; // at least one class is guaranteed to be in the iterator by .hasClasses() for (auto classesIter = indexeeNormTy->classes.ordering.begin(); classesIter != indexeeNormTy->classes.ordering.end(); ++classesIter) @@ -2586,7 +2679,7 @@ TypeFunctionReductionResult indexFunctionImpl( if (!classTy) { LUAU_ASSERT(false); // this should not be possible according to normalization's spec - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } for (TypeId ty : *typesToFind) @@ -2615,10 +2708,10 @@ TypeFunctionReductionResult indexFunctionImpl( ErrorVec dummy; std::optional mmType = findMetatableEntry(ctx->builtins, dummy, *classesIter, "__index", Location{}); if (!mmType) // if a metatable does not exist, there is no where else to look - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; if (!tblIndexInto(ty, *mmType, properties, ctx, isRaw)) // if indexer is not in the metatable, we fail to reduce - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } } } @@ -2632,7 +2725,7 @@ TypeFunctionReductionResult indexFunctionImpl( { for (TypeId ty : *typesToFind) if (!tblIndexInto(ty, *tablesIter, properties, ctx, isRaw)) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } } @@ -2649,9 +2742,9 @@ TypeFunctionReductionResult indexFunctionImpl( // If the type being reduced to is a single type, no need to union if (properties.size() == 1) - return {*properties.begin(), false, {}, {}}; + return {*properties.begin(), Reduction::MaybeOk, {}, {}}; - return {ctx->arena->addType(UnionType{std::vector(properties.begin(), properties.end())}), false, {}, {}}; + return {ctx->arena->addType(UnionType{std::vector(properties.begin(), properties.end())}), Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult indexTypeFunction( @@ -2686,6 +2779,215 @@ TypeFunctionReductionResult rawgetTypeFunction( return indexFunctionImpl(typeParams, packParams, ctx, /* isRaw */ true); } +TypeFunctionReductionResult setmetatableTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("setmetatable type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + const Location location = ctx->constraint ? ctx->constraint->location : Location{}; + + TypeId targetTy = follow(typeParams.at(0)); + TypeId metatableTy = follow(typeParams.at(1)); + + std::shared_ptr targetNorm = ctx->normalizer->normalize(targetTy); + + // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!targetNorm) + return {std::nullopt, Reduction::MaybeOk, {}, {}}; + + // cannot setmetatable on something without table parts. + if (!targetNorm->hasTables()) + return {std::nullopt, Reduction::Erroneous, {}, {}}; + + // we're trying to reject any type that has not normalized to a table or a union/intersection of tables. + if (targetNorm->hasTops() || targetNorm->hasBooleans() || targetNorm->hasErrors() || targetNorm->hasNils() || + targetNorm->hasNumbers() || targetNorm->hasStrings() || targetNorm->hasThreads() || targetNorm->hasBuffers() || + targetNorm->hasFunctions() || targetNorm->hasTyvars() || targetNorm->hasClasses()) + return {std::nullopt, Reduction::Erroneous, {}, {}}; + + // if the supposed metatable is not a table, we will fail to reduce. + if (!get(metatableTy) && !get(metatableTy)) + return {std::nullopt, Reduction::Erroneous, {}, {}}; + + if (targetNorm->tables.size() == 1) + { + TypeId table = *targetNorm->tables.begin(); + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional metatableMetamethod = findMetatableEntry(ctx->builtins, dummy, table, "__metatable", location); + + // if the `__metatable` metamethod is present, then the table is locked and we cannot `setmetatable` on it. + if (metatableMetamethod) + return {std::nullopt, Reduction::Erroneous, {}, {}}; + + TypeId withMetatable = ctx->arena->addType(MetatableType{table, metatableTy}); + + return {withMetatable, Reduction::MaybeOk, {}, {}}; + } + + TypeId result = ctx->builtins->neverType; + + for (auto componentTy : targetNorm->tables) + { + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional metatableMetamethod = findMetatableEntry(ctx->builtins, dummy, componentTy, "__metatable", location); + + // if the `__metatable` metamethod is present, then the table is locked and we cannot `setmetatable` on it. + if (metatableMetamethod) + return {std::nullopt, Reduction::Erroneous, {}, {}}; + + TypeId withMetatable = ctx->arena->addType(MetatableType{componentTy, metatableTy}); + SimplifyResult simplified = simplifyUnion(ctx->builtins, ctx->arena, result, withMetatable); + + if (!simplified.blockedTypes.empty()) + { + std::vector blockedTypes{}; + blockedTypes.reserve(simplified.blockedTypes.size()); + for (auto ty : simplified.blockedTypes) + blockedTypes.push_back(ty); + return {std::nullopt, Reduction::MaybeOk, blockedTypes, {}}; + } + + result = simplified.result; + } + + return {result, Reduction::MaybeOk, {}, {}}; +} + +static TypeFunctionReductionResult getmetatableHelper( + TypeId targetTy, + const Location& location, + NotNull ctx +) +{ + targetTy = follow(targetTy); + + std::optional metatable = std::nullopt; + bool erroneous = true; + + if (auto table = get(targetTy)) + erroneous = false; + + if (auto mt = get(targetTy)) + { + metatable = mt->metatable; + erroneous = false; + } + + if (auto clazz = get(targetTy)) + { + metatable = clazz->metatable; + erroneous = false; + } + + if (auto primitive = get(targetTy)) + { + metatable = primitive->metatable; + erroneous = false; + } + + if (auto singleton = get(targetTy)) + { + if (get(singleton)) + { + auto primitiveString = get(ctx->builtins->stringType); + metatable = primitiveString->metatable; + } + erroneous = false; + } + + if (erroneous) + return {std::nullopt, Reduction::Erroneous, {}, {}}; + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional metatableMetamethod = findMetatableEntry(ctx->builtins, dummy, targetTy, "__metatable", location); + + if (metatableMetamethod) + return {metatableMetamethod, Reduction::MaybeOk, {}, {}}; + + if (metatable) + return {metatable, Reduction::MaybeOk, {}, {}}; + + return {ctx->builtins->nilType, Reduction::MaybeOk, {}, {}}; +} + +TypeFunctionReductionResult getmetatableTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 1 || !packParams.empty()) + { + ctx->ice->ice("getmetatable type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + const Location location = ctx->constraint ? ctx->constraint->location : Location{}; + + TypeId targetTy = follow(typeParams.at(0)); + + if (isPending(targetTy, ctx->solver)) + return {std::nullopt, Reduction::MaybeOk, {targetTy}, {}}; + + if (auto ut = get(targetTy)) + { + std::vector options{}; + options.reserve(ut->options.size()); + + for (auto option : ut->options) + { + TypeFunctionReductionResult result = getmetatableHelper(option, location, ctx); + + if (!result.result) + return result; + + options.push_back(*result.result); + } + + return {ctx->arena->addType(UnionType{std::move(options)}), Reduction::MaybeOk, {}, {}}; + } + + if (auto it = get(targetTy)) + { + std::vector parts{}; + parts.reserve(it->parts.size()); + + for (auto part : it->parts) + { + TypeFunctionReductionResult result = getmetatableHelper(part, location, ctx); + + if (!result.result) + return result; + + parts.push_back(*result.result); + } + + return {ctx->arena->addType(IntersectionType{std::move(parts)}), Reduction::MaybeOk, {}, {}}; + } + + return getmetatableHelper(targetTy, location, ctx); +} + + BuiltinTypeFunctions::BuiltinTypeFunctions() : userFunc{"user", userDefinedTypeFunction} , notFunc{"not", notTypeFunction} @@ -2712,6 +3014,8 @@ BuiltinTypeFunctions::BuiltinTypeFunctions() , rawkeyofFunc{"rawkeyof", rawkeyofTypeFunction} , indexFunc{"index", indexTypeFunction} , rawgetFunc{"rawget", rawgetTypeFunction} + , setmetatableFunc{"setmetatable", setmetatableTypeFunction} + , getmetatableFunc{"getmetatable", getmetatableTypeFunction} { } @@ -2758,6 +3062,12 @@ void BuiltinTypeFunctions::addToScope(NotNull arena, NotNull s scope->exportedTypeBindings[indexFunc.name] = mkBinaryTypeFunction(&indexFunc); scope->exportedTypeBindings[rawgetFunc.name] = mkBinaryTypeFunction(&rawgetFunc); + + if (FFlag::LuauMetatableTypeFunctions) + { + scope->exportedTypeBindings[setmetatableFunc.name] = mkBinaryTypeFunction(&setmetatableFunc); + scope->exportedTypeBindings[getmetatableFunc.name] = mkUnaryTypeFunction(&getmetatableFunc); + } } const BuiltinTypeFunctions& builtinTypeFunctions() diff --git a/Analysis/src/TypeFunctionRuntime.cpp b/Analysis/src/TypeFunctionRuntime.cpp index 8a129462..c5c54477 100644 --- a/Analysis/src/TypeFunctionRuntime.cpp +++ b/Analysis/src/TypeFunctionRuntime.cpp @@ -14,9 +14,9 @@ #include LUAU_DYNAMIC_FASTINT(LuauTypeFunctionSerdeIterationLimit) -LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixRegister) -LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixNoReadWrite) -LUAU_FASTFLAGVARIABLE(LuauUserTypeFunThreadBuffer) +LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixInner) +LUAU_FASTFLAGVARIABLE(LuauUserTypeFunGenerics) +LUAU_FASTFLAGVARIABLE(LuauUserTypeFunCloneTail) namespace Luau { @@ -134,11 +134,9 @@ static std::string getTag(lua_State* L, TypeFunctionTypeId ty) return "number"; else if (auto s = get(ty); s && s->type == TypeFunctionPrimitiveType::Type::String) return "string"; - else if (auto s = get(ty); - FFlag::LuauUserTypeFunThreadBuffer && s && s->type == TypeFunctionPrimitiveType::Type::Thread) + else if (auto s = get(ty); s && s->type == TypeFunctionPrimitiveType::Type::Thread) return "thread"; - else if (auto s = get(ty); - FFlag::LuauUserTypeFunThreadBuffer && s && s->type == TypeFunctionPrimitiveType::Type::Buffer) + else if (auto s = get(ty); s && s->type == TypeFunctionPrimitiveType::Type::Buffer) return "buffer"; else if (get(ty)) return "unknown"; @@ -160,6 +158,8 @@ static std::string getTag(lua_State* L, TypeFunctionTypeId ty) return "function"; else if (get(ty)) return "class"; + else if (FFlag::LuauUserTypeFunGenerics && get(ty)) + return "generic"; LUAU_UNREACHABLE(); luaL_error(L, "VM encountered unexpected type variant when determining tag"); @@ -266,6 +266,20 @@ static int createSingleton(lua_State* L) luaL_error(L, "types.singleton: can't create singleton from `%s` type", lua_typename(L, 1)); } +// Luau: `types.generic(name: string, ispack: boolean?) -> type +// Create a generic type with the specified type. If an optinal boolean is set to true, result is a generic pack +static int createGeneric(lua_State* L) +{ + const char* name = luaL_checkstring(L, 1); + bool isPack = luaL_optboolean(L, 2, false); + + if (strlen(name) == 0) + luaL_error(L, "types.generic: generic name cannot be empty"); + + allocTypeUserData(L, TypeFunctionGenericType{/* isNamed */ true, isPack, name}); + return 1; +} + // Luau: `self:value() -> type` // Returns the value of a singleton static int getSingletonValue(lua_State* L) @@ -413,10 +427,21 @@ static int getNegatedValue(lua_State* L) luaL_error(L, "type.inner: expected 1 argument, but got %d", argumentCount); TypeFunctionTypeId self = getTypeUserData(L, 1); - if (auto tfnt = get(self); !tfnt) - allocTypeUserData(L, tfnt->type->type); + + if (FFlag::LuauUserTypeFunFixInner) + { + if (auto tfnt = get(self); tfnt) + allocTypeUserData(L, tfnt->type->type); + else + luaL_error(L, "type.inner: cannot call inner method on non-negation type: `%s` type", getTag(L, self).c_str()); + } else - luaL_error(L, "type.inner: cannot call inner method on non-negation type: `%s` type", getTag(L, self).c_str()); + { + if (auto tfnt = get(self); !tfnt) + allocTypeUserData(L, tfnt->type->type); + else + luaL_error(L, "type.inner: cannot call inner method on non-negation type: `%s` type", getTag(L, self).c_str()); + } return 1; } @@ -657,10 +682,8 @@ static int readTableProp(lua_State* L) auto prop = tftt->props.at(tfsst->value); if (prop.readTy) allocTypeUserData(L, (*prop.readTy)->type); - else if (FFlag::LuauUserTypeFunFixNoReadWrite) - lua_pushnil(L); else - luaL_error(L, "type.readproperty: property %s is write-only, and therefore does not have a read type.", tfsst->value.c_str()); + lua_pushnil(L); return 1; } @@ -697,10 +720,8 @@ static int writeTableProp(lua_State* L) auto prop = tftt->props.at(tfsst->value); if (prop.writeTy) allocTypeUserData(L, (*prop.writeTy)->type); - else if (FFlag::LuauUserTypeFunFixNoReadWrite) - lua_pushnil(L); else - luaL_error(L, "type.writeproperty: property %s is read-only, and therefore does not have a write type.", tfsst->value.c_str()); + lua_pushnil(L); return 1; } @@ -768,9 +789,159 @@ static int setTableMetatable(lua_State* L) return 0; } -// Luau: `types.newfunction(parameters: {head: {type}?, tail: type?}, returns: {head: {type}?, tail: type?}) -> type` -// Returns the type instance representing a function -static int createFunction(lua_State* L) +static std::tuple, std::vector> getGenerics(lua_State* L, int idx, const char* fname) +{ + std::vector types; + std::vector packs; + + if (lua_istable(L, idx)) + { + lua_pushvalue(L, idx); + + for (int i = 1; i <= lua_objlen(L, -1); i++) + { + lua_pushinteger(L, i); + lua_gettable(L, -2); + + if (lua_isnil(L, -1)) + { + lua_pop(L, 1); + break; + } + + TypeFunctionTypeId ty = getTypeUserData(L, -1); + + if (auto gty = get(ty)) + { + if (gty->isPack) + { + packs.push_back(allocateTypeFunctionTypePack(L, TypeFunctionGenericTypePack{gty->isNamed, gty->name})); + } + else + { + if (!packs.empty()) + luaL_error(L, "%s: generic type cannot follow a generic pack", fname); + + types.push_back(ty); + } + } + else + { + luaL_error(L, "%s: table member was not a generic type", fname); + } + + lua_pop(L, 1); + } + + lua_pop(L, 1); + } + else if (!lua_isnoneornil(L, idx)) + { + luaL_typeerrorL(L, idx, "table"); + } + + return {types, packs}; +} + +static TypeFunctionTypePackId getTypePack(lua_State* L, int headIdx, int tailIdx) +{ + TypeFunctionTypePackId result = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{}); + + std::vector head; + + if (lua_istable(L, headIdx)) + { + lua_pushvalue(L, headIdx); + + for (int i = 1; i <= lua_objlen(L, -1); i++) + { + lua_pushinteger(L, i); + lua_gettable(L, -2); + + if (lua_isnil(L, -1)) + { + lua_pop(L, 1); + break; + } + + head.push_back(getTypeUserData(L, -1)); + lua_pop(L, 1); + } + + lua_pop(L, 1); + } + + std::optional tail; + + if (auto type = optionalTypeUserData(L, tailIdx)) + { + if (auto gty = get(*type); gty && gty->isPack) + tail = allocateTypeFunctionTypePack(L, TypeFunctionGenericTypePack{gty->isNamed, gty->name}); + else + tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); + } + + if (head.size() == 0 && tail.has_value()) + result = *tail; + else + result = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); + + return result; +} + +static void pushTypePack(lua_State* L, TypeFunctionTypePackId tp) +{ + if (auto tftp = get(tp)) + { + lua_createtable(L, 0, 2); + + if (!tftp->head.empty()) + { + lua_createtable(L, int(tftp->head.size()), 0); + int pos = 1; + + for (auto el : tftp->head) + { + allocTypeUserData(L, el->type); + lua_rawseti(L, -2, pos++); + } + + lua_setfield(L, -2, "head"); + } + + if (tftp->tail.has_value()) + { + if (auto tfvp = get(*tftp->tail)) + allocTypeUserData(L, tfvp->type->type); + else if (auto tfgp = get(*tftp->tail)) + allocTypeUserData(L, TypeFunctionGenericType{tfgp->isNamed, true, tfgp->name}); + else + luaL_error(L, "unsupported type pack type"); + + lua_setfield(L, -2, "tail"); + } + } + else if (auto tfvp = get(tp)) + { + lua_createtable(L, 0, 1); + + allocTypeUserData(L, tfvp->type->type); + lua_setfield(L, -2, "tail"); + } + else if (auto tfgp = get(tp)) + { + lua_createtable(L, 0, 1); + + allocTypeUserData(L, TypeFunctionGenericType{tfgp->isNamed, true, tfgp->name}); + lua_setfield(L, -2, "tail"); + } + else + { + luaL_error(L, "unsupported type pack type"); + } +} + +static int createFunction_DEPRECATED(lua_State* L) { int argumentCount = lua_gettop(L); if (argumentCount > 2) @@ -858,7 +1029,62 @@ static int createFunction(lua_State* L) else if (!lua_isnoneornil(L, 2)) luaL_typeerrorL(L, 2, "table"); - allocTypeUserData(L, TypeFunctionFunctionType{argTypes, retTypes}); + allocTypeUserData(L, TypeFunctionFunctionType{{}, {}, argTypes, retTypes}); + + return 1; +} + +// Luau: `types.newfunction(parameters: {head: {type}?, tail: type?}, returns: {head: {type}?, tail: type?}, generics: {type}?) -> type` +// Returns the type instance representing a function +static int createFunction(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount > 3) + luaL_error(L, "types.newfunction: expected 0-3 arguments, but got %d", argumentCount); + + TypeFunctionTypePackId argTypes = nullptr; + + if (lua_istable(L, 1)) + { + lua_getfield(L, 1, "head"); + lua_getfield(L, 1, "tail"); + + argTypes = getTypePack(L, -2, -1); + + lua_pop(L, 2); + } + else if (!lua_isnoneornil(L, 1)) + { + luaL_typeerrorL(L, 1, "table"); + } + else + { + argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{}); + } + + TypeFunctionTypePackId retTypes = nullptr; + + if (lua_istable(L, 2)) + { + lua_getfield(L, 2, "head"); + lua_getfield(L, 2, "tail"); + + retTypes = getTypePack(L, -2, -1); + + lua_pop(L, 2); + } + else if (!lua_isnoneornil(L, 2)) + { + luaL_typeerrorL(L, 2, "table"); + } + else + { + retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{}); + } + + auto [genericTypes, genericPacks] = getGenerics(L, 3, "types.newfunction"); + + allocTypeUserData(L, TypeFunctionFunctionType{std::move(genericTypes), std::move(genericPacks), argTypes, retTypes}); return 1; } @@ -876,38 +1102,45 @@ static int setFunctionParameters(lua_State* L) if (!tfft) luaL_error(L, "type.setparameters: expected self to be a function, but got %s instead", getTag(L, self).c_str()); - std::vector head{}; - if (lua_istable(L, 2)) + if (FFlag::LuauUserTypeFunGenerics) { - int argSize = lua_objlen(L, 2); - for (int i = 1; i <= argSize; i++) - { - lua_pushinteger(L, i); - lua_gettable(L, 2); - - if (lua_isnil(L, -1)) - { - lua_pop(L, 1); - break; - } - - TypeFunctionTypeId ty = getTypeUserData(L, -1); - head.push_back(ty); - - lua_pop(L, 1); // Remove `ty` from stack - } + tfft->argTypes = getTypePack(L, 2, 3); } - else if (!lua_isnoneornil(L, 2)) - luaL_typeerrorL(L, 2, "table"); + else + { + std::vector head{}; + if (lua_istable(L, 2)) + { + int argSize = lua_objlen(L, 2); + for (int i = 1; i <= argSize; i++) + { + lua_pushinteger(L, i); + lua_gettable(L, 2); - std::optional tail; - if (auto type = optionalTypeUserData(L, 3)) - tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); + if (lua_isnil(L, -1)) + { + lua_pop(L, 1); + break; + } - if (head.size() == 0 && tail.has_value()) // Make argTypes a variadic type pack - tfft->argTypes = *tail; - else // Make argTypes a type pack - tfft->argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); + TypeFunctionTypeId ty = getTypeUserData(L, -1); + head.push_back(ty); + + lua_pop(L, 1); // Remove `ty` from stack + } + } + else if (!lua_isnoneornil(L, 2)) + luaL_typeerrorL(L, 2, "table"); + + std::optional tail; + if (auto type = optionalTypeUserData(L, 3)) + tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); + + if (head.size() == 0 && tail.has_value()) // Make argTypes a variadic type pack + tfft->argTypes = *tail; + else // Make argTypes a type pack + tfft->argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); + } return 0; } @@ -925,52 +1158,60 @@ static int getFunctionParameters(lua_State* L) if (!tfft) luaL_error(L, "type.parameters: expected self to be a function, but got %s instead", getTag(L, self).c_str()); - if (auto tftp = get(tfft->argTypes)) + if (FFlag::LuauUserTypeFunGenerics) { - int size = 0; - if (tftp->head.size() > 0) - size++; - if (tftp->tail.has_value()) - size++; - - lua_createtable(L, 0, size); - - int argSize = (int)tftp->head.size(); - if (argSize > 0) + pushTypePack(L, tfft->argTypes); + } + else + { + if (auto tftp = get(tfft->argTypes)) { - lua_createtable(L, argSize, 0); - for (int i = 0; i < argSize; i++) + int size = 0; + if (tftp->head.size() > 0) + size++; + if (tftp->tail.has_value()) + size++; + + lua_createtable(L, 0, size); + + int argSize = (int)tftp->head.size(); + if (argSize > 0) { - allocTypeUserData(L, tftp->head[i]->type); - lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed + lua_createtable(L, argSize, 0); + for (int i = 0; i < argSize; i++) + { + allocTypeUserData(L, tftp->head[i]->type); + lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed + } + lua_setfield(L, -2, "head"); } - lua_setfield(L, -2, "head"); + + if (tftp->tail.has_value()) + { + auto tfvp = get(*tftp->tail); + if (!tfvp) + LUAU_ASSERT(!"We should only be supporting variadic packs as TypeFunctionTypePack.tail at the moment"); + + allocTypeUserData(L, tfvp->type->type); + lua_setfield(L, -2, "tail"); + } + + return 1; } - if (tftp->tail.has_value()) + if (auto tfvp = get(tfft->argTypes)) { - auto tfvp = get(*tftp->tail); - if (!tfvp) - LUAU_ASSERT(!"We should only be supporting variadic packs as TypeFunctionTypePack.tail at the moment"); + lua_createtable(L, 0, 1); allocTypeUserData(L, tfvp->type->type); lua_setfield(L, -2, "tail"); + + return 1; } - return 1; + lua_createtable(L, 0, 0); } - if (auto tfvp = get(tfft->argTypes)) - { - lua_createtable(L, 0, 1); - - allocTypeUserData(L, tfvp->type->type); - lua_setfield(L, -2, "tail"); - - return 1; - } - - lua_createtable(L, 0, 0); return 1; } @@ -987,38 +1228,45 @@ static int setFunctionReturns(lua_State* L) if (!tfft) luaL_error(L, "type.setreturns: expected self to be a function, but got %s instead", getTag(L, self).c_str()); - std::vector head{}; - if (lua_istable(L, 2)) + if (FFlag::LuauUserTypeFunGenerics) { - int argSize = lua_objlen(L, 2); - for (int i = 1; i <= argSize; i++) - { - lua_pushinteger(L, i); - lua_gettable(L, 2); - - if (lua_isnil(L, -1)) - { - lua_pop(L, 1); - break; - } - - TypeFunctionTypeId ty = getTypeUserData(L, -1); - head.push_back(ty); - - lua_pop(L, 1); // Remove `ty` from stack - } + tfft->retTypes = getTypePack(L, 2, 3); } - else if (!lua_isnoneornil(L, 2)) - luaL_typeerrorL(L, 2, "table"); + else + { + std::vector head{}; + if (lua_istable(L, 2)) + { + int argSize = lua_objlen(L, 2); + for (int i = 1; i <= argSize; i++) + { + lua_pushinteger(L, i); + lua_gettable(L, 2); - std::optional tail; - if (auto type = optionalTypeUserData(L, 3)) - tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); + if (lua_isnil(L, -1)) + { + lua_pop(L, 1); + break; + } - if (head.size() == 0 && tail.has_value()) // Make retTypes a variadic type pack - tfft->retTypes = *tail; - else // Make retTypes a type pack - tfft->retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); + TypeFunctionTypeId ty = getTypeUserData(L, -1); + head.push_back(ty); + + lua_pop(L, 1); // Remove `ty` from stack + } + } + else if (!lua_isnoneornil(L, 2)) + luaL_typeerrorL(L, 2, "table"); + + std::optional tail; + if (auto type = optionalTypeUserData(L, 3)) + tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); + + if (head.size() == 0 && tail.has_value()) // Make retTypes a variadic type pack + tfft->retTypes = *tail; + else // Make retTypes a type pack + tfft->retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); + } return 0; } @@ -1036,52 +1284,109 @@ static int getFunctionReturns(lua_State* L) if (!tfft) luaL_error(L, "type.returns: expected self to be a function, but got %s instead", getTag(L, self).c_str()); - if (auto tftp = get(tfft->retTypes)) + if (FFlag::LuauUserTypeFunGenerics) { - int size = 0; - if (tftp->head.size() > 0) - size++; - if (tftp->tail.has_value()) - size++; - - lua_createtable(L, 0, size); - - int argSize = (int)tftp->head.size(); - if (argSize > 0) + pushTypePack(L, tfft->retTypes); + } + else + { + if (auto tftp = get(tfft->retTypes)) { - lua_createtable(L, argSize, 0); - for (int i = 0; i < argSize; i++) + int size = 0; + if (tftp->head.size() > 0) + size++; + if (tftp->tail.has_value()) + size++; + + lua_createtable(L, 0, size); + + int argSize = (int)tftp->head.size(); + if (argSize > 0) { - allocTypeUserData(L, tftp->head[i]->type); - lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed + lua_createtable(L, argSize, 0); + for (int i = 0; i < argSize; i++) + { + allocTypeUserData(L, tftp->head[i]->type); + lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed + } + lua_setfield(L, -2, "head"); } - lua_setfield(L, -2, "head"); + + if (tftp->tail.has_value()) + { + auto tfvp = get(*tftp->tail); + if (!tfvp) + LUAU_ASSERT(!"We should only be supporting variadic packs as TypeFunctionTypePack.tail at the moment"); + + allocTypeUserData(L, tfvp->type->type); + lua_setfield(L, -2, "tail"); + } + + return 1; } - if (tftp->tail.has_value()) + if (auto tfvp = get(tfft->retTypes)) { - auto tfvp = get(*tftp->tail); - if (!tfvp) - LUAU_ASSERT(!"We should only be supporting variadic packs as TypeFunctionTypePack.tail at the moment"); + lua_createtable(L, 0, 1); allocTypeUserData(L, tfvp->type->type); lua_setfield(L, -2, "tail"); + + return 1; } - return 1; + lua_createtable(L, 0, 0); } - if (auto tfvp = get(tfft->retTypes)) + return 1; +} + +// Luau: `self:setgenerics(generics: {type}?)` +static int setFunctionGenerics(lua_State* L) +{ + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = getMutable(self); + if (!tfft) + luaL_error(L, "type.setgenerics: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + int argumentCount = lua_gettop(L); + if (argumentCount > 3) + luaL_error(L, "type.setgenerics: expected 3 arguments, but got %d", argumentCount); + + auto [genericTypes, genericPacks] = getGenerics(L, 2, "types.setgenerics"); + + tfft->generics = std::move(genericTypes); + tfft->genericPacks = std::move(genericPacks); + + return 0; +} + +// Luau: `self:generics() -> {type}` +static int getFunctionGenerics(lua_State* L) +{ + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = get(self); + if (!tfft) + luaL_error(L, "type.generics: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + lua_createtable(L, int(tfft->generics.size()) + int(tfft->genericPacks.size()), 0); + + int pos = 1; + + for (const auto& el : tfft->generics) { - lua_createtable(L, 0, 1); - - allocTypeUserData(L, tfvp->type->type); - lua_setfield(L, -2, "tail"); - - return 1; + allocTypeUserData(L, el->type); + lua_rawseti(L, -2, pos++); + } + + for (const auto& el : tfft->genericPacks) + { + auto gty = get(el); + LUAU_ASSERT(gty); + allocTypeUserData(L, TypeFunctionGenericType{gty->isNamed, true, gty->name}); + lua_rawseti(L, -2, pos++); } - lua_createtable(L, 0, 0); return 1; } @@ -1107,6 +1412,36 @@ static int getClassParent(lua_State* L) return 1; } +// Luau: `self:name() -> string?` +// Returns the name of the generic or 'nil' if the generic is unnamed +static int getGenericName(lua_State* L) +{ + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfgt = get(self); + if (!tfgt) + luaL_error(L, "type.name: expected self to be a generic, but got %s instead", getTag(L, self).c_str()); + + if (tfgt->isNamed) + lua_pushstring(L, tfgt->name.c_str()); + else + lua_pushnil(L); + + return 1; +} + +// Luau: `self:ispack() -> boolean` +// Returns true if the generic is a pack +static int getGenericIsPack(lua_State* L) +{ + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfgt = get(self); + if (!tfgt) + luaL_error(L, "type.ispack: expected self to be a generic, but got %s instead", getTag(L, self).c_str()); + + lua_pushboolean(L, tfgt->isPack); + return 1; +} + // Luau: `self:properties() -> {[type]: { read: type?, write: type? }}` // Returns the properties of a table or class type static int getProps(lua_State* L) @@ -1376,7 +1711,7 @@ static int checkTag(lua_State* L) TypeFunctionTypeId deepClone(NotNull runtime, TypeFunctionTypeId ty); // Forward declaration -// Luau: `types.copy(arg: string) -> type` +// Luau: `types.copy(arg: type) -> type` // Returns a deep copy of the argument static int deepCopy(lua_State* L) { @@ -1408,8 +1743,6 @@ static int isEqualToType(lua_State* L) void registerTypesLibrary(lua_State* L) { - LUAU_ASSERT(FFlag::LuauUserTypeFunFixRegister); - luaL_Reg fields[] = { {"unknown", createUnknown}, {"never", createNever}, @@ -1417,8 +1750,8 @@ void registerTypesLibrary(lua_State* L) {"boolean", createBoolean}, {"number", createNumber}, {"string", createString}, - {FFlag::LuauUserTypeFunThreadBuffer ? "thread" : nullptr, FFlag::LuauUserTypeFunThreadBuffer ? createThread : nullptr}, - {FFlag::LuauUserTypeFunThreadBuffer ? "buffer" : nullptr, FFlag::LuauUserTypeFunThreadBuffer ? createBuffer : nullptr}, + {"thread", createThread}, + {"buffer", createBuffer}, {nullptr, nullptr} }; @@ -1428,8 +1761,9 @@ void registerTypesLibrary(lua_State* L) {"unionof", createUnion}, {"intersectionof", createIntersection}, {"newtable", createTable}, - {"newfunction", createFunction}, + {"newfunction", FFlag::LuauUserTypeFunGenerics ? createFunction : createFunction_DEPRECATED}, {"copy", deepCopy}, + {FFlag::LuauUserTypeFunGenerics ? "generic" : nullptr, FFlag::LuauUserTypeFunGenerics ? createGeneric : nullptr}, {nullptr, nullptr} }; @@ -1464,170 +1798,109 @@ static int typeUserdataIndex(lua_State* L) void registerTypeUserData(lua_State* L) { - if (FFlag::LuauUserTypeFunFixRegister) - { - luaL_Reg typeUserdataMethods[] = { - {"is", checkTag}, + luaL_Reg typeUserdataMethods[] = { + {"is", checkTag}, - // Negation type methods - {"inner", getNegatedValue}, + // Negation type methods + {"inner", getNegatedValue}, - // Singleton type methods - {"value", getSingletonValue}, + // Singleton type methods + {"value", getSingletonValue}, - // Table type methods - {"setproperty", setTableProp}, - {"setreadproperty", setReadTableProp}, - {"setwriteproperty", setWriteTableProp}, - {"readproperty", readTableProp}, - {"writeproperty", writeTableProp}, - {"properties", getProps}, - {"setindexer", setTableIndexer}, - {"setreadindexer", setTableReadIndexer}, - {"setwriteindexer", setTableWriteIndexer}, - {"indexer", getIndexer}, - {"readindexer", getReadIndexer}, - {"writeindexer", getWriteIndexer}, - {"setmetatable", setTableMetatable}, - {"metatable", getMetatable}, + // Table type methods + {"setproperty", setTableProp}, + {"setreadproperty", setReadTableProp}, + {"setwriteproperty", setWriteTableProp}, + {"readproperty", readTableProp}, + {"writeproperty", writeTableProp}, + {"properties", getProps}, + {"setindexer", setTableIndexer}, + {"setreadindexer", setTableReadIndexer}, + {"setwriteindexer", setTableWriteIndexer}, + {"indexer", getIndexer}, + {"readindexer", getReadIndexer}, + {"writeindexer", getWriteIndexer}, + {"setmetatable", setTableMetatable}, + {"metatable", getMetatable}, - // Function type methods - {"setparameters", setFunctionParameters}, - {"parameters", getFunctionParameters}, - {"setreturns", setFunctionReturns}, - {"returns", getFunctionReturns}, + // Function type methods + {"setparameters", setFunctionParameters}, + {"parameters", getFunctionParameters}, + {"setreturns", setFunctionReturns}, + {"returns", getFunctionReturns}, + {"setgenerics", setFunctionGenerics}, + {"generics", getFunctionGenerics}, - // Union and Intersection type methods - {"components", getComponents}, + // Union and Intersection type methods + {"components", getComponents}, - // Class type methods - {"parent", getClassParent}, + // Class type methods + {"parent", getClassParent}, - {nullptr, nullptr} - }; + // Function type methods (cont.) + {FFlag::LuauUserTypeFunGenerics ? "setgenerics" : nullptr, FFlag::LuauUserTypeFunGenerics ? setFunctionGenerics : nullptr}, + {FFlag::LuauUserTypeFunGenerics ? "generics" : nullptr, FFlag::LuauUserTypeFunGenerics ? getFunctionGenerics : nullptr}, - // Create and register metatable for type userdata - luaL_newmetatable(L, "type"); + // Generic type methods + {FFlag::LuauUserTypeFunGenerics ? "name" : nullptr, FFlag::LuauUserTypeFunGenerics ? getGenericName : nullptr}, + {FFlag::LuauUserTypeFunGenerics ? "ispack" : nullptr, FFlag::LuauUserTypeFunGenerics ? getGenericIsPack : nullptr}, - // Protect metatable from being changed - lua_pushstring(L, "The metatable is locked"); - lua_setfield(L, -2, "__metatable"); + {nullptr, nullptr} + }; - lua_pushcfunction(L, isEqualToType, "__eq"); - lua_setfield(L, -2, "__eq"); + // Create and register metatable for type userdata + luaL_newmetatable(L, "type"); - // Indexing will be a dynamic function because some type fields are dynamic - lua_newtable(L); - luaL_register(L, nullptr, typeUserdataMethods); - lua_setreadonly(L, -1, true); - lua_pushcclosure(L, typeUserdataIndex, "__index", 1); - lua_setfield(L, -2, "__index"); + // Protect metatable from being changed + lua_pushstring(L, "The metatable is locked"); + lua_setfield(L, -2, "__metatable"); - lua_setreadonly(L, -1, true); - lua_pop(L, 1); - } - else - { - // List of fields for type userdata - luaL_Reg typeUserdataFields[] = { - {"unknown", createUnknown}, - {"never", createNever}, - {"any", createAny}, - {"boolean", createBoolean}, - {"number", createNumber}, - {"string", createString}, - {nullptr, nullptr} - }; + lua_pushcfunction(L, isEqualToType, "__eq"); + lua_setfield(L, -2, "__eq"); - // List of methods for type userdata - luaL_Reg typeUserdataMethods[] = { - {"singleton", createSingleton}, - {"negationof", createNegation}, - {"unionof", createUnion}, - {"intersectionof", createIntersection}, - {"newtable", createTable}, - {"newfunction", createFunction}, - {"copy", deepCopy}, + // Indexing will be a dynamic function because some type fields are dynamic + lua_newtable(L); + luaL_register(L, nullptr, typeUserdataMethods); + lua_setreadonly(L, -1, true); + lua_pushcclosure(L, typeUserdataIndex, "__index", 1); + lua_setfield(L, -2, "__index"); - // Common methods - {"is", checkTag}, - - // Negation type methods - {"inner", getNegatedValue}, - - // Singleton type methods - {"value", getSingletonValue}, - - // Table type methods - {"setproperty", setTableProp}, - {"setreadproperty", setReadTableProp}, - {"setwriteproperty", setWriteTableProp}, - {"readproperty", readTableProp}, - {"writeproperty", writeTableProp}, - {"properties", getProps}, - {"setindexer", setTableIndexer}, - {"setreadindexer", setTableReadIndexer}, - {"setwriteindexer", setTableWriteIndexer}, - {"indexer", getIndexer}, - {"readindexer", getReadIndexer}, - {"writeindexer", getWriteIndexer}, - {"setmetatable", setTableMetatable}, - {"metatable", getMetatable}, - - // Function type methods - {"setparameters", setFunctionParameters}, - {"parameters", getFunctionParameters}, - {"setreturns", setFunctionReturns}, - {"returns", getFunctionReturns}, - - // Union and Intersection type methods - {"components", getComponents}, - - // Class type methods - {"parent", getClassParent}, - {"indexer", getIndexer}, - {nullptr, nullptr} - }; - - // Create and register metatable for type userdata - luaL_newmetatable(L, "type"); - - // Protect metatable from being fetched. - lua_pushstring(L, "The metatable is locked"); - lua_setfield(L, -2, "__metatable"); - - // Set type userdata metatable's __eq to type_equals() - lua_pushcfunction(L, isEqualToType, "__eq"); - lua_setfield(L, -2, "__eq"); - - // Set type userdata metatable's __index to itself - lua_pushvalue(L, -1); // Push a copy of type userdata metatable - lua_setfield(L, -2, "__index"); - - luaL_register(L, nullptr, typeUserdataMethods); - - // Set fields for type userdata - for (luaL_Reg* l = typeUserdataFields; l->name; l++) - { - l->func(L); - lua_setfield(L, -2, l->name); - } - - // Set types library as a global name "types" - lua_setglobal(L, "types"); - } + lua_setreadonly(L, -1, true); + lua_pop(L, 1); // Sets up a destructor for the type userdata. lua_setuserdatadtor(L, kTypeUserdataTag, deallocTypeUserData); } // Used to redirect all the removed global functions to say "this function is unsupported" -int unsupportedFunction(lua_State* L) +static int unsupportedFunction(lua_State* L) { luaL_errorL(L, "this function is not supported in type functions"); return 0; } +static int print(lua_State* L) +{ + std::string result; + + int n = lua_gettop(L); + for (int i = 1; i <= n; i++) + { + size_t l = 0; + const char* s = luaL_tolstring(L, i, &l); // convert to string using __tostring et al + if (i > 1) + result.append('\t', 1); + result.append(s, l); + lua_pop(L, 1); + } + + auto ctx = getTypeFunctionRuntime(L); + + ctx->messages.push_back(std::move(result)); + + return 0; +} + // Add libraries / globals for type function environment void setTypeFunctionEnvironment(lua_State* L) { @@ -1660,12 +1933,15 @@ void setTypeFunctionEnvironment(lua_State* L) lua_pop(L, 1); // Remove certain global functions from the base library - static const std::string unavailableGlobals[] = {"gcinfo", "getfenv", "newproxy", "setfenv", "pcall", "xpcall"}; + static const char* unavailableGlobals[] = {"gcinfo", "getfenv", "newproxy", "setfenv", "pcall", "xpcall"}; for (auto& name : unavailableGlobals) { - lua_pushcfunction(L, unsupportedFunction, "Removing global function from type function environment"); - lua_setglobal(L, name.c_str()); + lua_pushcfunction(L, unsupportedFunction, name); + lua_setglobal(L, name); } + + lua_pushcfunction(L, print, "print"); + lua_setglobal(L, "print"); } void resetTypeFunctionState(lua_State* L) @@ -1821,6 +2097,27 @@ bool areEqual(SeenSet& seen, const TypeFunctionFunctionType& lhs, const TypeFunc if (seenSetContains(seen, &lhs, &rhs)) return true; + if (FFlag::LuauUserTypeFunGenerics) + { + if (lhs.generics.size() != rhs.generics.size()) + return false; + + for (auto l = lhs.generics.begin(), r = rhs.generics.begin(); l != lhs.generics.end() && r != rhs.generics.end(); ++l, ++r) + { + if (!areEqual(seen, **l, **r)) + return false; + } + + if (lhs.genericPacks.size() != rhs.genericPacks.size()) + return false; + + for (auto l = lhs.genericPacks.begin(), r = rhs.genericPacks.begin(); l != lhs.genericPacks.end() && r != rhs.genericPacks.end(); ++l, ++r) + { + if (!areEqual(seen, **l, **r)) + return false; + } + } + if (bool(lhs.argTypes) != bool(rhs.argTypes)) return false; @@ -1921,6 +2218,16 @@ bool areEqual(SeenSet& seen, const TypeFunctionType& lhs, const TypeFunctionType return areEqual(seen, *lf, *rf); } + if (FFlag::LuauUserTypeFunGenerics) + { + { + const TypeFunctionGenericType* lg = get(&lhs); + const TypeFunctionGenericType* rg = get(&rhs); + if (lg && rg) + return lg->isNamed == rg->isNamed && lg->isPack == rg->isPack && lg->name == rg->name; + } + } + return false; } @@ -1967,6 +2274,16 @@ bool areEqual(SeenSet& seen, const TypeFunctionTypePackVar& lhs, const TypeFunct return areEqual(seen, *lv, *rv); } + if (FFlag::LuauUserTypeFunGenerics) + { + { + const TypeFunctionGenericTypePack* lg = get(&lhs); + const TypeFunctionGenericTypePack* rg = get(&rhs); + if (lg && rg) + return lg->isNamed == rg->isNamed && lg->name == rg->name; + } + } + return false; } @@ -2156,12 +2473,10 @@ private: target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String)); break; case TypeFunctionPrimitiveType::Thread: - if (FFlag::LuauUserTypeFunThreadBuffer) - target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Thread)); + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Thread)); break; case TypeFunctionPrimitiveType::Buffer: - if (FFlag::LuauUserTypeFunThreadBuffer) - target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Buffer)); + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Buffer)); break; default: break; @@ -2191,10 +2506,14 @@ private: else if (auto f = get(ty)) { TypeFunctionTypePackId emptyTypePack = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{}); - target = typeFunctionRuntime->typeArena.allocate(TypeFunctionFunctionType{emptyTypePack, emptyTypePack}); + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionFunctionType{{}, {}, emptyTypePack, emptyTypePack}); } else if (auto c = get(ty)) target = ty; // Don't copy a class since they are immutable + else if (auto g = get(ty); FFlag::LuauUserTypeFunGenerics && g) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionGenericType{g->isNamed, g->isPack, g->name}); + else + LUAU_ASSERT(!"Unknown type"); types[ty] = target; queue.emplace_back(ty, target); @@ -2212,6 +2531,10 @@ private: target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}}); else if (auto vPack = get(tp)) target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{}); + else if (auto gPack = get(tp); gPack && FFlag::LuauUserTypeFunGenerics) + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionGenericTypePack{gPack->isNamed, gPack->name}); + else + LUAU_ASSERT(!"Unknown type"); packs[tp] = target; queue.emplace_back(tp, target); @@ -2242,6 +2565,9 @@ private: cloneChildren(f1, f2); else if (auto [c1, c2] = std::tuple{getMutable(ty), getMutable(tfti)}; c1 && c2) cloneChildren(c1, c2); + else if (auto [g1, g2] = std::tuple{getMutable(ty), getMutable(tfti)}; + FFlag::LuauUserTypeFunGenerics && g1 && g2) + cloneChildren(g1, g2); else LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types } @@ -2253,6 +2579,9 @@ private: else if (auto [vPack1, vPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; vPack1 && vPack2) cloneChildren(vPack1, vPack2); + else if (auto [gPack1, gPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; + FFlag::LuauUserTypeFunGenerics && gPack1 && gPack2) + cloneChildren(gPack1, gPack2); else LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types } @@ -2333,6 +2662,17 @@ private: void cloneChildren(TypeFunctionFunctionType* f1, TypeFunctionFunctionType* f2) { + if (FFlag::LuauUserTypeFunGenerics) + { + f2->generics.reserve(f1->generics.size()); + for (auto ty : f1->generics) + f2->generics.push_back(shallowClone(ty)); + + f2->genericPacks.reserve(f1->genericPacks.size()); + for (auto tp : f1->genericPacks) + f2->genericPacks.push_back(shallowClone(tp)); + } + f2->argTypes = shallowClone(f1->argTypes); f2->retTypes = shallowClone(f1->retTypes); } @@ -2342,16 +2682,32 @@ private: // noop. } + void cloneChildren(TypeFunctionGenericType* g1, TypeFunctionGenericType* g2) + { + // noop. + } + void cloneChildren(TypeFunctionTypePack* t1, TypeFunctionTypePack* t2) { for (TypeFunctionTypeId& ty : t1->head) t2->head.push_back(shallowClone(ty)); + + if (FFlag::LuauUserTypeFunCloneTail) + { + if (t1->tail) + t2->tail = shallowClone(*t1->tail); + } } void cloneChildren(TypeFunctionVariadicTypePack* v1, TypeFunctionVariadicTypePack* v2) { v2->type = shallowClone(v1->type); } + + void cloneChildren(TypeFunctionGenericTypePack* g1, TypeFunctionGenericTypePack* g2) + { + // noop. + } }; TypeFunctionTypeId deepClone(NotNull runtime, TypeFunctionTypeId ty) diff --git a/Analysis/src/TypeFunctionRuntimeBuilder.cpp b/Analysis/src/TypeFunctionRuntimeBuilder.cpp index a102e5da..6b7fa419 100644 --- a/Analysis/src/TypeFunctionRuntimeBuilder.cpp +++ b/Analysis/src/TypeFunctionRuntimeBuilder.cpp @@ -20,8 +20,7 @@ // currently, controls serialization, deserialization, and `type.copy` LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFunctionSerdeIterationLimit, 100'000); -LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixMetatable) -LUAU_FASTFLAG(LuauUserTypeFunThreadBuffer) +LUAU_FASTFLAG(LuauUserTypeFunGenerics) namespace Luau { @@ -161,26 +160,10 @@ private: target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String)); break; case PrimitiveType::Thread: - if (FFlag::LuauUserTypeFunThreadBuffer) - { - target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Thread)); - } - else - { - std::string error = format("Argument of primitive type %s is not currently serializable by type functions", toString(ty).c_str()); - state->errors.push_back(error); - } + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Thread)); break; case PrimitiveType::Buffer: - if (FFlag::LuauUserTypeFunThreadBuffer) - { - target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Buffer)); - } - else - { - std::string error = format("Argument of primitive type %s is not currently serializable by type functions", toString(ty).c_str()); - state->errors.push_back(error); - } + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Buffer)); break; case PrimitiveType::Function: case PrimitiveType::Table: @@ -222,13 +205,22 @@ private: else if (auto f = get(ty)) { TypeFunctionTypePackId emptyTypePack = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{}); - target = typeFunctionRuntime->typeArena.allocate(TypeFunctionFunctionType{emptyTypePack, emptyTypePack}); + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionFunctionType{{}, {}, emptyTypePack, emptyTypePack}); } else if (auto c = get(ty)) { state->classesSerialized[c->name] = ty; target = typeFunctionRuntime->typeArena.allocate(TypeFunctionClassType{{}, std::nullopt, std::nullopt, std::nullopt, c->name}); } + else if (auto g = get(ty); FFlag::LuauUserTypeFunGenerics && g) + { + Name name = g->name; + + if (!g->explicitName) + name = format("g%d", g->index); + + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionGenericType{g->explicitName, false, name}); + } else { std::string error = format("Argument of type %s is not currently serializable by type functions", toString(ty).c_str()); @@ -253,6 +245,15 @@ private: target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}}); else if (auto vPack = get(tp)) target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{}); + else if (auto gPack = get(tp); FFlag::LuauUserTypeFunGenerics && gPack) + { + Name name = gPack->name; + + if (!gPack->explicitName) + name = format("g%d", gPack->index); + + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionGenericTypePack{gPack->explicitName, name}); + } else { std::string error = format("Argument of type pack %s is not currently serializable by type functions", toString(tp).c_str()); @@ -290,6 +291,9 @@ private: serializeChildren(f1, f2); else if (auto [c1, c2] = std::tuple{get(ty), getMutable(tfti)}; c1 && c2) serializeChildren(c1, c2); + else if (auto [g1, g2] = std::tuple{get(ty), getMutable(tfti)}; + FFlag::LuauUserTypeFunGenerics && g1 && g2) + serializeChildren(g1, g2); else { // Either this or ty and tfti do not represent the same type std::string error = format("Argument of type %s is not currently serializable by type functions", toString(ty).c_str()); @@ -303,6 +307,9 @@ private: serializeChildren(tPack1, tPack2); else if (auto [vPack1, vPack2] = std::tuple{get(tp), getMutable(tftp)}; vPack1 && vPack2) serializeChildren(vPack1, vPack2); + else if (auto [gPack1, gPack2] = std::tuple{get(tp), getMutable(tftp)}; + FFlag::LuauUserTypeFunGenerics && gPack1 && gPack2) + serializeChildren(gPack1, gPack2); else { // Either this or ty and tfti do not represent the same type std::string error = format("Argument of type pack %s is not currently serializable by type functions", toString(tp).c_str()); @@ -383,27 +390,26 @@ private: void serializeChildren(const MetatableType* m1, TypeFunctionTableType* m2) { - if (FFlag::LuauUserTypeFunFixMetatable) - { - // Serialize main part of the metatable immediately - if (auto tableTy = get(m1->table)) - serializeChildren(tableTy, m2); - } - else - { - auto tmpTable = get(shallowSerialize(m1->table)); - if (!tmpTable) - state->ctx->ice->ice("Serializing user defined type function arguments: metatable's table is not a TableType"); - - m2->props = tmpTable->props; - m2->indexer = tmpTable->indexer; - } + // Serialize main part of the metatable immediately + if (auto tableTy = get(m1->table)) + serializeChildren(tableTy, m2); m2->metatable = shallowSerialize(m1->metatable); } void serializeChildren(const FunctionType* f1, TypeFunctionFunctionType* f2) { + if (FFlag::LuauUserTypeFunGenerics) + { + f2->generics.reserve(f1->generics.size()); + for (auto ty : f1->generics) + f2->generics.push_back(shallowSerialize(ty)); + + f2->genericPacks.reserve(f1->genericPacks.size()); + for (auto tp : f1->genericPacks) + f2->genericPacks.push_back(shallowSerialize(tp)); + } + f2->argTypes = shallowSerialize(f1->argTypes); f2->retTypes = shallowSerialize(f1->retTypes); } @@ -433,6 +439,11 @@ private: c2->parent = shallowSerialize(*c1->parent); } + void serializeChildren(const GenericType* g1, TypeFunctionGenericType* g2) + { + // noop. + } + void serializeChildren(const TypePack* t1, TypeFunctionTypePack* t2) { for (const TypeId& ty : t1->head) @@ -446,6 +457,25 @@ private: { v2->type = shallowSerialize(v1->ty); } + + void serializeChildren(const GenericTypePack* v1, TypeFunctionGenericTypePack* v2) + { + // noop. + } +}; + +template +struct SerializedGeneric +{ + bool isNamed = false; + std::string name; + T type = nullptr; +}; + +struct SerializedFunctionScope +{ + size_t oldQueueSize = 0; + TypeFunctionFunctionType* function = nullptr; }; // Complete inverse of TypeFunctionSerializer @@ -466,6 +496,15 @@ class TypeFunctionDeserializer // second must be PrimitiveType; else there should be an error std::vector> queue; + // Generic types and packs currently in scope + // Generics are resolved by name even if runtime generic type pointers are different + // Multiple names mapping to the same generic can be in scope for nested generic functions + std::vector> genericTypes; + std::vector> genericPacks; + + // To track when generics go out of scope, we have a list of queue positions at which a specific function has introduced generics + std::vector functionScopes; + SeenTypes types; // Mapping of TypeFunctionTypeIds that have been shallow deserialized to TypeIds SeenTypePacks packs; // Mapping of TypeFunctionTypePackIds that have been shallow deserialized to TypePackIds @@ -477,7 +516,9 @@ public: , typeFunctionRuntime(state->ctx->typeFunctionRuntime) , queue({}) , types({}) - , packs({}){}; + , packs({}) + { + } TypeId deserialize(TypeFunctionTypeId ty) { @@ -531,6 +572,16 @@ private: queue.pop_back(); deserializeChildren(tfti, ty); + + if (FFlag::LuauUserTypeFunGenerics) + { + // If we have completed working on all children of a function, remove the generic parameters from scope + if (!functionScopes.empty() && queue.size() == functionScopes.back().oldQueueSize && state->errors.empty()) + { + closeFunctionScope(functionScopes.back().function); + functionScopes.pop_back(); + } + } } } @@ -563,6 +614,21 @@ private: } } + void closeFunctionScope(TypeFunctionFunctionType* f) + { + if (!f->generics.empty()) + { + LUAU_ASSERT(genericTypes.size() >= f->generics.size()); + genericTypes.erase(genericTypes.begin() + int(genericTypes.size() - f->generics.size()), genericTypes.end()); + } + + if (!f->genericPacks.empty()) + { + LUAU_ASSERT(genericPacks.size() >= f->genericPacks.size()); + genericPacks.erase(genericPacks.begin() + int(genericPacks.size() - f->genericPacks.size()), genericPacks.end()); + } + } + TypeId shallowDeserialize(TypeFunctionTypeId ty) { if (auto it = find(ty)) @@ -587,16 +653,10 @@ private: target = state->ctx->builtins->stringType; break; case TypeFunctionPrimitiveType::Type::Thread: - if (FFlag::LuauUserTypeFunThreadBuffer) - target = state->ctx->builtins->threadType; - else - state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + target = state->ctx->builtins->threadType; break; case TypeFunctionPrimitiveType::Type::Buffer: - if (FFlag::LuauUserTypeFunThreadBuffer) - target = state->ctx->builtins->bufferType; - else - state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + target = state->ctx->builtins->bufferType; break; default: state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); @@ -642,6 +702,33 @@ private: else state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious class type is being deserialized"); } + else if (auto g = get(ty); FFlag::LuauUserTypeFunGenerics && g) + { + if (g->isPack) + { + state->errors.push_back(format("Generic type pack '%s...' cannot be placed in a type position", g->name.c_str())); + return nullptr; + } + else + { + auto it = std::find_if( + genericTypes.rbegin(), + genericTypes.rend(), + [&](const SerializedGeneric& el) + { + return g->isNamed == el.isNamed && g->name == el.name; + } + ); + + if (it == genericTypes.rend()) + { + state->errors.push_back(format("Generic type '%s' is not in a scope of the active generic function", g->name.c_str())); + return nullptr; + } + + target = it->type; + } + } else state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); @@ -658,11 +745,36 @@ private: // Create a shallow deserialization TypePackId target = {}; if (auto tPack = get(tp)) + { target = state->ctx->arena->addTypePack(TypePack{}); + } else if (auto vPack = get(tp)) + { target = state->ctx->arena->addTypePack(VariadicTypePack{}); + } + else if (auto gPack = get(tp); FFlag::LuauUserTypeFunGenerics && gPack) + { + auto it = std::find_if( + genericPacks.rbegin(), + genericPacks.rend(), + [&](const SerializedGeneric& el) + { + return gPack->isNamed == el.isNamed && gPack->name == el.name; + } + ); + + if (it == genericPacks.rend()) + { + state->errors.push_back(format("Generic type pack '%s...' is not in a scope of the active generic function", gPack->name.c_str())); + return nullptr; + } + + target = it->type; + } else + { state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + } packs[tp] = target; queue.emplace_back(tp, target); @@ -697,6 +809,9 @@ private: deserializeChildren(f2, f1); else if (auto [c1, c2] = std::tuple{getMutable(ty), getMutable(tfti)}; c1 && c2) deserializeChildren(c2, c1); + else if (auto [g1, g2] = std::tuple{getMutable(ty), getMutable(tfti)}; + FFlag::LuauUserTypeFunGenerics && g1 && g2) + deserializeChildren(g2, g1); else state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); } @@ -708,6 +823,9 @@ private: else if (auto [vPack1, vPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; vPack1 && vPack2) deserializeChildren(vPack2, vPack1); + else if (auto [gPack1, gPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; + FFlag::LuauUserTypeFunGenerics && gPack1 && gPack2) + deserializeChildren(gPack2, gPack1); else state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); } @@ -791,6 +909,64 @@ private: void deserializeChildren(TypeFunctionFunctionType* f2, FunctionType* f1) { + if (FFlag::LuauUserTypeFunGenerics) + { + functionScopes.push_back({queue.size(), f2}); + + std::set> genericNames; + + // Introduce generic function parameters into scope + for (auto ty : f2->generics) + { + auto gty = get(ty); + LUAU_ASSERT(gty && !gty->isPack); + + std::pair nameKey = std::make_pair(gty->isNamed, gty->name); + + // Duplicates are not allowed + if (genericNames.find(nameKey) != genericNames.end()) + { + state->errors.push_back(format("Duplicate type parameter '%s'", gty->name.c_str())); + return; + } + + genericNames.insert(nameKey); + + TypeId mapping = state->ctx->arena->addTV(Type(gty->isNamed ? GenericType{state->ctx->scope.get(), gty->name} : GenericType{})); + genericTypes.push_back({gty->isNamed, gty->name, mapping}); + } + + for (auto tp : f2->genericPacks) + { + auto gtp = get(tp); + LUAU_ASSERT(gtp); + + std::pair nameKey = std::make_pair(gtp->isNamed, gtp->name); + + // Duplicates are not allowed + if (genericNames.find(nameKey) != genericNames.end()) + { + state->errors.push_back(format("Duplicate type parameter '%s'", gtp->name.c_str())); + return; + } + + genericNames.insert(nameKey); + + TypePackId mapping = + state->ctx->arena->addTypePack(TypePackVar(gtp->isNamed ? GenericTypePack{state->ctx->scope.get(), gtp->name} : GenericTypePack{}) + ); + genericPacks.push_back({gtp->isNamed, gtp->name, mapping}); + } + + f1->generics.reserve(f2->generics.size()); + for (auto ty : f2->generics) + f1->generics.push_back(shallowDeserialize(ty)); + + f1->genericPacks.reserve(f2->genericPacks.size()); + for (auto tp : f2->genericPacks) + f1->genericPacks.push_back(shallowDeserialize(tp)); + } + if (f2->argTypes) f1->argTypes = shallowDeserialize(f2->argTypes); @@ -803,6 +979,11 @@ private: // noop. } + void deserializeChildren(TypeFunctionGenericType* g2, GenericType* g1) + { + // noop. + } + void deserializeChildren(TypeFunctionTypePack* t2, TypePack* t1) { for (TypeFunctionTypeId& ty : t2->head) @@ -816,6 +997,11 @@ private: { v1->ty = shallowDeserialize(v2->type); } + + void deserializeChildren(TypeFunctionGenericTypePack* v2, GenericTypePack* v1) + { + // noop. + } }; TypeFunctionTypeId serialize(TypeId ty, TypeFunctionRuntimeBuilderState* state) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 911d4b5e..25d1f5c2 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -32,7 +32,9 @@ LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification) LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAGVARIABLE(LuauMetatableFollow) +LUAU_FASTFLAGVARIABLE(LuauOldSolverCreatesChildScopePointers) +LUAU_FASTFLAG(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -761,8 +763,12 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& state struct Demoter : Substitution { - Demoter(TypeArena* arena) + TypeArena* arena = nullptr; + NotNull builtins; + Demoter(TypeArena* arena, NotNull builtins) : Substitution(TxnLog::empty(), arena) + , arena(arena) + , builtins(builtins) { } @@ -788,7 +794,8 @@ struct Demoter : Substitution { auto ftv = get(ty); LUAU_ASSERT(ftv); - return addType(FreeType{demotedLevel(ftv->level)}); + return FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtins, demotedLevel(ftv->level)) + : addType(FreeType{demotedLevel(ftv->level)}); } TypePackId clean(TypePackId tp) override @@ -835,7 +842,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatReturn& retur } } - Demoter demoter{¤tModule->internalTypes}; + Demoter demoter{¤tModule->internalTypes, builtinTypes}; demoter.demote(expectedTypes); TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type; @@ -2799,10 +2806,10 @@ TypeId TypeChecker::checkRelationalOperation( reportError( expr.location, GenericError{ - format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str()) - } - ); - } + format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str()) + } + ); + } return booleanType; } @@ -2866,7 +2873,7 @@ TypeId TypeChecker::checkRelationalOperation( std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location, /* addErrors= */ true); if (metamethod) { - if (const FunctionType* ftv = get(FFlag::LuauMetatableFollow ? follow(*metamethod) : *metamethod)) + if (const FunctionType* ftv = get(follow(*metamethod))) { if (isEquality) { @@ -4408,7 +4415,7 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st } } - Demoter demoter{¤tModule->internalTypes}; + Demoter demoter{¤tModule->internalTypes, builtinTypes}; demoter.demote(expectedTypes); return expectedTypes; @@ -4506,10 +4513,10 @@ std::unique_ptr> TypeChecker::checkCallOverload( // When this function type has magic functions and did return something, we select that overload instead. // TODO: pass in a Unifier object to the magic functions? This will allow the magic functions to cooperate with overload resolution. - if (ftv->magicFunction) + if (ftv->magic) { // TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458 - if (std::optional> ret = ftv->magicFunction(*this, scope, expr, argListResult)) + if (std::optional> ret = ftv->magic->handleOldSolver(*this, scope, expr, argListResult)) return std::make_unique>(std::move(*ret)); } @@ -5205,6 +5212,13 @@ LUAU_NOINLINE void TypeChecker::reportErrorCodeTooComplex(const Location& locati ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel) { ScopePtr scope = std::make_shared(parent, subLevel); + if (FFlag::LuauOldSolverCreatesChildScopePointers) + { + scope->location = location; + scope->returnType = parent->returnType; + parent->children.emplace_back(scope.get()); + } + currentModule->scopes.push_back(std::make_pair(location, scope)); return scope; } @@ -5215,6 +5229,12 @@ ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& locatio ScopePtr scope = std::make_shared(parent); scope->level = parent->level; scope->varargPack = parent->varargPack; + if (FFlag::LuauOldSolverCreatesChildScopePointers) + { + scope->location = location; + scope->returnType = parent->returnType; + parent->children.emplace_back(scope.get()); + } currentModule->scopes.push_back(std::make_pair(location, scope)); return scope; @@ -5260,7 +5280,8 @@ TypeId TypeChecker::freshType(const ScopePtr& scope) TypeId TypeChecker::freshType(TypeLevel level) { - return currentModule->internalTypes.addType(Type(FreeType(level))); + return FFlag::LuauFreeTypesMustHaveBounds ? currentModule->internalTypes.freshType(builtinTypes, level) + : currentModule->internalTypes.addType(Type(FreeType(level))); } TypeId TypeChecker::singletonType(bool value) @@ -5705,6 +5726,12 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno } else if (const auto& un = annotation.as()) { + if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) + { + if (un->types.size == 1) + return resolveType(scope, *un->types.data[0]); + } + std::vector types; for (AstType* ann : un->types) types.push_back(resolveType(scope, *ann)); @@ -5713,12 +5740,22 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno } else if (const auto& un = annotation.as()) { + if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) + { + if (un->types.size == 1) + return resolveType(scope, *un->types.data[0]); + } + std::vector types; for (AstType* ann : un->types) types.push_back(resolveType(scope, *ann)); return addType(IntersectionType{types}); } + else if (const auto& g = annotation.as()) + { + return resolveType(scope, *g->type); + } else if (const auto& tsb = annotation.as()) { return singletonType(tsb->value); diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index ed7d5ebf..bb68503f 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -5,12 +5,15 @@ #include "Luau/Normalize.h" #include "Luau/Scope.h" #include "Luau/ToString.h" +#include "Luau/Type.h" #include "Luau/TypeInfer.h" #include LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete); +LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope); +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -318,9 +321,11 @@ TypePack extendTypePack( { FreeType ft{ftp->scope, builtinTypes->neverType, builtinTypes->unknownType}; t = arena.addType(ft); + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + trackInteriorFreeType(ftp->scope, t); } else - t = arena.freshType(ftp->scope); + t = FFlag::LuauFreeTypesMustHaveBounds ? arena.freshType(builtinTypes, ftp->scope) : arena.freshType_DEPRECATED(ftp->scope); } newPack.head.push_back(t); @@ -533,7 +538,7 @@ std::vector findBlockedArgTypesIn(AstExprCall* expr, NotNull toBlock; BlockedTypeInLiteralVisitor v{astTypes, NotNull{&toBlock}}; - for (auto arg: expr->args) + for (auto arg : expr->args) { if (isLiteral(arg) || arg->is()) { @@ -543,5 +548,21 @@ std::vector findBlockedArgTypesIn(AstExprCall* expr, NotNullparent.get()) + { + if (scope->interiorFreeTypes) + { + scope->interiorFreeTypes->push_back(ty); + return; + } + } + // There should at least be *one* generalization constraint per module + // where `interiorFreeTypes` is present, which would be the one made + // by ConstraintGenerator::visitModuleRoot. + LUAU_ASSERT(!"No scopes in parent chain had a present `interiorFreeTypes` member."); +} } // namespace Luau diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 5d71d5cb..926245ea 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,6 +22,7 @@ LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping) LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering) LUAU_FASTFLAGVARIABLE(LuauUnifierRecursionOnRestart) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -1648,7 +1649,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (FFlag::LuauSolverV2) return freshType(NotNull{types}, builtinTypes, scope); else - return types->freshType(scope, level); + return FFlag::LuauFreeTypesMustHaveBounds ? types->freshType(builtinTypes, scope, level) : types->freshType_DEPRECATED(scope, level); }; const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); diff --git a/Ast/include/Luau/Allocator.h b/Ast/include/Luau/Allocator.h index 7fd951ae..eaabcd8a 100644 --- a/Ast/include/Luau/Allocator.h +++ b/Ast/include/Luau/Allocator.h @@ -45,4 +45,4 @@ private: size_t offset; }; -} +} // namespace Luau diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 736f24a2..d4764656 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -1204,6 +1204,18 @@ public: const AstArray value; }; +class AstTypeGroup : public AstType +{ +public: + LUAU_RTTI(AstTypeGroup) + + explicit AstTypeGroup(const Location& location, AstType* type); + + void visit(AstVisitor* visitor) override; + + AstType* type; +}; + class AstTypePack : public AstNode { public: @@ -1470,6 +1482,10 @@ public: { return visit(static_cast(node)); } + virtual bool visit(class AstTypeGroup* node) + { + return visit(static_cast(node)); + } virtual bool visit(class AstTypeError* node) { return visit(static_cast(node)); diff --git a/Ast/include/Luau/Cst.h b/Ast/include/Luau/Cst.h new file mode 100644 index 00000000..bea3df90 --- /dev/null +++ b/Ast/include/Luau/Cst.h @@ -0,0 +1,334 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Location.h" + +#include + +namespace Luau +{ + +extern int gCstRttiIndex; + +template +struct CstRtti +{ + static const int value; +}; + +template +const int CstRtti::value = ++gCstRttiIndex; + +#define LUAU_CST_RTTI(Class) \ + static int CstClassIndex() \ + { \ + return CstRtti::value; \ + } + +class CstNode +{ +public: + explicit CstNode(int classIndex) + : classIndex(classIndex) + { + } + + template + bool is() const + { + return classIndex == T::CstClassIndex(); + } + template + T* as() + { + return classIndex == T::CstClassIndex() ? static_cast(this) : nullptr; + } + template + const T* as() const + { + return classIndex == T::CstClassIndex() ? static_cast(this) : nullptr; + } + + const int classIndex; +}; + +class CstExprConstantNumber : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprConstantNumber) + + explicit CstExprConstantNumber(const AstArray& value); + + AstArray value; +}; + +class CstExprConstantString : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprConstantNumber) + + enum QuoteStyle + { + QuotedSingle, + QuotedDouble, + QuotedRaw, + QuotedInterp, + }; + + CstExprConstantString(AstArray sourceString, QuoteStyle quoteStyle, unsigned int blockDepth); + + AstArray sourceString; + QuoteStyle quoteStyle; + unsigned int blockDepth; +}; + +class CstExprCall : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprCall) + + CstExprCall(std::optional openParens, std::optional closeParens, AstArray commaPositions); + + std::optional openParens; + std::optional closeParens; + AstArray commaPositions; +}; + +class CstExprIndexExpr : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprIndexExpr) + + CstExprIndexExpr(Position openBracketPosition, Position closeBracketPosition); + + Position openBracketPosition; + Position closeBracketPosition; +}; + +class CstExprTable : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprTable) + + enum Separator + { + Comma, + Semicolon, + }; + + struct Item + { + std::optional indexerOpenPosition; // '[', only if Kind == General + std::optional indexerClosePosition; // ']', only if Kind == General + std::optional equalsPosition; // only if Kind != List + std::optional separator; // may be missing for last Item + std::optional separatorPosition; + }; + + explicit CstExprTable(const AstArray& items); + + AstArray items; +}; + +// TODO: Shared between unary and binary, should we split? +class CstExprOp : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprOp) + + explicit CstExprOp(Position opPosition); + + Position opPosition; +}; + +class CstExprIfElse : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprIfElse) + + CstExprIfElse(Position thenPosition, Position elsePosition, bool isElseIf); + + Position thenPosition; + Position elsePosition; + bool isElseIf; +}; + +class CstExprInterpString : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprInterpString) + + explicit CstExprInterpString(AstArray> sourceStrings, AstArray stringPositions); + + AstArray> sourceStrings; + AstArray stringPositions; +}; + +class CstStatDo : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatDo) + + explicit CstStatDo(Position endPosition); + + Position endPosition; +}; + +class CstStatRepeat : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatRepeat) + + explicit CstStatRepeat(Position untilPosition); + + Position untilPosition; +}; + +class CstStatReturn : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatReturn) + + explicit CstStatReturn(AstArray commaPositions); + + AstArray commaPositions; +}; + +class CstStatLocal : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatLocal) + + CstStatLocal(AstArray varsCommaPositions, AstArray valuesCommaPositions); + + AstArray varsCommaPositions; + AstArray valuesCommaPositions; +}; + +class CstStatFor : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatFor) + + CstStatFor(Position equalsPosition, Position endCommaPosition, std::optional stepCommaPosition); + + Position equalsPosition; + Position endCommaPosition; + std::optional stepCommaPosition; +}; + +class CstStatForIn : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatForIn) + + CstStatForIn(AstArray varsCommaPositions, AstArray valuesCommaPositions); + + AstArray varsCommaPositions; + AstArray valuesCommaPositions; +}; + +class CstStatAssign : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatAssign) + + CstStatAssign(AstArray varsCommaPositions, Position equalsPosition, AstArray valuesCommaPositions); + + AstArray varsCommaPositions; + Position equalsPosition; + AstArray valuesCommaPositions; +}; + +class CstStatCompoundAssign : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatCompoundAssign) + + explicit CstStatCompoundAssign(Position opPosition); + + Position opPosition; +}; + +class CstStatLocalFunction : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatLocalFunction) + + explicit CstStatLocalFunction(Position functionKeywordPosition); + + Position functionKeywordPosition; +}; + +class CstTypeReference : public CstNode +{ +public: + LUAU_CST_RTTI(CstTypeReference) + + CstTypeReference( + std::optional prefixPointPosition, + Position openParametersPosition, + AstArray parametersCommaPositions, + Position closeParametersPosition + ); + + std::optional prefixPointPosition; + Position openParametersPosition; + AstArray parametersCommaPositions; + Position closeParametersPosition; +}; + +class CstTypeTable : public CstNode +{ +public: + LUAU_CST_RTTI(CstTypeTable) + + struct Item + { + enum struct Kind + { + Indexer, + Property, + StringProperty, + }; + + Kind kind; + Position indexerOpenPosition; // '[', only if Kind != Property + Position indexerClosePosition; // ']' only if Kind != Property + Position colonPosition; + std::optional separator; // may be missing for last Item + std::optional separatorPosition; + + CstExprConstantString* stringInfo = nullptr; // only if Kind == StringProperty + }; + + CstTypeTable(AstArray items, bool isArray); + + AstArray items; + bool isArray = false; +}; + +class CstTypeTypeof : public CstNode +{ +public: + LUAU_CST_RTTI(CstTypeTypeof) + + CstTypeTypeof(Position openPosition, Position closePosition); + + Position openPosition; + Position closePosition; +}; + +class CstTypeSingletonString : public CstNode +{ +public: + LUAU_CST_RTTI(CstTypeSingletonString) + + CstTypeSingletonString(AstArray sourceString, CstExprConstantString::QuoteStyle quoteStyle, unsigned int blockDepth); + + AstArray sourceString; + CstExprConstantString::QuoteStyle quoteStyle; + unsigned int blockDepth; +}; + +} // namespace Luau \ No newline at end of file diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index f91f6115..20814860 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -87,6 +87,12 @@ struct Lexeme Reserved_END }; + enum struct QuoteStyle + { + Single, + Double, + }; + Type type; Location location; @@ -111,6 +117,8 @@ public: Lexeme(const Location& location, Type type, const char* name); unsigned int getLength() const; + unsigned int getBlockDepth() const; + QuoteStyle getQuoteStyle() const; std::string toString() const; }; @@ -230,17 +238,6 @@ private: bool skipComments; bool readNames; - // This offset represents a column offset to be applied to any positions created by the lexer until the next new line. - // For example: - // local x = 4 - // local y = 5 - // If we start lexing from the position of `l` in `local x = 4`, the line number will be 1, and the column will be 4 - // However, because the lexer calculates line offsets by 'index in source buffer where there is a newline', the column - // count will start at 0. For this reason, for just the first line, we'll need to store the offset. - unsigned int lexResumeOffset; - - - enum class BraceType { InterpolatedString, diff --git a/Ast/include/Luau/Location.h b/Ast/include/Luau/Location.h index 3fc8921a..95d4c78a 100644 --- a/Ast/include/Luau/Location.h +++ b/Ast/include/Luau/Location.h @@ -14,12 +14,37 @@ struct Position { } - bool operator==(const Position& rhs) const; - bool operator!=(const Position& rhs) const; - bool operator<(const Position& rhs) const; - bool operator>(const Position& rhs) const; - bool operator<=(const Position& rhs) const; - bool operator>=(const Position& rhs) const; + bool operator==(const Position& rhs) const + { + return this->column == rhs.column && this->line == rhs.line; + } + + bool operator!=(const Position& rhs) const + { + return !(*this == rhs); + } + bool operator<(const Position& rhs) const + { + if (line == rhs.line) + return column < rhs.column; + else + return line < rhs.line; + } + bool operator>(const Position& rhs) const + { + if (line == rhs.line) + return column > rhs.column; + else + return line > rhs.line; + } + bool operator<=(const Position& rhs) const + { + return *this == rhs || *this < rhs; + } + bool operator>=(const Position& rhs) const + { + return *this == rhs || *this > rhs; + } void shift(const Position& start, const Position& oldEnd, const Position& newEnd); }; @@ -52,8 +77,14 @@ struct Location { } - bool operator==(const Location& rhs) const; - bool operator!=(const Location& rhs) const; + bool operator==(const Location& rhs) const + { + return this->begin == rhs.begin && this->end == rhs.end; + } + bool operator!=(const Location& rhs) const + { + return !(*this == rhs); + } bool encloses(const Location& l) const; bool overlaps(const Location& l) const; diff --git a/Ast/include/Luau/ParseOptions.h b/Ast/include/Luau/ParseOptions.h index ff727a0b..ac8e9348 100644 --- a/Ast/include/Luau/ParseOptions.h +++ b/Ast/include/Luau/ParseOptions.h @@ -29,6 +29,8 @@ struct ParseOptions bool allowDeclarationSyntax = false; bool captureComments = false; std::optional parseFragment = std::nullopt; + bool storeCstData = false; + bool noErrorLimit = false; }; } // namespace Luau diff --git a/Ast/include/Luau/ParseResult.h b/Ast/include/Luau/ParseResult.h index 9c0a9527..1ad9c5e9 100644 --- a/Ast/include/Luau/ParseResult.h +++ b/Ast/include/Luau/ParseResult.h @@ -10,6 +10,7 @@ namespace Luau { class AstStatBlock; +class CstNode; class ParseError : public std::exception { @@ -55,6 +56,8 @@ struct Comment Location location; }; +using CstNodeMap = DenseHashMap; + struct ParseResult { AstStatBlock* root; @@ -64,6 +67,8 @@ struct ParseResult std::vector errors; std::vector commentLocations; + + CstNodeMap cstNodeMap{nullptr}; }; static constexpr const char* kParseNameError = "%error-id%"; diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 475d19da..584782ee 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -8,6 +8,7 @@ #include "Luau/StringUtils.h" #include "Luau/DenseHash.h" #include "Luau/Common.h" +#include "Luau/Cst.h" #include #include @@ -116,7 +117,7 @@ private: AstStat* parseFor(); // funcname ::= Name {`.' Name} [`:' Name] - AstExpr* parseFunctionName(Location start, bool& hasself, AstName& debugname); + AstExpr* parseFunctionName(Location start_DEPRECATED, bool& hasself, AstName& debugname); // function funcname funcbody LUAU_FORCEINLINE AstStat* parseFunctionStat(const AstArray& attributes = {nullptr, 0}); @@ -173,14 +174,18 @@ private: ); // explist ::= {exp `,'} exp - void parseExprList(TempVector& result); + void parseExprList(TempVector& result, TempVector* commaPositions = nullptr); // binding ::= Name [`:` Type] Binding parseBinding(); // bindinglist ::= (binding | `...') {`,' bindinglist} // Returns the location of the vararg ..., or std::nullopt if the function is not vararg. - std::tuple parseBindingList(TempVector& result, bool allowDot3 = false); + std::tuple parseBindingList( + TempVector& result, + bool allowDot3 = false, + TempVector* commaPositions = nullptr + ); AstType* parseOptionalType(); @@ -201,7 +206,17 @@ private: std::optional parseOptionalReturnType(); std::pair parseReturnType(); - AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional accessLocation); + struct TableIndexerResult + { + AstTableIndexer* node; + Position indexerOpenPosition; + Position indexerClosePosition; + Position colonPosition; + }; + + TableIndexerResult parseTableIndexer(AstTableAccess access, std::optional accessLocation); + // Remove with FFlagLuauStoreCSTData + AstTableIndexer* parseTableIndexer_DEPRECATED(AstTableAccess access, std::optional accessLocation); AstTypeOrPack parseFunctionType(bool allowPack, const AstArray& attributes); AstType* parseFunctionTypeTail( @@ -259,6 +274,8 @@ private: // args ::= `(' [explist] `)' | tableconstructor | String AstExpr* parseFunctionArgs(AstExpr* func, bool self); + std::optional tableSeparator(); + // tableconstructor ::= `{' [fieldlist] `}' // fieldlist ::= field {fieldsep field} [fieldsep] // field ::= `[' exp `]' `=' exp | Name `=' exp | exp @@ -280,9 +297,13 @@ private: std::pair, AstArray> parseGenericTypeList(bool withDefaultValues); // `<' Type[, ...] `>' - AstArray parseTypeParams(); + AstArray parseTypeParams( + Position* openingPosition = nullptr, + TempVector* commaPositions = nullptr, + Position* closingPosition = nullptr + ); - std::optional> parseCharArray(); + std::optional> parseCharArray(AstArray* originalString = nullptr); AstExpr* parseString(); AstExpr* parseNumber(); @@ -292,6 +313,9 @@ private: void restoreLocals(unsigned int offset); + /// Returns string quote style and block depth + std::pair extractStringDetails(); + // check that parser is at lexeme/symbol, move to next lexeme/symbol on success, report failure and continue on failure bool expectAndConsume(char value, const char* context = nullptr); bool expectAndConsume(Lexeme::Type type, const char* context = nullptr); @@ -435,6 +459,7 @@ private: std::vector scratchAttr; std::vector scratchStat; std::vector> scratchString; + std::vector> scratchString2; std::vector scratchExpr; std::vector scratchExprAux; std::vector scratchName; @@ -442,15 +467,20 @@ private: std::vector scratchBinding; std::vector scratchLocal; std::vector scratchTableTypeProps; + std::vector scratchCstTableTypeProps; std::vector scratchType; std::vector scratchTypeOrPack; std::vector scratchDeclaredClassProps; std::vector scratchItem; + std::vector scratchCstItem; std::vector scratchArgName; std::vector scratchGenericTypes; std::vector scratchGenericTypePacks; std::vector> scratchOptArgName; + std::vector scratchPosition; std::string scratchData; + + CstNodeMap cstNodeMap; }; } // namespace Luau diff --git a/Ast/src/Allocator.cpp b/Ast/src/Allocator.cpp index f8a99db4..c7614d8c 100644 --- a/Ast/src/Allocator.cpp +++ b/Ast/src/Allocator.cpp @@ -63,4 +63,4 @@ void* Allocator::allocate(size_t size) return page->data; } -} +} // namespace Luau diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index 7e0efd43..5fa63149 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -1091,6 +1091,18 @@ void AstTypeSingletonString::visit(AstVisitor* visitor) visitor->visit(this); } +AstTypeGroup::AstTypeGroup(const Location& location, AstType* type) + : AstType(ClassIndex(), location) + , type(type) +{ +} + +void AstTypeGroup::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + type->visit(visitor); +} + AstTypeError::AstTypeError(const Location& location, const AstArray& types, bool isMissing, unsigned messageIndex) : AstType(ClassIndex(), location) , types(types) @@ -1151,10 +1163,7 @@ void AstTypePackGeneric::visit(AstVisitor* visitor) bool isLValue(const AstExpr* expr) { - return expr->is() - || expr->is() - || expr->is() - || expr->is(); + return expr->is() || expr->is() || expr->is() || expr->is(); } AstName getIdentifier(AstExpr* node) diff --git a/Ast/src/Cst.cpp b/Ast/src/Cst.cpp new file mode 100644 index 00000000..e2faf6e7 --- /dev/null +++ b/Ast/src/Cst.cpp @@ -0,0 +1,169 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Ast.h" +#include "Luau/Cst.h" +#include "Luau/Common.h" + +namespace Luau +{ + +int gCstRttiIndex = 0; + +CstExprConstantNumber::CstExprConstantNumber(const AstArray& value) + : CstNode(CstClassIndex()) + , value(value) +{ +} + +CstExprConstantString::CstExprConstantString(AstArray sourceString, QuoteStyle quoteStyle, unsigned int blockDepth) + : CstNode(CstClassIndex()) + , sourceString(sourceString) + , quoteStyle(quoteStyle) + , blockDepth(blockDepth) +{ + LUAU_ASSERT(blockDepth == 0 || quoteStyle == QuoteStyle::QuotedRaw); +} + +CstExprCall::CstExprCall(std::optional openParens, std::optional closeParens, AstArray commaPositions) + : CstNode(CstClassIndex()) + , openParens(openParens) + , closeParens(closeParens) + , commaPositions(commaPositions) +{ +} + +CstExprIndexExpr::CstExprIndexExpr(Position openBracketPosition, Position closeBracketPosition) + : CstNode(CstClassIndex()) + , openBracketPosition(openBracketPosition) + , closeBracketPosition(closeBracketPosition) +{ +} + +CstExprTable::CstExprTable(const AstArray& items) + : CstNode(CstClassIndex()) + , items(items) +{ +} + +CstExprOp::CstExprOp(Position opPosition) + : CstNode(CstClassIndex()) + , opPosition(opPosition) +{ +} + +CstExprIfElse::CstExprIfElse(Position thenPosition, Position elsePosition, bool isElseIf) + : CstNode(CstClassIndex()) + , thenPosition(thenPosition) + , elsePosition(elsePosition) + , isElseIf(isElseIf) +{ +} + +CstExprInterpString::CstExprInterpString(AstArray> sourceStrings, AstArray stringPositions) + : CstNode(CstClassIndex()) + , sourceStrings(sourceStrings) + , stringPositions(stringPositions) +{ +} + +CstStatDo::CstStatDo(Position endPosition) + : CstNode(CstClassIndex()) + , endPosition(endPosition) +{ +} + +CstStatRepeat::CstStatRepeat(Position untilPosition) + : CstNode(CstClassIndex()) + , untilPosition(untilPosition) +{ +} + +CstStatReturn::CstStatReturn(AstArray commaPositions) + : CstNode(CstClassIndex()) + , commaPositions(commaPositions) +{ +} + +CstStatLocal::CstStatLocal(AstArray varsCommaPositions, AstArray valuesCommaPositions) + : CstNode(CstClassIndex()) + , varsCommaPositions(varsCommaPositions) + , valuesCommaPositions(valuesCommaPositions) +{ +} + +CstStatFor::CstStatFor(Position equalsPosition, Position endCommaPosition, std::optional stepCommaPosition) + : CstNode(CstClassIndex()) + , equalsPosition(equalsPosition) + , endCommaPosition(endCommaPosition) + , stepCommaPosition(stepCommaPosition) +{ +} + +CstStatForIn::CstStatForIn(AstArray varsCommaPositions, AstArray valuesCommaPositions) + : CstNode(CstClassIndex()) + , varsCommaPositions(varsCommaPositions) + , valuesCommaPositions(valuesCommaPositions) +{ +} + +CstStatAssign::CstStatAssign( + AstArray varsCommaPositions, + Position equalsPosition, + AstArray valuesCommaPositions +) + : CstNode(CstClassIndex()) + , varsCommaPositions(varsCommaPositions) + , equalsPosition(equalsPosition) + , valuesCommaPositions(valuesCommaPositions) +{ +} + +CstStatCompoundAssign::CstStatCompoundAssign(Position opPosition) + : CstNode(CstClassIndex()) + , opPosition(opPosition) +{ +} + +CstStatLocalFunction::CstStatLocalFunction(Position functionKeywordPosition) + : CstNode(CstClassIndex()) + , functionKeywordPosition(functionKeywordPosition) +{ +} + +CstTypeReference::CstTypeReference( + std::optional prefixPointPosition, + Position openParametersPosition, + AstArray parametersCommaPositions, + Position closeParametersPosition +) + : CstNode(CstClassIndex()) + , prefixPointPosition(prefixPointPosition) + , openParametersPosition(openParametersPosition) + , parametersCommaPositions(parametersCommaPositions) + , closeParametersPosition(closeParametersPosition) +{ +} + +CstTypeTable::CstTypeTable(AstArray items, bool isArray) + : CstNode(CstClassIndex()) + , items(items) + , isArray(isArray) +{ +} + +CstTypeTypeof::CstTypeTypeof(Position openPosition, Position closePosition) + : CstNode(CstClassIndex()) + , openPosition(openPosition) + , closePosition(closePosition) +{ +} + +CstTypeSingletonString::CstTypeSingletonString(AstArray sourceString, CstExprConstantString::QuoteStyle quoteStyle, unsigned int blockDepth) + : CstNode(CstClassIndex()) + , sourceString(sourceString) + , quoteStyle(quoteStyle) + , blockDepth(blockDepth) +{ + LUAU_ASSERT(quoteStyle != CstExprConstantString::QuotedInterp); +} + +} // namespace Luau diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 03532e06..557295e0 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -8,7 +8,9 @@ #include -LUAU_FASTFLAGVARIABLE(LexerResumesFromPosition) +LUAU_FASTFLAGVARIABLE(LexerResumesFromPosition2) +LUAU_FASTFLAGVARIABLE(LexerFixInterpStringStart) + namespace Luau { @@ -304,20 +306,51 @@ static char unescape(char ch) } } +unsigned int Lexeme::getBlockDepth() const +{ + LUAU_ASSERT(type == Lexeme::RawString || type == Lexeme::BlockComment); + + // If we have a well-formed string, we are guaranteed to see 2 `]` characters after the end of the string contents + LUAU_ASSERT(*(data + length) == ']'); + unsigned int depth = 0; + do + { + depth++; + } while (*(data + length + depth) != ']'); + + return depth - 1; +} + +Lexeme::QuoteStyle Lexeme::getQuoteStyle() const +{ + LUAU_ASSERT(type == Lexeme::QuotedString); + + // If we have a well-formed string, we are guaranteed to see a closing delimiter after the string + LUAU_ASSERT(data); + + char quote = *(data + length); + if (quote == '\'') + return Lexeme::QuoteStyle::Single; + else if (quote == '"') + return Lexeme::QuoteStyle::Double; + + LUAU_ASSERT(!"Unknown quote style"); + return Lexeme::QuoteStyle::Double; // unreachable, but required due to compiler warning +} + Lexer::Lexer(const char* buffer, size_t bufferSize, AstNameTable& names, Position startPosition) : buffer(buffer) , bufferSize(bufferSize) , offset(0) - , line(FFlag::LexerResumesFromPosition ? startPosition.line : 0) - , lineOffset(0) + , line(FFlag::LexerResumesFromPosition2 ? startPosition.line : 0) + , lineOffset(FFlag::LexerResumesFromPosition2 ? 0u - startPosition.column : 0) , lexeme( - (FFlag::LexerResumesFromPosition ? Location(Position(startPosition.line, startPosition.column), 0) : Location(Position(0, 0), 0)), + (FFlag::LexerResumesFromPosition2 ? Location(Position(startPosition.line, startPosition.column), 0) : Location(Position(0, 0), 0)), Lexeme::Eof ) , names(names) , skipComments(false) , readNames(true) - , lexResumeOffset(FFlag::LexerResumesFromPosition ? startPosition.column : 0) { } @@ -372,7 +405,6 @@ Lexeme Lexer::lookahead() Location currentPrevLocation = prevLocation; size_t currentBraceStackSize = braceStack.size(); BraceType currentBraceType = braceStack.empty() ? BraceType::Normal : braceStack.back(); - unsigned int currentLexResumeOffset = lexResumeOffset; Lexeme result = next(); @@ -381,7 +413,6 @@ Lexeme Lexer::lookahead() lineOffset = currentLineOffset; lexeme = currentLexeme; prevLocation = currentPrevLocation; - lexResumeOffset = currentLexResumeOffset; if (braceStack.size() < currentBraceStackSize) braceStack.push_back(currentBraceType); @@ -412,9 +443,10 @@ char Lexer::peekch(unsigned int lookahead) const return (offset + lookahead < bufferSize) ? buffer[offset + lookahead] : 0; } +LUAU_FORCEINLINE Position Lexer::position() const { - return Position(line, offset - lineOffset + (FFlag::LexerResumesFromPosition ? lexResumeOffset : 0)); + return Position(line, offset - lineOffset); } LUAU_FORCEINLINE @@ -433,9 +465,6 @@ void Lexer::consumeAny() { line++; lineOffset = offset + 1; - // every new line, we reset - if (FFlag::LexerResumesFromPosition) - lexResumeOffset = 0; } offset++; @@ -764,7 +793,7 @@ Lexeme Lexer::readNext() return Lexeme(Location(start, 1), '}'); } - return readInterpolatedStringSection(position(), Lexeme::InterpStringMid, Lexeme::InterpStringEnd); + return readInterpolatedStringSection(FFlag::LexerFixInterpStringStart ? start : position(), Lexeme::InterpStringMid, Lexeme::InterpStringEnd); } case '=': diff --git a/Ast/src/Location.cpp b/Ast/src/Location.cpp index c2c66d9f..e96fafb7 100644 --- a/Ast/src/Location.cpp +++ b/Ast/src/Location.cpp @@ -4,42 +4,6 @@ namespace Luau { -bool Position::operator==(const Position& rhs) const -{ - return this->column == rhs.column && this->line == rhs.line; -} - -bool Position::operator!=(const Position& rhs) const -{ - return !(*this == rhs); -} - -bool Position::operator<(const Position& rhs) const -{ - if (line == rhs.line) - return column < rhs.column; - else - return line < rhs.line; -} - -bool Position::operator>(const Position& rhs) const -{ - if (line == rhs.line) - return column > rhs.column; - else - return line > rhs.line; -} - -bool Position::operator<=(const Position& rhs) const -{ - return *this == rhs || *this < rhs; -} - -bool Position::operator>=(const Position& rhs) const -{ - return *this == rhs || *this > rhs; -} - void Position::shift(const Position& start, const Position& oldEnd, const Position& newEnd) { if (*this >= start) @@ -54,16 +18,6 @@ void Position::shift(const Position& start, const Position& oldEnd, const Positi } } -bool Location::operator==(const Location& rhs) const -{ - return this->begin == rhs.begin && this->end == rhs.end; -} - -bool Location::operator!=(const Location& rhs) const -{ - return !(*this == rhs); -} - bool Location::encloses(const Location& l) const { return begin <= l.begin && end >= l.end; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 1a533fa5..3fa0ccc9 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -18,13 +18,16 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) // flag so that we don't break production games by reverting syntax changes. // See docs/SyntaxChanges.md for an explanation. LUAU_FASTFLAGVARIABLE(LuauSolverV2) -LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionsSyntax2) -LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunParseExport) LUAU_FASTFLAGVARIABLE(LuauAllowFragmentParsing) -LUAU_FASTFLAGVARIABLE(LuauPortableStringZeroCheck) LUAU_FASTFLAGVARIABLE(LuauAllowComplexTypesInGenericParams) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForTableTypes) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForClassNames) +LUAU_FASTFLAGVARIABLE(LuauFixFunctionNameStartPosition) +LUAU_FASTFLAGVARIABLE(LuauExtendStatEndPosWithSemicolon) +LUAU_FASTFLAGVARIABLE(LuauStoreCSTData) +LUAU_FASTFLAGVARIABLE(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) +LUAU_FASTFLAGVARIABLE(LuauAstTypeGroup) +LUAU_FASTFLAGVARIABLE(ParserNoErrorLimit) namespace Luau { @@ -167,14 +170,14 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n AstStatBlock* root = p.parseChunk(); size_t lines = p.lexer.current().location.end.line + (bufferSize > 0 && buffer[bufferSize - 1] != '\n'); - return ParseResult{root, lines, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)}; + return ParseResult{root, lines, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations), std::move(p.cstNodeMap)}; } catch (ParseError& err) { // when catching a fatal error, append it to the list of non-fatal errors and return p.parseErrors.push_back(err); - return ParseResult{nullptr, 0, {}, p.parseErrors}; + return ParseResult{nullptr, 0, {}, p.parseErrors, {}, std::move(p.cstNodeMap)}; } } @@ -185,6 +188,7 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc , recursionCounter(0) , endMismatchSuspect(Lexeme(Location(), Lexeme::Eof)) , localMap(AstName()) + , cstNodeMap(nullptr) { Function top; top.vararg = true; @@ -290,6 +294,10 @@ AstStatBlock* Parser::parseBlockNoScope() { nextLexeme(); stat->hasSemicolon = true; + if (FFlag::LuauExtendStatEndPosWithSemicolon) + { + stat->location.end = lexer.previousLocation().end; + } } body.push_back(stat); @@ -493,6 +501,7 @@ AstStat* Parser::parseRepeat() functionStack.back().loopDepth--; + Position untilPosition = lexer.current().location.begin; bool hasUntil = expectMatchEndAndConsume(Lexeme::ReservedUntil, matchRepeat); body->hasEnd = hasUntil; @@ -500,7 +509,17 @@ AstStat* Parser::parseRepeat() restoreLocals(localsBegin); - return allocator.alloc(Location(start, cond->location), cond, body, hasUntil); + if (FFlag::LuauStoreCSTData) + { + AstStatRepeat* node = allocator.alloc(Location(start, cond->location), cond, body, hasUntil); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(untilPosition); + return node; + } + else + { + return allocator.alloc(Location(start, cond->location), cond, body, hasUntil); + } } // do block end @@ -515,8 +534,12 @@ AstStat* Parser::parseDo() body->location.begin = start.begin; + Position endPosition = lexer.current().location.begin; body->hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstNodeMap[body] = allocator.alloc(endPosition); + return body; } @@ -556,18 +579,22 @@ AstStat* Parser::parseFor() if (lexer.current().type == '=') { + Position equalsPosition = lexer.current().location.begin; nextLexeme(); AstExpr* from = parseExpr(); + Position endCommaPosition = lexer.current().location.begin; expectAndConsume(',', "index range"); AstExpr* to = parseExpr(); + std::optional stepCommaPosition = std::nullopt; AstExpr* step = nullptr; if (lexer.current().type == ',') { + stepCommaPosition = lexer.current().location.begin; nextLexeme(); step = parseExpr(); @@ -593,25 +620,46 @@ AstStat* Parser::parseFor() bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); body->hasEnd = hasEnd; - return allocator.alloc(Location(start, end), var, from, to, step, body, hasDo, matchDo.location); + if (FFlag::LuauStoreCSTData) + { + AstStatFor* node = allocator.alloc(Location(start, end), var, from, to, step, body, hasDo, matchDo.location); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(equalsPosition, endCommaPosition, stepCommaPosition); + return node; + } + else + { + return allocator.alloc(Location(start, end), var, from, to, step, body, hasDo, matchDo.location); + } } else { TempVector names(scratchBinding); + TempVector varsCommaPosition(scratchPosition); names.push_back(varname); if (lexer.current().type == ',') { - nextLexeme(); + if (FFlag::LuauStoreCSTData && options.storeCstData) + { + varsCommaPosition.push_back(lexer.current().location.begin); + nextLexeme(); + parseBindingList(names, false, &varsCommaPosition); + } + else + { + nextLexeme(); - parseBindingList(names); + parseBindingList(names); + } } Location inLocation = lexer.current().location; bool hasIn = expectAndConsume(Lexeme::ReservedIn, "for loop"); TempVector values(scratchExpr); - parseExprList(values); + TempVector valuesCommaPositions(scratchPosition); + parseExprList(values, (FFlag::LuauStoreCSTData && options.storeCstData) ? &valuesCommaPositions : nullptr); Lexeme matchDo = lexer.current(); bool hasDo = expectAndConsume(Lexeme::ReservedDo, "for loop"); @@ -636,12 +684,23 @@ AstStat* Parser::parseFor() bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); body->hasEnd = hasEnd; - return allocator.alloc(Location(start, end), copy(vars), copy(values), body, hasIn, inLocation, hasDo, matchDo.location); + if (FFlag::LuauStoreCSTData) + { + AstStatForIn* node = + allocator.alloc(Location(start, end), copy(vars), copy(values), body, hasIn, inLocation, hasDo, matchDo.location); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(copy(varsCommaPosition), copy(valuesCommaPositions)); + return node; + } + else + { + return allocator.alloc(Location(start, end), copy(vars), copy(values), body, hasIn, inLocation, hasDo, matchDo.location); + } } } // funcname ::= Name {`.' Name} [`:' Name] -AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debugname) +AstExpr* Parser::parseFunctionName(Location start_DEPRECATED, bool& hasself, AstName& debugname) { if (lexer.current().type == Lexeme::Name) debugname = AstName(lexer.current().name); @@ -661,7 +720,9 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug // while we could concatenate the name chain, for now let's just write the short name debugname = name.name; - expr = allocator.alloc(Location(start, name.location), expr, name.name, name.location, opPosition, '.'); + expr = allocator.alloc( + Location(FFlag::LuauFixFunctionNameStartPosition ? expr->location : start_DEPRECATED, name.location), expr, name.name, name.location, opPosition, '.' + ); // note: while the parser isn't recursive here, we're generating recursive structures of unbounded depth incrementRecursionCounter("function name"); @@ -680,7 +741,9 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug // while we could concatenate the name chain, for now let's just write the short name debugname = name.name; - expr = allocator.alloc(Location(start, name.location), expr, name.name, name.location, opPosition, ':'); + expr = allocator.alloc( + Location(FFlag::LuauFixFunctionNameStartPosition ? expr->location : start_DEPRECATED, name.location), expr, name.name, name.location, opPosition, ':' + ); hasself = true; } @@ -828,6 +891,7 @@ AstStat* Parser::parseLocal(const AstArray& attributes) Lexeme matchFunction = lexer.current(); nextLexeme(); + Position functionKeywordPosition = matchFunction.location.begin; // matchFunction is only used for diagnostics; to make it suitable for detecting missed indentation between // `local function` and `end`, we patch the token to begin at the column where `local` starts if (matchFunction.location.begin.line == start.begin.line) @@ -843,7 +907,17 @@ AstStat* Parser::parseLocal(const AstArray& attributes) Location location{start.begin, body->location.end}; - return allocator.alloc(location, var, body); + if (FFlag::LuauStoreCSTData) + { + AstStatLocalFunction* node = allocator.alloc(location, var, body); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(functionKeywordPosition); + return node; + } + else + { + return allocator.alloc(location, var, body); + } } else { @@ -861,13 +935,18 @@ AstStat* Parser::parseLocal(const AstArray& attributes) matchRecoveryStopOnToken['=']++; TempVector names(scratchBinding); - parseBindingList(names); + TempVector varsCommaPositions(scratchPosition); + if (FFlag::LuauStoreCSTData && options.storeCstData) + parseBindingList(names, false, &varsCommaPositions); + else + parseBindingList(names); matchRecoveryStopOnToken['=']--; TempVector vars(scratchLocal); TempVector values(scratchExpr); + TempVector valuesCommaPositions(scratchPosition); std::optional equalsSignLocation; @@ -877,7 +956,7 @@ AstStat* Parser::parseLocal(const AstArray& attributes) nextLexeme(); - parseExprList(values); + parseExprList(values, (FFlag::LuauStoreCSTData && options.storeCstData) ? &valuesCommaPositions : nullptr); } for (size_t i = 0; i < names.size(); ++i) @@ -885,7 +964,17 @@ AstStat* Parser::parseLocal(const AstArray& attributes) Location end = values.empty() ? lexer.previousLocation() : values.back()->location; - return allocator.alloc(Location(start, end), copy(vars), copy(values), equalsSignLocation); + if (FFlag::LuauStoreCSTData) + { + AstStatLocal* node = allocator.alloc(Location(start, end), copy(vars), copy(values), equalsSignLocation); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(copy(varsCommaPositions), copy(valuesCommaPositions)); + return node; + } + else + { + return allocator.alloc(Location(start, end), copy(vars), copy(values), equalsSignLocation); + } } } @@ -897,24 +986,32 @@ AstStat* Parser::parseReturn() nextLexeme(); TempVector list(scratchExpr); + TempVector commaPositions(scratchPosition); if (!blockFollow(lexer.current()) && lexer.current().type != ';') - parseExprList(list); + parseExprList(list, (FFlag::LuauStoreCSTData && options.storeCstData) ? &commaPositions : nullptr); Location end = list.empty() ? start : list.back()->location; - return allocator.alloc(Location(start, end), copy(list)); + if (FFlag::LuauStoreCSTData) + { + AstStatReturn* node = allocator.alloc(Location(start, end), copy(list)); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(copy(commaPositions)); + return node; + } + else + { + return allocator.alloc(Location(start, end), copy(list)); + } } // type Name [`<' varlist `>'] `=' Type AstStat* Parser::parseTypeAlias(const Location& start, bool exported) { // parsing a type function - if (FFlag::LuauUserDefinedTypeFunctionsSyntax2) - { - if (lexer.current().type == Lexeme::ReservedFunction) - return parseTypeFunction(start, exported); - } + if (lexer.current().type == Lexeme::ReservedFunction) + return parseTypeFunction(start, exported); // parsing a type alias @@ -941,12 +1038,6 @@ AstStat* Parser::parseTypeFunction(const Location& start, bool exported) Lexeme matchFn = lexer.current(); nextLexeme(); - if (!FFlag::LuauUserDefinedTypeFunParseExport) - { - if (exported) - report(start, "Type function cannot be exported"); - } - // parse the name of the type function std::optional fnName = parseNameOpt("type function name"); if (!fnName) @@ -1134,8 +1225,7 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArraydata, 0, chars->size) != nullptr - : strnlen(chars->data, chars->size) < chars->size); + bool containsNull = chars && (memchr(chars->data, 0, chars->size) != nullptr); if (chars && !containsNull) { @@ -1154,14 +1244,21 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArraylocation, "Cannot have more than one class indexer"); } else { - indexer = parseTableIndexer(AstTableAccess::ReadWrite, std::nullopt); + if (FFlag::LuauStoreCSTData) + indexer = parseTableIndexer(AstTableAccess::ReadWrite, std::nullopt).node; + else + indexer = parseTableIndexer_DEPRECATED(AstTableAccess::ReadWrite, std::nullopt); } } else @@ -1223,10 +1320,13 @@ AstStat* Parser::parseAssignment(AstExpr* initial) initial = reportExprError(initial->location, copy({initial}), "Assigned expression must be a variable or a field"); TempVector vars(scratchExpr); + TempVector varsCommaPositions(scratchPosition); vars.push_back(initial); while (lexer.current().type == ',') { + if (FFlag::LuauStoreCSTData && options.storeCstData) + varsCommaPositions.push_back(lexer.current().location.begin); nextLexeme(); AstExpr* expr = parsePrimaryExpr(/* asStatement= */ true); @@ -1237,12 +1337,23 @@ AstStat* Parser::parseAssignment(AstExpr* initial) vars.push_back(expr); } + Position equalsPosition = lexer.current().location.begin; expectAndConsume('=', "assignment"); TempVector values(scratchExprAux); - parseExprList(values); + TempVector valuesCommaPositions(scratchPosition); + parseExprList(values, FFlag::LuauStoreCSTData && options.storeCstData ? &valuesCommaPositions : nullptr); - return allocator.alloc(Location(initial->location, values.back()->location), copy(vars), copy(values)); + if (FFlag::LuauStoreCSTData) + { + AstStatAssign* node = allocator.alloc(Location(initial->location, values.back()->location), copy(vars), copy(values)); + cstNodeMap[node] = allocator.alloc(copy(varsCommaPositions), equalsPosition, copy(valuesCommaPositions)); + return node; + } + else + { + return allocator.alloc(Location(initial->location, values.back()->location), copy(vars), copy(values)); + } } // var [`+=' | `-=' | `*=' | `/=' | `%=' | `^=' | `..='] exp @@ -1253,11 +1364,22 @@ AstStat* Parser::parseCompoundAssignment(AstExpr* initial, AstExprBinary::Op op) initial = reportExprError(initial->location, copy({initial}), "Assigned expression must be a variable or a field"); } + Position opPosition = lexer.current().location.begin; nextLexeme(); AstExpr* value = parseExpr(); - return allocator.alloc(Location(initial->location, value->location), op, initial, value); + if (FFlag::LuauStoreCSTData) + { + AstStatCompoundAssign* node = allocator.alloc(Location(initial->location, value->location), op, initial, value); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(opPosition); + return node; + } + else + { + return allocator.alloc(Location(initial->location, value->location), op, initial, value); + } } std::pair> Parser::prepareFunctionArguments(const Location& start, bool hasself, const TempVector& args) @@ -1373,12 +1495,14 @@ std::pair Parser::parseFunctionBody( } // explist ::= {exp `,'} exp -void Parser::parseExprList(TempVector& result) +void Parser::parseExprList(TempVector& result, TempVector* commaPositions) { result.push_back(parseExpr()); while (lexer.current().type == ',') { + if (FFlag::LuauStoreCSTData && commaPositions) + commaPositions->push_back(lexer.current().location.begin); nextLexeme(); if (lexer.current().type == ')') @@ -1405,7 +1529,7 @@ Parser::Binding Parser::parseBinding() } // bindinglist ::= (binding | `...') [`,' bindinglist] -std::tuple Parser::parseBindingList(TempVector& result, bool allowDot3) +std::tuple Parser::parseBindingList(TempVector& result, bool allowDot3, TempVector* commaPositions) { while (true) { @@ -1428,6 +1552,8 @@ std::tuple Parser::parseBindingList(TempVectorpush_back(lexer.current().location.begin); nextLexeme(); } @@ -1562,15 +1688,31 @@ std::pair Parser::parseReturnType() if (lexer.current().type != Lexeme::SkinnyArrow && resultNames.empty()) { // If it turns out that it's just '(A)', it's possible that there are unions/intersections to follow, so fold over it. - if (result.size() == 1) + if (FFlag::LuauAstTypeGroup) { - AstType* returnType = parseTypeSuffix(result[0], innerBegin); + if (result.size() == 1 && varargAnnotation == nullptr) + { + AstType* returnType = parseTypeSuffix(allocator.alloc(location, result[0]), begin.location); - // If parseType parses nothing, then returnType->location.end only points at the last non-type-pack - // type to successfully parse. We need the span of the whole annotation. - Position endPos = result.size() == 1 ? location.end : returnType->location.end; + // If parseType parses nothing, then returnType->location.end only points at the last non-type-pack + // type to successfully parse. We need the span of the whole annotation. + Position endPos = result.size() == 1 ? location.end : returnType->location.end; - return {Location{location.begin, endPos}, AstTypeList{copy(&returnType, 1), varargAnnotation}}; + return {Location{location.begin, endPos}, AstTypeList{copy(&returnType, 1), varargAnnotation}}; + } + } + else + { + if (result.size() == 1) + { + AstType* returnType = parseTypeSuffix(result[0], innerBegin); + + // If parseType parses nothing, then returnType->location.end only points at the last non-type-pack + // type to successfully parse. We need the span of the whole annotation. + Position endPos = result.size() == 1 ? location.end : returnType->location.end; + + return {Location{location.begin, endPos}, AstTypeList{copy(&returnType, 1), varargAnnotation}}; + } } return {location, AstTypeList{copy(result), varargAnnotation}}; @@ -1581,8 +1723,61 @@ std::pair Parser::parseReturnType() return {Location{location, tail->location}, AstTypeList{copy(&tail, 1), varargAnnotation}}; } +std::pair Parser::extractStringDetails() +{ + LUAU_ASSERT(FFlag::LuauStoreCSTData); + + CstExprConstantString::QuoteStyle style; + unsigned int blockDepth = 0; + + switch (lexer.current().type) + { + case Lexeme::QuotedString: + style = lexer.current().getQuoteStyle() == Lexeme::QuoteStyle::Double ? CstExprConstantString::QuotedDouble + : CstExprConstantString::QuotedSingle; + break; + case Lexeme::InterpStringSimple: + style = CstExprConstantString::QuotedInterp; + break; + case Lexeme::RawString: + { + style = CstExprConstantString::QuotedRaw; + blockDepth = lexer.current().getBlockDepth(); + break; + } + default: + LUAU_ASSERT(false && "Invalid string type"); + } + + return {style, blockDepth}; +} + // TableIndexer ::= `[' Type `]' `:' Type -AstTableIndexer* Parser::parseTableIndexer(AstTableAccess access, std::optional accessLocation) +Parser::TableIndexerResult Parser::parseTableIndexer(AstTableAccess access, std::optional accessLocation) +{ + const Lexeme begin = lexer.current(); + nextLexeme(); // [ + + AstType* index = parseType(); + + Position indexerClosePosition = lexer.current().location.begin; + expectMatchAndConsume(']', begin); + + Position colonPosition = lexer.current().location.begin; + expectAndConsume(':', "table field"); + + AstType* result = parseType(); + + return { + allocator.alloc(AstTableIndexer{index, result, Location(begin.location, result->location), access, accessLocation}), + begin.location.begin, + indexerClosePosition, + colonPosition, + }; +} + +// Remove with FFlagLuauStoreCSTData +AstTableIndexer* Parser::parseTableIndexer_DEPRECATED(AstTableAccess access, std::optional accessLocation) { const Lexeme begin = lexer.current(); nextLexeme(); // [ @@ -1607,6 +1802,7 @@ AstType* Parser::parseTableType(bool inDeclarationContext) incrementRecursionCounter("type annotation"); TempVector props(scratchTableTypeProps); + TempVector cstItems(scratchCstTableTypeProps); AstTableIndexer* indexer = nullptr; Location start = lexer.current().location; @@ -1614,6 +1810,8 @@ AstType* Parser::parseTableType(bool inDeclarationContext) MatchLexeme matchBrace = lexer.current(); expectAndConsume('{', "table type"); + bool isArray = false; + while (lexer.current().type != '}') { AstTableAccess access = AstTableAccess::ReadWrite; @@ -1639,19 +1837,39 @@ AstType* Parser::parseTableType(bool inDeclarationContext) { const Lexeme begin = lexer.current(); nextLexeme(); // [ - std::optional> chars = parseCharArray(); + CstExprConstantString::QuoteStyle style; + unsigned int blockDepth = 0; + if (FFlag::LuauStoreCSTData && options.storeCstData) + std::tie(style, blockDepth) = extractStringDetails(); + + AstArray sourceString; + std::optional> chars = parseCharArray(options.storeCstData ? &sourceString : nullptr); + + Position indexerClosePosition = lexer.current().location.begin; expectMatchAndConsume(']', begin); + Position colonPosition = lexer.current().location.begin; expectAndConsume(':', "table field"); AstType* type = parseType(); // since AstName contains a char*, it can't contain null - bool containsNull = chars && (FFlag::LuauPortableStringZeroCheck ? memchr(chars->data, 0, chars->size) != nullptr - : strnlen(chars->data, chars->size) < chars->size); + bool containsNull = chars && (memchr(chars->data, 0, chars->size) != nullptr); if (chars && !containsNull) + { props.push_back(AstTableProp{AstName(chars->data), begin.location, type, access, accessLocation}); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstItems.push_back(CstTypeTable::Item{ + CstTypeTable::Item::Kind::StringProperty, + begin.location.begin, + indexerClosePosition, + colonPosition, + tableSeparator(), + lexer.current().location.begin, + allocator.alloc(sourceString, style, blockDepth) + }); + } else report(begin.location, "String literal contains malformed escape sequence or \\0"); } @@ -1661,14 +1879,35 @@ AstType* Parser::parseTableType(bool inDeclarationContext) { // maybe we don't need to parse the entire badIndexer... // however, we either have { or [ to lint, not the entire table type or the bad indexer. - AstTableIndexer* badIndexer = parseTableIndexer(access, accessLocation); + AstTableIndexer* badIndexer; + if (FFlag::LuauStoreCSTData) + badIndexer = parseTableIndexer(access, accessLocation).node; + else + badIndexer = parseTableIndexer_DEPRECATED(access, accessLocation); // we lose all additional indexer expressions from the AST after error recovery here report(badIndexer->location, "Cannot have more than one table indexer"); } else { - indexer = parseTableIndexer(access, accessLocation); + if (FFlag::LuauStoreCSTData) + { + auto tableIndexerResult = parseTableIndexer(access, accessLocation); + indexer = tableIndexerResult.node; + if (options.storeCstData) + cstItems.push_back(CstTypeTable::Item{ + CstTypeTable::Item::Kind::Indexer, + tableIndexerResult.indexerOpenPosition, + tableIndexerResult.indexerClosePosition, + tableIndexerResult.colonPosition, + tableSeparator(), + lexer.current().location.begin, + }); + } + else + { + indexer = parseTableIndexer_DEPRECATED(access, accessLocation); + } } } else if (props.empty() && !indexer && !(lexer.current().type == Lexeme::Name && lexer.lookahead().type == ':')) @@ -1676,6 +1915,7 @@ AstType* Parser::parseTableType(bool inDeclarationContext) AstType* type = parseType(); // array-like table type: {T} desugars into {[number]: T} + isArray = true; AstType* index = allocator.alloc(type->location, std::nullopt, nameNumber, std::nullopt, type->location); indexer = allocator.alloc(AstTableIndexer{index, type, type->location, access, accessLocation}); @@ -1688,11 +1928,21 @@ AstType* Parser::parseTableType(bool inDeclarationContext) if (!name) break; + Position colonPosition = lexer.current().location.begin; expectAndConsume(':', "table field"); AstType* type = parseType(inDeclarationContext); props.push_back(AstTableProp{name->name, name->location, type, access, accessLocation}); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstItems.push_back(CstTypeTable::Item{ + CstTypeTable::Item::Kind::Property, + Position{0, 0}, + Position{0, 0}, + colonPosition, + tableSeparator(), + lexer.current().location.begin + }); } if (lexer.current().type == ',' || lexer.current().type == ';') @@ -1711,7 +1961,17 @@ AstType* Parser::parseTableType(bool inDeclarationContext) if (!expectMatchAndConsume('}', matchBrace, /* searchForMissing = */ FFlag::LuauErrorRecoveryForTableTypes)) end = lexer.previousLocation(); - return allocator.alloc(Location(start, end), copy(props), indexer); + if (FFlag::LuauStoreCSTData) + { + AstTypeTable* node = allocator.alloc(Location(start, end), copy(props), indexer); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(copy(cstItems), isArray); + return node; + } + else + { + return allocator.alloc(Location(start, end), copy(props), indexer); + } } // ReturnType ::= Type | `(' TypeList `)' @@ -1756,7 +2016,12 @@ AstTypeOrPack Parser::parseFunctionType(bool allowPack, const AstArray if (allowPack) return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, nullptr})}; else - return {params[0], {}}; + { + if (FFlag::LuauAstTypeGroup) + return {allocator.alloc(Location(parameterStart.location, params[0]->location), params[0]), {}}; + else + return {params[0], {}}; + } } if (!forceFunctionType && !returnTypeIntroducer && allowPack) @@ -1878,8 +2143,16 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) ParseError::raise(parts.back()->location, "Exceeded allowed type length; simplify your type annotation to make the code compile"); } - if (parts.size() == 1) - return parts[0]; + if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) + { + if (parts.size() == 1 && !isUnion && !isIntersection) + return parts[0]; + } + else + { + if (parts.size() == 1) + return parts[0]; + } if (isUnion && isIntersection) { @@ -1983,13 +2256,35 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) } else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) { - if (std::optional> value = parseCharArray()) + if (FFlag::LuauStoreCSTData) { - AstArray svalue = *value; - return {allocator.alloc(start, svalue)}; + CstExprConstantString::QuoteStyle style; + unsigned int blockDepth = 0; + if (options.storeCstData) + std::tie(style, blockDepth) = extractStringDetails(); + + AstArray originalString; + if (std::optional> value = parseCharArray(options.storeCstData ? &originalString : nullptr)) + { + AstArray svalue = *value; + auto node = allocator.alloc(start, svalue); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(originalString, style, blockDepth); + return {node}; + } + else + return {reportTypeError(start, {}, "String literal contains malformed escape sequence")}; } else - return {reportTypeError(start, {}, "String literal contains malformed escape sequence")}; + { + if (std::optional> value = parseCharArray()) + { + AstArray svalue = *value; + return {allocator.alloc(start, svalue)}; + } + else + return {reportTypeError(start, {}, "String literal contains malformed escape sequence")}; + } } else if (lexer.current().type == Lexeme::InterpStringBegin || lexer.current().type == Lexeme::InterpStringSimple) { @@ -2005,17 +2300,30 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) else if (lexer.current().type == Lexeme::Name) { std::optional prefix; + std::optional prefixPointPosition; std::optional prefixLocation; Name name = parseName("type name"); if (lexer.current().type == '.') { - Position pointPosition = lexer.current().location.begin; - nextLexeme(); + if (FFlag::LuauStoreCSTData) + { + prefixPointPosition = lexer.current().location.begin; + nextLexeme(); - prefix = name.name; - prefixLocation = name.location; - name = parseIndexName("field name", pointPosition); + prefix = name.name; + prefixLocation = name.location; + name = parseIndexName("field name", *prefixPointPosition); + } + else + { + Position pointPosition = lexer.current().location.begin; + nextLexeme(); + + prefix = name.name; + prefixLocation = name.location; + name = parseIndexName("field name", pointPosition); + } } else if (lexer.current().type == Lexeme::Dot3) { @@ -2033,23 +2341,53 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) expectMatchAndConsume(')', typeofBegin); - return {allocator.alloc(Location(start, end), expr), {}}; + if (FFlag::LuauStoreCSTData) + { + AstTypeTypeof* node = allocator.alloc(Location(start, end), expr); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(typeofBegin.location.begin, end.begin); + return {node, {}}; + } + else + { + return {allocator.alloc(Location(start, end), expr), {}}; + } } bool hasParameters = false; AstArray parameters{}; + Position parametersOpeningPosition{0, 0}; + TempVector parametersCommaPositions(scratchPosition); + Position parametersClosingPosition{0, 0}; if (lexer.current().type == '<') { hasParameters = true; - parameters = parseTypeParams(); + if (FFlag::LuauStoreCSTData && options.storeCstData) + parameters = parseTypeParams(¶metersOpeningPosition, ¶metersCommaPositions, ¶metersClosingPosition); + else + parameters = parseTypeParams(); } Location end = lexer.previousLocation(); - return { - allocator.alloc(Location(start, end), prefix, name.name, prefixLocation, name.location, hasParameters, parameters), {} - }; + if (FFlag::LuauStoreCSTData) + { + AstTypeReference* node = + allocator.alloc(Location(start, end), prefix, name.name, prefixLocation, name.location, hasParameters, parameters); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc( + prefixPointPosition, parametersOpeningPosition, copy(parametersCommaPositions), parametersClosingPosition + ); + return {node, {}}; + } + else + { + return { + allocator.alloc(Location(start, end), prefix, name.name, prefixLocation, name.location, hasParameters, parameters), + {} + }; + } } else if (lexer.current().type == '{') { @@ -2246,7 +2584,8 @@ std::optional Parser::checkBinaryConfusables(const BinaryOpPr report(Location(start, next.location), "Unexpected '||'; did you mean 'or'?"); return AstExprBinary::Or; } - else if (curr.type == '!' && next.type == '=' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::CompareNe].left > limit) + else if (curr.type == '!' && next.type == '=' && curr.location.end == next.location.begin && + binaryPriority[AstExprBinary::CompareNe].left > limit) { nextLexeme(); report(Location(start, next.location), "Unexpected '!='; did you mean '~='?"); @@ -2299,11 +2638,14 @@ AstExpr* Parser::parseExpr(unsigned int limit) if (uop) { + Position opPosition = lexer.current().location.begin; nextLexeme(); AstExpr* subexpr = parseExpr(unaryPriority); expr = allocator.alloc(Location(start, subexpr->location), *uop, subexpr); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstNodeMap[expr] = allocator.alloc(opPosition); } else { @@ -2318,12 +2660,15 @@ AstExpr* Parser::parseExpr(unsigned int limit) while (op && binaryPriority[*op].left > limit) { + Position opPosition = lexer.current().location.begin; nextLexeme(); // read sub-expression with higher priority AstExpr* next = parseExpr(binaryPriority[*op].right); expr = allocator.alloc(Location(start, next->location), *op, expr, next); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstNodeMap[expr] = allocator.alloc(opPosition); op = parseBinaryOp(lexer.current()); if (!op) @@ -2352,11 +2697,8 @@ AstExpr* Parser::parseNameExpr(const char* context) { AstLocal* local = *value; - if (FFlag::LuauUserDefinedTypeFunctionsSyntax2) - { - if (local->functionDepth < typeFunctionDepth) - return reportExprError(lexer.current().location, {}, "Type function cannot reference outer local '%s'", local->name.value); - } + if (local->functionDepth < typeFunctionDepth) + return reportExprError(lexer.current().location, {}, "Type function cannot reference outer local '%s'", local->name.value); return allocator.alloc(name->location, local, local->functionDepth != functionStack.size() - 1); } @@ -2426,11 +2768,14 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement) AstExpr* index = parseExpr(); + Position closeBracketPosition = lexer.current().location.begin; Position end = lexer.current().location.end; expectMatchAndConsume(']', matchBracket); expr = allocator.alloc(Location(start, end), expr, index); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstNodeMap[expr] = allocator.alloc(matchBracket.position, closeBracketPosition); } else if (lexer.current().type == ':') { @@ -2597,7 +2942,8 @@ AstExpr* Parser::parseSimpleExpr() { return parseNumber(); } - else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple) + else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString || + lexer.current().type == Lexeme::InterpStringSimple) { return parseString(); } @@ -2657,16 +3003,29 @@ AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self) nextLexeme(); TempVector args(scratchExpr); + TempVector commaPositions(scratchPosition); if (lexer.current().type != ')') - parseExprList(args); + parseExprList(args, (FFlag::LuauStoreCSTData && options.storeCstData) ? &commaPositions : nullptr); Location end = lexer.current().location; Position argEnd = end.end; expectMatchAndConsume(')', matchParen); - return allocator.alloc(Location(func->location, end), func, copy(args), self, Location(argStart, argEnd)); + if (FFlag::LuauStoreCSTData) + { + AstExprCall* node = allocator.alloc(Location(func->location, end), func, copy(args), self, Location(argStart, argEnd)); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc( + matchParen.position, lexer.previousLocation().begin, copy(commaPositions) + ); + return node; + } + else + { + return allocator.alloc(Location(func->location, end), func, copy(args), self, Location(argStart, argEnd)); + } } else if (lexer.current().type == '{') { @@ -2674,14 +3033,35 @@ AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self) AstExpr* expr = parseTableConstructor(); Position argEnd = lexer.previousLocation().end; - return allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, Location(argStart, argEnd)); + if (FFlag::LuauStoreCSTData) + { + AstExprCall* node = + allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, Location(argStart, argEnd)); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(std::nullopt, std::nullopt, AstArray{nullptr, 0}); + return node; + } + else + { + return allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, Location(argStart, argEnd)); + } } else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) { Location argLocation = lexer.current().location; AstExpr* expr = parseString(); - return allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, argLocation); + if (FFlag::LuauStoreCSTData) + { + AstExprCall* node = allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, argLocation); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(std::nullopt, std::nullopt, AstArray{nullptr, 0}); + return node; + } + else + { + return allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, argLocation); + } } else { @@ -2715,6 +3095,17 @@ LUAU_NOINLINE void Parser::reportAmbiguousCallError() ); } +std::optional Parser::tableSeparator() +{ + LUAU_ASSERT(FFlag::LuauStoreCSTData); + if (lexer.current().type == ',') + return CstExprTable::Comma; + else if (lexer.current().type == ';') + return CstExprTable::Semicolon; + else + return std::nullopt; +} + // tableconstructor ::= `{' [fieldlist] `}' // fieldlist ::= field {fieldsep field} [fieldsep] // field ::= `[' exp `]' `=' exp | Name `=' exp | exp @@ -2722,6 +3113,7 @@ LUAU_NOINLINE void Parser::reportAmbiguousCallError() AstExpr* Parser::parseTableConstructor() { TempVector items(scratchItem); + TempVector cstItems(scratchCstItem); Location start = lexer.current().location; @@ -2735,23 +3127,29 @@ AstExpr* Parser::parseTableConstructor() if (lexer.current().type == '[') { + Position indexerOpenPosition = lexer.current().location.begin; MatchLexeme matchLocationBracket = lexer.current(); nextLexeme(); AstExpr* key = parseExpr(); + Position indexerClosePosition = lexer.current().location.begin; expectMatchAndConsume(']', matchLocationBracket); + Position equalsPosition = lexer.current().location.begin; expectAndConsume('=', "table field"); AstExpr* value = parseExpr(); items.push_back({AstExprTable::Item::General, key, value}); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstItems.push_back({indexerOpenPosition, indexerClosePosition, equalsPosition, tableSeparator(), lexer.current().location.begin}); } else if (lexer.current().type == Lexeme::Name && lexer.lookahead().type == '=') { Name name = parseName("table field"); + Position equalsPosition = lexer.current().location.begin; expectAndConsume('=', "table field"); AstArray nameString; @@ -2765,12 +3163,16 @@ AstExpr* Parser::parseTableConstructor() func->debugname = name.name; items.push_back({AstExprTable::Item::Record, key, value}); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstItems.push_back({std::nullopt, std::nullopt, equalsPosition, tableSeparator(), lexer.current().location.begin}); } else { AstExpr* expr = parseExpr(); items.push_back({AstExprTable::Item::List, nullptr, expr}); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstItems.push_back({std::nullopt, std::nullopt, std::nullopt, tableSeparator(), lexer.current().location.begin}); } if (lexer.current().type == ',' || lexer.current().type == ';') @@ -2792,7 +3194,17 @@ AstExpr* Parser::parseTableConstructor() if (!expectMatchAndConsume('}', matchBrace)) end = lexer.previousLocation(); - return allocator.alloc(Location(start, end), copy(items)); + if (FFlag::LuauStoreCSTData) + { + AstExprTable* node = allocator.alloc(Location(start, end), copy(items)); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(copy(cstItems)); + return node; + } + else + { + return allocator.alloc(Location(start, end), copy(items)); + } } AstExpr* Parser::parseIfElseExpr() @@ -2804,11 +3216,14 @@ AstExpr* Parser::parseIfElseExpr() AstExpr* condition = parseExpr(); + Position thenPosition = lexer.current().location.begin; bool hasThen = expectAndConsume(Lexeme::ReservedThen, "if then else expression"); AstExpr* trueExpr = parseExpr(); AstExpr* falseExpr = nullptr; + Position elsePosition = lexer.current().location.begin; + bool isElseIf = false; if (lexer.current().type == Lexeme::ReservedElseif) { unsigned int oldRecursionCount = recursionCounter; @@ -2816,6 +3231,8 @@ AstExpr* Parser::parseIfElseExpr() hasElse = true; falseExpr = parseIfElseExpr(); recursionCounter = oldRecursionCount; + if (FFlag::LuauStoreCSTData) + isElseIf = true; } else { @@ -2825,7 +3242,17 @@ AstExpr* Parser::parseIfElseExpr() Location end = falseExpr->location; - return allocator.alloc(Location(start, end), condition, hasThen, trueExpr, hasElse, falseExpr); + if (FFlag::LuauStoreCSTData) + { + AstExprIfElse* node = allocator.alloc(Location(start, end), condition, hasThen, trueExpr, hasElse, falseExpr); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(thenPosition, elsePosition, isElseIf); + return node; + } + else + { + return allocator.alloc(Location(start, end), condition, hasThen, trueExpr, hasElse, falseExpr); + } } // Name @@ -2975,13 +3402,15 @@ std::pair, AstArray> Parser::parseG return {generics, genericPacks}; } -AstArray Parser::parseTypeParams() +AstArray Parser::parseTypeParams(Position* openingPosition, TempVector* commaPositions, Position* closingPosition) { TempVector parameters{scratchTypeOrPack}; if (lexer.current().type == '<') { Lexeme begin = lexer.current(); + if (FFlag::LuauStoreCSTData && openingPosition) + *openingPosition = begin.location.begin; nextLexeme(); while (true) @@ -3027,7 +3456,15 @@ AstArray Parser::parseTypeParams() // the next lexeme is one that follows a type // (&, |, ?), then assume that this was actually a // parenthesized type. - parameters.push_back({parseTypeSuffix(explicitTypePack->typeList.types.data[0], begin), {}}); + if (FFlag::LuauAstTypeGroup) + { + auto parenthesizedType = explicitTypePack->typeList.types.data[0]; + parameters.push_back( + {parseTypeSuffix(allocator.alloc(parenthesizedType->location, parenthesizedType), begin), {}} + ); + } + else + parameters.push_back({parseTypeSuffix(explicitTypePack->typeList.types.data[0], begin), {}}); } else { @@ -3069,18 +3506,24 @@ AstArray Parser::parseTypeParams() } if (lexer.current().type == ',') + { + if (FFlag::LuauStoreCSTData && commaPositions) + commaPositions->push_back(lexer.current().location.begin); nextLexeme(); + } else break; } + if (FFlag::LuauStoreCSTData && closingPosition) + *closingPosition = lexer.current().location.begin; expectMatchAndConsume('>', begin); } return copy(parameters); } -std::optional> Parser::parseCharArray() +std::optional> Parser::parseCharArray(AstArray* originalString) { LUAU_ASSERT( lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString || @@ -3088,6 +3531,11 @@ std::optional> Parser::parseCharArray() ); scratchData.assign(lexer.current().data, lexer.current().getLength()); + if (FFlag::LuauStoreCSTData) + { + if (originalString) + *originalString = copy(scratchData); + } if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple) { @@ -3125,15 +3573,38 @@ AstExpr* Parser::parseString() LUAU_ASSERT(false && "Invalid string type"); } - if (std::optional> value = parseCharArray()) - return allocator.alloc(location, *value, style); + if (FFlag::LuauStoreCSTData) + { + CstExprConstantString::QuoteStyle fullStyle; + unsigned int blockDepth; + if (options.storeCstData) + std::tie(fullStyle, blockDepth) = extractStringDetails(); + + AstArray originalString; + if (std::optional> value = parseCharArray(options.storeCstData ? &originalString : nullptr)) + { + AstExprConstantString* node = allocator.alloc(location, *value, style); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(originalString, fullStyle, blockDepth); + return node; + } + else + return reportExprError(location, {}, "String literal contains malformed escape sequence"); + } else - return reportExprError(location, {}, "String literal contains malformed escape sequence"); + { + if (std::optional> value = parseCharArray()) + return allocator.alloc(location, *value, style); + else + return reportExprError(location, {}, "String literal contains malformed escape sequence"); + } } AstExpr* Parser::parseInterpString() { TempVector> strings(scratchString); + TempVector> sourceStrings(scratchString2); + TempVector stringPositions(scratchPosition); TempVector expressions(scratchExpr); Location startLocation = lexer.current().location; @@ -3151,6 +3622,12 @@ AstExpr* Parser::parseInterpString() scratchData.assign(currentLexeme.data, currentLexeme.getLength()); + if (FFlag::LuauStoreCSTData && options.storeCstData) + { + sourceStrings.push_back(copy(scratchData)); + stringPositions.push_back(currentLexeme.location.begin); + } + if (!Lexer::fixupQuotedString(scratchData)) { nextLexeme(); @@ -3215,7 +3692,15 @@ AstExpr* Parser::parseInterpString() AstArray> stringsArray = copy(strings); AstArray expressionsArray = copy(expressions); - return allocator.alloc(Location{startLocation, endLocation}, stringsArray, expressionsArray); + if (FFlag::LuauStoreCSTData) + { + AstExprInterpString* node = allocator.alloc(Location{startLocation, endLocation}, stringsArray, expressionsArray); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(copy(sourceStrings), copy(stringPositions)); + return node; + } + else + return allocator.alloc(Location{startLocation, endLocation}, stringsArray, expressionsArray); } AstExpr* Parser::parseNumber() @@ -3223,6 +3708,9 @@ AstExpr* Parser::parseNumber() Location start = lexer.current().location; scratchData.assign(lexer.current().data, lexer.current().getLength()); + AstArray sourceData; + if (FFlag::LuauStoreCSTData && options.storeCstData) + sourceData = copy(scratchData); // Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al if (scratchData.find('_') != std::string::npos) @@ -3237,7 +3725,17 @@ AstExpr* Parser::parseNumber() if (result == ConstantNumberParseResult::Malformed) return reportExprError(start, {}, "Malformed number"); - return allocator.alloc(start, value, result); + if (FFlag::LuauStoreCSTData) + { + AstExprConstantNumber* node = allocator.alloc(start, value, result); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(sourceData); + return node; + } + else + { + return allocator.alloc(start, value, result); + } } AstLocal* Parser::pushLocal(const Binding& binding) @@ -3514,7 +4012,7 @@ void Parser::report(const Location& location, const char* format, va_list args) parseErrors.emplace_back(location, message); - if (parseErrors.size() >= unsigned(FInt::LuauParseErrorLimit)) + if (parseErrors.size() >= unsigned(FInt::LuauParseErrorLimit) && (!FFlag::ParserNoErrorLimit || !options.noErrorLimit)) ParseError::raise(location, "Reached error limit (%d)", int(FInt::LuauParseErrorLimit)); } diff --git a/CLI/Coverage.h b/CLI/include/Luau/Coverage.h similarity index 100% rename from CLI/Coverage.h rename to CLI/include/Luau/Coverage.h diff --git a/CLI/FileUtils.h b/CLI/include/Luau/FileUtils.h similarity index 79% rename from CLI/FileUtils.h rename to CLI/include/Luau/FileUtils.h index f723c765..80e36378 100644 --- a/CLI/FileUtils.h +++ b/CLI/include/Luau/FileUtils.h @@ -10,7 +10,7 @@ std::optional getCurrentWorkingDirectory(); std::string normalizePath(std::string_view path); -std::string resolvePath(std::string_view relativePath, std::string_view baseFilePath); +std::optional resolvePath(std::string_view relativePath, std::string_view baseFilePath); std::optional readFile(const std::string& name); std::optional readStdin(); @@ -23,7 +23,7 @@ bool isDirectory(const std::string& path); bool traverseDirectory(const std::string& path, const std::function& callback); std::vector splitPath(std::string_view path); -std::string joinPaths(const std::string& lhs, const std::string& rhs); -std::optional getParentPath(const std::string& path); +std::string joinPaths(std::string_view lhs, std::string_view rhs); +std::optional getParentPath(std::string_view path); std::vector getSourceFiles(int argc, char** argv); diff --git a/CLI/Flags.h b/CLI/include/Luau/Flags.h similarity index 100% rename from CLI/Flags.h rename to CLI/include/Luau/Flags.h diff --git a/CLI/Profiler.h b/CLI/include/Luau/Profiler.h similarity index 100% rename from CLI/Profiler.h rename to CLI/include/Luau/Profiler.h diff --git a/CLI/Repl.h b/CLI/include/Luau/Repl.h similarity index 100% rename from CLI/Repl.h rename to CLI/include/Luau/Repl.h diff --git a/CLI/Require.h b/CLI/include/Luau/Require.h similarity index 100% rename from CLI/Require.h rename to CLI/include/Luau/Require.h diff --git a/CLI/Analyze.cpp b/CLI/src/Analyze.cpp similarity index 99% rename from CLI/Analyze.cpp rename to CLI/src/Analyze.cpp index bc78f7cb..e10a2c2e 100644 --- a/CLI/Analyze.cpp +++ b/CLI/src/Analyze.cpp @@ -7,9 +7,9 @@ #include "Luau/TypeAttach.h" #include "Luau/Transpiler.h" -#include "FileUtils.h" -#include "Flags.h" -#include "Require.h" +#include "Luau/FileUtils.h" +#include "Luau/Flags.h" +#include "Luau/Require.h" #include #include diff --git a/CLI/Ast.cpp b/CLI/src/Ast.cpp similarity index 98% rename from CLI/Ast.cpp rename to CLI/src/Ast.cpp index b5a922aa..5341d889 100644 --- a/CLI/Ast.cpp +++ b/CLI/src/Ast.cpp @@ -8,7 +8,7 @@ #include "Luau/ParseOptions.h" #include "Luau/ToString.h" -#include "FileUtils.h" +#include "Luau/FileUtils.h" static void displayHelp(const char* argv0) { diff --git a/CLI/Bytecode.cpp b/CLI/src/Bytecode.cpp similarity index 99% rename from CLI/Bytecode.cpp rename to CLI/src/Bytecode.cpp index 2da9570b..dc8e4833 100644 --- a/CLI/Bytecode.cpp +++ b/CLI/src/Bytecode.cpp @@ -7,8 +7,8 @@ #include "Luau/BytecodeBuilder.h" #include "Luau/Parser.h" #include "Luau/BytecodeSummary.h" -#include "FileUtils.h" -#include "Flags.h" +#include "Luau/FileUtils.h" +#include "Luau/Flags.h" #include diff --git a/CLI/Compile.cpp b/CLI/src/Compile.cpp similarity index 99% rename from CLI/Compile.cpp rename to CLI/src/Compile.cpp index 7d95387c..6f41b42d 100644 --- a/CLI/Compile.cpp +++ b/CLI/src/Compile.cpp @@ -8,8 +8,8 @@ #include "Luau/Parser.h" #include "Luau/TimeTrace.h" -#include "FileUtils.h" -#include "Flags.h" +#include "Luau/FileUtils.h" +#include "Luau/Flags.h" #include @@ -341,7 +341,8 @@ static bool compileFile(const char* name, CompileFormat format, Luau::CodeGen::A bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks); bcb.setDumpSource(*source); } - else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenAsm || format == CompileFormat::CodegenIr || format == CompileFormat::CodegenVerbose) + else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenAsm || format == CompileFormat::CodegenIr || + format == CompileFormat::CodegenVerbose) { bcb.setDumpFlags( Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | diff --git a/CLI/Coverage.cpp b/CLI/src/Coverage.cpp similarity index 98% rename from CLI/Coverage.cpp rename to CLI/src/Coverage.cpp index a509ab89..7330d492 100644 --- a/CLI/Coverage.cpp +++ b/CLI/src/Coverage.cpp @@ -1,5 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Coverage.h" +#include "Luau/Coverage.h" #include "lua.h" diff --git a/CLI/FileUtils.cpp b/CLI/src/FileUtils.cpp similarity index 71% rename from CLI/FileUtils.cpp rename to CLI/src/FileUtils.cpp index 6925e99f..d54d94e0 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/src/FileUtils.cpp @@ -1,5 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "FileUtils.h" +#include "Luau/FileUtils.h" #include "Luau/Common.h" @@ -20,6 +20,7 @@ #endif #include +#include #ifdef _WIN32 static std::wstring fromUtf8(const std::string& path) @@ -90,108 +91,76 @@ std::optional getCurrentWorkingDirectory() return std::nullopt; } -// Returns the normal/canonical form of a path (e.g. "../subfolder/../module.luau" -> "../module.luau") std::string normalizePath(std::string_view path) { - return resolvePath(path, ""); -} + const std::vector components = splitPath(path); + std::vector normalizedComponents; -// Takes a path that is relative to the file at baseFilePath and returns the path explicitly rebased onto baseFilePath. -// For absolute paths, baseFilePath will be ignored, and this function will resolve the path to a canonical path: -// (e.g. "/Users/.././Users/johndoe" -> "/Users/johndoe"). -std::string resolvePath(std::string_view path, std::string_view baseFilePath) -{ - std::vector pathComponents; - std::vector baseFilePathComponents; + const bool isAbsolute = isAbsolutePath(path); - // Dependent on whether the final resolved path is absolute or relative - // - if relative (when path and baseFilePath are both relative), resolvedPathPrefix remains empty - // - if absolute (if either path or baseFilePath are absolute), resolvedPathPrefix is "C:\", "/", etc. - std::string resolvedPathPrefix; - bool isResolvedPathRelative = false; - - if (isAbsolutePath(path)) - { - // path is absolute, we use path's prefix and ignore baseFilePath - size_t afterPrefix = path.find_first_of("\\/") + 1; - resolvedPathPrefix = path.substr(0, afterPrefix); - pathComponents = splitPath(path.substr(afterPrefix)); - } - else - { - size_t afterPrefix = baseFilePath.find_first_of("\\/") + 1; - baseFilePathComponents = splitPath(baseFilePath.substr(afterPrefix)); - if (isAbsolutePath(baseFilePath)) - { - // path is relative and baseFilePath is absolute, we use baseFilePath's prefix - resolvedPathPrefix = baseFilePath.substr(0, afterPrefix); - } - else - { - // path and baseFilePath are both relative, we do not set a prefix (resolved path will be relative) - isResolvedPathRelative = true; - } - pathComponents = splitPath(path); - } - - // Remove filename from components - if (!baseFilePathComponents.empty()) - baseFilePathComponents.pop_back(); - - // Resolve the path by applying pathComponents to baseFilePathComponents - int numPrependedParents = 0; - for (std::string_view component : pathComponents) + // 1. Normalize path components + const size_t startIndex = isAbsolute ? 1 : 0; + for (size_t i = startIndex; i < components.size(); i++) { + std::string_view component = components[i]; if (component == "..") { - if (baseFilePathComponents.empty()) + if (normalizedComponents.empty()) { - if (isResolvedPathRelative) - numPrependedParents++; // "../" will later be added to the beginning of the resolved path + if (!isAbsolute) + { + normalizedComponents.emplace_back(".."); + } } - else if (baseFilePathComponents.back() != "..") + else if (normalizedComponents.back() == "..") { - baseFilePathComponents.pop_back(); // Resolve cases like "folder/subfolder/../../file" to "file" + normalizedComponents.emplace_back(".."); + } + else + { + normalizedComponents.pop_back(); } } - else if (component != "." && !component.empty()) + else if (!component.empty() && component != ".") { - baseFilePathComponents.push_back(component); + normalizedComponents.emplace_back(component); } } - // Create resolved path prefix for relative paths - if (isResolvedPathRelative) + std::string normalizedPath; + + // 2. Add correct prefix to formatted path + if (isAbsolute) { - if (numPrependedParents > 0) - { - resolvedPathPrefix.reserve(numPrependedParents * 3); - for (int i = 0; i < numPrependedParents; i++) - { - resolvedPathPrefix += "../"; - } - } - else - { - resolvedPathPrefix = "./"; - } + normalizedPath += components[0]; + normalizedPath += "/"; + } + else if (normalizedComponents.empty() || normalizedComponents[0] != "..") + { + normalizedPath += "./"; } - // Join baseFilePathComponents to form the resolved path - std::string resolvedPath = resolvedPathPrefix; - for (auto iter = baseFilePathComponents.begin(); iter != baseFilePathComponents.end(); ++iter) + // 3. Join path components to form the normalized path + for (auto iter = normalizedComponents.begin(); iter != normalizedComponents.end(); ++iter) { - if (iter != baseFilePathComponents.begin()) - resolvedPath += "/"; + if (iter != normalizedComponents.begin()) + normalizedPath += "/"; - resolvedPath += *iter; + normalizedPath += *iter; } - if (resolvedPath.size() > resolvedPathPrefix.size() && resolvedPath.back() == '/') - { - // Remove trailing '/' if present - resolvedPath.pop_back(); - } - return resolvedPath; + if (normalizedPath.size() >= 2 && normalizedPath[normalizedPath.size() - 1] == '.' && normalizedPath[normalizedPath.size() - 2] == '.') + normalizedPath += "/"; + + return normalizedPath; +} + +std::optional resolvePath(std::string_view path, std::string_view baseFilePath) +{ + std::optional baseFilePathParent = getParentPath(baseFilePath); + if (!baseFilePathParent) + return std::nullopt; + + return normalizePath(joinPaths(*baseFilePathParent, path)); } bool hasFileExtension(std::string_view name, const std::vector& extensions) @@ -416,16 +385,16 @@ std::vector splitPath(std::string_view path) return components; } -std::string joinPaths(const std::string& lhs, const std::string& rhs) +std::string joinPaths(std::string_view lhs, std::string_view rhs) { - std::string result = lhs; + std::string result = std::string(lhs); if (!result.empty() && result.back() != '/' && result.back() != '\\') result += '/'; result += rhs; return result; } -std::optional getParentPath(const std::string& path) +std::optional getParentPath(std::string_view path) { if (path == "" || path == "." || path == "/") return std::nullopt; @@ -441,7 +410,7 @@ std::optional getParentPath(const std::string& path) return "/"; if (slash != std::string::npos) - return path.substr(0, slash); + return std::string(path.substr(0, slash)); return ""; } @@ -471,10 +440,12 @@ std::vector getSourceFiles(int argc, char** argv) if (argv[i][0] == '-' && argv[i][1] != '\0') continue; - if (isDirectory(argv[i])) + std::string normalized = normalizePath(argv[i]); + + if (isDirectory(normalized)) { traverseDirectory( - argv[i], + normalized, [&](const std::string& name) { std::string ext = getExtension(name); @@ -486,7 +457,7 @@ std::vector getSourceFiles(int argc, char** argv) } else { - files.push_back(argv[i]); + files.push_back(normalized); } } diff --git a/CLI/Flags.cpp b/CLI/src/Flags.cpp similarity index 100% rename from CLI/Flags.cpp rename to CLI/src/Flags.cpp diff --git a/CLI/Profiler.cpp b/CLI/src/Profiler.cpp similarity index 100% rename from CLI/Profiler.cpp rename to CLI/src/Profiler.cpp diff --git a/CLI/Reduce.cpp b/CLI/src/Reduce.cpp similarity index 99% rename from CLI/Reduce.cpp rename to CLI/src/Reduce.cpp index 7f8c459c..e66d80dc 100644 --- a/CLI/Reduce.cpp +++ b/CLI/src/Reduce.cpp @@ -5,7 +5,7 @@ #include "Luau/Parser.h" #include "Luau/Transpiler.h" -#include "FileUtils.h" +#include "Luau/FileUtils.h" #include #include diff --git a/CLI/Repl.cpp b/CLI/src/Repl.cpp similarity index 99% rename from CLI/Repl.cpp rename to CLI/src/Repl.cpp index 3bda38f1..2dec1d8c 100644 --- a/CLI/Repl.cpp +++ b/CLI/src/Repl.cpp @@ -1,5 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Repl.h" +#include "Luau/Repl.h" #include "Luau/Common.h" #include "lua.h" @@ -10,11 +10,11 @@ #include "Luau/Parser.h" #include "Luau/TimeTrace.h" -#include "Coverage.h" -#include "FileUtils.h" -#include "Flags.h" -#include "Profiler.h" -#include "Require.h" +#include "Luau/Coverage.h" +#include "Luau/FileUtils.h" +#include "Luau/Flags.h" +#include "Luau/Profiler.h" +#include "Luau/Require.h" #include "isocline.h" diff --git a/CLI/ReplEntry.cpp b/CLI/src/ReplEntry.cpp similarity index 89% rename from CLI/ReplEntry.cpp rename to CLI/src/ReplEntry.cpp index 8543e3f7..7e5f9e06 100644 --- a/CLI/ReplEntry.cpp +++ b/CLI/src/ReplEntry.cpp @@ -1,5 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Repl.h" +#include "Luau/Repl.h" int main(int argc, char** argv) { diff --git a/CLI/Require.cpp b/CLI/src/Require.cpp similarity index 92% rename from CLI/Require.cpp rename to CLI/src/Require.cpp index 4c1c3ac6..1039f85c 100644 --- a/CLI/Require.cpp +++ b/CLI/src/Require.cpp @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Require.h" +#include "Luau/Require.h" -#include "FileUtils.h" +#include "Luau/FileUtils.h" #include "Luau/Common.h" #include "Luau/Config.h" @@ -141,8 +141,17 @@ bool RequireResolver::resolveAndStoreDefaultPaths() return false; // resolvePath automatically sanitizes/normalizes the paths - resolvedRequire.identifier = resolvePath(pathToResolve, identifierContext); - resolvedRequire.absolutePath = resolvePath(pathToResolve, *absolutePathContext); + std::optional identifier = resolvePath(pathToResolve, identifierContext); + std::optional absolutePath = resolvePath(pathToResolve, *absolutePathContext); + + if (!identifier || !absolutePath) + { + errorHandler.reportError("could not resolve require path"); + return false; + } + + resolvedRequire.identifier = std::move(*identifier); + resolvedRequire.absolutePath = std::move(*absolutePath); } else { @@ -181,7 +190,7 @@ std::optional RequireResolver::getRequiringContextAbsolute() else { // Require statement is being executed in a file, must resolve relative to CWD - requiringFile = resolvePath(requireContext.getPath(), joinPaths(*cwd, "stdin")); + requiringFile = normalizePath(joinPaths(*cwd, requireContext.getPath())); } } std::replace(requiringFile.begin(), requiringFile.end(), '\\', '/'); @@ -190,7 +199,7 @@ std::optional RequireResolver::getRequiringContextAbsolute() std::string RequireResolver::getRequiringContextRelative() { - return requireContext.isStdin() ? "" : requireContext.getPath(); + return requireContext.isStdin() ? "./" : requireContext.getPath(); } bool RequireResolver::substituteAliasIfPresent(std::string& path) @@ -301,4 +310,4 @@ bool RequireResolver::parseConfigInDirectory(const std::string& directory) } return true; -} \ No newline at end of file +} diff --git a/CLI/Web.cpp b/CLI/src/Web.cpp similarity index 100% rename from CLI/Web.cpp rename to CLI/src/Web.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 51fa919e..5286fd9f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,11 +68,12 @@ include(Sources.cmake) target_include_directories(Luau.Common INTERFACE Common/include) target_compile_features(Luau.CLI.lib PUBLIC cxx_std_17) -target_link_libraries(Luau.CLI.lib PRIVATE Luau.Common) +target_include_directories(Luau.CLI.lib PUBLIC CLI/include) +target_link_libraries(Luau.CLI.lib PRIVATE Luau.Common Luau.Config) target_compile_features(Luau.Ast PUBLIC cxx_std_17) target_include_directories(Luau.Ast PUBLIC Ast/include) -target_link_libraries(Luau.Ast PUBLIC Luau.Common Luau.CLI.lib) +target_link_libraries(Luau.Ast PUBLIC Luau.Common) target_compile_features(Luau.Compiler PUBLIC cxx_std_17) target_include_directories(Luau.Compiler PUBLIC Compiler/include) diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index 30790ee5..ca5fa7a9 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -160,6 +160,7 @@ public: void vmaxsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vcmpeqsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vblendvpd(RegisterX64 dst, RegisterX64 src1, OperandX64 mask, RegisterX64 src3); diff --git a/CodeGen/include/Luau/CodeAllocator.h b/CodeGen/include/Luau/CodeAllocator.h index dcc1de85..db1774d8 100644 --- a/CodeGen/include/Luau/CodeAllocator.h +++ b/CodeGen/include/Luau/CodeAllocator.h @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/CodeGen.h" +#include "Luau/CodeGenOptions.h" #include diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index 0cf9d9a5..2e689fe2 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -1,7 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include +#include "Luau/CodeGenCommon.h" +#include "Luau/CodeGenOptions.h" +#include "Luau/LoweringStats.h" + #include #include #include @@ -12,25 +15,11 @@ struct lua_State; -#if defined(__x86_64__) || defined(_M_X64) -#define CODEGEN_TARGET_X64 -#elif defined(__aarch64__) || defined(_M_ARM64) -#define CODEGEN_TARGET_A64 -#endif - namespace Luau { namespace CodeGen { -enum CodeGenFlags -{ - // Only run native codegen for modules that have been marked with --!native - CodeGen_OnlyNativeModules = 1 << 0, - // Run native codegen for functions that the compiler considers not profitable - CodeGen_ColdFunctions = 1 << 1, -}; - // These enum values can be reported through telemetry. // To ensure consistency, changes should be additive. enum class CodeGenCompilationResult @@ -72,106 +61,6 @@ struct CompilationResult } }; -struct IrBuilder; -struct IrOp; - -using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength); -using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); -using HostVectorNamecallHandler = - bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos); - -enum class HostMetamethod -{ - Add, - Sub, - Mul, - Div, - Idiv, - Mod, - Pow, - Minus, - Equal, - LessThan, - LessEqual, - Length, - Concat, -}; - -using HostUserdataOperationBytecodeType = uint8_t (*)(uint8_t type, const char* member, size_t memberLength); -using HostUserdataMetamethodBytecodeType = uint8_t (*)(uint8_t lhsTy, uint8_t rhsTy, HostMetamethod method); -using HostUserdataAccessHandler = - bool (*)(IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); -using HostUserdataMetamethodHandler = - bool (*)(IrBuilder& builder, uint8_t lhsTy, uint8_t rhsTy, int resultReg, IrOp lhs, IrOp rhs, HostMetamethod method, int pcpos); -using HostUserdataNamecallHandler = bool (*)( - IrBuilder& builder, - uint8_t type, - const char* member, - size_t memberLength, - int argResReg, - int sourceReg, - int params, - int results, - int pcpos -); - -struct HostIrHooks -{ - // Suggest result type of a vector field access - HostVectorOperationBytecodeType vectorAccessBytecodeType = nullptr; - - // Suggest result type of a vector function namecall - HostVectorOperationBytecodeType vectorNamecallBytecodeType = nullptr; - - // Handle vector value field access - // 'sourceReg' is guaranteed to be a vector - // Guards should take a VM exit to 'pcpos' - HostVectorAccessHandler vectorAccess = nullptr; - - // Handle namecall performed on a vector value - // 'sourceReg' (self argument) is guaranteed to be a vector - // All other arguments can be of any type - // Guards should take a VM exit to 'pcpos' - HostVectorNamecallHandler vectorNamecall = nullptr; - - // Suggest result type of a userdata field access - HostUserdataOperationBytecodeType userdataAccessBytecodeType = nullptr; - - // Suggest result type of a metamethod call - HostUserdataMetamethodBytecodeType userdataMetamethodBytecodeType = nullptr; - - // Suggest result type of a userdata namecall - HostUserdataOperationBytecodeType userdataNamecallBytecodeType = nullptr; - - // Handle userdata value field access - // 'sourceReg' is guaranteed to be a userdata, but tag has to be checked - // Write to 'resultReg' might invalidate 'sourceReg' - // Guards should take a VM exit to 'pcpos' - HostUserdataAccessHandler userdataAccess = nullptr; - - // Handle metamethod operation on a userdata value - // 'lhs' and 'rhs' operands can be VM registers of constants - // Operand types have to be checked and userdata operand tags have to be checked - // Write to 'resultReg' might invalidate source operands - // Guards should take a VM exit to 'pcpos' - HostUserdataMetamethodHandler userdataMetamethod = nullptr; - - // Handle namecall performed on a userdata value - // 'sourceReg' (self argument) is guaranteed to be a userdata, but tag has to be checked - // All other arguments can be of any type - // Guards should take a VM exit to 'pcpos' - HostUserdataNamecallHandler userdataNamecall = nullptr; -}; - -struct CompilationOptions -{ - unsigned int flags = 0; - HostIrHooks hooks; - - // null-terminated array of userdata types names that might have custom lowering - const char* const* userdataTypes = nullptr; -}; - struct CompilationStats { size_t bytecodeSizeBytes = 0; @@ -184,8 +73,6 @@ struct CompilationStats uint32_t functionsBound = 0; }; -using AllocationCallback = void(void* context, void* oldPointer, size_t oldSize, void* newPointer, size_t newSize); - bool isSupported(); class SharedCodeGenContext; @@ -249,153 +136,6 @@ CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsig CompilationResult compile(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats = nullptr); CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats = nullptr); -using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int instpos); - -// Output "#" before IR blocks and instructions -enum class IncludeIrPrefix -{ - No, - Yes -}; - -// Output user count and last use information of blocks and instructions -enum class IncludeUseInfo -{ - No, - Yes -}; - -// Output CFG informations like block predecessors, successors and etc -enum class IncludeCfgInfo -{ - No, - Yes -}; - -// Output VM register live in/out information for blocks -enum class IncludeRegFlowInfo -{ - No, - Yes -}; - -struct AssemblyOptions -{ - enum Target - { - Host, - A64, - A64_NoFeatures, - X64_Windows, - X64_SystemV, - }; - - Target target = Host; - - CompilationOptions compilationOptions; - - bool outputBinary = false; - - bool includeAssembly = false; - bool includeIr = false; - bool includeOutlinedCode = false; - bool includeIrTypes = false; - - IncludeIrPrefix includeIrPrefix = IncludeIrPrefix::Yes; - IncludeUseInfo includeUseInfo = IncludeUseInfo::Yes; - IncludeCfgInfo includeCfgInfo = IncludeCfgInfo::Yes; - IncludeRegFlowInfo includeRegFlowInfo = IncludeRegFlowInfo::Yes; - - // Optional annotator function can be provided to describe each instruction, it takes function id and sequential instruction id - AnnotatorFn annotator = nullptr; - void* annotatorContext = nullptr; -}; - -struct BlockLinearizationStats -{ - unsigned int constPropInstructionCount = 0; - double timeSeconds = 0.0; - - BlockLinearizationStats& operator+=(const BlockLinearizationStats& that) - { - this->constPropInstructionCount += that.constPropInstructionCount; - this->timeSeconds += that.timeSeconds; - - return *this; - } - - BlockLinearizationStats operator+(const BlockLinearizationStats& other) const - { - BlockLinearizationStats result(*this); - result += other; - return result; - } -}; - -enum FunctionStatsFlags -{ - // Enable stats collection per function - FunctionStats_Enable = 1 << 0, - // Compute function bytecode summary - FunctionStats_BytecodeSummary = 1 << 1, -}; - -struct FunctionStats -{ - std::string name; - int line = -1; - unsigned bcodeCount = 0; - unsigned irCount = 0; - unsigned asmCount = 0; - unsigned asmSize = 0; - std::vector> bytecodeSummary; -}; - -struct LoweringStats -{ - unsigned totalFunctions = 0; - unsigned skippedFunctions = 0; - int spillsToSlot = 0; - int spillsToRestore = 0; - unsigned maxSpillSlotsUsed = 0; - unsigned blocksPreOpt = 0; - unsigned blocksPostOpt = 0; - unsigned maxBlockInstructions = 0; - - int regAllocErrors = 0; - int loweringErrors = 0; - - BlockLinearizationStats blockLinearizationStats; - - unsigned functionStatsFlags = 0; - std::vector functions; - - LoweringStats operator+(const LoweringStats& other) const - { - LoweringStats result(*this); - result += other; - return result; - } - - LoweringStats& operator+=(const LoweringStats& that) - { - this->totalFunctions += that.totalFunctions; - this->skippedFunctions += that.skippedFunctions; - this->spillsToSlot += that.spillsToSlot; - this->spillsToRestore += that.spillsToRestore; - this->maxSpillSlotsUsed = std::max(this->maxSpillSlotsUsed, that.maxSpillSlotsUsed); - this->blocksPreOpt += that.blocksPreOpt; - this->blocksPostOpt += that.blocksPostOpt; - this->maxBlockInstructions = std::max(this->maxBlockInstructions, that.maxBlockInstructions); - this->regAllocErrors += that.regAllocErrors; - this->loweringErrors += that.loweringErrors; - this->blockLinearizationStats += that.blockLinearizationStats; - if (this->functionStatsFlags & FunctionStats_Enable) - this->functions.insert(this->functions.end(), that.functions.begin(), that.functions.end()); - return *this; - } -}; - // Generates assembly for target function and all inner functions std::string getAssembly(lua_State* L, int idx, AssemblyOptions options = {}, LoweringStats* stats = nullptr); diff --git a/CodeGen/include/Luau/CodeGenCommon.h b/CodeGen/include/Luau/CodeGenCommon.h index 84090423..a9d1761c 100644 --- a/CodeGen/include/Luau/CodeGenCommon.h +++ b/CodeGen/include/Luau/CodeGenCommon.h @@ -10,3 +10,9 @@ #else #define CODEGEN_ASSERT(expr) (void)sizeof(!!(expr)) #endif + +#if defined(__x86_64__) || defined(_M_X64) +#define CODEGEN_TARGET_X64 +#elif defined(__aarch64__) || defined(_M_ARM64) +#define CODEGEN_TARGET_A64 +#endif diff --git a/CodeGen/include/Luau/CodeGenOptions.h b/CodeGen/include/Luau/CodeGenOptions.h new file mode 100644 index 00000000..de95efa6 --- /dev/null +++ b/CodeGen/include/Luau/CodeGenOptions.h @@ -0,0 +1,188 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +enum CodeGenFlags +{ + // Only run native codegen for modules that have been marked with --!native + CodeGen_OnlyNativeModules = 1 << 0, + // Run native codegen for functions that the compiler considers not profitable + CodeGen_ColdFunctions = 1 << 1, +}; + +using AllocationCallback = void(void* context, void* oldPointer, size_t oldSize, void* newPointer, size_t newSize); + +struct IrBuilder; +struct IrOp; + +using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength); +using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); +using HostVectorNamecallHandler = + bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos); + +enum class HostMetamethod +{ + Add, + Sub, + Mul, + Div, + Idiv, + Mod, + Pow, + Minus, + Equal, + LessThan, + LessEqual, + Length, + Concat, +}; + +using HostUserdataOperationBytecodeType = uint8_t (*)(uint8_t type, const char* member, size_t memberLength); +using HostUserdataMetamethodBytecodeType = uint8_t (*)(uint8_t lhsTy, uint8_t rhsTy, HostMetamethod method); +using HostUserdataAccessHandler = + bool (*)(IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); +using HostUserdataMetamethodHandler = + bool (*)(IrBuilder& builder, uint8_t lhsTy, uint8_t rhsTy, int resultReg, IrOp lhs, IrOp rhs, HostMetamethod method, int pcpos); +using HostUserdataNamecallHandler = bool (*)( + IrBuilder& builder, + uint8_t type, + const char* member, + size_t memberLength, + int argResReg, + int sourceReg, + int params, + int results, + int pcpos +); + +struct HostIrHooks +{ + // Suggest result type of a vector field access + HostVectorOperationBytecodeType vectorAccessBytecodeType = nullptr; + + // Suggest result type of a vector function namecall + HostVectorOperationBytecodeType vectorNamecallBytecodeType = nullptr; + + // Handle vector value field access + // 'sourceReg' is guaranteed to be a vector + // Guards should take a VM exit to 'pcpos' + HostVectorAccessHandler vectorAccess = nullptr; + + // Handle namecall performed on a vector value + // 'sourceReg' (self argument) is guaranteed to be a vector + // All other arguments can be of any type + // Guards should take a VM exit to 'pcpos' + HostVectorNamecallHandler vectorNamecall = nullptr; + + // Suggest result type of a userdata field access + HostUserdataOperationBytecodeType userdataAccessBytecodeType = nullptr; + + // Suggest result type of a metamethod call + HostUserdataMetamethodBytecodeType userdataMetamethodBytecodeType = nullptr; + + // Suggest result type of a userdata namecall + HostUserdataOperationBytecodeType userdataNamecallBytecodeType = nullptr; + + // Handle userdata value field access + // 'sourceReg' is guaranteed to be a userdata, but tag has to be checked + // Write to 'resultReg' might invalidate 'sourceReg' + // Guards should take a VM exit to 'pcpos' + HostUserdataAccessHandler userdataAccess = nullptr; + + // Handle metamethod operation on a userdata value + // 'lhs' and 'rhs' operands can be VM registers of constants + // Operand types have to be checked and userdata operand tags have to be checked + // Write to 'resultReg' might invalidate source operands + // Guards should take a VM exit to 'pcpos' + HostUserdataMetamethodHandler userdataMetamethod = nullptr; + + // Handle namecall performed on a userdata value + // 'sourceReg' (self argument) is guaranteed to be a userdata, but tag has to be checked + // All other arguments can be of any type + // Guards should take a VM exit to 'pcpos' + HostUserdataNamecallHandler userdataNamecall = nullptr; +}; + +struct CompilationOptions +{ + unsigned int flags = 0; + HostIrHooks hooks; + + // null-terminated array of userdata types names that might have custom lowering + const char* const* userdataTypes = nullptr; +}; + + +using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int instpos); + +// Output "#" before IR blocks and instructions +enum class IncludeIrPrefix +{ + No, + Yes +}; + +// Output user count and last use information of blocks and instructions +enum class IncludeUseInfo +{ + No, + Yes +}; + +// Output CFG informations like block predecessors, successors and etc +enum class IncludeCfgInfo +{ + No, + Yes +}; + +// Output VM register live in/out information for blocks +enum class IncludeRegFlowInfo +{ + No, + Yes +}; + +struct AssemblyOptions +{ + enum Target + { + Host, + A64, + A64_NoFeatures, + X64_Windows, + X64_SystemV, + }; + + Target target = Host; + + CompilationOptions compilationOptions; + + bool outputBinary = false; + + bool includeAssembly = false; + bool includeIr = false; + bool includeOutlinedCode = false; + bool includeIrTypes = false; + + IncludeIrPrefix includeIrPrefix = IncludeIrPrefix::Yes; + IncludeUseInfo includeUseInfo = IncludeUseInfo::Yes; + IncludeCfgInfo includeCfgInfo = IncludeCfgInfo::Yes; + IncludeRegFlowInfo includeRegFlowInfo = IncludeRegFlowInfo::Yes; + + // Optional annotator function can be provided to describe each instruction, it takes function id and sequential instruction id + AnnotatorFn annotator = nullptr; + void* annotatorContext = nullptr; +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 779fe012..38519f95 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -20,6 +20,8 @@ namespace Luau namespace CodeGen { +struct LoweringStats; + // IR extensions to LuauBuiltinFunction enum (these only exist inside IR, and start from 256 to avoid collisions) enum { @@ -67,18 +69,18 @@ enum class IrCmd : uint8_t LOAD_ENV, // Get pointer (TValue) to table array at index - // A: pointer (Table) + // A: pointer (LuaTable) // B: int GET_ARR_ADDR, // Get pointer (LuaNode) to table node element at the active cached slot index - // A: pointer (Table) + // A: pointer (LuaTable) // B: unsigned int (pcpos) // C: Kn GET_SLOT_NODE_ADDR, // Get pointer (LuaNode) to table node element at the main position of the specified key hash - // A: pointer (Table) + // A: pointer (LuaTable) // B: unsigned int (hash) GET_HASH_NODE_ADDR, @@ -185,6 +187,11 @@ enum class IrCmd : uint8_t // A: double SIGN_NUM, + // Select B if C == D, otherwise select A + // A, B: double (endpoints) + // C, D: double (condition arguments) + SELECT_NUM, + // Add/Sub/Mul/Div/Idiv two vectors // A, B: TValue ADD_VEC, @@ -268,7 +275,7 @@ enum class IrCmd : uint8_t JUMP_SLOT_MATCH, // Get table length - // A: pointer (Table) + // A: pointer (LuaTable) TABLE_LEN, // Get string length @@ -281,11 +288,11 @@ enum class IrCmd : uint8_t NEW_TABLE, // Duplicate a table - // A: pointer (Table) + // A: pointer (LuaTable) DUP_TABLE, // Insert an integer key into a table and return the pointer to inserted value (TValue) - // A: pointer (Table) + // A: pointer (LuaTable) // B: int (key) TABLE_SETNUM, @@ -425,13 +432,13 @@ enum class IrCmd : uint8_t CHECK_TRUTHY, // Guard against readonly table - // A: pointer (Table) + // A: pointer (LuaTable) // B: block/vmexit/undef // When undef is specified instead of a block, execution is aborted on check failure CHECK_READONLY, // Guard against table having a metatable - // A: pointer (Table) + // A: pointer (LuaTable) // B: block/vmexit/undef // When undef is specified instead of a block, execution is aborted on check failure CHECK_NO_METATABLE, @@ -442,7 +449,7 @@ enum class IrCmd : uint8_t CHECK_SAFE_ENV, // Guard against index overflowing the table array size - // A: pointer (Table) + // A: pointer (LuaTable) // B: int (index) // C: block/vmexit/undef // When undef is specified instead of a block, execution is aborted on check failure @@ -498,11 +505,11 @@ enum class IrCmd : uint8_t BARRIER_OBJ, // Handle GC write barrier (backwards) for a write into a table - // A: pointer (Table) + // A: pointer (LuaTable) BARRIER_TABLE_BACK, // Handle GC write barrier (forward) for a write into a table - // A: pointer (Table) + // A: pointer (LuaTable) // B: Rn (TValue that was written to the object) // C: tag/undef (tag of the value that was written) BARRIER_TABLE_FORWARD, @@ -1044,6 +1051,8 @@ struct IrFunction CfgInfo cfg; + LoweringStats* stats = nullptr; + IrBlock& blockOp(IrOp op) { CODEGEN_ASSERT(op.kind == IrOpKind::Block); diff --git a/CodeGen/include/Luau/IrDump.h b/CodeGen/include/Luau/IrDump.h index 9364f461..27a9feb4 100644 --- a/CodeGen/include/Luau/IrDump.h +++ b/CodeGen/include/Luau/IrDump.h @@ -2,11 +2,13 @@ #pragma once #include "Luau/IrData.h" -#include "Luau/CodeGen.h" +#include "Luau/CodeGenOptions.h" #include #include +struct Proto; + namespace Luau { namespace CodeGen @@ -23,6 +25,7 @@ struct IrToStringContext const std::vector& blocks; const std::vector& constants; const CfgInfo& cfg; + Proto* proto = nullptr; }; void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index); diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 773b23a6..1afa1a34 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -174,6 +174,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::SQRT_NUM: case IrCmd::ABS_NUM: case IrCmd::SIGN_NUM: + case IrCmd::SELECT_NUM: case IrCmd::ADD_VEC: case IrCmd::SUB_VEC: case IrCmd::MUL_VEC: diff --git a/CodeGen/include/Luau/LoweringStats.h b/CodeGen/include/Luau/LoweringStats.h new file mode 100644 index 00000000..532a5270 --- /dev/null +++ b/CodeGen/include/Luau/LoweringStats.h @@ -0,0 +1,103 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +struct BlockLinearizationStats +{ + unsigned int constPropInstructionCount = 0; + double timeSeconds = 0.0; + + BlockLinearizationStats& operator+=(const BlockLinearizationStats& that) + { + this->constPropInstructionCount += that.constPropInstructionCount; + this->timeSeconds += that.timeSeconds; + + return *this; + } + + BlockLinearizationStats operator+(const BlockLinearizationStats& other) const + { + BlockLinearizationStats result(*this); + result += other; + return result; + } +}; + +enum FunctionStatsFlags +{ + // Enable stats collection per function + FunctionStats_Enable = 1 << 0, + // Compute function bytecode summary + FunctionStats_BytecodeSummary = 1 << 1, +}; + +struct FunctionStats +{ + std::string name; + int line = -1; + unsigned bcodeCount = 0; + unsigned irCount = 0; + unsigned asmCount = 0; + unsigned asmSize = 0; + std::vector> bytecodeSummary; +}; + +struct LoweringStats +{ + unsigned totalFunctions = 0; + unsigned skippedFunctions = 0; + int spillsToSlot = 0; + int spillsToRestore = 0; + unsigned maxSpillSlotsUsed = 0; + unsigned blocksPreOpt = 0; + unsigned blocksPostOpt = 0; + unsigned maxBlockInstructions = 0; + + int regAllocErrors = 0; + int loweringErrors = 0; + + BlockLinearizationStats blockLinearizationStats; + + unsigned functionStatsFlags = 0; + std::vector functions; + + LoweringStats operator+(const LoweringStats& other) const + { + LoweringStats result(*this); + result += other; + return result; + } + + LoweringStats& operator+=(const LoweringStats& that) + { + this->totalFunctions += that.totalFunctions; + this->skippedFunctions += that.skippedFunctions; + this->spillsToSlot += that.spillsToSlot; + this->spillsToRestore += that.spillsToRestore; + this->maxSpillSlotsUsed = std::max(this->maxSpillSlotsUsed, that.maxSpillSlotsUsed); + this->blocksPreOpt += that.blocksPreOpt; + this->blocksPostOpt += that.blocksPostOpt; + this->maxBlockInstructions = std::max(this->maxBlockInstructions, that.maxBlockInstructions); + + this->regAllocErrors += that.regAllocErrors; + this->loweringErrors += that.loweringErrors; + + this->blockLinearizationStats += that.blockLinearizationStats; + + if (this->functionStatsFlags & FunctionStats_Enable) + this->functions.insert(this->functions.end(), that.functions.begin(), that.functions.end()); + + return *this; + } +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 803732e2..1fb1b671 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -927,6 +927,11 @@ void AssemblyBuilderX64::vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2 placeAvx("vminsd", dst, src1, src2, 0x5d, false, AVX_0F, AVX_F2); } +void AssemblyBuilderX64::vcmpeqsd(OperandX64 dst, OperandX64 src1, OperandX64 src2) +{ + placeAvx("vcmpeqsd", dst, src1, src2, 0x00, 0xc2, false, AVX_0F, AVX_F2); +} + void AssemblyBuilderX64::vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2) { placeAvx("vcmpltsd", dst, src1, src2, 0x01, 0xc2, false, AVX_0F, AVX_F2); diff --git a/CodeGen/src/BytecodeAnalysis.cpp b/CodeGen/src/BytecodeAnalysis.cpp index 85317b60..b859b111 100644 --- a/CodeGen/src/BytecodeAnalysis.cpp +++ b/CodeGen/src/BytecodeAnalysis.cpp @@ -2,7 +2,7 @@ #include "Luau/BytecodeAnalysis.h" #include "Luau/BytecodeUtils.h" -#include "Luau/CodeGen.h" +#include "Luau/CodeGenOptions.h" #include "Luau/IrData.h" #include "Luau/IrUtils.h" @@ -235,7 +235,7 @@ static uint8_t getBytecodeConstantTag(Proto* proto, unsigned ki) return LBC_TYPE_ANY; } -static void applyBuiltinCall(int bfid, BytecodeTypes& types) +static void applyBuiltinCall(LuauBuiltinFunction bfid, BytecodeTypes& types) { switch (bfid) { @@ -549,6 +549,12 @@ static void applyBuiltinCall(int bfid, BytecodeTypes& types) types.b = LBC_TYPE_VECTOR; types.c = LBC_TYPE_VECTOR; // We can mark optional arguments break; + case LBF_MATH_LERP: + types.result = LBC_TYPE_NUMBER; + types.a = LBC_TYPE_NUMBER; + types.b = LBC_TYPE_NUMBER; + types.c = LBC_TYPE_NUMBER; + break; } } @@ -842,7 +848,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; - else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + else if (hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; @@ -873,7 +880,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; } - else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + else if (hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) { regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); } @@ -895,7 +903,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER) regTags[ra] = LBC_TYPE_NUMBER; - else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + else if (hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; @@ -917,7 +926,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; - else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + else if (hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; @@ -948,7 +958,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; } - else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + else if (hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) { regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); } @@ -970,7 +981,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER) regTags[ra] = LBC_TYPE_NUMBER; - else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + else if (hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; @@ -991,7 +1003,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; - else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + else if (hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; @@ -1020,7 +1033,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; } - else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + else if (hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) { regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); } @@ -1086,7 +1100,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); int ra = LUAU_INSN_A(call); - applyBuiltinCall(bfid, bcType); + applyBuiltinCall(LuauBuiltinFunction(bfid), bcType); regTags[ra + 1] = bcType.a; regTags[ra + 2] = bcType.b; regTags[ra + 3] = bcType.c; @@ -1105,7 +1119,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); int ra = LUAU_INSN_A(call); - applyBuiltinCall(bfid, bcType); + applyBuiltinCall(LuauBuiltinFunction(bfid), bcType); regTags[LUAU_INSN_B(*pc)] = bcType.a; regTags[ra] = bcType.result; @@ -1122,7 +1136,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); int ra = LUAU_INSN_A(call); - applyBuiltinCall(bfid, bcType); + applyBuiltinCall(LuauBuiltinFunction(bfid), bcType); regTags[LUAU_INSN_B(*pc)] = bcType.a; regTags[int(pc[1])] = bcType.b; @@ -1141,7 +1155,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); int ra = LUAU_INSN_A(call); - applyBuiltinCall(bfid, bcType); + applyBuiltinCall(LuauBuiltinFunction(bfid), bcType); regTags[LUAU_INSN_B(*pc)] = bcType.a; regTags[aux & 0xff] = bcType.b; diff --git a/CodeGen/src/CodeBlockUnwind.cpp b/CodeGen/src/CodeBlockUnwind.cpp index cb2d693a..3e980566 100644 --- a/CodeGen/src/CodeBlockUnwind.cpp +++ b/CodeGen/src/CodeBlockUnwind.cpp @@ -2,6 +2,7 @@ #include "Luau/CodeBlockUnwind.h" #include "Luau/CodeAllocator.h" +#include "Luau/CodeGenCommon.h" #include "Luau/UnwindBuilder.h" #include @@ -19,9 +20,21 @@ #elif defined(__linux__) || defined(__APPLE__) -// Defined in unwind.h which may not be easily discoverable on various platforms -extern "C" void __register_frame(const void*) __attribute__((weak)); -extern "C" void __deregister_frame(const void*) __attribute__((weak)); +// __register_frame and __deregister_frame are defined in libgcc or libc++ +// (depending on how it's built). We want to declare them as weak symbols +// so that if they're provided by a shared library, we'll use them, and if +// not, we'll disable some c++ exception handling support. However, if they're +// declared as weak and the definitions are linked in a static library +// that's not linked with whole-archive, then the symbols will technically be defined here, +// and the linker won't look for the strong ones in the library. +#ifndef LUAU_ENABLE_REGISTER_FRAME +#define REGISTER_FRAME_WEAK __attribute__((weak)) +#else +#define REGISTER_FRAME_WEAK +#endif + +extern "C" void __register_frame(const void*) REGISTER_FRAME_WEAK; +extern "C" void __deregister_frame(const void*) REGISTER_FRAME_WEAK; extern "C" void __unw_add_dynamic_fde() __attribute__((weak)); #endif @@ -120,7 +133,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz #endif #elif defined(__linux__) || defined(__APPLE__) - if (!__register_frame) + if (!&__register_frame) return nullptr; visitFdeEntries(unwindData, __register_frame); @@ -149,7 +162,7 @@ void destroyBlockUnwindInfo(void* context, void* unwindData) #endif #elif defined(__linux__) || defined(__APPLE__) - if (!__deregister_frame) + if (!&__deregister_frame) { CODEGEN_ASSERT(!"Cannot deregister unwind information"); return; diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 2850dd15..a518165f 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -3,7 +3,7 @@ #include "CodeGenLower.h" -#include "Luau/Common.h" +#include "Luau/CodeGenCommon.h" #include "Luau/CodeAllocator.h" #include "Luau/CodeBlockUnwind.h" #include "Luau/IrBuilder.h" @@ -44,6 +44,7 @@ LUAU_FASTFLAGVARIABLE(DebugCodegenNoOpt) LUAU_FASTFLAGVARIABLE(DebugCodegenOptSize) LUAU_FASTFLAGVARIABLE(DebugCodegenSkipNumbering) +LUAU_FASTFLAGVARIABLE(CodegenWiderLoweringStats) // Per-module IR instruction count limit LUAU_FASTINTVARIABLE(CodegenHeuristicsInstructionLimit, 1'048'576) // 1 M @@ -166,7 +167,7 @@ bool isSupported() if (sizeof(LuaNode) != 32) return false; - // Windows CRT uses stack unwinding in longjmp so we have to use unwind data; on other platforms, it's only necessary for C++ EH. + // Windows CRT uses stack unwinding in longjmp so we have to use unwind data; on other platforms, it's only necessary for C++ EH. #if defined(_WIN32) if (!isUnwindSupported()) return false; diff --git a/CodeGen/src/CodeGenAssembly.cpp b/CodeGen/src/CodeGenAssembly.cpp index bffce517..6bbdc473 100644 --- a/CodeGen/src/CodeGenAssembly.cpp +++ b/CodeGen/src/CodeGenAssembly.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/CodeGen.h" #include "Luau/BytecodeAnalysis.h" #include "Luau/BytecodeUtils.h" #include "Luau/BytecodeSummary.h" diff --git a/CodeGen/src/CodeGenContext.cpp b/CodeGen/src/CodeGenContext.cpp index 262d4a42..82dfa17e 100644 --- a/CodeGen/src/CodeGenContext.cpp +++ b/CodeGen/src/CodeGenContext.cpp @@ -5,6 +5,7 @@ #include "CodeGenLower.h" #include "CodeGenX64.h" +#include "Luau/CodeGenCommon.h" #include "Luau/CodeBlockUnwind.h" #include "Luau/UnwindBuilder.h" #include "Luau/UnwindBuilderDwarf2.h" diff --git a/CodeGen/src/CodeGenLower.h b/CodeGen/src/CodeGenLower.h index 03eaabea..406fe5c9 100644 --- a/CodeGen/src/CodeGenLower.h +++ b/CodeGen/src/CodeGenLower.h @@ -7,6 +7,7 @@ #include "Luau/IrBuilder.h" #include "Luau/IrDump.h" #include "Luau/IrUtils.h" +#include "Luau/LoweringStats.h" #include "Luau/OptimizeConstProp.h" #include "Luau/OptimizeDeadStore.h" #include "Luau/OptimizeFinalX64.h" @@ -24,6 +25,7 @@ LUAU_FASTFLAG(DebugCodegenNoOpt) LUAU_FASTFLAG(DebugCodegenOptSize) LUAU_FASTFLAG(DebugCodegenSkipNumbering) +LUAU_FASTFLAG(CodegenWiderLoweringStats) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTINT(CodegenHeuristicsBlockLimit) LUAU_FASTINT(CodegenHeuristicsBlockInstructionLimit) @@ -101,7 +103,7 @@ inline bool lowerImpl( bool outputEnabled = options.includeAssembly || options.includeIr; - IrToStringContext ctx{build.text, function.blocks, function.constants, function.cfg}; + IrToStringContext ctx{build.text, function.blocks, function.constants, function.cfg, function.proto}; // We use this to skip outlined fallback blocks from IR/asm text output size_t textSize = build.text.length(); @@ -298,6 +300,9 @@ inline bool lowerFunction( CodeGenCompilationResult& codeGenCompilationResult ) { + if (FFlag::CodegenWiderLoweringStats) + ir.function.stats = stats; + killUnusedBlocks(ir.function); unsigned preOptBlockCount = 0; diff --git a/CodeGen/src/CodeGenUtils.cpp b/CodeGen/src/CodeGenUtils.cpp index 9bda7c81..26451eea 100644 --- a/CodeGen/src/CodeGenUtils.cpp +++ b/CodeGen/src/CodeGenUtils.cpp @@ -18,6 +18,8 @@ #include +LUAU_DYNAMIC_FASTFLAG(LuauPopIncompleteCi) + // All external function calls that can cause stack realloc or Lua calls have to be wrapped in VM_PROTECT // This makes sure that we save the pc (in case the Lua call needs to generate a backtrace) before the call, // and restores the stack pointer after in case stack gets reallocated @@ -61,7 +63,7 @@ namespace Luau namespace CodeGen { -bool forgLoopTableIter(lua_State* L, Table* h, int index, TValue* ra) +bool forgLoopTableIter(lua_State* L, LuaTable* h, int index, TValue* ra) { int sizearray = h->sizearray; @@ -104,7 +106,7 @@ bool forgLoopTableIter(lua_State* L, Table* h, int index, TValue* ra) return false; } -bool forgLoopNodeIter(lua_State* L, Table* h, int index, TValue* ra) +bool forgLoopNodeIter(lua_State* L, LuaTable* h, int index, TValue* ra) { int sizearray = h->sizearray; int sizenode = 1 << h->lsizenode; @@ -191,7 +193,14 @@ Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults) // note: this reallocs stack, but we don't need to VM_PROTECT this // this is because we're going to modify base/savedpc manually anyhow // crucially, we can't use ra/argtop after this line - luaD_checkstack(L, ccl->stacksize); + if (DFFlag::LuauPopIncompleteCi) + { + luaD_checkstackfornewci(L, ccl->stacksize); + } + else + { + luaD_checkstack(L, ccl->stacksize); + } return ccl; } @@ -224,7 +233,7 @@ Udata* newUserdata(lua_State* L, size_t s, int tag) { Udata* u = luaU_newudata(L, s, tag); - if (Table* h = L->global->udatamt[tag]) + if (LuaTable* h = L->global->udatamt[tag]) { // currently, we always allocate unmarked objects, so forward barrier can be skipped LUAU_ASSERT(!isblack(obj2gco(u))); @@ -261,7 +270,14 @@ Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults) // note: this reallocs stack, but we don't need to VM_PROTECT this // this is because we're going to modify base/savedpc manually anyhow // crucially, we can't use ra/argtop after this line - luaD_checkstack(L, ccl->stacksize); + if (DFFlag::LuauPopIncompleteCi) + { + luaD_checkstackfornewci(L, ccl->stacksize); + } + else + { + luaD_checkstack(L, ccl->stacksize); + } LUAU_ASSERT(ci->top <= L->stack_last); @@ -329,7 +345,7 @@ const Instruction* executeGETGLOBAL(lua_State* L, const Instruction* pc, StkId b LUAU_ASSERT(ttisstring(kv)); // fast-path should already have been checked, so we skip checking for it here - Table* h = cl->env; + LuaTable* h = cl->env; int slot = LUAU_INSN_C(insn) & h->nodemask8; // slow-path, may invoke Lua calls via __index metamethod @@ -352,7 +368,7 @@ const Instruction* executeSETGLOBAL(lua_State* L, const Instruction* pc, StkId b LUAU_ASSERT(ttisstring(kv)); // fast-path should already have been checked, so we skip checking for it here - Table* h = cl->env; + LuaTable* h = cl->env; int slot = LUAU_INSN_C(insn) & h->nodemask8; // slow-path, may invoke Lua calls via __newindex metamethod @@ -378,7 +394,7 @@ const Instruction* executeGETTABLEKS(lua_State* L, const Instruction* pc, StkId // fast-path: built-in table if (ttistable(rb)) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); // we ignore the fast path that checks for the cached slot since IrTranslation already checks for it. @@ -490,7 +506,7 @@ const Instruction* executeSETTABLEKS(lua_State* L, const Instruction* pc, StkId // fast-path: built-in table if (ttistable(rb)) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); // we ignore the fast path that checks for the cached slot since IrTranslation already checks for it. @@ -575,7 +591,7 @@ const Instruction* executeNAMECALL(lua_State* L, const Instruction* pc, StkId ba } else { - Table* mt = ttisuserdata(rb) ? uvalue(rb)->metatable : L->global->mt[ttype(rb)]; + LuaTable* mt = ttisuserdata(rb) ? uvalue(rb)->metatable : L->global->mt[ttype(rb)]; const TValue* tmi = 0; // fast-path: metatable with __namecall @@ -589,7 +605,7 @@ const Instruction* executeNAMECALL(lua_State* L, const Instruction* pc, StkId ba } else if ((tmi = fasttm(L, mt, TM_INDEX)) && ttistable(tmi)) { - Table* h = hvalue(tmi); + LuaTable* h = hvalue(tmi); int slot = LUAU_INSN_C(insn) & h->nodemask8; LuaNode* n = &h->node[slot]; @@ -646,7 +662,7 @@ const Instruction* executeSETLIST(lua_State* L, const Instruction* pc, StkId bas L->top = L->ci->top; } - Table* h = hvalue(ra); + LuaTable* h = hvalue(ra); // TODO: we really don't need this anymore if (!ttistable(ra)) @@ -681,7 +697,7 @@ const Instruction* executeFORGPREP(lua_State* L, const Instruction* pc, StkId ba } else { - Table* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(Table*, NULL); + LuaTable* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(LuaTable*, NULL); if (const TValue* fn = fasttm(L, mt, TM_ITER)) { diff --git a/CodeGen/src/CodeGenUtils.h b/CodeGen/src/CodeGenUtils.h index 15d4c95d..1003a6f3 100644 --- a/CodeGen/src/CodeGenUtils.h +++ b/CodeGen/src/CodeGenUtils.h @@ -8,8 +8,8 @@ namespace Luau namespace CodeGen { -bool forgLoopTableIter(lua_State* L, Table* h, int index, TValue* ra); -bool forgLoopNodeIter(lua_State* L, Table* h, int index, TValue* ra); +bool forgLoopTableIter(lua_State* L, LuaTable* h, int index, TValue* ra); +bool forgLoopNodeIter(lua_State* L, LuaTable* h, int index, TValue* ra); bool forgLoopNonTableFallback(lua_State* L, int insnA, int aux); void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc); diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index 79562b88..36b5130e 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -120,12 +120,12 @@ void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, Regist CODEGEN_ASSERT(tmp != node); CODEGEN_ASSERT(table != node); - build.mov(node, qword[table + offsetof(Table, node)]); + build.mov(node, qword[table + offsetof(LuaTable, node)]); // compute cached slot build.mov(tmp, sCode); build.movzx(dwordReg(tmp), byte[tmp + pcpos * sizeof(Instruction) + kOffsetOfInstructionC]); - build.and_(byteReg(tmp), byte[table + offsetof(Table, nodemask8)]); + build.and_(byteReg(tmp), byte[table + offsetof(LuaTable, nodemask8)]); // LuaNode* n = &h->node[slot]; build.shl(dwordReg(tmp), kLuaNodeSizeLog2); @@ -282,7 +282,7 @@ void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, Regist IrCallWrapperX64 callWrap(regs, build); callWrap.addArgument(SizeX64::qword, rState); callWrap.addArgument(SizeX64::qword, table, tableOp); - callWrap.addArgument(SizeX64::qword, addr[table + offsetof(Table, gclist)]); + callWrap.addArgument(SizeX64::qword, addr[table + offsetof(LuaTable, gclist)]); callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierback)]); } diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index ae3d1308..207f7f56 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -292,7 +292,7 @@ void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int Label skipResize; // Resize if h->sizearray < last - build.cmp(dword[table + offsetof(Table, sizearray)], last); + build.cmp(dword[table + offsetof(LuaTable, sizearray)], last); build.jcc(ConditionX64::NotBelow, skipResize); // Argument setup reordered to avoid conflicts @@ -309,7 +309,7 @@ void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int RegisterX64 arrayDst = rdx; RegisterX64 offset = rcx; - build.mov(arrayDst, qword[table + offsetof(Table, array)]); + build.mov(arrayDst, qword[table + offsetof(LuaTable, array)]); const int kUnrollSetListLimit = 4; @@ -380,7 +380,7 @@ void emitInstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRep // &array[index] build.mov(dwordReg(elemPtr), dwordReg(index)); build.shl(dwordReg(elemPtr), kTValueSizeLog2); - build.add(elemPtr, qword[table + offsetof(Table, array)]); + build.add(elemPtr, qword[table + offsetof(LuaTable, array)]); // Clear extra variables since we might have more than two for (int i = 2; i < aux; ++i) @@ -391,7 +391,7 @@ void emitInstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRep // First we advance index through the array portion // while (unsigned(index) < unsigned(sizearray)) Label arrayLoop = build.setLabel(); - build.cmp(dwordReg(index), dword[table + offsetof(Table, sizearray)]); + build.cmp(dwordReg(index), dword[table + offsetof(LuaTable, sizearray)]); build.jcc(ConditionX64::NotBelow, skipArray); // If element is nil, we increment the index; if it's not, we still need 'index + 1' inside diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index 0d2f9bd3..0d4b0a1f 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -684,7 +684,7 @@ void computeCfgDominanceTreeChildren(IrFunction& function) info.domChildrenOffsets[domParent]++; } - // Convert counds to offsets using prefix sum + // Convert counts to offsets using prefix sum uint32_t total = 0; for (size_t blockIdx = 0; blockIdx < function.blocks.size(); blockIdx++) diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index a59db8e8..dcc9d879 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -4,6 +4,8 @@ #include "Luau/IrUtils.h" #include "lua.h" +#include "lobject.h" +#include "lstate.h" #include @@ -19,6 +21,7 @@ static const char* textForCondition[] = static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(IrCondition::Count), "all conditions have to be covered"); const int kDetailsAlignColumn = 60; +const unsigned kMaxStringConstantPrintLength = 16; LUAU_PRINTF_ATTR(2, 3) static void append(std::string& result, const char* fmt, ...) @@ -39,6 +42,17 @@ static void padToDetailColumn(std::string& result, size_t lineStart) result.append(pad, ' '); } +static bool isPrintableStringConstant(const char* str, size_t len) +{ + for (size_t i = 0; i < len; ++i) + { + if (unsigned(str[i]) < ' ') + return false; + } + + return true; +} + static const char* getTagName(uint8_t tag) { switch (tag) @@ -155,6 +169,8 @@ const char* getCmdName(IrCmd cmd) return "ABS_NUM"; case IrCmd::SIGN_NUM: return "SIGN_NUM"; + case IrCmd::SELECT_NUM: + return "SELECT_NUM"; case IrCmd::ADD_VEC: return "ADD_VEC"; case IrCmd::SUB_VEC: @@ -431,6 +447,53 @@ void toString(IrToStringContext& ctx, const IrBlock& block, uint32_t index) append(ctx.result, "%s_%u", getBlockKindName(block.kind), index); } +static void appendVmConstant(std::string& result, Proto* proto, int index) +{ + TValue constant = proto->k[index]; + + if (constant.tt == LUA_TNIL) + { + append(result, "nil"); + } + else if (constant.tt == LUA_TBOOLEAN) + { + append(result, constant.value.b != 0 ? "true" : "false"); + } + else if (constant.tt == LUA_TNUMBER) + { + if (constant.value.n != constant.value.n) + append(result, "nan"); + else + append(result, "%.17g", constant.value.n); + } + else if (constant.tt == LUA_TSTRING) + { + TString* str = gco2ts(constant.value.gc); + const char* data = getstr(str); + + if (isPrintableStringConstant(data, str->len)) + { + if (str->len < kMaxStringConstantPrintLength) + append(result, "'%.*s'", int(str->len), data); + else + append(result, "'%.*s'...", int(kMaxStringConstantPrintLength), data); + } + } + else if (constant.tt == LUA_TVECTOR) + { + const float* v = constant.value.v; + +#if LUA_VECTOR_SIZE == 4 + if (v[3] != 0) + append(result, "%.9g, %.9g, %.9g, %.9g", v[0], v[1], v[2], v[3]); + else + append(result, "%.9g, %.9g, %.9g", v[0], v[1], v[2]); +#else + append(result, "%.9g, %.9g, %.9g", v[0], v[1], v[2]); +#endif + } +} + void toString(IrToStringContext& ctx, IrOp op) { switch (op.kind) @@ -458,6 +521,14 @@ void toString(IrToStringContext& ctx, IrOp op) break; case IrOpKind::VmConst: append(ctx.result, "K%d", vmConstOp(op)); + + if (ctx.proto) + { + append(ctx.result, " ("); + appendVmConstant(ctx.result, ctx.proto, vmConstOp(op)); + append(ctx.result, ")"); + } + break; case IrOpKind::VmUpvalue: append(ctx.result, "U%d", vmUpvalueOp(op)); @@ -770,7 +841,7 @@ void toStringDetailed( std::string toString(const IrFunction& function, IncludeUseInfo includeUseInfo) { std::string result; - IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; + IrToStringContext ctx{result, function.blocks, function.constants, function.cfg, function.proto}; for (size_t i = 0; i < function.blocks.size(); i++) { @@ -877,7 +948,7 @@ static void appendBlocks(IrToStringContext& ctx, const IrFunction& function, boo std::string toDot(const IrFunction& function, bool includeInst) { std::string result; - IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; + IrToStringContext ctx{result, function.blocks, function.constants, function.cfg, function.proto}; append(ctx.result, "digraph CFG {\n"); append(ctx.result, "node[shape=record]\n"); @@ -924,7 +995,7 @@ std::string toDot(const IrFunction& function, bool includeInst) std::string toDotCfg(const IrFunction& function) { std::string result; - IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; + IrToStringContext ctx{result, function.blocks, function.constants, function.cfg, function.proto}; append(ctx.result, "digraph CFG {\n"); append(ctx.result, "node[shape=record]\n"); @@ -947,7 +1018,7 @@ std::string toDotCfg(const IrFunction& function) std::string toDotDjGraph(const IrFunction& function) { std::string result; - IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; + IrToStringContext ctx{result, function.blocks, function.constants, function.cfg, function.proto}; append(ctx.result, "digraph CFG {\n"); diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index c7fcac27..1eece87f 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -4,6 +4,7 @@ #include "Luau/DenseHash.h" #include "Luau/IrData.h" #include "Luau/IrUtils.h" +#include "Luau/LoweringStats.h" #include "EmitCommonA64.h" #include "NativeState.h" @@ -12,7 +13,7 @@ #include "lgc.h" LUAU_FASTFLAG(LuauVectorLibNativeDot) -LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim) +LUAU_FASTFLAG(LuauCodeGenLerp) namespace Luau { @@ -329,7 +330,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::GET_ARR_ADDR: { inst.regA64 = regs.allocReuse(KindA64::x, index, {inst.a}); - build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, array))); + build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(LuaTable, array))); if (inst.b.kind == IrOpKind::Inst) { @@ -375,11 +376,11 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) // C field can be shifted as long as it's at the most significant byte of the instruction word CODEGEN_ASSERT(kOffsetOfInstructionC == 3); - build.ldrb(temp2, mem(regOp(inst.a), offsetof(Table, nodemask8))); + build.ldrb(temp2, mem(regOp(inst.a), offsetof(LuaTable, nodemask8))); build.and_(temp2, temp2, temp1w, -24); // note: this may clobber inst.a, so it's important that we don't use it after this - build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, node))); + build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(LuaTable, node))); build.add(inst.regA64, inst.regA64, temp2x, kLuaNodeSizeLog2); // "zero extend" temp2 to get a larger shift (top 32 bits are zero) break; } @@ -392,13 +393,13 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) // hash & ((1 << lsizenode) - 1) == hash & ~(-1 << lsizenode) build.mov(temp1, -1); - build.ldrb(temp2, mem(regOp(inst.a), offsetof(Table, lsizenode))); + build.ldrb(temp2, mem(regOp(inst.a), offsetof(LuaTable, lsizenode))); build.lsl(temp1, temp1, temp2); build.mov(temp2, uintOp(inst.b)); build.bic(temp2, temp2, temp1); // note: this may clobber inst.a, so it's important that we don't use it after this - build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, node))); + build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(LuaTable, node))); build.add(inst.regA64, inst.regA64, temp2x, kLuaNodeSizeLog2); // "zero extend" temp2 to get a larger shift (top 32 bits are zero) break; } @@ -499,7 +500,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.fcvt(temp4, temp3); build.str(temp4, AddressA64(addr.base, addr.data + 8)); - if (FFlag::LuauCodeGenVectorDeadStoreElim && inst.e.kind != IrOpKind::None) + if (inst.e.kind != IrOpKind::None) { RegisterA64 temp = regs.allocTemp(KindA64::w); build.mov(temp, tagOp(inst.e)); @@ -703,6 +704,20 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.fcsel(inst.regA64, temp1, inst.regA64, getConditionFP(IrCondition::Less)); break; } + case IrCmd::SELECT_NUM: + { + LUAU_ASSERT(FFlag::LuauCodeGenLerp); + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b, inst.c, inst.d}); + + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + RegisterA64 temp3 = tempDouble(inst.c); + RegisterA64 temp4 = tempDouble(inst.d); + + build.fcmp(temp3, temp4); + build.fcsel(inst.regA64, temp2, temp1, getConditionFP(IrCondition::Equal)); + break; + } case IrCmd::ADD_VEC: { inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b}); @@ -1060,10 +1075,10 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) RegisterA64 temp1 = regs.allocTemp(KindA64::x); RegisterA64 temp2 = regs.allocTemp(KindA64::w); - build.ldr(temp1, mem(regOp(inst.a), offsetof(Table, metatable))); + build.ldr(temp1, mem(regOp(inst.a), offsetof(LuaTable, metatable))); build.cbz(temp1, labelOp(inst.c)); // no metatable - build.ldrb(temp2, mem(temp1, offsetof(Table, tmcache))); + build.ldrb(temp2, mem(temp1, offsetof(LuaTable, tmcache))); build.tst(temp2, 1 << intOp(inst.b)); // can't use tbz/tbnz because their jump offsets are too short build.b(ConditionA64::NotEqual, labelOp(inst.c)); // Equal = Zero after tst; tmcache caches *absence* of metamethods @@ -1500,7 +1515,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { Label fresh; // used when guard aborts execution or jumps to a VM exit RegisterA64 temp = regs.allocTemp(KindA64::w); - build.ldrb(temp, mem(regOp(inst.a), offsetof(Table, readonly))); + build.ldrb(temp, mem(regOp(inst.a), offsetof(LuaTable, readonly))); build.cbnz(temp, getTargetLabel(inst.b, fresh)); finalizeTargetLabel(inst.b, fresh); break; @@ -1509,7 +1524,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { Label fresh; // used when guard aborts execution or jumps to a VM exit RegisterA64 temp = regs.allocTemp(KindA64::x); - build.ldr(temp, mem(regOp(inst.a), offsetof(Table, metatable))); + build.ldr(temp, mem(regOp(inst.a), offsetof(LuaTable, metatable))); build.cbnz(temp, getTargetLabel(inst.b, fresh)); finalizeTargetLabel(inst.b, fresh); break; @@ -1520,7 +1535,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) RegisterA64 temp = regs.allocTemp(KindA64::x); RegisterA64 tempw = castReg(KindA64::w, temp); build.ldr(temp, mem(rClosure, offsetof(Closure, env))); - build.ldrb(tempw, mem(temp, offsetof(Table, safeenv))); + build.ldrb(tempw, mem(temp, offsetof(LuaTable, safeenv))); build.cbz(tempw, getTargetLabel(inst.a, fresh)); finalizeTargetLabel(inst.a, fresh); break; @@ -1531,7 +1546,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) Label& fail = getTargetLabel(inst.c, fresh); RegisterA64 temp = regs.allocTemp(KindA64::w); - build.ldr(temp, mem(regOp(inst.a), offsetof(Table, sizearray))); + build.ldr(temp, mem(regOp(inst.a), offsetof(LuaTable, sizearray))); if (inst.b.kind == IrOpKind::Inst) { @@ -1758,7 +1773,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) size_t spills = regs.spill(build, index, {reg}); build.mov(x1, reg); build.mov(x0, rState); - build.add(x2, x1, uint16_t(offsetof(Table, gclist))); + build.add(x2, x1, uint16_t(offsetof(LuaTable, gclist))); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierback))); build.blr(x3); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 814c6d8c..0f99959f 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -4,6 +4,7 @@ #include "Luau/DenseHash.h" #include "Luau/IrData.h" #include "Luau/IrUtils.h" +#include "Luau/LoweringStats.h" #include "Luau/IrCallWrapperX64.h" @@ -16,7 +17,7 @@ #include "lgc.h" LUAU_FASTFLAG(LuauVectorLibNativeDot) -LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim) +LUAU_FASTFLAG(LuauCodeGenLerp) namespace Luau { @@ -158,13 +159,13 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.mov(dwordReg(inst.regX64), regOp(inst.b)); build.shl(dwordReg(inst.regX64), kTValueSizeLog2); - build.add(inst.regX64, qword[regOp(inst.a) + offsetof(Table, array)]); + build.add(inst.regX64, qword[regOp(inst.a) + offsetof(LuaTable, array)]); } else if (inst.b.kind == IrOpKind::Constant) { inst.regX64 = regs.allocRegOrReuse(SizeX64::qword, index, {inst.a}); - build.mov(inst.regX64, qword[regOp(inst.a) + offsetof(Table, array)]); + build.mov(inst.regX64, qword[regOp(inst.a) + offsetof(LuaTable, array)]); if (intOp(inst.b) != 0) build.lea(inst.regX64, addr[inst.regX64 + intOp(inst.b) * sizeof(TValue)]); @@ -192,9 +193,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) ScopedRegX64 tmp{regs, SizeX64::qword}; - build.mov(inst.regX64, qword[regOp(inst.a) + offsetof(Table, node)]); + build.mov(inst.regX64, qword[regOp(inst.a) + offsetof(LuaTable, node)]); build.mov(dwordReg(tmp.reg), 1); - build.mov(byteReg(shiftTmp.reg), byte[regOp(inst.a) + offsetof(Table, lsizenode)]); + build.mov(byteReg(shiftTmp.reg), byte[regOp(inst.a) + offsetof(LuaTable, lsizenode)]); build.shl(dwordReg(tmp.reg), byteReg(shiftTmp.reg)); build.dec(dwordReg(tmp.reg)); build.and_(dwordReg(tmp.reg), uintOp(inst.b)); @@ -299,7 +300,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 1), inst.c); storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 2), inst.d); - if (FFlag::LuauCodeGenVectorDeadStoreElim && inst.e.kind != IrOpKind::None) + if (inst.e.kind != IrOpKind::None) build.mov(luauRegTag(vmRegOp(inst.a)), tagOp(inst.e)); break; case IrCmd::STORE_TVALUE: @@ -622,6 +623,30 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.vblendvpd(inst.regX64, tmp1.reg, build.f64x2(1, 1), inst.regX64); break; } + case IrCmd::SELECT_NUM: + { + LUAU_ASSERT(FFlag::LuauCodeGenLerp); + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.c, inst.d}); // can't reuse b if a is a memory operand + + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + + if (inst.c.kind == IrOpKind::Inst) + build.vcmpeqsd(tmp.reg, regOp(inst.c), memRegDoubleOp(inst.d)); + else + { + build.vmovsd(tmp.reg, memRegDoubleOp(inst.c)); + build.vcmpeqsd(tmp.reg, tmp.reg, memRegDoubleOp(inst.d)); + } + + if (inst.a.kind == IrOpKind::Inst) + build.vblendvpd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b), tmp.reg); + else + { + build.vmovsd(inst.regX64, memRegDoubleOp(inst.a)); + build.vblendvpd(inst.regX64, inst.regX64, memRegDoubleOp(inst.b), tmp.reg); + } + break; + } case IrCmd::ADD_VEC: { inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); @@ -929,13 +954,13 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { ScopedRegX64 tmp{regs, SizeX64::qword}; - build.mov(tmp.reg, qword[regOp(inst.a) + offsetof(Table, metatable)]); + build.mov(tmp.reg, qword[regOp(inst.a) + offsetof(LuaTable, metatable)]); regs.freeLastUseReg(function.instOp(inst.a), index); // Release before the call if it's the last use build.test(tmp.reg, tmp.reg); build.jcc(ConditionX64::Zero, labelOp(inst.c)); // No metatable - build.test(byte[tmp.reg + offsetof(Table, tmcache)], 1 << intOp(inst.b)); + build.test(byte[tmp.reg + offsetof(LuaTable, tmcache)], 1 << intOp(inst.b)); build.jcc(ConditionX64::NotZero, labelOp(inst.c)); // No tag method ScopedRegX64 tmp2{regs, SizeX64::qword}; @@ -1295,11 +1320,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) break; } case IrCmd::CHECK_READONLY: - build.cmp(byte[regOp(inst.a) + offsetof(Table, readonly)], 0); + build.cmp(byte[regOp(inst.a) + offsetof(LuaTable, readonly)], 0); jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.b, next); break; case IrCmd::CHECK_NO_METATABLE: - build.cmp(qword[regOp(inst.a) + offsetof(Table, metatable)], 0); + build.cmp(qword[regOp(inst.a) + offsetof(LuaTable, metatable)], 0); jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.b, next); break; case IrCmd::CHECK_SAFE_ENV: @@ -1308,16 +1333,16 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.mov(tmp.reg, sClosure); build.mov(tmp.reg, qword[tmp.reg + offsetof(Closure, env)]); - build.cmp(byte[tmp.reg + offsetof(Table, safeenv)], 0); + build.cmp(byte[tmp.reg + offsetof(LuaTable, safeenv)], 0); jumpOrAbortOnUndef(ConditionX64::Equal, inst.a, next); break; } case IrCmd::CHECK_ARRAY_SIZE: if (inst.b.kind == IrOpKind::Inst) - build.cmp(dword[regOp(inst.a) + offsetof(Table, sizearray)], regOp(inst.b)); + build.cmp(dword[regOp(inst.a) + offsetof(LuaTable, sizearray)], regOp(inst.b)); else if (inst.b.kind == IrOpKind::Constant) - build.cmp(dword[regOp(inst.a) + offsetof(Table, sizearray)], intOp(inst.b)); + build.cmp(dword[regOp(inst.a) + offsetof(LuaTable, sizearray)], intOp(inst.b)); else CODEGEN_ASSERT(!"Unsupported instruction form"); diff --git a/CodeGen/src/IrRegAllocA64.cpp b/CodeGen/src/IrRegAllocA64.cpp index bd2147a7..15a306c9 100644 --- a/CodeGen/src/IrRegAllocA64.cpp +++ b/CodeGen/src/IrRegAllocA64.cpp @@ -2,8 +2,8 @@ #include "IrRegAllocA64.h" #include "Luau/AssemblyBuilderA64.h" -#include "Luau/CodeGen.h" #include "Luau/IrUtils.h" +#include "Luau/LoweringStats.h" #include "BitUtils.h" #include "EmitCommonA64.h" diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index d647484b..64625868 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -1,8 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/IrRegAllocX64.h" -#include "Luau/CodeGen.h" #include "Luau/IrUtils.h" +#include "Luau/LoweringStats.h" #include "EmitCommonX64.h" diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index ebded522..a5fa3ad0 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -15,6 +15,7 @@ static const int kBit32BinaryOpUnrolledParams = 5; LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen); LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeDot); +LUAU_FASTFLAGVARIABLE(LuauCodeGenLerp); namespace Luau { @@ -284,6 +285,42 @@ static BuiltinImplResult translateBuiltinMathClamp( return {BuiltinImplType::UsesFallback, 1}; } +static BuiltinImplResult translateBuiltinMathLerp( + IrBuilder& build, + int nparams, + int ra, + int arg, + IrOp args, + IrOp arg3, + int nresults, + IrOp fallback, + int pcpos +) +{ + LUAU_ASSERT(FFlag::LuauCodeGenLerp); + + if (nparams < 3 || nresults > 1) + return {BuiltinImplType::None, -1}; + + builtinCheckDouble(build, build.vmReg(arg), pcpos); + builtinCheckDouble(build, args, pcpos); + builtinCheckDouble(build, arg3, pcpos); + + IrOp a = builtinLoadDouble(build, build.vmReg(arg)); + IrOp b = builtinLoadDouble(build, args); + IrOp t = builtinLoadDouble(build, arg3); + + IrOp l = build.inst(IrCmd::ADD_NUM, a, build.inst(IrCmd::MUL_NUM, build.inst(IrCmd::SUB_NUM, b, a), t)); + IrOp r = build.inst(IrCmd::SELECT_NUM, l, b, t, build.constDouble(1.0)); // select on t==1.0 + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), r); + + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::Full, 1}; +} + static BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, int nresults, int pcpos) { if (nparams < 1 || nresults > 1) @@ -1387,6 +1424,8 @@ BuiltinImplResult translateBuiltin( case LBF_VECTOR_MAX: return FFlag::LuauVectorLibNativeCodegen ? translateBuiltinVectorMap2(build, IrCmd::MAX_NUM, nparams, ra, arg, args, arg3, nresults, pcpos) : noneResult; + case LBF_MATH_LERP: + return FFlag::LuauCodeGenLerp ? translateBuiltinMathLerp(build, nparams, ra, arg, args, arg3, nresults, fallback, pcpos) : noneResult; default: return {BuiltinImplType::None, -1}; } diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 62829766..d15d57e2 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -3,7 +3,7 @@ #include "Luau/Bytecode.h" #include "Luau/BytecodeUtils.h" -#include "Luau/CodeGen.h" +#include "Luau/CodeGenOptions.h" #include "Luau/IrBuilder.h" #include "Luau/IrUtils.h" diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 5f384807..02d19e49 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/IrUtils.h" +#include "Luau/CodeGenOptions.h" #include "Luau/IrBuilder.h" #include "BitUtils.h" @@ -9,10 +10,14 @@ #include "lua.h" #include "lnumutils.h" +#include +#include + #include #include LUAU_FASTFLAG(LuauVectorLibNativeDot); +LUAU_FASTFLAG(LuauCodeGenLerp); namespace Luau { @@ -70,6 +75,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::SQRT_NUM: case IrCmd::ABS_NUM: case IrCmd::SIGN_NUM: + case IrCmd::SELECT_NUM: return IrValueKind::Double; case IrCmd::ADD_VEC: case IrCmd::SUB_VEC: @@ -656,6 +662,16 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 substitute(function, inst, build.constDouble(v > 0.0 ? 1.0 : v < 0.0 ? -1.0 : 0.0)); } break; + case IrCmd::SELECT_NUM: + LUAU_ASSERT(FFlag::LuauCodeGenLerp); + if (inst.c.kind == IrOpKind::Constant && inst.d.kind == IrOpKind::Constant) + { + double c = function.doubleOp(inst.c); + double d = function.doubleOp(inst.d); + + substitute(function, inst, c == d ? inst.b : inst.a); + } + break; case IrCmd::NOT_ANY: if (inst.a.kind == IrOpKind::Constant) { diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index 941db252..b4f74132 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -44,25 +44,25 @@ struct NativeContext void (*luaV_dolen)(lua_State* L, StkId ra, const TValue* rb) = nullptr; void (*luaV_gettable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr; void (*luaV_settable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr; - void (*luaV_getimport)(lua_State* L, Table* env, TValue* k, StkId res, uint32_t id, bool propagatenil) = nullptr; + void (*luaV_getimport)(lua_State* L, LuaTable* env, TValue* k, StkId res, uint32_t id, bool propagatenil) = nullptr; void (*luaV_concat)(lua_State* L, int total, int last) = nullptr; - int (*luaH_getn)(Table* t) = nullptr; - Table* (*luaH_new)(lua_State* L, int narray, int lnhash) = nullptr; - Table* (*luaH_clone)(lua_State* L, Table* tt) = nullptr; - void (*luaH_resizearray)(lua_State* L, Table* t, int nasize) = nullptr; - TValue* (*luaH_setnum)(lua_State* L, Table* t, int key); + int (*luaH_getn)(LuaTable* t) = nullptr; + LuaTable* (*luaH_new)(lua_State* L, int narray, int lnhash) = nullptr; + LuaTable* (*luaH_clone)(lua_State* L, LuaTable* tt) = nullptr; + void (*luaH_resizearray)(lua_State* L, LuaTable* t, int nasize) = nullptr; + TValue* (*luaH_setnum)(lua_State* L, LuaTable* t, int key); - void (*luaC_barriertable)(lua_State* L, Table* t, GCObject* v) = nullptr; + void (*luaC_barriertable)(lua_State* L, LuaTable* t, GCObject* v) = nullptr; void (*luaC_barrierf)(lua_State* L, GCObject* o, GCObject* v) = nullptr; void (*luaC_barrierback)(lua_State* L, GCObject* o, GCObject** gclist) = nullptr; size_t (*luaC_step)(lua_State* L, bool assist) = nullptr; void (*luaF_close)(lua_State* L, StkId level) = nullptr; UpVal* (*luaF_findupval)(lua_State* L, StkId level) = nullptr; - Closure* (*luaF_newLclosure)(lua_State* L, int nelems, Table* e, Proto* p) = nullptr; + Closure* (*luaF_newLclosure)(lua_State* L, int nelems, LuaTable* e, Proto* p) = nullptr; - const TValue* (*luaT_gettm)(Table* events, TMS event, TString* ename) = nullptr; + const TValue* (*luaT_gettm)(LuaTable* events, TMS event, TString* ename) = nullptr; const TString* (*luaT_objtypenamestr)(lua_State* L, const TValue* o) = nullptr; double (*libm_exp)(double) = nullptr; @@ -87,8 +87,8 @@ struct NativeContext double (*libm_modf)(double, double*) = nullptr; // Helper functions - bool (*forgLoopTableIter)(lua_State* L, Table* h, int index, TValue* ra) = nullptr; - bool (*forgLoopNodeIter)(lua_State* L, Table* h, int index, TValue* ra) = nullptr; + bool (*forgLoopTableIter)(lua_State* L, LuaTable* h, int index, TValue* ra) = nullptr; + bool (*forgLoopNodeIter)(lua_State* L, LuaTable* h, int index, TValue* ra) = nullptr; bool (*forgLoopNonTableFallback)(lua_State* L, int insnA, int aux) = nullptr; void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr; Closure* (*callProlog)(lua_State* L, TValue* ra, StkId argtop, int nresults) = nullptr; diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 1e532280..e2fefbed 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -18,9 +19,10 @@ LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64) +LUAU_FASTINTVARIABLE(LuauCodeGenLiveSlotReuseLimit, 8) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks) -LUAU_FASTFLAG(LuauVectorLibNativeDot); -LUAU_FASTFLAGVARIABLE(LuauCodeGenArithOpt); +LUAU_FASTFLAG(LuauVectorLibNativeDot) +LUAU_FASTFLAGVARIABLE(LuauCodeGenLimitLiveSlotReuse) namespace Luau { @@ -50,6 +52,14 @@ struct RegisterLink uint32_t version = 0; }; +// Reference to an instruction together with the position of that instruction in the current block chain and the last position of reuse +struct NumberedInstruction +{ + uint32_t instIdx = 0; + uint32_t startPos = 0; + uint32_t finishPos = 0; +}; + // Data we know about the current VM state struct ConstPropState { @@ -190,7 +200,11 @@ struct ConstPropState // Same goes for table array elements as well void invalidateHeapTableData() { - getSlotNodeCache.clear(); + if (FFlag::LuauCodeGenLimitLiveSlotReuse) + getSlotNodeCache.clear(); + else + getSlotNodeCache_DEPRECATED.clear(); + checkSlotMatchCache.clear(); getArrAddrCache.clear(); @@ -409,6 +423,64 @@ struct ConstPropState valueMap[versionedVmRegLoad(loadCmd, storeInst.a)] = storeInst.b.index; } + // Used to compute the pressure of the cached value 'set' on the spill registers + // We want to find out the maximum live range intersection count between the cached value at 'slot' and current instruction + // Note that this pressure is approximate, as some values that might have been live at one point could have been marked dead later + int getMaxInternalOverlap(std::vector& set, size_t slot) + { + CODEGEN_ASSERT(FFlag::LuauCodeGenLimitLiveSlotReuse); + + // Start with one live range for the slot we want to reuse + int curr = 1; + + // For any slots where lifetime began before the slot of interest, mark as live if lifetime end is still active + // This saves us from processing slots [0; slot] in the range sweep later, which requires sorting the lifetime end points + for (size_t i = 0; i < slot; i++) + { + if (set[i].finishPos >= set[slot].startPos) + curr++; + } + + int max = curr; + + // Collect lifetime end points and sort them + rangeEndTemp.clear(); + + for (size_t i = slot + 1; i < set.size(); i++) + rangeEndTemp.push_back(set[i].finishPos); + + std::sort(rangeEndTemp.begin(), rangeEndTemp.end()); + + // Go over the lifetime begin/end ranges that we store as separate array and walk based on the smallest of values + for (size_t i1 = slot + 1, i2 = 0; i1 < set.size() && i2 < rangeEndTemp.size();) + { + if (rangeEndTemp[i2] == set[i1].startPos) + { + i1++; + i2++; + } + else if (rangeEndTemp[i2] < set[i1].startPos) + { + CODEGEN_ASSERT(curr > 0); + + curr--; + i2++; + } + else + { + curr++; + i1++; + + if (curr > max) + max = curr; + } + } + + // We might have unprocessed lifetime end entries, but we will never have unprocessed lifetime start entries + // Not that lifetime end entries can only decrease the current value and do not affect the end result (maximum) + return max; + } + void clear() { for (int i = 0; i <= maxReg; ++i) @@ -416,6 +488,9 @@ struct ConstPropState maxReg = 0; + if (FFlag::LuauCodeGenLimitLiveSlotReuse) + instPos = 0u; + inSafeEnv = false; checkedGc = false; @@ -436,6 +511,9 @@ struct ConstPropState // For range/full invalidations, we only want to visit a limited number of data that we have recorded int maxReg = 0; + // Number of the instruction being processed + uint32_t instPos = 0; + bool inSafeEnv = false; bool checkedGc = false; @@ -447,7 +525,8 @@ struct ConstPropState std::vector tryNumToIndexCache; // Fallback block argument might be different // Heap changes might affect table state - std::vector getSlotNodeCache; // Additionally, pcpos argument might be different + std::vector getSlotNodeCache; // Additionally, pcpos argument might be different + std::vector getSlotNodeCache_DEPRECATED; // Additionally, pcpos argument might be different std::vector checkSlotMatchCache; // Additionally, fallback block argument might be different std::vector getArrAddrCache; @@ -457,6 +536,8 @@ struct ConstPropState // Userdata tag cache can point to both NEW_USERDATA and CHECK_USERDATA_TAG instructions std::vector useradataTagCache; // Additionally, fallback block argument might be different + + std::vector rangeEndTemp; }; static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid, uint32_t firstReturnReg, int nresults) @@ -550,6 +631,7 @@ static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid case LBF_VECTOR_CLAMP: case LBF_VECTOR_MIN: case LBF_VECTOR_MAX: + case LBF_MATH_LERP: break; case LBF_TABLE_INSERT: state.invalidateHeap(); @@ -569,6 +651,9 @@ static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& function, IrBlock& block, IrInst& inst, uint32_t index) { + if (FFlag::LuauCodeGenLimitLiveSlotReuse) + state.instPos++; + switch (inst.cmd) { case IrCmd::LOAD_TAG: @@ -1175,19 +1260,49 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.getArrAddrCache.push_back(index); break; case IrCmd::GET_SLOT_NODE_ADDR: - for (uint32_t prevIdx : state.getSlotNodeCache) + if (FFlag::LuauCodeGenLimitLiveSlotReuse) { - const IrInst& prev = function.instructions[prevIdx]; - - if (prev.a == inst.a && prev.c == inst.c) + for (size_t i = 0; i < state.getSlotNodeCache.size(); i++) { - substitute(function, inst, IrOp{IrOpKind::Inst, prevIdx}); - return; // Break out from both the loop and the switch - } - } + auto&& [prevIdx, num, lastNum] = state.getSlotNodeCache[i]; - if (int(state.getSlotNodeCache.size()) < FInt::LuauCodeGenReuseSlotLimit) - state.getSlotNodeCache.push_back(index); + const IrInst& prev = function.instructions[prevIdx]; + + if (prev.a == inst.a && prev.c == inst.c) + { + // Check if this reuse will increase the overall register pressure over the limit + int limit = FInt::LuauCodeGenLiveSlotReuseLimit; + + if (int(state.getSlotNodeCache.size()) > limit && state.getMaxInternalOverlap(state.getSlotNodeCache, i) > limit) + return; + + // Update live range of the value from the optimization standpoint + lastNum = state.instPos; + + substitute(function, inst, IrOp{IrOpKind::Inst, prevIdx}); + return; // Break out from both the loop and the switch + } + } + + if (int(state.getSlotNodeCache.size()) < FInt::LuauCodeGenReuseSlotLimit) + state.getSlotNodeCache.push_back({index, state.instPos, state.instPos}); + } + else + { + for (uint32_t prevIdx : state.getSlotNodeCache_DEPRECATED) + { + const IrInst& prev = function.instructions[prevIdx]; + + if (prev.a == inst.a && prev.c == inst.c) + { + substitute(function, inst, IrOp{IrOpKind::Inst, prevIdx}); + return; // Break out from both the loop and the switch + } + } + + if (int(state.getSlotNodeCache_DEPRECATED.size()) < FInt::LuauCodeGenReuseSlotLimit) + state.getSlotNodeCache_DEPRECATED.push_back(index); + } break; case IrCmd::GET_HASH_NODE_ADDR: case IrCmd::GET_CLOSURE_UPVAL_ADDR: @@ -1198,17 +1313,12 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& break; case IrCmd::ADD_NUM: case IrCmd::SUB_NUM: - if (FFlag::LuauCodeGenArithOpt) + if (std::optional k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b))) { - if (std::optional k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b))) - { - // a + 0.0 and a - (-0.0) can't be folded since the behavior is different for negative zero - // however, a - 0.0 and a + (-0.0) can be folded into a - if (*k == 0.0 && bool(signbit(*k)) == (inst.cmd == IrCmd::ADD_NUM)) - substitute(function, inst, inst.a); - else - state.substituteOrRecord(inst, index); - } + // a + 0.0 and a - (-0.0) can't be folded since the behavior is different for negative zero + // however, a - 0.0 and a + (-0.0) can be folded into a + if (*k == 0.0 && bool(signbit(*k)) == (inst.cmd == IrCmd::ADD_NUM)) + substitute(function, inst, inst.a); else state.substituteOrRecord(inst, index); } @@ -1216,19 +1326,14 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.substituteOrRecord(inst, index); break; case IrCmd::MUL_NUM: - if (FFlag::LuauCodeGenArithOpt) + if (std::optional k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b))) { - if (std::optional k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b))) - { - if (*k == 1.0) // a * 1.0 = a - substitute(function, inst, inst.a); - else if (*k == 2.0) // a * 2.0 = a + a - replace(function, block, index, {IrCmd::ADD_NUM, inst.a, inst.a}); - else if (*k == -1.0) // a * -1.0 = -a - replace(function, block, index, {IrCmd::UNM_NUM, inst.a}); - else - state.substituteOrRecord(inst, index); - } + if (*k == 1.0) // a * 1.0 = a + substitute(function, inst, inst.a); + else if (*k == 2.0) // a * 2.0 = a + a + replace(function, block, index, {IrCmd::ADD_NUM, inst.a, inst.a}); + else if (*k == -1.0) // a * -1.0 = -a + replace(function, block, index, {IrCmd::UNM_NUM, inst.a}); else state.substituteOrRecord(inst, index); } @@ -1236,19 +1341,14 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.substituteOrRecord(inst, index); break; case IrCmd::DIV_NUM: - if (FFlag::LuauCodeGenArithOpt) + if (std::optional k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b))) { - if (std::optional k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b))) - { - if (*k == 1.0) // a / 1.0 = a - substitute(function, inst, inst.a); - else if (*k == -1.0) // a / -1.0 = -a - replace(function, block, index, {IrCmd::UNM_NUM, inst.a}); - else if (int exp = 0; frexp(*k, &exp) == 0.5 && exp >= -1000 && exp <= 1000) // a / 2^k = a * 2^-k - replace(function, block, index, {IrCmd::MUL_NUM, inst.a, build.constDouble(1.0 / *k)}); - else - state.substituteOrRecord(inst, index); - } + if (*k == 1.0) // a / 1.0 = a + substitute(function, inst, inst.a); + else if (*k == -1.0) // a / -1.0 = -a + replace(function, block, index, {IrCmd::UNM_NUM, inst.a}); + else if (int exp = 0; frexp(*k, &exp) == 0.5 && exp >= -1000 && exp <= 1000) // a / 2^k = a * 2^-k + replace(function, block, index, {IrCmd::MUL_NUM, inst.a, build.constDouble(1.0 / *k)}); else state.substituteOrRecord(inst, index); } @@ -1266,6 +1366,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::SQRT_NUM: case IrCmd::ABS_NUM: case IrCmd::SIGN_NUM: + case IrCmd::SELECT_NUM: case IrCmd::NOT_ANY: state.substituteOrRecord(inst, index); break; diff --git a/CodeGen/src/OptimizeDeadStore.cpp b/CodeGen/src/OptimizeDeadStore.cpp index 8362cf2b..1483e4a2 100644 --- a/CodeGen/src/OptimizeDeadStore.cpp +++ b/CodeGen/src/OptimizeDeadStore.cpp @@ -9,8 +9,6 @@ #include "lobject.h" -LUAU_FASTFLAGVARIABLE(LuauCodeGenVectorDeadStoreElim) - // TODO: optimization can be improved by knowing which registers are live in at each VM exit namespace Luau @@ -326,27 +324,19 @@ static bool tryReplaceTagWithFullStore( // And value store has to follow, as the pre-DSO code would not allow GC to observe an incomplete stack variable if (tag != LUA_TNIL && regInfo.valueInstIdx != ~0u) { - if (FFlag::LuauCodeGenVectorDeadStoreElim) - { - IrInst& prevValueInst = function.instructions[regInfo.valueInstIdx]; + IrInst& prevValueInst = function.instructions[regInfo.valueInstIdx]; - if (prevValueInst.cmd == IrCmd::STORE_VECTOR) - { - CODEGEN_ASSERT(prevValueInst.e.kind == IrOpKind::None); - IrOp prevValueX = prevValueInst.b; - IrOp prevValueY = prevValueInst.c; - IrOp prevValueZ = prevValueInst.d; - replace(function, block, instIndex, IrInst{IrCmd::STORE_VECTOR, targetOp, prevValueX, prevValueY, prevValueZ, tagOp}); - } - else - { - IrOp prevValueOp = prevValueInst.b; - replace(function, block, instIndex, IrInst{IrCmd::STORE_SPLIT_TVALUE, targetOp, tagOp, prevValueOp}); - } + if (prevValueInst.cmd == IrCmd::STORE_VECTOR) + { + CODEGEN_ASSERT(prevValueInst.e.kind == IrOpKind::None); + IrOp prevValueX = prevValueInst.b; + IrOp prevValueY = prevValueInst.c; + IrOp prevValueZ = prevValueInst.d; + replace(function, block, instIndex, IrInst{IrCmd::STORE_VECTOR, targetOp, prevValueX, prevValueY, prevValueZ, tagOp}); } else { - IrOp prevValueOp = function.instructions[regInfo.valueInstIdx].b; + IrOp prevValueOp = prevValueInst.b; replace(function, block, instIndex, IrInst{IrCmd::STORE_SPLIT_TVALUE, targetOp, tagOp, prevValueOp}); } } @@ -385,7 +375,7 @@ static bool tryReplaceTagWithFullStore( state.hasGcoToClear |= regInfo.maybeGco; return true; } - else if (FFlag::LuauCodeGenVectorDeadStoreElim && prev.cmd == IrCmd::STORE_VECTOR) + else if (prev.cmd == IrCmd::STORE_VECTOR) { // If the 'nil' is stored, we keep 'STORE_TAG Rn, tnil' as it writes the 'full' TValue if (tag != LUA_TNIL) @@ -455,7 +445,7 @@ static bool tryReplaceValueWithFullStore( regInfo.tvalueInstIdx = instIndex; return true; } - else if (FFlag::LuauCodeGenVectorDeadStoreElim && prev.cmd == IrCmd::STORE_VECTOR) + else if (prev.cmd == IrCmd::STORE_VECTOR) { IrOp prevTagOp = prev.e; CODEGEN_ASSERT(prevTagOp.kind != IrOpKind::None); @@ -483,8 +473,6 @@ static bool tryReplaceVectorValueWithFullStore( StoreRegInfo& regInfo ) { - CODEGEN_ASSERT(FFlag::LuauCodeGenVectorDeadStoreElim); - // If the tag+value pair is established, we can mark both as dead and use a single split TValue store if (regInfo.tagInstIdx != ~0u && regInfo.valueInstIdx != ~0u) { @@ -631,29 +619,22 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, case IrCmd::STORE_VECTOR: if (inst.a.kind == IrOpKind::VmReg) { - if (FFlag::LuauCodeGenVectorDeadStoreElim) - { - int reg = vmRegOp(inst.a); + int reg = vmRegOp(inst.a); - if (function.cfg.captured.regs.test(reg)) - return; + if (function.cfg.captured.regs.test(reg)) + return; - StoreRegInfo& regInfo = state.info[reg]; + StoreRegInfo& regInfo = state.info[reg]; - if (tryReplaceVectorValueWithFullStore(state, build, function, block, index, regInfo)) - break; + if (tryReplaceVectorValueWithFullStore(state, build, function, block, index, regInfo)) + break; - // Partial value store can be removed by a new one if the tag is known - if (regInfo.knownTag != kUnknownTag) - state.killValueStore(regInfo); + // Partial value store can be removed by a new one if the tag is known + if (regInfo.knownTag != kUnknownTag) + state.killValueStore(regInfo); - regInfo.valueInstIdx = index; - regInfo.maybeGco = false; - } - else - { - state.useReg(vmRegOp(inst.a)); - } + regInfo.valueInstIdx = index; + regInfo.maybeGco = false; } break; case IrCmd::STORE_TVALUE: diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 8d281393..a151056c 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -613,6 +613,9 @@ enum LuauBuiltinFunction LBF_VECTOR_CLAMP, LBF_VECTOR_MIN, LBF_VECTOR_MAX, + + // math.lerp + LBF_MATH_LERP, }; // Capture type, used in LOP_CAPTURE diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index c534bcb4..68ae1e8c 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -13,7 +13,8 @@ inline bool isFlagExperimental(const char* flag) static const char* const kList[] = { "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code "LuauFixIndexerSubtypingOrdering", // requires some small fixes to lua-apps code since this fixes a false negative - "StudioReportLuauAny2", // takes telemetry data for usage of any types + "StudioReportLuauAny2", // takes telemetry data for usage of any types + "LuauTableCloneClonesType3", // requires fixes in lua-apps code, terrifyingly "LuauSolverV2", // makes sure we always have at least one entry nullptr, diff --git a/Compiler/include/Luau/Compiler.h b/Compiler/include/Luau/Compiler.h index b37b58ff..2c82116d 100644 --- a/Compiler/include/Luau/Compiler.h +++ b/Compiler/include/Luau/Compiler.h @@ -13,6 +13,16 @@ struct ParseResult; class BytecodeBuilder; class BytecodeEncoder; +using CompileConstant = void*; + +// return a type identifier for a global library member +// values are defined by 'enum LuauBytecodeType' in Bytecode.h +using LibraryMemberTypeCallback = int (*)(const char* library, const char* member); + +// setup a value of a constant for a global library member +// use setCompileConstant*** set of functions for values +using LibraryMemberConstantCallback = void (*)(const char* library, const char* member, CompileConstant* constant); + // Note: this structure is duplicated in luacode.h, don't forget to change these in sync! struct CompileOptions { @@ -49,6 +59,15 @@ struct CompileOptions // null-terminated array of userdata types that will be included in the type information const char* const* userdataTypes = nullptr; + + // null-terminated array of globals which act as libraries and have members with known type and/or constant value + // when an import of one of these libraries is accessed, callbacks below will be called to receive that information + const char* const* librariesWithKnownMembers = nullptr; + LibraryMemberTypeCallback libraryMemberTypeCb = nullptr; + LibraryMemberConstantCallback libraryMemberConstantCb = nullptr; + + // null-terminated array of library functions that should not be compiled into a built-in fastcall ("name" "lib.name") + const char* const* disabledBuiltins = nullptr; }; class CompileError : public std::exception @@ -81,4 +100,10 @@ std::string compile( BytecodeEncoder* encoder = nullptr ); +void setCompileConstantNil(CompileConstant* constant); +void setCompileConstantBoolean(CompileConstant* constant, bool b); +void setCompileConstantNumber(CompileConstant* constant, double n); +void setCompileConstantVector(CompileConstant* constant, float x, float y, float z, float w); +void setCompileConstantString(CompileConstant* constant, const char* s, size_t l); + } // namespace Luau diff --git a/Compiler/include/luacode.h b/Compiler/include/luacode.h index 1eaf28d4..4445af43 100644 --- a/Compiler/include/luacode.h +++ b/Compiler/include/luacode.h @@ -3,12 +3,21 @@ #include -// Can be used to reconfigure visibility/exports for public APIs +// can be used to reconfigure visibility/exports for public APIs #ifndef LUACODE_API #define LUACODE_API extern #endif typedef struct lua_CompileOptions lua_CompileOptions; +typedef void* lua_CompileConstant; + +// return a type identifier for a global library member +// values are defined by 'enum LuauBytecodeType' in Bytecode.h +typedef int (*lua_LibraryMemberTypeCallback)(const char* library, const char* member); + +// setup a value of a constant for a global library member +// use luau_set_compile_constant_*** set of functions for values +typedef void (*lua_LibraryMemberConstantCallback)(const char* library, const char* member, lua_CompileConstant* constant); struct lua_CompileOptions { @@ -45,7 +54,25 @@ struct lua_CompileOptions // null-terminated array of userdata types that will be included in the type information const char* const* userdataTypes; + + // null-terminated array of globals which act as libraries and have members with known type and/or constant value + // when an import of one of these libraries is accessed, callbacks below will be called to receive that information + const char* const* librariesWithKnownMembers; + lua_LibraryMemberTypeCallback libraryMemberTypeCb; + lua_LibraryMemberConstantCallback libraryMemberConstantCb; + + // null-terminated array of library functions that should not be compiled into a built-in fastcall ("name" "lib.name") + const char* const* disabledBuiltins; }; // compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy LUACODE_API char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize); + +// when libraryMemberConstantCb is called, these methods can be used to set a value of the opaque lua_CompileConstant struct +// vector component 'w' is not visible to VM runtime configured with LUA_VECTOR_SIZE == 3, but can affect constant folding during compilation +// string storage must outlive the invocation of 'luau_compile' which used the callback +LUACODE_API void luau_set_compile_constant_nil(lua_CompileConstant* constant); +LUACODE_API void luau_set_compile_constant_boolean(lua_CompileConstant* constant, int b); +LUACODE_API void luau_set_compile_constant_number(lua_CompileConstant* constant, double n); +LUACODE_API void luau_set_compile_constant_vector(lua_CompileConstant* constant, float x, float y, float z, float w); +LUACODE_API void luau_set_compile_constant_string(lua_CompileConstant* constant, const char* s, size_t l); diff --git a/Compiler/src/BuiltinFolding.cpp b/Compiler/src/BuiltinFolding.cpp index 0886e94a..d6aeb3dd 100644 --- a/Compiler/src/BuiltinFolding.cpp +++ b/Compiler/src/BuiltinFolding.cpp @@ -5,6 +5,9 @@ #include +LUAU_FASTFLAGVARIABLE(LuauVector2Constants) +LUAU_FASTFLAG(LuauCompileMathLerp) + namespace Luau { namespace Compile @@ -471,14 +474,29 @@ Constant foldBuiltin(int bfid, const Constant* args, size_t count) break; case LBF_VECTOR: - if (count >= 3 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number) + if (count >= 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number) { - if (count == 3) + if (count == 2 && FFlag::LuauVector2Constants) + return cvector(args[0].valueNumber, args[1].valueNumber, 0.0, 0.0); + else if (count == 3 && args[2].type == Constant::Type_Number) return cvector(args[0].valueNumber, args[1].valueNumber, args[2].valueNumber, 0.0); - else if (count == 4 && args[3].type == Constant::Type_Number) + else if (count == 4 && args[2].type == Constant::Type_Number && args[3].type == Constant::Type_Number) return cvector(args[0].valueNumber, args[1].valueNumber, args[2].valueNumber, args[3].valueNumber); } break; + + case LBF_MATH_LERP: + if (FFlag::LuauCompileMathLerp && count == 3 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && + args[2].type == Constant::Type_Number) + { + double a = args[0].valueNumber; + double b = args[1].valueNumber; + double t = args[2].valueNumber; + + double v = (t == 1.0) ? b : a + (b - a) * t; + return cnum(v); + } + break; } return cvar(); diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index d5d23629..64d4d3f2 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -3,8 +3,11 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" +#include "Luau/Lexer.h" -LUAU_FASTFLAGVARIABLE(LuauVectorBuiltins) +#include + +LUAU_FASTFLAGVARIABLE(LuauCompileMathLerp) namespace Luau { @@ -136,6 +139,8 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op return LBF_MATH_SIGN; if (builtin.method == "round") return LBF_MATH_ROUND; + if (FFlag::LuauCompileMathLerp && builtin.method == "lerp") + return LBF_MATH_LERP; } if (builtin.object == "bit32") @@ -222,7 +227,7 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op return LBF_BUFFER_WRITEF64; } - if (FFlag::LuauVectorBuiltins && builtin.object == "vector") + if (builtin.object == "vector") { if (builtin.method == "create") return LBF_VECTOR; @@ -270,23 +275,58 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op struct BuiltinVisitor : AstVisitor { DenseHashMap& result; + std::array builtinIsDisabled; const DenseHashMap& globals; const DenseHashMap& variables; const CompileOptions& options; + const AstNameTable& names; BuiltinVisitor( DenseHashMap& result, const DenseHashMap& globals, const DenseHashMap& variables, - const CompileOptions& options + const CompileOptions& options, + const AstNameTable& names ) : result(result) , globals(globals) , variables(variables) , options(options) + , names(names) { + builtinIsDisabled.fill(false); + + if (const char* const* ptr = options.disabledBuiltins) + { + for (; *ptr; ++ptr) + { + if (const char* dot = strchr(*ptr, '.')) + { + AstName library = names.getWithType(*ptr, dot - *ptr).first; + AstName name = names.get(dot + 1); + + if (library.value && name.value && getGlobalState(globals, name) == Global::Default) + { + Builtin builtin = Builtin{library, name}; + + if (int bfid = getBuiltinFunctionId(builtin, options); bfid >= 0) + builtinIsDisabled[bfid] = true; + } + } + else + { + if (AstName name = names.get(*ptr); name.value && getGlobalState(globals, name) == Global::Default) + { + Builtin builtin = Builtin{AstName(), name}; + + if (int bfid = getBuiltinFunctionId(builtin, options); bfid >= 0) + builtinIsDisabled[bfid] = true; + } + } + } + } } bool visit(AstExprCall* node) override @@ -297,6 +337,9 @@ struct BuiltinVisitor : AstVisitor int bfid = getBuiltinFunctionId(builtin, options); + if (bfid >= 0 && builtinIsDisabled[bfid]) + bfid = -1; + // getBuiltinFunctionId optimistically assumes all select() calls are builtin but actually the second argument must be a vararg if (bfid == LBF_SELECT_VARARG && !(node->args.size == 2 && node->args.data[1]->is())) bfid = -1; @@ -313,10 +356,11 @@ void analyzeBuiltins( const DenseHashMap& globals, const DenseHashMap& variables, const CompileOptions& options, - AstNode* root + AstNode* root, + const AstNameTable& names ) { - BuiltinVisitor visitor{result, globals, variables, options}; + BuiltinVisitor visitor{result, globals, variables, options, names}; root->visit(&visitor); } @@ -510,6 +554,10 @@ BuiltinInfo getBuiltinInfo(int bfid) case LBF_VECTOR_MIN: case LBF_VECTOR_MAX: return {-1, 1}; // variadic + + case LBF_MATH_LERP: + LUAU_ASSERT(FFlag::LuauCompileMathLerp); + return {3, 1, BuiltinInfo::Flag_NoneSafe}; } LUAU_UNREACHABLE(); diff --git a/Compiler/src/Builtins.h b/Compiler/src/Builtins.h index e6427c2a..cef48fa5 100644 --- a/Compiler/src/Builtins.h +++ b/Compiler/src/Builtins.h @@ -41,7 +41,8 @@ void analyzeBuiltins( const DenseHashMap& globals, const DenseHashMap& variables, const CompileOptions& options, - AstNode* root + AstNode* root, + const AstNameTable& names ); struct BuiltinInfo diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 685d94fa..46985628 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -1751,7 +1751,8 @@ void BytecodeBuilder::validateVariadic() const // variadic sequence since they are never executed if FASTCALL does anything, so it's okay to skip their validation until CALL // (we can't simply start a variadic sequence here because that would trigger assertions during linked CALL validation) } - else if (op == LOP_CLOSEUPVALS || op == LOP_NAMECALL || op == LOP_GETIMPORT || op == LOP_MOVE || op == LOP_GETUPVAL || op == LOP_GETGLOBAL || op == LOP_GETTABLEKS || op == LOP_COVERAGE) + else if (op == LOP_CLOSEUPVALS || op == LOP_NAMECALL || op == LOP_GETIMPORT || op == LOP_MOVE || op == LOP_GETUPVAL || op == LOP_GETGLOBAL || + op == LOP_GETTABLEKS || op == LOP_COVERAGE) { // instructions inside a variadic sequence must be neutral (can't change L->top) // while there are many neutral instructions like this, here we check that the instruction is one of the few diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 84700177..29cf5c05 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -26,8 +26,6 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(LuauCompileOptimizeRevArith) - namespace Luau { @@ -725,7 +723,7 @@ struct Compiler inlineFrames.push_back({func, oldLocals, target, targetCount}); // fold constant values updated above into expressions in the function body - foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, func->body); + foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldLibraryK, options.libraryMemberConstantCb, func->body); bool usedFallthrough = false; @@ -770,7 +768,7 @@ struct Compiler var->type = Constant::Type_Unknown; } - foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, func->body); + foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldLibraryK, options.libraryMemberConstantCb, func->body); } void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false) @@ -1623,7 +1621,7 @@ struct Compiler return; } } - else if (FFlag::LuauCompileOptimizeRevArith && options.optimizationLevel >= 2 && (expr->op == AstExprBinary::Add || expr->op == AstExprBinary::Mul)) + else if (options.optimizationLevel >= 2 && (expr->op == AstExprBinary::Add || expr->op == AstExprBinary::Mul)) { // Optimization: replace k*r with r*k when r is known to be a number (otherwise metamethods may be called) if (LuauBytecodeType* ty = exprTypes.find(expr); ty && *ty == LBC_TYPE_NUMBER) @@ -3052,7 +3050,7 @@ struct Compiler locstants[var].type = Constant::Type_Number; locstants[var].valueNumber = from + iv * step; - foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, stat); + foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldLibraryK, options.libraryMemberConstantCb, stat); size_t iterJumps = loopJumps.size(); @@ -3080,7 +3078,7 @@ struct Compiler // clean up fold state in case we need to recompile - normally we compile the loop body once, but due to inlining we may need to do it again locstants[var].type = Constant::Type_Unknown; - foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, stat); + foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldLibraryK, options.libraryMemberConstantCb, stat); } void compileStatFor(AstStatFor* stat) @@ -4141,7 +4139,7 @@ struct Compiler BuiltinAstTypes builtinTypes; const DenseHashMap* builtinsFold = nullptr; - bool builtinsFoldMathK = false; + bool builtinsFoldLibraryK = false; // compileFunction state, gets reset for every function unsigned int regTop = 0; @@ -4221,16 +4219,37 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c compiler.builtinsFold = &compiler.builtins; if (AstName math = names.get("math"); math.value && getGlobalState(compiler.globals, math) == Global::Default) - compiler.builtinsFoldMathK = true; + { + compiler.builtinsFoldLibraryK = true; + } + else if (const char* const* ptr = options.librariesWithKnownMembers) + { + for (; *ptr; ++ptr) + { + if (AstName name = names.get(*ptr); name.value && getGlobalState(compiler.globals, name) == Global::Default) + { + compiler.builtinsFoldLibraryK = true; + break; + } + } + } } if (options.optimizationLevel >= 1) { // this pass tracks which calls are builtins and can be compiled more efficiently - analyzeBuiltins(compiler.builtins, compiler.globals, compiler.variables, options, root); + analyzeBuiltins(compiler.builtins, compiler.globals, compiler.variables, options, root, names); // this pass analyzes constantness of expressions - foldConstants(compiler.constants, compiler.variables, compiler.locstants, compiler.builtinsFold, compiler.builtinsFoldMathK, root); + foldConstants( + compiler.constants, + compiler.variables, + compiler.locstants, + compiler.builtinsFold, + compiler.builtinsFoldLibraryK, + options.libraryMemberConstantCb, + root + ); // this pass analyzes table assignments to estimate table shapes for initially empty tables predictTableShapes(compiler.tableShapes, root); @@ -4261,6 +4280,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c compiler.builtinTypes, compiler.builtins, compiler.globals, + options.libraryMemberTypeCb, bytecode ); @@ -4277,9 +4297,9 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c AstExprFunction main( root->location, - /*attributes=*/AstArray({nullptr, 0}), - /*generics= */ AstArray(), - /*genericPacks= */ AstArray(), + /* attributes= */ AstArray({nullptr, 0}), + /* generics= */ AstArray(), + /* genericPacks= */ AstArray(), /* self= */ nullptr, AstArray(), /* vararg= */ true, @@ -4340,4 +4360,50 @@ std::string compile(const std::string& source, const CompileOptions& options, co } } +void setCompileConstantNil(CompileConstant* constant) +{ + Compile::Constant* target = reinterpret_cast(constant); + + target->type = Compile::Constant::Type_Nil; +} + +void setCompileConstantBoolean(CompileConstant* constant, bool b) +{ + Compile::Constant* target = reinterpret_cast(constant); + + target->type = Compile::Constant::Type_Boolean; + target->valueBoolean = b; +} + +void setCompileConstantNumber(CompileConstant* constant, double n) +{ + Compile::Constant* target = reinterpret_cast(constant); + + target->type = Compile::Constant::Type_Number; + target->valueNumber = n; +} + +void setCompileConstantVector(CompileConstant* constant, float x, float y, float z, float w) +{ + Compile::Constant* target = reinterpret_cast(constant); + + target->type = Compile::Constant::Type_Vector; + target->valueVector[0] = x; + target->valueVector[1] = y; + target->valueVector[2] = z; + target->valueVector[3] = w; +} + +void setCompileConstantString(CompileConstant* constant, const char* s, size_t l) +{ + Compile::Constant* target = reinterpret_cast(constant); + + if (l > std::numeric_limits::max()) + CompileError::raise({}, "Exceeded custom string constant length limit"); + + target->type = Compile::Constant::Type_String; + target->stringLength = unsigned(l); + target->valueString = s; +} + } // namespace Luau diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index 2895bf08..818a5bf7 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -57,6 +57,14 @@ static void foldUnary(Constant& result, AstExprUnary::Op op, const Constant& arg result.type = Constant::Type_Number; result.valueNumber = -arg.valueNumber; } + else if (arg.type == Constant::Type_Vector) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = -arg.valueVector[0]; + result.valueVector[1] = -arg.valueVector[1]; + result.valueVector[2] = -arg.valueVector[2]; + result.valueVector[3] = -arg.valueVector[3]; + } break; case AstExprUnary::Len: @@ -82,6 +90,14 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l result.type = Constant::Type_Number; result.valueNumber = la.valueNumber + ra.valueNumber; } + else if (la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = la.valueVector[0] + ra.valueVector[0]; + result.valueVector[1] = la.valueVector[1] + ra.valueVector[1]; + result.valueVector[2] = la.valueVector[2] + ra.valueVector[2]; + result.valueVector[3] = la.valueVector[3] + ra.valueVector[3]; + } break; case AstExprBinary::Sub: @@ -90,6 +106,14 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l result.type = Constant::Type_Number; result.valueNumber = la.valueNumber - ra.valueNumber; } + else if (la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = la.valueVector[0] - ra.valueVector[0]; + result.valueVector[1] = la.valueVector[1] - ra.valueVector[1]; + result.valueVector[2] = la.valueVector[2] - ra.valueVector[2]; + result.valueVector[3] = la.valueVector[3] - ra.valueVector[3]; + } break; case AstExprBinary::Mul: @@ -98,6 +122,48 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l result.type = Constant::Type_Number; result.valueNumber = la.valueNumber * ra.valueNumber; } + else if (la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector) + { + bool hadW = la.valueVector[3] != 0.0f || ra.valueVector[3] != 0.0f; + float resultW = la.valueVector[3] * ra.valueVector[3]; + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = la.valueVector[0] * ra.valueVector[0]; + result.valueVector[1] = la.valueVector[1] * ra.valueVector[1]; + result.valueVector[2] = la.valueVector[2] * ra.valueVector[2]; + result.valueVector[3] = resultW; + } + } + else if (la.type == Constant::Type_Number && ra.type == Constant::Type_Vector) + { + bool hadW = ra.valueVector[3] != 0.0f; + float resultW = float(la.valueNumber) * ra.valueVector[3]; + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = float(la.valueNumber) * ra.valueVector[0]; + result.valueVector[1] = float(la.valueNumber) * ra.valueVector[1]; + result.valueVector[2] = float(la.valueNumber) * ra.valueVector[2]; + result.valueVector[3] = resultW; + } + } + else if (la.type == Constant::Type_Vector && ra.type == Constant::Type_Number) + { + bool hadW = la.valueVector[3] != 0.0f; + float resultW = la.valueVector[3] * float(ra.valueNumber); + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = la.valueVector[0] * float(ra.valueNumber); + result.valueVector[1] = la.valueVector[1] * float(ra.valueNumber); + result.valueVector[2] = la.valueVector[2] * float(ra.valueNumber); + result.valueVector[3] = resultW; + } + } break; case AstExprBinary::Div: @@ -106,6 +172,48 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l result.type = Constant::Type_Number; result.valueNumber = la.valueNumber / ra.valueNumber; } + else if (la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector) + { + bool hadW = la.valueVector[3] != 0.0f || ra.valueVector[3] != 0.0f; + float resultW = la.valueVector[3] / ra.valueVector[3]; + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = la.valueVector[0] / ra.valueVector[0]; + result.valueVector[1] = la.valueVector[1] / ra.valueVector[1]; + result.valueVector[2] = la.valueVector[2] / ra.valueVector[2]; + result.valueVector[3] = resultW; + } + } + else if (la.type == Constant::Type_Number && ra.type == Constant::Type_Vector) + { + bool hadW = ra.valueVector[3] != 0.0f; + float resultW = float(la.valueNumber) / ra.valueVector[3]; + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = float(la.valueNumber) / ra.valueVector[0]; + result.valueVector[1] = float(la.valueNumber) / ra.valueVector[1]; + result.valueVector[2] = float(la.valueNumber) / ra.valueVector[2]; + result.valueVector[3] = resultW; + } + } + else if (la.type == Constant::Type_Vector && ra.type == Constant::Type_Number) + { + bool hadW = la.valueVector[3] != 0.0f; + float resultW = la.valueVector[3] / float(ra.valueNumber); + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = la.valueVector[0] / float(ra.valueNumber); + result.valueVector[1] = la.valueVector[1] / float(ra.valueNumber); + result.valueVector[2] = la.valueVector[2] / float(ra.valueNumber); + result.valueVector[3] = resultW; + } + } break; case AstExprBinary::FloorDiv: @@ -114,6 +222,48 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l result.type = Constant::Type_Number; result.valueNumber = floor(la.valueNumber / ra.valueNumber); } + else if (la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector) + { + bool hadW = la.valueVector[3] != 0.0f || ra.valueVector[3] != 0.0f; + float resultW = floor(la.valueVector[3] / ra.valueVector[3]); + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = floor(la.valueVector[0] / ra.valueVector[0]); + result.valueVector[1] = floor(la.valueVector[1] / ra.valueVector[1]); + result.valueVector[2] = floor(la.valueVector[2] / ra.valueVector[2]); + result.valueVector[3] = resultW; + } + } + else if (la.type == Constant::Type_Number && ra.type == Constant::Type_Vector) + { + bool hadW = ra.valueVector[3] != 0.0f; + float resultW = floor(float(la.valueNumber) / ra.valueVector[3]); + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = floor(float(la.valueNumber) / ra.valueVector[0]); + result.valueVector[1] = floor(float(la.valueNumber) / ra.valueVector[1]); + result.valueVector[2] = floor(float(la.valueNumber) / ra.valueVector[2]); + result.valueVector[3] = resultW; + } + } + else if (la.type == Constant::Type_Vector && ra.type == Constant::Type_Number) + { + bool hadW = la.valueVector[3] != 0.0f; + float resultW = floor(la.valueVector[3] / float(ra.valueNumber)); + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = floor(la.valueVector[0] / float(ra.valueNumber)); + result.valueVector[1] = floor(la.valueVector[1] / float(ra.valueNumber)); + result.valueVector[2] = floor(la.valueVector[2] / float(ra.valueNumber)); + result.valueVector[3] = floor(la.valueVector[3] / float(ra.valueNumber)); + } + } break; case AstExprBinary::Mod: @@ -209,7 +359,8 @@ struct ConstantVisitor : AstVisitor DenseHashMap& locals; const DenseHashMap* builtins; - bool foldMathK = false; + bool foldLibraryK = false; + LibraryMemberConstantCallback libraryMemberConstantCb; bool wasEmpty = false; @@ -220,13 +371,15 @@ struct ConstantVisitor : AstVisitor DenseHashMap& variables, DenseHashMap& locals, const DenseHashMap* builtins, - bool foldMathK + bool foldLibraryK, + LibraryMemberConstantCallback libraryMemberConstantCb ) : constants(constants) , variables(variables) , locals(locals) , builtins(builtins) - , foldMathK(foldMathK) + , foldLibraryK(foldLibraryK) + , libraryMemberConstantCb(libraryMemberConstantCb) { // since we do a single pass over the tree, if the initial state was empty we don't need to clear out old entries wasEmpty = constants.empty() && locals.empty(); @@ -316,11 +469,16 @@ struct ConstantVisitor : AstVisitor { analyze(expr->expr); - if (foldMathK) + if (foldLibraryK) { - if (AstExprGlobal* eg = expr->expr->as(); eg && eg->name == "math") + if (AstExprGlobal* eg = expr->expr->as()) { - result = foldBuiltinMath(expr->index); + if (eg->name == "math") + result = foldBuiltinMath(expr->index); + + // if we have a custom handler and the constant hasn't been resolved + if (libraryMemberConstantCb && result.type == Constant::Type_Unknown) + libraryMemberConstantCb(eg->name.value, expr->index.value, reinterpret_cast(&result)); } } } @@ -468,11 +626,12 @@ void foldConstants( DenseHashMap& variables, DenseHashMap& locals, const DenseHashMap* builtins, - bool foldMathK, + bool foldLibraryK, + LibraryMemberConstantCallback libraryMemberConstantCb, AstNode* root ) { - ConstantVisitor visitor{constants, variables, locals, builtins, foldMathK}; + ConstantVisitor visitor{constants, variables, locals, builtins, foldLibraryK, libraryMemberConstantCb}; root->visit(&visitor); } diff --git a/Compiler/src/ConstantFolding.h b/Compiler/src/ConstantFolding.h index e4eb6428..2653c064 100644 --- a/Compiler/src/ConstantFolding.h +++ b/Compiler/src/ConstantFolding.h @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Compiler.h" + #include "ValueTracking.h" namespace Luau @@ -49,7 +51,8 @@ void foldConstants( DenseHashMap& variables, DenseHashMap& locals, const DenseHashMap* builtins, - bool foldMathK, + bool foldLibraryK, + LibraryMemberConstantCallback libraryMemberConstantCb, AstNode* root ); diff --git a/Compiler/src/CostModel.cpp b/Compiler/src/CostModel.cpp index 4c8e13c6..04adf3e3 100644 --- a/Compiler/src/CostModel.cpp +++ b/Compiler/src/CostModel.cpp @@ -130,7 +130,8 @@ struct CostVisitor : AstVisitor { return model(expr->expr); } - else if (node->is() || node->is() || node->is() || node->is()) + else if (node->is() || node->is() || node->is() || + node->is()) { return Cost(0, Cost::kLiteral); } diff --git a/Compiler/src/Types.cpp b/Compiler/src/Types.cpp index 7f5885a5..e251447b 100644 --- a/Compiler/src/Types.cpp +++ b/Compiler/src/Types.cpp @@ -3,8 +3,6 @@ #include "Luau/BytecodeBuilder.h" -LUAU_FASTFLAGVARIABLE(LuauCompileVectorTypeInfo) - namespace Luau { @@ -31,7 +29,7 @@ static LuauBytecodeType getPrimitiveType(AstName name) return LBC_TYPE_THREAD; else if (name == "buffer") return LBC_TYPE_BUFFER; - else if (FFlag::LuauCompileVectorTypeInfo && name == "vector") + else if (name == "vector") return LBC_TYPE_VECTOR; else if (name == "any" || name == "unknown") return LBC_TYPE_ANY; @@ -123,6 +121,10 @@ static LuauBytecodeType getType( { return LBC_TYPE_ANY; } + else if (const AstTypeGroup* group = ty->as()) + { + return getType(group->type, generics, typeAliases, resolveAliases, hostVectorType, userdataTypes, bytecode); + } return LBC_TYPE_ANY; } @@ -175,16 +177,30 @@ static bool isMatchingGlobal(const DenseHashMap& globa return false; } +static bool isMatchingGlobalMember( + const DenseHashMap& globals, + AstExprIndexName* expr, + const char* library, + const char* member +) +{ + if (AstExprGlobal* object = expr->expr->as()) + return getGlobalState(globals, object->name) == Compile::Global::Default && object->name == library && expr->index == member; + + return false; +} + struct TypeMapVisitor : AstVisitor { DenseHashMap& functionTypes; DenseHashMap& localTypes; DenseHashMap& exprTypes; - const char* hostVectorType; + const char* hostVectorType = nullptr; const DenseHashMap& userdataTypes; const BuiltinAstTypes& builtinTypes; const DenseHashMap& builtinCalls; const DenseHashMap& globals; + LibraryMemberTypeCallback libraryMemberTypeCb = nullptr; BytecodeBuilder& bytecode; DenseHashMap typeAliases; @@ -201,6 +217,7 @@ struct TypeMapVisitor : AstVisitor const BuiltinAstTypes& builtinTypes, const DenseHashMap& builtinCalls, const DenseHashMap& globals, + LibraryMemberTypeCallback libraryMemberTypeCb, BytecodeBuilder& bytecode ) : functionTypes(functionTypes) @@ -211,6 +228,7 @@ struct TypeMapVisitor : AstVisitor , builtinTypes(builtinTypes) , builtinCalls(builtinCalls) , globals(globals) + , libraryMemberTypeCb(libraryMemberTypeCb) , bytecode(bytecode) , typeAliases(AstName()) , resolvedLocals(nullptr) @@ -461,7 +479,48 @@ struct TypeMapVisitor : AstVisitor if (*typeBcPtr == LBC_TYPE_VECTOR) { if (node->index == "X" || node->index == "Y" || node->index == "Z") + { recordResolvedType(node, &builtinTypes.numberType); + return false; + } + } + } + + if (isMatchingGlobalMember(globals, node, "vector", "zero") || isMatchingGlobalMember(globals, node, "vector", "one")) + { + recordResolvedType(node, &builtinTypes.vectorType); + return false; + } + + if (libraryMemberTypeCb) + { + if (AstExprGlobal* object = node->expr->as()) + { + if (LuauBytecodeType ty = LuauBytecodeType(libraryMemberTypeCb(object->name.value, node->index.value)); ty != LBC_TYPE_ANY) + { + // TODO: 'resolvedExprs' is more limited than 'exprTypes' which limits full inference of more complex types that a user + // callback can return + switch (ty) + { + case LBC_TYPE_BOOLEAN: + resolvedExprs[node] = &builtinTypes.booleanType; + break; + case LBC_TYPE_NUMBER: + resolvedExprs[node] = &builtinTypes.numberType; + break; + case LBC_TYPE_STRING: + resolvedExprs[node] = &builtinTypes.stringType; + break; + case LBC_TYPE_VECTOR: + resolvedExprs[node] = &builtinTypes.vectorType; + break; + default: + break; + } + + exprTypes[node] = ty; + return false; + } } } @@ -682,6 +741,7 @@ struct TypeMapVisitor : AstVisitor case LBF_BUFFER_READF64: case LBF_VECTOR_MAGNITUDE: case LBF_VECTOR_DOT: + case LBF_MATH_LERP: recordResolvedType(node, &builtinTypes.numberType); break; @@ -733,10 +793,13 @@ void buildTypeMap( const BuiltinAstTypes& builtinTypes, const DenseHashMap& builtinCalls, const DenseHashMap& globals, + LibraryMemberTypeCallback libraryMemberTypeCb, BytecodeBuilder& bytecode ) { - TypeMapVisitor visitor(functionTypes, localTypes, exprTypes, hostVectorType, userdataTypes, builtinTypes, builtinCalls, globals, bytecode); + TypeMapVisitor visitor( + functionTypes, localTypes, exprTypes, hostVectorType, userdataTypes, builtinTypes, builtinCalls, globals, libraryMemberTypeCb, bytecode + ); root->visit(&visitor); } diff --git a/Compiler/src/Types.h b/Compiler/src/Types.h index 46610db2..e60b3b93 100644 --- a/Compiler/src/Types.h +++ b/Compiler/src/Types.h @@ -3,6 +3,7 @@ #include "Luau/Ast.h" #include "Luau/Bytecode.h" +#include "Luau/Compiler.h" #include "Luau/DenseHash.h" #include "ValueTracking.h" @@ -19,7 +20,7 @@ struct BuiltinAstTypes { } - // AstName use here will not match the AstNameTable, but the was we use them here always force a full string compare + // AstName use here will not match the AstNameTable, but the way we use them here always forces a full string compare AstTypeReference booleanType{{}, std::nullopt, AstName{"boolean"}, std::nullopt, {}}; AstTypeReference numberType{{}, std::nullopt, AstName{"number"}, std::nullopt, {}}; AstTypeReference stringType{{}, std::nullopt, AstName{"string"}, std::nullopt, {}}; @@ -38,6 +39,7 @@ void buildTypeMap( const BuiltinAstTypes& builtinTypes, const DenseHashMap& builtinCalls, const DenseHashMap& globals, + LibraryMemberTypeCallback libraryMemberTypeCb, BytecodeBuilder& bytecode ); diff --git a/Compiler/src/lcode.cpp b/Compiler/src/lcode.cpp index ee150b17..ff2edc3d 100644 --- a/Compiler/src/lcode.cpp +++ b/Compiler/src/lcode.cpp @@ -27,3 +27,28 @@ char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, *outsize = result.size(); return copy; } + +void luau_set_compile_constant_nil(lua_CompileConstant* constant) +{ + Luau::setCompileConstantNil(constant); +} + +void luau_set_compile_constant_boolean(lua_CompileConstant* constant, int b) +{ + Luau::setCompileConstantBoolean(constant, b != 0); +} + +void luau_set_compile_constant_number(lua_CompileConstant* constant, double n) +{ + Luau::setCompileConstantNumber(constant, n); +} + +void luau_set_compile_constant_vector(lua_CompileConstant* constant, float x, float y, float z, float w) +{ + Luau::setCompileConstantVector(constant, x, y, z, w); +} + +void luau_set_compile_constant_string(lua_CompileConstant* constant, const char* s, size_t l) +{ + Luau::setCompileConstantString(constant, s, l); +} diff --git a/Config/include/Luau/Config.h b/Config/include/Luau/Config.h index 3f29a24f..89d018d2 100644 --- a/Config/include/Luau/Config.h +++ b/Config/include/Luau/Config.h @@ -42,6 +42,7 @@ struct Config { std::string value; std::string_view configLocation; + std::string originalCase; // The alias in its original case. }; DenseHashMap aliases{""}; diff --git a/Config/src/Config.cpp b/Config/src/Config.cpp index 15e58e29..44cbe2e5 100644 --- a/Config/src/Config.cpp +++ b/Config/src/Config.cpp @@ -26,9 +26,9 @@ Config::Config(const Config& other) , typeErrors(other.typeErrors) , globals(other.globals) { - for (const auto& [alias, aliasInfo] : other.aliases) + for (const auto& [_, aliasInfo] : other.aliases) { - setAlias(alias, aliasInfo.value, std::string(aliasInfo.configLocation)); + setAlias(aliasInfo.originalCase, aliasInfo.value, std::string(aliasInfo.configLocation)); } } @@ -44,8 +44,20 @@ Config& Config::operator=(const Config& other) void Config::setAlias(std::string alias, std::string value, const std::string& configLocation) { - AliasInfo& info = aliases[alias]; + std::string lowercasedAlias = alias; + std::transform( + lowercasedAlias.begin(), + lowercasedAlias.end(), + lowercasedAlias.begin(), + [](unsigned char c) + { + return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; + } + ); + + AliasInfo& info = aliases[lowercasedAlias]; info.value = std::move(value); + info.originalCase = std::move(alias); if (!configLocationCache.contains(configLocation)) configLocationCache[configLocation] = std::make_unique(configLocation); @@ -175,7 +187,7 @@ bool isValidAlias(const std::string& alias) static Error parseAlias( Config& config, - std::string aliasKey, + const std::string& aliasKey, const std::string& aliasValue, const std::optional& aliasOptions ) @@ -183,21 +195,11 @@ static Error parseAlias( if (!isValidAlias(aliasKey)) return Error{"Invalid alias " + aliasKey}; - std::transform( - aliasKey.begin(), - aliasKey.end(), - aliasKey.begin(), - [](unsigned char c) - { - return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; - } - ); - if (!aliasOptions) return Error("Cannot parse aliases without alias options"); if (aliasOptions->overwriteAliases || !config.aliases.contains(aliasKey)) - config.setAlias(std::move(aliasKey), aliasValue, aliasOptions->configLocation); + config.setAlias(aliasKey, aliasValue, aliasOptions->configLocation); return std::nullopt; } @@ -303,7 +305,8 @@ static Error parseJson(const std::string& contents, Action action) arrayTop = (lexer.current().type == '['); next(lexer); } - else if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::ReservedTrue || lexer.current().type == Lexeme::ReservedFalse) + else if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::ReservedTrue || + lexer.current().type == Lexeme::ReservedFalse) { std::string value = lexer.current().type == Lexeme::QuotedString ? std::string(lexer.current().data, lexer.current().getLength()) diff --git a/EqSat/include/Luau/EGraph.h b/EqSat/include/Luau/EGraph.h index 924da974..2703ad9d 100644 --- a/EqSat/include/Luau/EGraph.h +++ b/EqSat/include/Luau/EGraph.h @@ -51,13 +51,70 @@ struct Analysis final } }; +template +struct Node +{ + L node; + bool boring = false; + + struct Hash + { + size_t operator()(const Node& node) const + { + return typename L::Hash{}(node.node); + } + }; +}; + +template +struct NodeIterator +{ +private: + using iterator = std::vector>; + iterator iter; + +public: + L& operator*() + { + return iter->node; + } + + const L& operator*() const + { + return iter->node; + } + + iterator& operator++() + { + ++iter; + return *this; + } + + iterator operator++(int) + { + iterator copy = *this; + ++*this; + return copy; + } + + bool operator==(const iterator& rhs) const + { + return iter == rhs.iter; + } + + bool operator!=(const iterator& rhs) const + { + return iter != rhs.iter; + } +}; + /// Each e-class is a set of e-nodes representing equivalent terms from a given language, /// and an e-node is a function symbol paired with a list of children e-classes. template struct EClass final { Id id; - std::vector nodes; + std::vector> nodes; D data; std::vector> parents; }; @@ -125,9 +182,9 @@ struct EGraph final std::sort( eclass1.nodes.begin(), eclass1.nodes.end(), - [](const L& left, const L& right) + [](const Node& left, const Node& right) { - return left.index() < right.index(); + return left.node.index() < right.node.index(); } ); @@ -177,6 +234,11 @@ struct EGraph final return classes; } + void markBoring(Id id, size_t index) + { + get(id).nodes[index].boring = true; + } + private: Analysis analysis; @@ -198,8 +260,13 @@ private: { // An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where // canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...). - for (Id& id : enode.mutableOperands()) - id = find(id); + Luau::EqSat::canonicalize( + enode, + [&](Id id) + { + return find(id); + } + ); } bool isCanonical(const L& enode) const @@ -220,7 +287,7 @@ private: id, EClassT{ id, - {enode}, + {Node{enode, false}}, analysis.make(*this, enode), {}, } @@ -259,18 +326,18 @@ private: std::vector> parents = get(id).parents; for (auto& pair : parents) { - L& enode = pair.first; - Id id = pair.second; + L& parentNode = pair.first; + Id parentId = pair.second; // By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id. - hashcons.erase(enode); - canonicalize(enode); - hashcons.insert_or_assign(enode, find(id)); + hashcons.erase(parentNode); + canonicalize(parentNode); + hashcons.insert_or_assign(parentNode, find(parentId)); - if (auto it = newParents.find(enode); it != newParents.end()) - merge(id, it->second); + if (auto it = newParents.find(parentNode); it != newParents.end()) + merge(parentId, it->second); - newParents.insert_or_assign(enode, find(id)); + newParents.insert_or_assign(parentNode, find(parentId)); } // We reacquire the pointer because the prior loop potentially merges @@ -282,22 +349,30 @@ private: for (const auto& [node, id] : newParents) eclass->parents.emplace_back(std::move(node), std::move(id)); - std::unordered_set newNodes; - for (L node : eclass->nodes) + std::unordered_map newNodes; + for (Node node : eclass->nodes) { - canonicalize(node); - newNodes.insert(std::move(node)); + canonicalize(node.node); + + bool& b = newNodes[std::move(node.node)]; + b = b || node.boring; } - eclass->nodes.assign(newNodes.begin(), newNodes.end()); + eclass->nodes.clear(); + + while (!newNodes.empty()) + { + auto n = newNodes.extract(newNodes.begin()); + eclass->nodes.push_back(Node{n.key(), n.mapped()}); + } // FIXME: Extract into sortByTag() std::sort( eclass->nodes.begin(), eclass->nodes.end(), - [](const L& left, const L& right) + [](const Node& left, const Node& right) { - return left.index() < right.index(); + return left.node.index() < right.node.index(); } ); } diff --git a/EqSat/include/Luau/Language.h b/EqSat/include/Luau/Language.h index 56fc7202..f9d3aa4d 100644 --- a/EqSat/include/Luau/Language.h +++ b/EqSat/include/Luau/Language.h @@ -244,6 +244,9 @@ private: template struct NodeSet { + template + friend void canonicalize(NodeSet& node, Find&& find); + template NodeSet(Args&&... args) : vector{std::forward(args)...} @@ -299,6 +302,9 @@ struct Language final template using WithinDomain = std::disjunction, Ts>...>; + template + friend void canonicalize(Language& enode, Find&& find); + template Language(T&& t, std::enable_if_t::value>* = 0) noexcept : v(std::forward(t)) @@ -382,4 +388,37 @@ private: VariantTy v; }; +template +void canonicalize(Node& node, Find&& find) +{ + // An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where + // canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...). + for (Id& id : node.mutableOperands()) + id = find(id); +} + +// Canonicalizing the Ids in a NodeSet may result in the set decreasing in size. +template +void canonicalize(NodeSet& node, Find&& find) +{ + for (Id& id : node.vector) + id = find(id); + + std::sort(begin(node.vector), end(node.vector)); + auto endIt = std::unique(begin(node.vector), end(node.vector)); + node.vector.erase(endIt, end(node.vector)); +} + +template +void canonicalize(Language& enode, Find&& find) +{ + visit( + [&](auto&& v) + { + Luau::EqSat::canonicalize(v, find); + }, + enode.v + ); +} + } // namespace Luau::EqSat diff --git a/Makefile b/Makefile index 6fb0b8f6..2ad0fc00 100644 --- a/Makefile +++ b/Makefile @@ -42,23 +42,23 @@ ISOCLINE_SOURCES=extern/isocline/src/isocline.c ISOCLINE_OBJECTS=$(ISOCLINE_SOURCES:%=$(BUILD)/%.o) ISOCLINE_TARGET=$(BUILD)/libisocline.a -TESTS_SOURCES=$(wildcard tests/*.cpp) CLI/FileUtils.cpp CLI/Flags.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp CLI/Require.cpp +TESTS_SOURCES=$(wildcard tests/*.cpp) CLI/src/FileUtils.cpp CLI/src/Flags.cpp CLI/src/Profiler.cpp CLI/src/Coverage.cpp CLI/src/Repl.cpp CLI/src/Require.cpp TESTS_OBJECTS=$(TESTS_SOURCES:%=$(BUILD)/%.o) TESTS_TARGET=$(BUILD)/luau-tests -REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Flags.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp CLI/ReplEntry.cpp CLI/Require.cpp +REPL_CLI_SOURCES=CLI/src/FileUtils.cpp CLI/src/Flags.cpp CLI/src/Profiler.cpp CLI/src/Coverage.cpp CLI/src/Repl.cpp CLI/src/ReplEntry.cpp CLI/src/Require.cpp REPL_CLI_OBJECTS=$(REPL_CLI_SOURCES:%=$(BUILD)/%.o) REPL_CLI_TARGET=$(BUILD)/luau -ANALYZE_CLI_SOURCES=CLI/FileUtils.cpp CLI/Flags.cpp CLI/Require.cpp CLI/Analyze.cpp +ANALYZE_CLI_SOURCES=CLI/src/FileUtils.cpp CLI/src/Flags.cpp CLI/src/Require.cpp CLI/src/Analyze.cpp ANALYZE_CLI_OBJECTS=$(ANALYZE_CLI_SOURCES:%=$(BUILD)/%.o) ANALYZE_CLI_TARGET=$(BUILD)/luau-analyze -COMPILE_CLI_SOURCES=CLI/FileUtils.cpp CLI/Flags.cpp CLI/Compile.cpp +COMPILE_CLI_SOURCES=CLI/src/FileUtils.cpp CLI/src/Flags.cpp CLI/src/Compile.cpp COMPILE_CLI_OBJECTS=$(COMPILE_CLI_SOURCES:%=$(BUILD)/%.o) COMPILE_CLI_TARGET=$(BUILD)/luau-compile -BYTECODE_CLI_SOURCES=CLI/FileUtils.cpp CLI/Flags.cpp CLI/Bytecode.cpp +BYTECODE_CLI_SOURCES=CLI/src/FileUtils.cpp CLI/src/Flags.cpp CLI/src/Bytecode.cpp BYTECODE_CLI_OBJECTS=$(BYTECODE_CLI_SOURCES:%=$(BUILD)/%.o) BYTECODE_CLI_TARGET=$(BUILD)/luau-bytecode @@ -149,11 +149,11 @@ $(EQSAT_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IEqSat/include $(CODEGEN_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -ICodeGen/include -IVM/include -IVM/src # Code generation needs VM internals $(VM_OBJECTS): CXXFLAGS+=-std=c++11 -ICommon/include -IVM/include $(ISOCLINE_OBJECTS): CXXFLAGS+=-Wno-unused-function -Iextern/isocline/include -$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IConfig/include -IAnalysis/include -IEqSat/include -ICodeGen/include -IVM/include -ICLI -Iextern -DDOCTEST_CONFIG_DOUBLE_STRINGIFY -$(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -Iextern -Iextern/isocline/include -$(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include -Iextern -$(COMPILE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -$(BYTECODE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include +$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IConfig/include -IAnalysis/include -IEqSat/include -ICodeGen/include -IVM/include -ICLI/include -Iextern -DDOCTEST_CONFIG_DOUBLE_STRINGIFY +$(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -Iextern -Iextern/isocline/include -ICLI/include +$(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include -Iextern -ICLI/include +$(COMPILE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -ICLI/include +$(BYTECODE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -ICLI/include $(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -IEqSat/include -IVM/include -ICodeGen/include -IConfig/include $(TESTS_TARGET): LDFLAGS+=-lpthread diff --git a/Sources.cmake b/Sources.cmake index 1adbe862..1c312cb9 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -17,6 +17,7 @@ target_sources(Luau.Ast PRIVATE Ast/include/Luau/Allocator.h Ast/include/Luau/Ast.h Ast/include/Luau/Confusables.h + Ast/include/Luau/Cst.h Ast/include/Luau/Lexer.h Ast/include/Luau/Location.h Ast/include/Luau/ParseOptions.h @@ -28,6 +29,7 @@ target_sources(Luau.Ast PRIVATE Ast/src/Allocator.cpp Ast/src/Ast.cpp Ast/src/Confusables.cpp + Ast/src/Cst.cpp Ast/src/Lexer.cpp Ast/src/Location.cpp Ast/src/Parser.cpp @@ -78,6 +80,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/include/Luau/CodeBlockUnwind.h CodeGen/include/Luau/CodeGen.h CodeGen/include/Luau/CodeGenCommon.h + CodeGen/include/Luau/CodeGenOptions.h CodeGen/include/Luau/ConditionA64.h CodeGen/include/Luau/ConditionX64.h CodeGen/include/Luau/IrAnalysis.h @@ -89,6 +92,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/include/Luau/IrUtils.h CodeGen/include/Luau/IrVisitUseDef.h CodeGen/include/Luau/Label.h + CodeGen/include/Luau/LoweringStats.h CodeGen/include/Luau/NativeProtoExecData.h CodeGen/include/Luau/OperandX64.h CodeGen/include/Luau/OptimizeConstProp.h @@ -389,36 +393,39 @@ target_sources(isocline PRIVATE # Common sources shared between all CLI apps target_sources(Luau.CLI.lib PRIVATE - CLI/FileUtils.cpp - CLI/Flags.cpp - CLI/Flags.h - CLI/FileUtils.h + CLI/include/Luau/FileUtils.h + CLI/include/Luau/Flags.h + CLI/include/Luau/Require.h + + CLI/src/FileUtils.cpp + CLI/src/Flags.cpp + CLI/src/Require.cpp ) if(TARGET Luau.Repl.CLI) # Luau.Repl.CLI Sources target_sources(Luau.Repl.CLI PRIVATE - CLI/Coverage.h - CLI/Coverage.cpp - CLI/Profiler.h - CLI/Profiler.cpp - CLI/Repl.cpp - CLI/ReplEntry.cpp - CLI/Require.cpp) + CLI/include/Luau/Coverage.h + CLI/include/Luau/Profiler.h + + CLI/src/Coverage.cpp + CLI/src/Profiler.cpp + CLI/src/Repl.cpp + CLI/src/ReplEntry.cpp + ) endif() if(TARGET Luau.Analyze.CLI) # Luau.Analyze.CLI Sources target_sources(Luau.Analyze.CLI PRIVATE - CLI/Analyze.cpp - CLI/Require.cpp + CLI/src/Analyze.cpp ) endif() if(TARGET Luau.Ast.CLI) # Luau.Ast.CLI Sources target_sources(Luau.Ast.CLI PRIVATE - CLI/Ast.cpp + CLI/src/Ast.cpp ) endif() @@ -543,12 +550,12 @@ endif() if(TARGET Luau.CLI.Test) # Luau.CLI.Test Sources target_sources(Luau.CLI.Test PRIVATE - CLI/Coverage.h - CLI/Coverage.cpp - CLI/Profiler.h - CLI/Profiler.cpp - CLI/Repl.cpp - CLI/Require.cpp + CLI/include/Luau/Coverage.h + CLI/include/Luau/Profiler.h + + CLI/src/Coverage.cpp + CLI/src/Profiler.cpp + CLI/src/Repl.cpp tests/RegisterCallbacks.h tests/RegisterCallbacks.cpp @@ -560,24 +567,24 @@ endif() if(TARGET Luau.Web) # Luau.Web Sources target_sources(Luau.Web PRIVATE - CLI/Web.cpp) + CLI/src/Web.cpp) endif() if(TARGET Luau.Reduce.CLI) # Luau.Reduce.CLI Sources target_sources(Luau.Reduce.CLI PRIVATE - CLI/Reduce.cpp + CLI/src/Reduce.cpp ) endif() if(TARGET Luau.Compile.CLI) # Luau.Compile.CLI Sources target_sources(Luau.Compile.CLI PRIVATE - CLI/Compile.cpp) + CLI/src/Compile.cpp) endif() if(TARGET Luau.Bytecode.CLI) # Luau.Bytecode.CLI Sources target_sources(Luau.Bytecode.CLI PRIVATE - CLI/Bytecode.cpp) + CLI/src/Bytecode.cpp) endif() diff --git a/VM/include/lua.h b/VM/include/lua.h index c4f5f714..303d7162 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -154,6 +154,7 @@ LUA_API const float* lua_tovector(lua_State* L, int idx); LUA_API int lua_toboolean(lua_State* L, int idx); LUA_API const char* lua_tolstring(lua_State* L, int idx, size_t* len); LUA_API const char* lua_tostringatom(lua_State* L, int idx, int* atom); +LUA_API const char* lua_tolstringatom(lua_State* L, int idx, size_t* len, int* atom); LUA_API const char* lua_namecallatom(lua_State* L, int* atom); LUA_API int lua_objlen(lua_State* L, int idx); LUA_API lua_CFunction lua_tocfunction(lua_State* L, int idx); @@ -335,6 +336,7 @@ LUA_API const char* lua_getlightuserdataname(lua_State* L, int tag); LUA_API void lua_clonefunction(lua_State* L, int idx); LUA_API void lua_cleartable(lua_State* L, int idx); +LUA_API void lua_clonetable(lua_State* L, int idx); LUA_API lua_Alloc lua_getallocf(lua_State* L, void** ud); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 052d8c82..a956fa94 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -64,7 +64,7 @@ const char* luau_ident = "$Luau: Copyright (C) 2019-2024 Roblox Corporation $\n" ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, ts->len) : -1; \ } -static Table* getcurrenv(lua_State* L) +static LuaTable* getcurrenv(lua_State* L) { if (L->ci == L->base_ci) // no enclosing function? return L->gt; // use global table as environment @@ -454,6 +454,29 @@ const char* lua_tostringatom(lua_State* L, int idx, int* atom) return getstr(s); } +const char* lua_tolstringatom(lua_State* L, int idx, size_t* len, int* atom) +{ + StkId o = index2addr(L, idx); + + if (!ttisstring(o)) + { + if (len) + *len = 0; + return NULL; + } + + TString* s = tsvalue(o); + if (len) + *len = s->len; + if (atom) + { + updateatom(L, s); + *atom = s->atom; + } + + return getstr(s); +} + const char* lua_namecallatom(lua_State* L, int* atom) { TString* s = L->namecall; @@ -762,7 +785,7 @@ void lua_setreadonly(lua_State* L, int objindex, int enabled) { const TValue* o = index2addr(L, objindex); api_check(L, ttistable(o)); - Table* t = hvalue(o); + LuaTable* t = hvalue(o); api_check(L, t != hvalue(registry(L))); t->readonly = bool(enabled); } @@ -771,7 +794,7 @@ int lua_getreadonly(lua_State* L, int objindex) { const TValue* o = index2addr(L, objindex); api_check(L, ttistable(o)); - Table* t = hvalue(o); + LuaTable* t = hvalue(o); int res = t->readonly; return res; } @@ -780,14 +803,14 @@ void lua_setsafeenv(lua_State* L, int objindex, int enabled) { const TValue* o = index2addr(L, objindex); api_check(L, ttistable(o)); - Table* t = hvalue(o); + LuaTable* t = hvalue(o); t->safeenv = bool(enabled); } int lua_getmetatable(lua_State* L, int objindex) { luaC_threadbarrier(L); - Table* mt = NULL; + LuaTable* mt = NULL; const TValue* obj = index2addr(L, objindex); switch (ttype(obj)) { @@ -894,7 +917,7 @@ int lua_setmetatable(lua_State* L, int objindex) api_checknelems(L, 1); TValue* obj = index2addr(L, objindex); api_checkvalidindex(L, obj); - Table* mt = NULL; + LuaTable* mt = NULL; if (!ttisnil(L->top - 1)) { api_check(L, ttistable(L->top - 1)); @@ -1214,7 +1237,7 @@ int lua_rawiter(lua_State* L, int idx, int iter) api_check(L, ttistable(t)); api_check(L, iter >= 0); - Table* h = hvalue(t); + LuaTable* h = hvalue(t); int sizearray = h->sizearray; // first we advance iter through the array portion @@ -1293,7 +1316,7 @@ void* lua_newuserdatataggedwithmetatable(lua_State* L, size_t sz, int tag) // currently, we always allocate unmarked objects, so forward barrier can be skipped LUAU_ASSERT(!isblack(obj2gco(u))); - Table* h = L->global->udatamt[tag]; + LuaTable* h = L->global->udatamt[tag]; api_check(L, h != nullptr); u->metatable = h; @@ -1394,7 +1417,7 @@ int lua_ref(lua_State* L, int idx) StkId p = index2addr(L, idx); if (!ttisnil(p)) { - Table* reg = hvalue(registry(L)); + LuaTable* reg = hvalue(registry(L)); if (g->registryfree != 0) { // reuse existing slot @@ -1421,7 +1444,7 @@ void lua_unref(lua_State* L, int ref) return; global_State* g = L->global; - Table* reg = hvalue(registry(L)); + LuaTable* reg = hvalue(registry(L)); TValue* slot = luaH_setnum(L, reg, ref); setnvalue(slot, g->registryfree); // NB: no barrier needed because value isn't collectable g->registryfree = ref; @@ -1462,7 +1485,7 @@ void lua_getuserdatametatable(lua_State* L, int tag) api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); luaC_threadbarrier(L); - if (Table* h = L->global->udatamt[tag]) + if (LuaTable* h = L->global->udatamt[tag]) { sethvalue(L, L->top, h); } @@ -1510,12 +1533,22 @@ void lua_cleartable(lua_State* L, int idx) { StkId t = index2addr(L, idx); api_check(L, ttistable(t)); - Table* tt = hvalue(t); + LuaTable* tt = hvalue(t); if (tt->readonly) luaG_readonlyerror(L); luaH_clear(tt); } +void lua_clonetable(lua_State* L, int idx) +{ + StkId t = index2addr(L, idx); + api_check(L, ttistable(t)); + + LuaTable* tt = luaH_clone(L, hvalue(t)); + sethvalue(L, L->top, tt); + api_incr_top(L); +} + lua_Callbacks* lua_callbacks(lua_State* L) { return &L->global->cb; diff --git a/VM/src/lbuflib.cpp b/VM/src/lbuflib.cpp index 178261fb..17ca8b0b 100644 --- a/VM/src/lbuflib.cpp +++ b/VM/src/lbuflib.cpp @@ -10,6 +10,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauBufferBitMethods2) + // while C API returns 'size_t' for binary compatibility in case of future extensions, // in the current implementation, length and offset are limited to 31 bits // because offset is limited to an integer, a single 64bit comparison can be used and will not overflow @@ -247,7 +249,88 @@ static int buffer_fill(lua_State* L) return 0; } -static const luaL_Reg bufferlib[] = { +static int buffer_readbits(lua_State* L) +{ + size_t len = 0; + void* buf = luaL_checkbuffer(L, 1, &len); + int64_t bitoffset = (int64_t)luaL_checknumber(L, 2); + int bitcount = luaL_checkinteger(L, 3); + + if (bitoffset < 0) + luaL_error(L, "buffer access out of bounds"); + + if (unsigned(bitcount) > 32) + luaL_error(L, "bit count is out of range of [0; 32]"); + + if (uint64_t(bitoffset + bitcount) > uint64_t(len) * 8) + luaL_error(L, "buffer access out of bounds"); + + unsigned startbyte = unsigned(bitoffset / 8); + unsigned endbyte = unsigned((bitoffset + bitcount + 7) / 8); + + uint64_t data = 0; + +#if defined(LUAU_BIG_ENDIAN) + for (int i = int(endbyte) - 1; i >= int(startbyte); i--) + data = (data << 8) + uint8_t(((char*)buf)[i]); +#else + memcpy(&data, (char*)buf + startbyte, endbyte - startbyte); +#endif + + uint64_t subbyteoffset = bitoffset & 0x7; + uint64_t mask = (1ull << bitcount) - 1; + + lua_pushunsigned(L, unsigned((data >> subbyteoffset) & mask)); + return 1; +} + +static int buffer_writebits(lua_State* L) +{ + size_t len = 0; + void* buf = luaL_checkbuffer(L, 1, &len); + int64_t bitoffset = (int64_t)luaL_checknumber(L, 2); + int bitcount = luaL_checkinteger(L, 3); + unsigned value = luaL_checkunsigned(L, 4); + + if (bitoffset < 0) + luaL_error(L, "buffer access out of bounds"); + + if (unsigned(bitcount) > 32) + luaL_error(L, "bit count is out of range of [0; 32]"); + + if (uint64_t(bitoffset + bitcount) > uint64_t(len) * 8) + luaL_error(L, "buffer access out of bounds"); + + unsigned startbyte = unsigned(bitoffset / 8); + unsigned endbyte = unsigned((bitoffset + bitcount + 7) / 8); + + uint64_t data = 0; + +#if defined(LUAU_BIG_ENDIAN) + for (int i = int(endbyte) - 1; i >= int(startbyte); i--) + data = data * 256 + uint8_t(((char*)buf)[i]); +#else + memcpy(&data, (char*)buf + startbyte, endbyte - startbyte); +#endif + + uint64_t subbyteoffset = bitoffset & 0x7; + uint64_t mask = ((1ull << bitcount) - 1) << subbyteoffset; + + data = (data & ~mask) | ((uint64_t(value) << subbyteoffset) & mask); + +#if defined(LUAU_BIG_ENDIAN) + for (int i = int(startbyte); i < int(endbyte); i++) + { + ((char*)buf)[i] = data & 0xff; + data >>= 8; + } +#else + memcpy((char*)buf + startbyte, &data, endbyte - startbyte); +#endif + return 0; +} + +static const luaL_Reg bufferlib_DEPRECATED[] = { {"create", buffer_create}, {"fromstring", buffer_fromstring}, {"tostring", buffer_tostring}, @@ -275,9 +358,39 @@ static const luaL_Reg bufferlib[] = { {NULL, NULL}, }; +static const luaL_Reg bufferlib[] = { + {"create", buffer_create}, + {"fromstring", buffer_fromstring}, + {"tostring", buffer_tostring}, + {"readi8", buffer_readinteger}, + {"readu8", buffer_readinteger}, + {"readi16", buffer_readinteger}, + {"readu16", buffer_readinteger}, + {"readi32", buffer_readinteger}, + {"readu32", buffer_readinteger}, + {"readf32", buffer_readfp}, + {"readf64", buffer_readfp}, + {"writei8", buffer_writeinteger}, + {"writeu8", buffer_writeinteger}, + {"writei16", buffer_writeinteger}, + {"writeu16", buffer_writeinteger}, + {"writei32", buffer_writeinteger}, + {"writeu32", buffer_writeinteger}, + {"writef32", buffer_writefp}, + {"writef64", buffer_writefp}, + {"readstring", buffer_readstring}, + {"writestring", buffer_writestring}, + {"len", buffer_len}, + {"copy", buffer_copy}, + {"fill", buffer_fill}, + {"readbits", buffer_readbits}, + {"writebits", buffer_writebits}, + {NULL, NULL}, +}; + int luaopen_buffer(lua_State* L) { - luaL_register(L, LUA_BUFFERLIBNAME, bufferlib); + luaL_register(L, LUA_BUFFERLIBNAME, FFlag::LuauBufferBitMethods2 ? bufferlib : bufferlib_DEPRECATED); return 1; } diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index 0bca4495..74702fe5 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -25,6 +25,8 @@ #endif #endif +LUAU_FASTFLAG(LuauVector2Constructor) + // luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM // The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack. // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path @@ -998,7 +1000,7 @@ static int luauF_rawset(lua_State* L, StkId res, TValue* arg0, int nresults, Stk else if (ttisvector(key) && luai_vecisnan(vvalue(key))) return -1; - Table* t = hvalue(arg0); + LuaTable* t = hvalue(arg0); if (t->readonly) return -1; @@ -1015,7 +1017,7 @@ static int luauF_tinsert(lua_State* L, StkId res, TValue* arg0, int nresults, St { if (nparams == 2 && nresults <= 0 && ttistable(arg0)) { - Table* t = hvalue(arg0); + LuaTable* t = hvalue(arg0); if (t->readonly) return -1; @@ -1032,7 +1034,7 @@ static int luauF_tunpack(lua_State* L, StkId res, TValue* arg0, int nresults, St { if (nparams >= 1 && nresults < 0 && ttistable(arg0)) { - Table* t = hvalue(arg0); + LuaTable* t = hvalue(arg0); int n = -1; if (nparams == 1) @@ -1055,26 +1057,60 @@ static int luauF_tunpack(lua_State* L, StkId res, TValue* arg0, int nresults, St static int luauF_vector(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { - if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) + if (FFlag::LuauVector2Constructor) { - double x = nvalue(arg0); - double y = nvalue(args); - double z = nvalue(args + 1); + if (nparams >= 2 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args)) + { + float x = (float)nvalue(arg0); + float y = (float)nvalue(args); + float z = 0.0f; + + if (nparams >= 3) + { + if (!ttisnumber(args + 1)) + return -1; + z = (float)nvalue(args + 1); + } #if LUA_VECTOR_SIZE == 4 - double w = 0.0; - if (nparams >= 4) - { - if (!ttisnumber(args + 2)) - return -1; - w = nvalue(args + 2); - } - setvvalue(res, float(x), float(y), float(z), float(w)); + float w = 0.0f; + if (nparams >= 4) + { + if (!ttisnumber(args + 2)) + return -1; + w = (float)nvalue(args + 2); + } + setvvalue(res, x, y, z, w); #else - setvvalue(res, float(x), float(y), float(z), 0.0f); + setvvalue(res, x, y, z, 0.0f); #endif - return 1; + return 1; + } + } + else + { + if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) + { + double x = nvalue(arg0); + double y = nvalue(args); + double z = nvalue(args + 1); + +#if LUA_VECTOR_SIZE == 4 + double w = 0.0; + if (nparams >= 4) + { + if (!ttisnumber(args + 2)) + return -1; + w = nvalue(args + 2); + } + setvvalue(res, float(x), float(y), float(z), float(w)); +#else + setvvalue(res, float(x), float(y), float(z), 0.0f); +#endif + + return 1; + } } return -1; @@ -1160,7 +1196,7 @@ static int luauF_rawlen(lua_State* L, StkId res, TValue* arg0, int nresults, Stk { if (ttistable(arg0)) { - Table* h = hvalue(arg0); + LuaTable* h = hvalue(arg0); setnvalue(res, double(luaH_getn(h))); return 1; } @@ -1204,7 +1240,7 @@ static int luauF_getmetatable(lua_State* L, StkId res, TValue* arg0, int nresult { if (nparams >= 1 && nresults <= 1) { - Table* mt = NULL; + LuaTable* mt = NULL; if (ttistable(arg0)) mt = hvalue(arg0)->metatable; else if (ttisuserdata(arg0)) @@ -1239,11 +1275,11 @@ static int luauF_setmetatable(lua_State* L, StkId res, TValue* arg0, int nresult // note: setmetatable(_, nil) is rare so we use fallback for it to optimize the fast path if (nparams >= 2 && nresults <= 1 && ttistable(arg0) && ttistable(args)) { - Table* t = hvalue(arg0); + LuaTable* t = hvalue(arg0); if (t->readonly || t->metatable != NULL) return -1; // note: overwriting non-null metatable is very rare but it requires __metatable check - Table* mt = hvalue(args); + LuaTable* mt = hvalue(args); t->metatable = mt; luaC_objbarrier(L, t, mt); @@ -1694,6 +1730,23 @@ static int luauF_vectormax(lua_State* L, StkId res, TValue* arg0, int nresults, return -1; } +static int luauF_lerp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) + { + double a = nvalue(arg0); + double b = nvalue(args); + double t = nvalue(args + 1); + + double r = (t == 1.0) ? b : a + (b - a) * t; + + setnvalue(res, r); + return 1; + } + + return -1; +} + static int luauF_missing(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { return -1; @@ -1889,6 +1942,8 @@ const luau_FastFunction luauF_table[256] = { luauF_vectormin, luauF_vectormax, + luauF_lerp, + // When adding builtins, add them above this line; what follows is 64 "dummy" entries with luauF_missing fallback. // This is important so that older versions of the runtime that don't support newer builtins automatically fall back via luauF_missing. // Given the builtin addition velocity this should always provide a larger compatibility window than bytecode versions suggest. diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index 941137e9..5a372aec 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -6,7 +6,6 @@ #include "lstate.h" #include "lvm.h" -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauCoroCheckStack, false) LUAU_DYNAMIC_FASTFLAG(LuauStackLimit) #define CO_STATUS_ERROR -1 @@ -41,7 +40,7 @@ static int auxresume(lua_State* L, lua_State* co, int narg) luaL_error(L, "too many arguments to resume"); lua_xmove(L, co, narg); } - else if (DFFlag::LuauCoroCheckStack) + else { // coroutine might be completely full already if ((co->top - co->base) > LUAI_MAXCSTACK) diff --git a/VM/src/ldblib.cpp b/VM/src/ldblib.cpp index cab4dd6f..ff9fdd76 100644 --- a/VM/src/ldblib.cpp +++ b/VM/src/ldblib.cpp @@ -8,8 +8,6 @@ #include #include -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauDebugInfoInvArgLeftovers, false) - static lua_State* getthread(lua_State* L, int* arg) { if (lua_isthread(L, 1)) @@ -110,7 +108,7 @@ static int db_info(lua_State* L) default: // restore stack state of another thread as 'f' option might not have been visited yet - if (DFFlag::LuauDebugInfoInvArgLeftovers && L != L1) + if (L != L1) lua_settop(L1, l1top); luaL_argerror(L, arg + 2, "invalid option"); diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 07cc117e..44da57c2 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -422,6 +422,20 @@ int luaG_isnative(lua_State* L, int level) return (ci->flags & LUA_CALLINFO_NATIVE) != 0 ? 1 : 0; } +int luaG_hasnative(lua_State* L, int level) +{ + if (unsigned(level) >= unsigned(L->ci - L->base_ci)) + return 0; + + CallInfo* ci = L->ci - level; + + Proto* proto = getluaproto(ci); + if (proto == nullptr) + return 0; + + return (proto->execdata != nullptr); +} + void lua_singlestep(lua_State* L, int enabled) { L->singlestep = bool(enabled); diff --git a/VM/src/ldebug.h b/VM/src/ldebug.h index 49b1ca88..f215e815 100644 --- a/VM/src/ldebug.h +++ b/VM/src/ldebug.h @@ -31,3 +31,4 @@ LUAI_FUNC bool luaG_onbreak(lua_State* L); LUAI_FUNC int luaG_getline(Proto* p, int pc); LUAI_FUNC int luaG_isnative(lua_State* L, int level); +LUAI_FUNC int luaG_hasnative(lua_State* L, int level); diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 28ab00b6..f9fe30d6 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -18,6 +18,7 @@ #include LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauStackLimit, false) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauPopIncompleteCi, false) // keep max stack allocation request under 1GB #define MAX_STACK_SIZE (int(1024 / sizeof(TValue)) * 1024 * 1024) @@ -179,11 +180,23 @@ static void correctstack(lua_State* L, TValue* oldstack) L->base = (L->base - oldstack) + L->stack; } -void luaD_reallocstack(lua_State* L, int newsize) +void luaD_reallocstack(lua_State* L, int newsize, int fornewci) { // throw 'out of memory' error because space for a custom error message cannot be guaranteed here if (DFFlag::LuauStackLimit && newsize > MAX_STACK_SIZE) + { + // reallocation was performaed to setup a new CallInfo frame, which we have to remove + if (DFFlag::LuauPopIncompleteCi && fornewci) + { + CallInfo* cip = L->ci - 1; + + L->ci = cip; + L->base = cip->base; + L->top = cip->top; + } + luaD_throw(L, LUA_ERRMEM); + } TValue* oldstack = L->stack; int realsize = newsize + EXTRA_STACK; @@ -208,10 +221,17 @@ void luaD_reallocCI(lua_State* L, int newsize) void luaD_growstack(lua_State* L, int n) { - if (n <= L->stacksize) // double size is enough? - luaD_reallocstack(L, 2 * L->stacksize); + if (DFFlag::LuauPopIncompleteCi) + { + luaD_reallocstack(L, getgrownstacksize(L, n), 0); + } else - luaD_reallocstack(L, L->stacksize + n); + { + if (n <= L->stacksize) // double size is enough? + luaD_reallocstack(L, 2 * L->stacksize, 0); + else + luaD_reallocstack(L, L->stacksize + n, 0); + } } CallInfo* luaD_growCI(lua_State* L) diff --git a/VM/src/ldo.h b/VM/src/ldo.h index 0f7b42ad..707af0ee 100644 --- a/VM/src/ldo.h +++ b/VM/src/ldo.h @@ -7,11 +7,21 @@ #include "luaconf.h" #include "ldebug.h" +// returns target stack for 'n' extra elements to reallocate +// if possible, stack size growth factor is 2x +#define getgrownstacksize(L, n) ((n) <= L->stacksize ? 2 * L->stacksize : L->stacksize + (n)) + +#define luaD_checkstackfornewci(L, n) \ + if ((char*)L->stack_last - (char*)L->top <= (n) * (int)sizeof(TValue)) \ + luaD_reallocstack(L, getgrownstacksize(L, (n)), 1); \ + else \ + condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK, 1)); + #define luaD_checkstack(L, n) \ if ((char*)L->stack_last - (char*)L->top <= (n) * (int)sizeof(TValue)) \ luaD_growstack(L, n); \ else \ - condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK)); + condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK, 0)); #define incr_top(L) \ { \ @@ -47,7 +57,7 @@ LUAI_FUNC CallInfo* luaD_growCI(lua_State* L); LUAI_FUNC void luaD_call(lua_State* L, StkId func, int nresults); LUAI_FUNC int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t oldtop, ptrdiff_t ef); LUAI_FUNC void luaD_reallocCI(lua_State* L, int newsize); -LUAI_FUNC void luaD_reallocstack(lua_State* L, int newsize); +LUAI_FUNC void luaD_reallocstack(lua_State* L, int newsize, int fornewci); LUAI_FUNC void luaD_growstack(lua_State* L, int n); LUAI_FUNC void luaD_checkCstack(lua_State* L); diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 2a1e45c4..b172d0ad 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -55,7 +55,7 @@ Proto* luaF_newproto(lua_State* L) return f; } -Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p) +Closure* luaF_newLclosure(lua_State* L, int nelems, LuaTable* e, Proto* p) { Closure* c = luaM_newgco(L, Closure, sizeLclosure(nelems), L->activememcat); luaC_init(L, c, LUA_TFUNCTION); @@ -70,7 +70,7 @@ Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p) return c; } -Closure* luaF_newCclosure(lua_State* L, int nelems, Table* e) +Closure* luaF_newCclosure(lua_State* L, int nelems, LuaTable* e) { Closure* c = luaM_newgco(L, Closure, sizeCclosure(nelems), L->activememcat); luaC_init(L, c, LUA_TFUNCTION); diff --git a/VM/src/lfunc.h b/VM/src/lfunc.h index 679836e7..453cf581 100644 --- a/VM/src/lfunc.h +++ b/VM/src/lfunc.h @@ -8,8 +8,8 @@ #define sizeLclosure(n) (offsetof(Closure, l.uprefs) + sizeof(TValue) * (n)) LUAI_FUNC Proto* luaF_newproto(lua_State* L); -LUAI_FUNC Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p); -LUAI_FUNC Closure* luaF_newCclosure(lua_State* L, int nelems, Table* e); +LUAI_FUNC Closure* luaF_newLclosure(lua_State* L, int nelems, LuaTable* e, Proto* p); +LUAI_FUNC Closure* luaF_newCclosure(lua_State* L, int nelems, LuaTable* e); LUAI_FUNC UpVal* luaF_findupval(lua_State* L, StkId level); LUAI_FUNC void luaF_close(lua_State* L, StkId level); LUAI_FUNC void luaF_closeupval(lua_State* L, UpVal* uv, bool dead); diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 6ba758df..c5e16e43 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -14,8 +14,6 @@ #include -LUAU_DYNAMIC_FASTFLAG(LuauCoroCheckStack) - /* * Luau uses an incremental non-generational non-moving mark&sweep garbage collector. * @@ -246,7 +244,7 @@ static void reallymarkobject(global_State* g, GCObject* o) } case LUA_TUSERDATA: { - Table* mt = gco2u(o)->metatable; + LuaTable* mt = gco2u(o)->metatable; gray2black(o); // udata are never gray if (mt) markobject(g, mt); @@ -294,7 +292,7 @@ static void reallymarkobject(global_State* g, GCObject* o) } } -static const char* gettablemode(global_State* g, Table* h) +static const char* gettablemode(global_State* g, LuaTable* h) { const TValue* mode = gfasttm(g, h->metatable, TM_MODE); @@ -304,13 +302,13 @@ static const char* gettablemode(global_State* g, Table* h) return NULL; } -static int traversetable(global_State* g, Table* h) +static int traversetable(global_State* g, LuaTable* h) { int i; int weakkey = 0; int weakvalue = 0; if (h->metatable) - markobject(g, cast_to(Table*, h->metatable)); + markobject(g, cast_to(LuaTable*, h->metatable)); // is there a weak mode? if (const char* modev = gettablemode(g, h)) @@ -439,26 +437,13 @@ static void shrinkstack(lua_State* L) if (L->size_ci > LUAI_MAXCALLS) // handling overflow? return; // do not touch the stacks - if (DFFlag::LuauCoroCheckStack) - { - if (3 * size_t(ci_used) < size_t(L->size_ci) && 2 * BASIC_CI_SIZE < L->size_ci) - luaD_reallocCI(L, L->size_ci / 2); // still big enough... - condhardstacktests(luaD_reallocCI(L, ci_used + 1)); + if (3 * size_t(ci_used) < size_t(L->size_ci) && 2 * BASIC_CI_SIZE < L->size_ci) + luaD_reallocCI(L, L->size_ci / 2); // still big enough... + condhardstacktests(luaD_reallocCI(L, ci_used + 1)); - if (3 * size_t(s_used) < size_t(L->stacksize) && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) - luaD_reallocstack(L, L->stacksize / 2); // still big enough... - condhardstacktests(luaD_reallocstack(L, s_used)); - } - else - { - if (3 * ci_used < L->size_ci && 2 * BASIC_CI_SIZE < L->size_ci) - luaD_reallocCI(L, L->size_ci / 2); // still big enough... - condhardstacktests(luaD_reallocCI(L, ci_used + 1)); - - if (3 * s_used < L->stacksize && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) - luaD_reallocstack(L, L->stacksize / 2); // still big enough... - condhardstacktests(luaD_reallocstack(L, s_used)); - } + if (3 * size_t(s_used) < size_t(L->stacksize) && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) + luaD_reallocstack(L, L->stacksize / 2, 0); // still big enough... + condhardstacktests(luaD_reallocstack(L, s_used)); } /* @@ -474,11 +459,11 @@ static size_t propagatemark(global_State* g) { case LUA_TTABLE: { - Table* h = gco2h(o); + LuaTable* h = gco2h(o); g->gray = h->gclist; if (traversetable(g, h)) // table is weak? black2gray(o); // keep it gray - return sizeof(Table) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h); + return sizeof(LuaTable) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h); } case LUA_TFUNCTION: { @@ -568,8 +553,8 @@ static size_t cleartable(lua_State* L, GCObject* l) size_t work = 0; while (l) { - Table* h = gco2h(l); - work += sizeof(Table) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h); + LuaTable* h = gco2h(l); + work += sizeof(LuaTable) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h); int i = h->sizearray; while (i--) @@ -1170,7 +1155,7 @@ void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v) makewhite(g, o); // mark as white just to avoid other barriers } -void luaC_barriertable(lua_State* L, Table* t, GCObject* v) +void luaC_barriertable(lua_State* L, LuaTable* t, GCObject* v) { global_State* g = L->global; GCObject* o = obj2gco(t); diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 722de9d1..683542b6 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -131,7 +131,7 @@ LUAI_FUNC void luaC_fullgc(lua_State* L); LUAI_FUNC void luaC_initobj(lua_State* L, GCObject* o, uint8_t tt); LUAI_FUNC void luaC_upvalclosed(lua_State* L, UpVal* uv); LUAI_FUNC void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v); -LUAI_FUNC void luaC_barriertable(lua_State* L, Table* t, GCObject* v); +LUAI_FUNC void luaC_barriertable(lua_State* L, LuaTable* t, GCObject* v); LUAI_FUNC void luaC_barrierback(lua_State* L, GCObject* o, GCObject** gclist); LUAI_FUNC void luaC_validate(lua_State* L); LUAI_FUNC void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)); diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index 768561cb..7a47ab86 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -34,7 +34,7 @@ static void validateref(global_State* g, GCObject* f, TValue* v) } } -static void validatetable(global_State* g, Table* h) +static void validatetable(global_State* g, LuaTable* h) { int sizenode = 1 << h->lsizenode; @@ -290,9 +290,9 @@ static void dumpstring(FILE* f, TString* ts) fprintf(f, "\"}"); } -static void dumptable(FILE* f, Table* h) +static void dumptable(FILE* f, LuaTable* h) { - size_t size = sizeof(Table) + (h->node == &luaH_dummynode ? 0 : sizenode(h) * sizeof(LuaNode)) + h->sizearray * sizeof(TValue); + size_t size = sizeof(LuaTable) + (h->node == &luaH_dummynode ? 0 : sizenode(h) * sizeof(LuaNode)) + h->sizearray * sizeof(TValue); fprintf(f, "{\"type\":\"table\",\"cat\":%d,\"size\":%d", h->memcat, int(size)); @@ -654,9 +654,9 @@ static void enumstring(EnumContext* ctx, TString* ts) enumnode(ctx, obj2gco(ts), ts->len, NULL); } -static void enumtable(EnumContext* ctx, Table* h) +static void enumtable(EnumContext* ctx, LuaTable* h) { - size_t size = sizeof(Table) + (h->node == &luaH_dummynode ? 0 : sizenode(h) * sizeof(LuaNode)) + h->sizearray * sizeof(TValue); + size_t size = sizeof(LuaTable) + (h->node == &luaH_dummynode ? 0 : sizenode(h) * sizeof(LuaNode)) + h->sizearray * sizeof(TValue); // Provide a name for a special registry table enumnode(ctx, obj2gco(h), size, h == hvalue(registry(ctx->L)) ? "registry" : NULL); @@ -754,7 +754,7 @@ static void enumudata(EnumContext* ctx, Udata* u) { const char* name = NULL; - if (Table* h = u->metatable) + if (LuaTable* h = u->metatable) { if (h->node != &luaH_dummynode) { diff --git a/VM/src/lmathlib.cpp b/VM/src/lmathlib.cpp index 3a93abcf..9bd21607 100644 --- a/VM/src/lmathlib.cpp +++ b/VM/src/lmathlib.cpp @@ -7,7 +7,7 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauMathMap) +LUAU_FASTFLAGVARIABLE(LuauMathLerp) #undef PI #define PI (3.14159265358979323846) @@ -418,6 +418,17 @@ static int math_map(lua_State* L) return 1; } +static int math_lerp(lua_State* L) +{ + double a = luaL_checknumber(L, 1); + double b = luaL_checknumber(L, 2); + double t = luaL_checknumber(L, 3); + + double r = (t == 1.0) ? b : a + (b - a) * t; + lua_pushnumber(L, r); + return 1; +} + static const luaL_Reg mathlib[] = { {"abs", math_abs}, {"acos", math_acos}, @@ -451,6 +462,7 @@ static const luaL_Reg mathlib[] = { {"clamp", math_clamp}, {"sign", math_sign}, {"round", math_round}, + {"map", math_map}, {NULL, NULL}, }; @@ -471,10 +483,10 @@ int luaopen_math(lua_State* L) lua_pushnumber(L, HUGE_VAL); lua_setfield(L, -2, "huge"); - if (FFlag::LuauMathMap) + if (FFlag::LuauMathLerp) { - lua_pushcfunction(L, math_map, "map"); - lua_setfield(L, -2, "map"); + lua_pushcfunction(L, math_lerp, "lerp"); + lua_setfield(L, -2, "lerp"); } return 1; diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index f65d79dc..0738840b 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -121,7 +121,7 @@ static_assert(sizeof(LuaNode) == ABISWITCH(32, 32, 32), "size mismatch for table static_assert(offsetof(TString, data) == ABISWITCH(24, 20, 20), "size mismatch for string header"); static_assert(offsetof(Udata, data) == ABISWITCH(16, 16, 12), "size mismatch for userdata header"); -static_assert(sizeof(Table) == ABISWITCH(48, 32, 32), "size mismatch for table header"); +static_assert(sizeof(LuaTable) == ABISWITCH(48, 32, 32), "size mismatch for table header"); static_assert(offsetof(Buffer, data) == ABISWITCH(8, 8, 8), "size mismatch for buffer header"); const size_t kSizeClasses = LUA_SIZECLASSES; @@ -192,7 +192,7 @@ struct SizeClassConfig const SizeClassConfig kSizeClassConfig; // size class for a block of size sz; returns -1 for size=0 because empty allocations take no space -#define sizeclass(sz) (size_t((sz)-1) < kMaxSmallSizeUsed ? kSizeClassConfig.classForSize[sz] : -1) +#define sizeclass(sz) (size_t((sz) - 1) < kMaxSmallSizeUsed ? kSizeClassConfig.classForSize[sz] : -1) // metadata for a block is stored in the first pointer of the block #define metadata(block) (*(void**)(block)) diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 18c69641..6719faaf 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -263,7 +263,7 @@ typedef struct Udata int len; - struct Table* metatable; + struct LuaTable* metatable; union { @@ -390,7 +390,7 @@ typedef struct Closure uint8_t preload; GCObject* gclist; - struct Table* env; + struct LuaTable* env; union { @@ -454,7 +454,7 @@ typedef struct LuaNode } // clang-format off -typedef struct Table +typedef struct LuaTable { CommonHeader; @@ -473,11 +473,11 @@ typedef struct Table }; - struct Table* metatable; + struct LuaTable* metatable; TValue* array; // array part LuaNode* node; GCObject* gclist; -} Table; +} LuaTable; // clang-format on /* diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index 6b7a9aa0..ddb1e12e 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -149,7 +149,7 @@ void lua_resetthread(lua_State* L) L->nCcalls = L->baseCcalls = 0; // clear thread stack if (L->stacksize != BASIC_STACK_SIZE + EXTRA_STACK) - luaD_reallocstack(L, BASIC_STACK_SIZE); + luaD_reallocstack(L, BASIC_STACK_SIZE, 0); for (int i = 0; i < L->stacksize; i++) setnilvalue(L->stack + i); } diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 3f4f9425..ad162391 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -198,7 +198,7 @@ typedef struct global_State struct lua_State* mainthread; UpVal uvhead; // head of double-linked list of all open upvalues - struct Table* mt[LUA_T_COUNT]; // metatables for basic types + struct LuaTable* mt[LUA_T_COUNT]; // metatables for basic types TString* ttname[LUA_T_COUNT]; // names for basic types TString* tmname[TM_N]; // array with tag-method names @@ -217,7 +217,7 @@ typedef struct global_State lua_ExecutionCallbacks ecb; void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory - Table* udatamt[LUA_UTAG_LIMIT]; // metatables for tagged userdata + LuaTable* udatamt[LUA_UTAG_LIMIT]; // metatables for tagged userdata TString* lightuserdataname[LUA_LUTAG_LIMIT]; // names for tagged lightuserdata @@ -266,7 +266,7 @@ struct lua_State int cachedslot; // when table operations or INDEX/NEWINDEX is invoked from Luau, what is the expected slot for lookup? - Table* gt; // table of globals + LuaTable* gt; // table of globals UpVal* openupval; // list of open upvalues in this stack GCObject* gclist; @@ -285,7 +285,7 @@ union GCObject struct TString ts; struct Udata u; struct Closure cl; - struct Table h; + struct LuaTable h; struct Proto p; struct UpVal uv; struct lua_State th; // thread diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index dafb2b3f..ee5ae7ec 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -58,7 +58,7 @@ const LuaNode luaH_dummynode = { #define hashstr(t, str) hashpow2(t, (str)->hash) #define hashboolean(t, p) hashpow2(t, p) -static LuaNode* hashpointer(const Table* t, const void* p) +static LuaNode* hashpointer(const LuaTable* t, const void* p) { // we discard the high 32-bit portion of the pointer on 64-bit platforms as it doesn't carry much entropy anyway unsigned int h = unsigned(uintptr_t(p)); @@ -73,7 +73,7 @@ static LuaNode* hashpointer(const Table* t, const void* p) return hashpow2(t, h); } -static LuaNode* hashnum(const Table* t, double n) +static LuaNode* hashnum(const LuaTable* t, double n) { static_assert(sizeof(double) == sizeof(unsigned int) * 2, "expected a 8-byte double"); unsigned int i[2]; @@ -99,7 +99,7 @@ static LuaNode* hashnum(const Table* t, double n) return hashpow2(t, h2); } -static LuaNode* hashvec(const Table* t, const float* v) +static LuaNode* hashvec(const LuaTable* t, const float* v) { unsigned int i[LUA_VECTOR_SIZE]; memcpy(i, v, sizeof(i)); @@ -130,7 +130,7 @@ static LuaNode* hashvec(const Table* t, const float* v) ** returns the `main' position of an element in a table (that is, the index ** of its hash value) */ -static LuaNode* mainposition(const Table* t, const TValue* key) +static LuaNode* mainposition(const LuaTable* t, const TValue* key) { switch (ttype(key)) { @@ -166,7 +166,7 @@ static int arrayindex(double key) ** elements in the array part, then elements in the hash part. The ** beginning of a traversal is signalled by -1. */ -static int findindex(lua_State* L, Table* t, StkId key) +static int findindex(lua_State* L, LuaTable* t, StkId key) { int i; if (ttisnil(key)) @@ -194,7 +194,7 @@ static int findindex(lua_State* L, Table* t, StkId key) } } -int luaH_next(lua_State* L, Table* t, StkId key) +int luaH_next(lua_State* L, LuaTable* t, StkId key) { int i = findindex(L, t, key); // find original element for (i++; i < t->sizearray; i++) @@ -270,7 +270,7 @@ static int countint(double key, int* nums) return 0; } -static int numusearray(const Table* t, int* nums) +static int numusearray(const LuaTable* t, int* nums) { int lg; int ttlg; // 2^lg @@ -298,7 +298,7 @@ static int numusearray(const Table* t, int* nums) return ause; } -static int numusehash(const Table* t, int* nums, int* pnasize) +static int numusehash(const LuaTable* t, int* nums, int* pnasize) { int totaluse = 0; // total number of elements int ause = 0; // summation of `nums' @@ -317,7 +317,7 @@ static int numusehash(const Table* t, int* nums, int* pnasize) return totaluse; } -static void setarrayvector(lua_State* L, Table* t, int size) +static void setarrayvector(lua_State* L, LuaTable* t, int size) { if (size > MAXSIZE) luaG_runerror(L, "table overflow"); @@ -328,7 +328,7 @@ static void setarrayvector(lua_State* L, Table* t, int size) t->sizearray = size; } -static void setnodevector(lua_State* L, Table* t, int size) +static void setnodevector(lua_State* L, LuaTable* t, int size) { int lsize; if (size == 0) @@ -357,9 +357,9 @@ static void setnodevector(lua_State* L, Table* t, int size) t->lastfree = size; // all positions are free } -static TValue* newkey(lua_State* L, Table* t, const TValue* key); +static TValue* newkey(lua_State* L, LuaTable* t, const TValue* key); -static TValue* arrayornewkey(lua_State* L, Table* t, const TValue* key) +static TValue* arrayornewkey(lua_State* L, LuaTable* t, const TValue* key) { if (ttisnumber(key)) { @@ -373,7 +373,7 @@ static TValue* arrayornewkey(lua_State* L, Table* t, const TValue* key) return newkey(L, t, key); } -static void resize(lua_State* L, Table* t, int nasize, int nhsize) +static void resize(lua_State* L, LuaTable* t, int nasize, int nhsize) { if (nasize > MAXSIZE || nhsize > MAXSIZE) luaG_runerror(L, "table overflow"); @@ -424,7 +424,7 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) luaM_freearray(L, nold, twoto(oldhsize), LuaNode, t->memcat); // free old array } -static int adjustasize(Table* t, int size, const TValue* ek) +static int adjustasize(LuaTable* t, int size, const TValue* ek) { bool tbound = t->node != dummynode || size < t->sizearray; int ekindex = ek && ttisnumber(ek) ? arrayindex(nvalue(ek)) : -1; @@ -434,19 +434,19 @@ static int adjustasize(Table* t, int size, const TValue* ek) return size; } -void luaH_resizearray(lua_State* L, Table* t, int nasize) +void luaH_resizearray(lua_State* L, LuaTable* t, int nasize) { int nsize = (t->node == dummynode) ? 0 : sizenode(t); int asize = adjustasize(t, nasize, NULL); resize(L, t, asize, nsize); } -void luaH_resizehash(lua_State* L, Table* t, int nhsize) +void luaH_resizehash(lua_State* L, LuaTable* t, int nhsize) { resize(L, t, t->sizearray, nhsize); } -static void rehash(lua_State* L, Table* t, const TValue* ek) +static void rehash(lua_State* L, LuaTable* t, const TValue* ek) { int nums[MAXBITS + 1]; // nums[i] = number of keys between 2^(i-1) and 2^i for (int i = 0; i <= MAXBITS; i++) @@ -491,9 +491,9 @@ static void rehash(lua_State* L, Table* t, const TValue* ek) ** }============================================================= */ -Table* luaH_new(lua_State* L, int narray, int nhash) +LuaTable* luaH_new(lua_State* L, int narray, int nhash) { - Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); + LuaTable* t = luaM_newgco(L, LuaTable, sizeof(LuaTable), L->activememcat); luaC_init(L, t, LUA_TTABLE); t->metatable = NULL; t->tmcache = cast_byte(~0); @@ -512,16 +512,16 @@ Table* luaH_new(lua_State* L, int narray, int nhash) return t; } -void luaH_free(lua_State* L, Table* t, lua_Page* page) +void luaH_free(lua_State* L, LuaTable* t, lua_Page* page) { if (t->node != dummynode) luaM_freearray(L, t->node, sizenode(t), LuaNode, t->memcat); if (t->array) luaM_freearray(L, t->array, t->sizearray, TValue, t->memcat); - luaM_freegco(L, t, sizeof(Table), t->memcat, page); + luaM_freegco(L, t, sizeof(LuaTable), t->memcat, page); } -static LuaNode* getfreepos(Table* t) +static LuaNode* getfreepos(LuaTable* t) { while (t->lastfree > 0) { @@ -541,7 +541,7 @@ static LuaNode* getfreepos(Table* t) ** put new key in its main position; otherwise (colliding node is in its main ** position), new key goes to an empty position. */ -static TValue* newkey(lua_State* L, Table* t, const TValue* key) +static TValue* newkey(lua_State* L, LuaTable* t, const TValue* key) { // enforce boundary invariant if (ttisnumber(key) && nvalue(key) == t->sizearray + 1) @@ -601,7 +601,7 @@ static TValue* newkey(lua_State* L, Table* t, const TValue* key) /* ** search function for integers */ -const TValue* luaH_getnum(Table* t, int key) +const TValue* luaH_getnum(LuaTable* t, int key) { // (1 <= key && key <= t->sizearray) if (cast_to(unsigned int, key - 1) < cast_to(unsigned int, t->sizearray)) @@ -627,7 +627,7 @@ const TValue* luaH_getnum(Table* t, int key) /* ** search function for strings */ -const TValue* luaH_getstr(Table* t, TString* key) +const TValue* luaH_getstr(LuaTable* t, TString* key) { LuaNode* n = hashstr(t, key); for (;;) @@ -644,7 +644,7 @@ const TValue* luaH_getstr(Table* t, TString* key) /* ** main search function */ -const TValue* luaH_get(Table* t, const TValue* key) +const TValue* luaH_get(LuaTable* t, const TValue* key) { switch (ttype(key)) { @@ -677,7 +677,7 @@ const TValue* luaH_get(Table* t, const TValue* key) } } -TValue* luaH_set(lua_State* L, Table* t, const TValue* key) +TValue* luaH_set(lua_State* L, LuaTable* t, const TValue* key) { const TValue* p = luaH_get(t, key); invalidateTMcache(t); @@ -687,7 +687,7 @@ TValue* luaH_set(lua_State* L, Table* t, const TValue* key) return luaH_newkey(L, t, key); } -TValue* luaH_newkey(lua_State* L, Table* t, const TValue* key) +TValue* luaH_newkey(lua_State* L, LuaTable* t, const TValue* key) { if (ttisnil(key)) luaG_runerror(L, "table index is nil"); @@ -698,7 +698,7 @@ TValue* luaH_newkey(lua_State* L, Table* t, const TValue* key) return newkey(L, t, key); } -TValue* luaH_setnum(lua_State* L, Table* t, int key) +TValue* luaH_setnum(lua_State* L, LuaTable* t, int key) { // (1 <= key && key <= t->sizearray) if (cast_to(unsigned int, key - 1) < cast_to(unsigned int, t->sizearray)) @@ -715,7 +715,7 @@ TValue* luaH_setnum(lua_State* L, Table* t, int key) } } -TValue* luaH_setstr(lua_State* L, Table* t, TString* key) +TValue* luaH_setstr(lua_State* L, LuaTable* t, TString* key) { const TValue* p = luaH_getstr(t, key); invalidateTMcache(t); @@ -729,7 +729,7 @@ TValue* luaH_setstr(lua_State* L, Table* t, TString* key) } } -static int updateaboundary(Table* t, int boundary) +static int updateaboundary(LuaTable* t, int boundary) { if (boundary < t->sizearray && ttisnil(&t->array[boundary - 1])) { @@ -752,7 +752,7 @@ static int updateaboundary(Table* t, int boundary) ** Try to find a boundary in table `t'. A `boundary' is an integer index ** such that t[i] is non-nil and t[i+1] is nil (and 0 if t[1] is nil). */ -int luaH_getn(Table* t) +int luaH_getn(LuaTable* t) { int boundary = getaboundary(t); @@ -793,9 +793,9 @@ int luaH_getn(Table* t) } } -Table* luaH_clone(lua_State* L, Table* tt) +LuaTable* luaH_clone(lua_State* L, LuaTable* tt) { - Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); + LuaTable* t = luaM_newgco(L, LuaTable, sizeof(LuaTable), L->activememcat); luaC_init(L, t, LUA_TTABLE); t->metatable = tt->metatable; t->tmcache = tt->tmcache; @@ -830,7 +830,7 @@ Table* luaH_clone(lua_State* L, Table* tt) return t; } -void luaH_clear(Table* tt) +void luaH_clear(LuaTable* tt) { // clear array part for (int i = 0; i < tt->sizearray; ++i) diff --git a/VM/src/ltable.h b/VM/src/ltable.h index 021f21bf..50d1e643 100644 --- a/VM/src/ltable.h +++ b/VM/src/ltable.h @@ -14,21 +14,21 @@ // reset cache of absent metamethods, cache is updated in luaT_gettm #define invalidateTMcache(t) t->tmcache = 0 -LUAI_FUNC const TValue* luaH_getnum(Table* t, int key); -LUAI_FUNC TValue* luaH_setnum(lua_State* L, Table* t, int key); -LUAI_FUNC const TValue* luaH_getstr(Table* t, TString* key); -LUAI_FUNC TValue* luaH_setstr(lua_State* L, Table* t, TString* key); -LUAI_FUNC const TValue* luaH_get(Table* t, const TValue* key); -LUAI_FUNC TValue* luaH_set(lua_State* L, Table* t, const TValue* key); -LUAI_FUNC TValue* luaH_newkey(lua_State* L, Table* t, const TValue* key); -LUAI_FUNC Table* luaH_new(lua_State* L, int narray, int lnhash); -LUAI_FUNC void luaH_resizearray(lua_State* L, Table* t, int nasize); -LUAI_FUNC void luaH_resizehash(lua_State* L, Table* t, int nhsize); -LUAI_FUNC void luaH_free(lua_State* L, Table* t, struct lua_Page* page); -LUAI_FUNC int luaH_next(lua_State* L, Table* t, StkId key); -LUAI_FUNC int luaH_getn(Table* t); -LUAI_FUNC Table* luaH_clone(lua_State* L, Table* tt); -LUAI_FUNC void luaH_clear(Table* tt); +LUAI_FUNC const TValue* luaH_getnum(LuaTable* t, int key); +LUAI_FUNC TValue* luaH_setnum(lua_State* L, LuaTable* t, int key); +LUAI_FUNC const TValue* luaH_getstr(LuaTable* t, TString* key); +LUAI_FUNC TValue* luaH_setstr(lua_State* L, LuaTable* t, TString* key); +LUAI_FUNC const TValue* luaH_get(LuaTable* t, const TValue* key); +LUAI_FUNC TValue* luaH_set(lua_State* L, LuaTable* t, const TValue* key); +LUAI_FUNC TValue* luaH_newkey(lua_State* L, LuaTable* t, const TValue* key); +LUAI_FUNC LuaTable* luaH_new(lua_State* L, int narray, int lnhash); +LUAI_FUNC void luaH_resizearray(lua_State* L, LuaTable* t, int nasize); +LUAI_FUNC void luaH_resizehash(lua_State* L, LuaTable* t, int nhsize); +LUAI_FUNC void luaH_free(lua_State* L, LuaTable* t, struct lua_Page* page); +LUAI_FUNC int luaH_next(lua_State* L, LuaTable* t, StkId key); +LUAI_FUNC int luaH_getn(LuaTable* t); +LUAI_FUNC LuaTable* luaH_clone(lua_State* L, LuaTable* tt); +LUAI_FUNC void luaH_clear(LuaTable* tt); #define luaH_setslot(L, t, slot, key) (invalidateTMcache(t), (slot == luaO_nilobject ? luaH_newkey(L, t, key) : cast_to(TValue*, slot))) diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 75d9f400..dbe60e4e 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -53,7 +53,7 @@ static int maxn(lua_State* L) double max = 0; luaL_checktype(L, 1, LUA_TTABLE); - Table* t = hvalue(L->base); + LuaTable* t = hvalue(L->base); for (int i = 0; i < t->sizearray; i++) { @@ -87,8 +87,8 @@ static int getn(lua_State* L) static void moveelements(lua_State* L, int srct, int dstt, int f, int e, int t) { - Table* src = hvalue(L->base + (srct - 1)); - Table* dst = hvalue(L->base + (dstt - 1)); + LuaTable* src = hvalue(L->base + (srct - 1)); + LuaTable* dst = hvalue(L->base + (dstt - 1)); if (dst->readonly) luaG_readonlyerror(L); @@ -213,7 +213,7 @@ static int tmove(lua_State* L) int n = e - f + 1; // number of elements to move luaL_argcheck(L, t <= INT_MAX - n + 1, 4, "destination wrap around"); - Table* dst = hvalue(L->base + (tt - 1)); + LuaTable* dst = hvalue(L->base + (tt - 1)); if (dst->readonly) // also checked in moveelements, but this blocks resizes of r/o tables luaG_readonlyerror(L); @@ -229,7 +229,7 @@ static int tmove(lua_State* L) return 1; } -static void addfield(lua_State* L, luaL_Strbuf* b, int i, Table* t) +static void addfield(lua_State* L, luaL_Strbuf* b, int i, LuaTable* t) { if (t && unsigned(i - 1) < unsigned(t->sizearray) && ttisstring(&t->array[i - 1])) { @@ -253,7 +253,7 @@ static int tconcat(lua_State* L) int i = luaL_optinteger(L, 3, 1); int last = luaL_opt(L, luaL_checkinteger, 4, lua_objlen(L, 1)); - Table* t = hvalue(L->base); + LuaTable* t = hvalue(L->base); luaL_Strbuf b; luaL_buffinit(L, &b); @@ -274,7 +274,7 @@ static int tpack(lua_State* L) int n = lua_gettop(L); // number of elements to pack lua_createtable(L, n, 1); // create result table - Table* t = hvalue(L->top - 1); + LuaTable* t = hvalue(L->top - 1); for (int i = 0; i < n; ++i) { @@ -292,7 +292,7 @@ static int tpack(lua_State* L) static int tunpack(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); - Table* t = hvalue(L->base); + LuaTable* t = hvalue(L->base); int i = luaL_optinteger(L, 2, 1); int e = luaL_opt(L, luaL_checkinteger, 3, lua_objlen(L, 1)); @@ -335,7 +335,7 @@ static int sort_func(lua_State* L, const TValue* l, const TValue* r) return !l_isfalse(L->top); } -inline void sort_swap(lua_State* L, Table* t, int i, int j) +inline void sort_swap(lua_State* L, LuaTable* t, int i, int j) { TValue* arr = t->array; int n = t->sizearray; @@ -348,7 +348,7 @@ inline void sort_swap(lua_State* L, Table* t, int i, int j) setobj2t(L, &arr[j], &temp); } -inline int sort_less(lua_State* L, Table* t, int i, int j, SortPredicate pred) +inline int sort_less(lua_State* L, LuaTable* t, int i, int j, SortPredicate pred) { TValue* arr = t->array; int n = t->sizearray; @@ -363,7 +363,7 @@ inline int sort_less(lua_State* L, Table* t, int i, int j, SortPredicate pred) return res; } -static void sort_siftheap(lua_State* L, Table* t, int l, int u, SortPredicate pred, int root) +static void sort_siftheap(lua_State* L, LuaTable* t, int l, int u, SortPredicate pred, int root) { LUAU_ASSERT(l <= u); int count = u - l + 1; @@ -389,7 +389,7 @@ static void sort_siftheap(lua_State* L, Table* t, int l, int u, SortPredicate pr sort_swap(L, t, l + root, l + lastleft); } -static void sort_heap(lua_State* L, Table* t, int l, int u, SortPredicate pred) +static void sort_heap(lua_State* L, LuaTable* t, int l, int u, SortPredicate pred) { LUAU_ASSERT(l <= u); int count = u - l + 1; @@ -404,7 +404,7 @@ static void sort_heap(lua_State* L, Table* t, int l, int u, SortPredicate pred) } } -static void sort_rec(lua_State* L, Table* t, int l, int u, int limit, SortPredicate pred) +static void sort_rec(lua_State* L, LuaTable* t, int l, int u, int limit, SortPredicate pred) { // sort range [l..u] (inclusive, 0-based) while (l < u) @@ -477,7 +477,7 @@ static void sort_rec(lua_State* L, Table* t, int l, int u, int limit, SortPredic static int tsort(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); - Table* t = hvalue(L->base); + LuaTable* t = hvalue(L->base); int n = luaH_getn(t); if (t->readonly) luaG_readonlyerror(L); @@ -504,7 +504,7 @@ static int tcreate(lua_State* L) if (!lua_isnoneornil(L, 2)) { lua_createtable(L, size, 0); - Table* t = hvalue(L->top - 1); + LuaTable* t = hvalue(L->top - 1); StkId v = L->base + 1; @@ -530,7 +530,7 @@ static int tfind(lua_State* L) if (init < 1) luaL_argerror(L, 3, "index out of range"); - Table* t = hvalue(L->base); + LuaTable* t = hvalue(L->base); StkId v = L->base + 1; for (int i = init;; ++i) @@ -554,7 +554,7 @@ static int tclear(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); - Table* tt = hvalue(L->base); + LuaTable* tt = hvalue(L->base); if (tt->readonly) luaG_readonlyerror(L); @@ -587,7 +587,7 @@ static int tclone(lua_State* L) luaL_checktype(L, 1, LUA_TTABLE); luaL_argcheck(L, !luaL_getmetafield(L, 1, "__metatable"), 1, "table has a protected metatable"); - Table* tt = luaH_clone(L, hvalue(L->base)); + LuaTable* tt = luaH_clone(L, hvalue(L->base)); TValue v; sethvalue(L, &v, tt); diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index f38ab80b..f6b0079a 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -86,7 +86,7 @@ void luaT_init(lua_State* L) ** function to be used with macro "fasttm": optimized for absence of ** tag methods. */ -const TValue* luaT_gettm(Table* events, TMS event, TString* ename) +const TValue* luaT_gettm(LuaTable* events, TMS event, TString* ename) { const TValue* tm = luaH_getstr(events, ename); LUAU_ASSERT(event <= TM_EQ); @@ -105,7 +105,7 @@ const TValue* luaT_gettmbyobj(lua_State* L, const TValue* o, TMS event) NB: Tag-methods were replaced by meta-methods in Lua 5.0, but the old names are still around (this function, for example). */ - Table* mt; + LuaTable* mt; switch (ttype(o)) { case LUA_TTABLE: @@ -147,7 +147,7 @@ const TString* luaT_objtypenamestr(lua_State* L, const TValue* o) } // For all types except userdata and table, a global metatable can be set with a global name override - if (Table* mt = L->global->mt[ttype(o)]) + if (LuaTable* mt = L->global->mt[ttype(o)]) { const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]); diff --git a/VM/src/ltm.h b/VM/src/ltm.h index 7dafd4ed..f3294b64 100644 --- a/VM/src/ltm.h +++ b/VM/src/ltm.h @@ -51,7 +51,7 @@ typedef enum LUAI_DATA const char* const luaT_typenames[]; LUAI_DATA const char* const luaT_eventname[]; -LUAI_FUNC const TValue* luaT_gettm(Table* events, TMS event, TString* ename); +LUAI_FUNC const TValue* luaT_gettm(LuaTable* events, TMS event, TString* ename); LUAI_FUNC const TValue* luaT_gettmbyobj(lua_State* L, const TValue* o, TMS event); LUAI_FUNC const TString* luaT_objtypenamestr(lua_State* L, const TValue* o); diff --git a/VM/src/lveclib.cpp b/VM/src/lveclib.cpp index 2a4e58c6..c08087bd 100644 --- a/VM/src/lveclib.cpp +++ b/VM/src/lveclib.cpp @@ -6,17 +6,19 @@ #include -LUAU_FASTFLAGVARIABLE(LuauVectorMetatable) +LUAU_FASTFLAGVARIABLE(LuauVector2Constructor) static int vector_create(lua_State* L) { + // checking argument count to avoid accepting 'nil' as a valid value + int count = lua_gettop(L); + double x = luaL_checknumber(L, 1); double y = luaL_checknumber(L, 2); - double z = luaL_checknumber(L, 3); + double z = FFlag::LuauVector2Constructor ? (count >= 3 ? luaL_checknumber(L, 3) : 0.0) : luaL_checknumber(L, 3); #if LUA_VECTOR_SIZE == 4 - // checking argument count to avoid accepting 'nil' as a valid value - double w = lua_gettop(L) >= 4 ? luaL_checknumber(L, 4) : 0.0; + double w = count >= 4 ? luaL_checknumber(L, 4) : 0.0; lua_pushvector(L, float(x), float(y), float(z), float(w)); #else @@ -258,8 +260,6 @@ static int vector_max(lua_State* L) static int vector_index(lua_State* L) { - LUAU_ASSERT(FFlag::LuauVectorMetatable); - const float* v = luaL_checkvector(L, 1); size_t namelen = 0; const char* name = luaL_checklstring(L, 2, &namelen); @@ -304,8 +304,6 @@ static const luaL_Reg vectorlib[] = { static void createmetatable(lua_State* L) { - LUAU_ASSERT(FFlag::LuauVectorMetatable); - lua_createtable(L, 0, 1); // create metatable for vectors // push dummy vector @@ -342,8 +340,7 @@ int luaopen_vector(lua_State* L) lua_setfield(L, -2, "one"); #endif - if (FFlag::LuauVectorMetatable) - createmetatable(L); + createmetatable(L); return 1; } diff --git a/VM/src/lvm.h b/VM/src/lvm.h index 0b8690be..6989bcee 100644 --- a/VM/src/lvm.h +++ b/VM/src/lvm.h @@ -26,7 +26,7 @@ LUAI_FUNC int luaV_tostring(lua_State* L, StkId obj); LUAI_FUNC void luaV_gettable(lua_State* L, const TValue* t, TValue* key, StkId val); LUAI_FUNC void luaV_settable(lua_State* L, const TValue* t, TValue* key, StkId val); LUAI_FUNC void luaV_concat(lua_State* L, int total, int last); -LUAI_FUNC void luaV_getimport(lua_State* L, Table* env, TValue* k, StkId res, uint32_t id, bool propagatenil); +LUAI_FUNC void luaV_getimport(lua_State* L, LuaTable* env, TValue* k, StkId res, uint32_t id, bool propagatenil); LUAI_FUNC void luaV_prepareFORN(lua_State* L, StkId plimit, StkId pstep, StkId pinit); LUAI_FUNC void luaV_callTM(lua_State* L, int nparams, int res); LUAI_FUNC void luaV_tryfuncTM(lua_State* L, StkId func); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index d73f6496..ce07d878 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,6 +16,8 @@ #include +LUAU_DYNAMIC_FASTFLAG(LuauPopIncompleteCi) + // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -328,7 +330,7 @@ reentry: LUAU_ASSERT(ttisstring(kv)); // fast-path: value is in expected slot - Table* h = cl->env; + LuaTable* h = cl->env; int slot = LUAU_INSN_C(insn) & h->nodemask8; LuaNode* n = &h->node[slot]; @@ -359,7 +361,7 @@ reentry: LUAU_ASSERT(ttisstring(kv)); // fast-path: value is in expected slot - Table* h = cl->env; + LuaTable* h = cl->env; int slot = LUAU_INSN_C(insn) & h->nodemask8; LuaNode* n = &h->node[slot]; @@ -449,7 +451,7 @@ reentry: // fast-path: built-in table if (LUAU_LIKELY(ttistable(rb))) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); int slot = LUAU_INSN_C(insn) & h->nodemask8; LuaNode* n = &h->node[slot]; @@ -566,7 +568,7 @@ reentry: // fast-path: built-in table if (LUAU_LIKELY(ttistable(rb))) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); int slot = LUAU_INSN_C(insn) & h->nodemask8; LuaNode* n = &h->node[slot]; @@ -640,7 +642,7 @@ reentry: // fast-path: array lookup if (ttistable(rb) && ttisnumber(rc)) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); double indexd = nvalue(rc); int index = int(indexd); @@ -670,7 +672,7 @@ reentry: // fast-path: array assign if (ttistable(rb) && ttisnumber(rc)) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); double indexd = nvalue(rc); int index = int(indexd); @@ -701,7 +703,7 @@ reentry: // fast-path: array lookup if (ttistable(rb)) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); if (LUAU_LIKELY(unsigned(c) < unsigned(h->sizearray) && !h->metatable)) { @@ -729,7 +731,7 @@ reentry: // fast-path: array assign if (ttistable(rb)) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); if (LUAU_LIKELY(unsigned(c) < unsigned(h->sizearray) && !h->metatable && !h->readonly)) { @@ -802,7 +804,7 @@ reentry: if (LUAU_LIKELY(ttistable(rb))) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); // note: we can't use nodemask8 here because we need to query the main position of the table, and 8-bit nodemask8 only works // for predictive lookups LuaNode* n = &h->node[tsvalue(kv)->hash & (sizenode(h) - 1)]; @@ -842,7 +844,7 @@ reentry: } else { - Table* mt = ttisuserdata(rb) ? uvalue(rb)->metatable : L->global->mt[ttype(rb)]; + LuaTable* mt = ttisuserdata(rb) ? uvalue(rb)->metatable : L->global->mt[ttype(rb)]; const TValue* tmi = 0; // fast-path: metatable with __namecall @@ -856,7 +858,7 @@ reentry: } else if ((tmi = fasttm(L, mt, TM_INDEX)) && ttistable(tmi)) { - Table* h = hvalue(tmi); + LuaTable* h = hvalue(tmi); int slot = LUAU_INSN_C(insn) & h->nodemask8; LuaNode* n = &h->node[slot]; @@ -935,7 +937,14 @@ reentry: // note: this reallocs stack, but we don't need to VM_PROTECT this // this is because we're going to modify base/savedpc manually anyhow // crucially, we can't use ra/argtop after this line - luaD_checkstack(L, ccl->stacksize); + if (DFFlag::LuauPopIncompleteCi) + { + luaD_checkstackfornewci(L, ccl->stacksize); + } + else + { + luaD_checkstack(L, ccl->stacksize); + } LUAU_ASSERT(ci->top <= L->stack_last); @@ -2117,7 +2126,7 @@ reentry: // fast-path #1: tables if (LUAU_LIKELY(ttistable(rb))) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); if (fastnotm(h->metatable, TM_LEN)) { @@ -2187,7 +2196,7 @@ reentry: L->top = L->ci->top; } - Table* h = hvalue(ra); + LuaTable* h = hvalue(ra); // TODO: we really don't need this anymore if (!ttistable(ra)) @@ -2272,7 +2281,7 @@ reentry: } else { - Table* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(Table*, NULL); + LuaTable* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(LuaTable*, NULL); if (const TValue* fn = fasttm(L, mt, TM_ITER)) { @@ -2331,7 +2340,7 @@ reentry: // TODO: remove the table check per guarantee above if (ttisnil(ra) && ttistable(ra + 1)) { - Table* h = hvalue(ra + 1); + LuaTable* h = hvalue(ra + 1); int index = int(reinterpret_cast(pvalue(ra + 2))); int sizearray = h->sizearray; @@ -3071,7 +3080,14 @@ int luau_precall(lua_State* L, StkId func, int nresults) L->base = ci->base; // Note: L->top is assigned externally - luaD_checkstack(L, ccl->stacksize); + if (DFFlag::LuauPopIncompleteCi) + { + luaD_checkstackfornewci(L, ccl->stacksize); + } + else + { + luaD_checkstack(L, ccl->stacksize); + } LUAU_ASSERT(ci->top <= L->stack_last); if (!ccl->isC) diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index aa248fc1..2a3443eb 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -72,7 +72,7 @@ private: size_t originalThreshold = 0; }; -void luaV_getimport(lua_State* L, Table* env, TValue* k, StkId res, uint32_t id, bool propagatenil) +void luaV_getimport(lua_State* L, LuaTable* env, TValue* k, StkId res, uint32_t id, bool propagatenil) { int count = id >> 30; LUAU_ASSERT(count > 0); @@ -141,7 +141,7 @@ static TString* readString(TempBuffer& strings, const char* data, size return id == 0 ? NULL : strings[id - 1]; } -static void resolveImportSafe(lua_State* L, Table* env, TValue* k, uint32_t id) +static void resolveImportSafe(lua_State* L, LuaTable* env, TValue* k, uint32_t id) { struct ResolveImport { @@ -273,7 +273,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size const ScopedSetGCThreshold pauseGC{L->global, SIZE_MAX}; // env is 0 for current environment and a stack index otherwise - Table* envt = (env == 0) ? L->gt : hvalue(luaA_toobject(L, env)); + LuaTable* envt = (env == 0) ? L->gt : hvalue(luaA_toobject(L, env)); TString* source = luaS_new(L, chunkname); @@ -481,7 +481,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size case LBC_CONSTANT_TABLE: { int keys = readVarInt(data, size, offset); - Table* h = luaH_new(L, 0, keys); + LuaTable* h = luaH_new(L, 0, keys); for (int i = 0; i < keys; ++i) { int key = readVarInt(data, size, offset); diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 0cf9d206..5c49139f 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -101,7 +101,7 @@ void luaV_gettable(lua_State* L, const TValue* t, TValue* key, StkId val) const TValue* tm; if (ttistable(t)) { // `t' is a table? - Table* h = hvalue(t); + LuaTable* h = hvalue(t); const TValue* res = luaH_get(h, key); // do a primitive get @@ -137,7 +137,7 @@ void luaV_settable(lua_State* L, const TValue* t, TValue* key, StkId val) const TValue* tm; if (ttistable(t)) { // `t' is a table? - Table* h = hvalue(t); + LuaTable* h = hvalue(t); const TValue* oldval = luaH_get(h, key); @@ -185,7 +185,7 @@ static int call_binTM(lua_State* L, const TValue* p1, const TValue* p2, StkId re return 1; } -static const TValue* get_compTM(lua_State* L, Table* mt1, Table* mt2, TMS event) +static const TValue* get_compTM(lua_State* L, LuaTable* mt1, LuaTable* mt2, TMS event) { const TValue* tm1 = fasttm(L, mt1, event); const TValue* tm2; @@ -533,7 +533,7 @@ void luaV_dolen(lua_State* L, StkId ra, const TValue* rb) { case LUA_TTABLE: { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); if ((tm = fasttm(L, h->metatable, TM_LEN)) == NULL) { setnvalue(ra, cast_num(luaH_getn(h))); diff --git a/bench/bench_support.lua b/bench/bench_support.lua index da637ac9..b731c2fc 100644 --- a/bench/bench_support.lua +++ b/bench/bench_support.lua @@ -66,7 +66,7 @@ end -- and 'false' otherwise. -- -- Example usage: --- local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +-- local function prequire(name) local success, result = pcall(require, name); return success and result end -- local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") -- function testFunc() -- ... diff --git a/bench/gc/test_BinaryTree.lua b/bench/gc/test_BinaryTree.lua index 36dff9de..b7a36d73 100644 --- a/bench/gc/test_BinaryTree.lua +++ b/bench/gc/test_BinaryTree.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_GC_Boehm_Trees.lua b/bench/gc/test_GC_Boehm_Trees.lua index 8170103d..3a3a3698 100644 --- a/bench/gc/test_GC_Boehm_Trees.lua +++ b/bench/gc/test_GC_Boehm_Trees.lua @@ -1,5 +1,5 @@ --!nonstrict -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local stretchTreeDepth = 18 -- about 16Mb diff --git a/bench/gc/test_GC_Tree_Pruning_Eager.lua b/bench/gc/test_GC_Tree_Pruning_Eager.lua index 38aa7626..7a086254 100644 --- a/bench/gc/test_GC_Tree_Pruning_Eager.lua +++ b/bench/gc/test_GC_Tree_Pruning_Eager.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_GC_Tree_Pruning_Gen.lua b/bench/gc/test_GC_Tree_Pruning_Gen.lua index 85081f70..eb747e77 100644 --- a/bench/gc/test_GC_Tree_Pruning_Gen.lua +++ b/bench/gc/test_GC_Tree_Pruning_Gen.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_GC_Tree_Pruning_Lazy.lua b/bench/gc/test_GC_Tree_Pruning_Lazy.lua index 834ec1ab..16b68083 100644 --- a/bench/gc/test_GC_Tree_Pruning_Lazy.lua +++ b/bench/gc/test_GC_Tree_Pruning_Lazy.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_GC_hashtable_Keyval.lua b/bench/gc/test_GC_hashtable_Keyval.lua index aa7481d3..6e59072c 100644 --- a/bench/gc/test_GC_hashtable_Keyval.lua +++ b/bench/gc/test_GC_hashtable_Keyval.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_LB_mandel.lua b/bench/gc/test_LB_mandel.lua index a8beb4fd..be9977d6 100644 --- a/bench/gc/test_LB_mandel.lua +++ b/bench/gc/test_LB_mandel.lua @@ -21,7 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_LargeTableCtor_array.lua b/bench/gc/test_LargeTableCtor_array.lua index 016dfd2d..35b6f449 100644 --- a/bench/gc/test_LargeTableCtor_array.lua +++ b/bench/gc/test_LargeTableCtor_array.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_LargeTableCtor_hash.lua b/bench/gc/test_LargeTableCtor_hash.lua index c46a7ab4..e2b11b4b 100644 --- a/bench/gc/test_LargeTableCtor_hash.lua +++ b/bench/gc/test_LargeTableCtor_hash.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_Pcall_pcall_yield.lua b/bench/gc/test_Pcall_pcall_yield.lua index ae0a4b46..2ae0baa6 100644 --- a/bench/gc/test_Pcall_pcall_yield.lua +++ b/bench/gc/test_Pcall_pcall_yield.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_SunSpider_3d-raytrace.lua b/bench/gc/test_SunSpider_3d-raytrace.lua index 3c050df7..d8f224c4 100644 --- a/bench/gc/test_SunSpider_3d-raytrace.lua +++ b/bench/gc/test_SunSpider_3d-raytrace.lua @@ -22,7 +22,7 @@ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_TableCreate_nil.lua b/bench/gc/test_TableCreate_nil.lua index 707a2750..546e9d6b 100644 --- a/bench/gc/test_TableCreate_nil.lua +++ b/bench/gc/test_TableCreate_nil.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_TableCreate_number.lua b/bench/gc/test_TableCreate_number.lua index 3e4305bd..fe8437b7 100644 --- a/bench/gc/test_TableCreate_number.lua +++ b/bench/gc/test_TableCreate_number.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_TableCreate_zerofill.lua b/bench/gc/test_TableCreate_zerofill.lua index fed439b4..e2cfda30 100644 --- a/bench/gc/test_TableCreate_zerofill.lua +++ b/bench/gc/test_TableCreate_zerofill.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_TableMarshal_select.lua b/bench/gc/test_TableMarshal_select.lua index 9869da60..df5ebf78 100644 --- a/bench/gc/test_TableMarshal_select.lua +++ b/bench/gc/test_TableMarshal_select.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_TableMarshal_table_pack.lua b/bench/gc/test_TableMarshal_table_pack.lua index 3da855f5..3d0190e7 100644 --- a/bench/gc/test_TableMarshal_table_pack.lua +++ b/bench/gc/test_TableMarshal_table_pack.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_TableMarshal_varargs.lua b/bench/gc/test_TableMarshal_varargs.lua index 64b41b43..b88d8213 100644 --- a/bench/gc/test_TableMarshal_varargs.lua +++ b/bench/gc/test_TableMarshal_varargs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_AbsSum_abs.lua b/bench/micro_tests/test_AbsSum_abs.lua index 7e85646e..ea473556 100644 --- a/bench/micro_tests/test_AbsSum_abs.lua +++ b/bench/micro_tests/test_AbsSum_abs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_AbsSum_and_or.lua b/bench/micro_tests/test_AbsSum_and_or.lua index c6ef3dea..6cd5b4d0 100644 --- a/bench/micro_tests/test_AbsSum_and_or.lua +++ b/bench/micro_tests/test_AbsSum_and_or.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_AbsSum_math_abs.lua b/bench/micro_tests/test_AbsSum_math_abs.lua index e95ea674..e02b710a 100644 --- a/bench/micro_tests/test_AbsSum_math_abs.lua +++ b/bench/micro_tests/test_AbsSum_math_abs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Assert.lua b/bench/micro_tests/test_Assert.lua index 014de8dc..750f411b 100644 --- a/bench/micro_tests/test_Assert.lua +++ b/bench/micro_tests/test_Assert.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Factorial.lua b/bench/micro_tests/test_Factorial.lua index 90cff22a..5dc797ce 100644 --- a/bench/micro_tests/test_Factorial.lua +++ b/bench/micro_tests/test_Factorial.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Failure_pcall_a_bar.lua b/bench/micro_tests/test_Failure_pcall_a_bar.lua index 5b6108ba..95887e58 100644 --- a/bench/micro_tests/test_Failure_pcall_a_bar.lua +++ b/bench/micro_tests/test_Failure_pcall_a_bar.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Failure_pcall_game_Foo.lua b/bench/micro_tests/test_Failure_pcall_game_Foo.lua index 6bd209ae..9966262d 100644 --- a/bench/micro_tests/test_Failure_pcall_game_Foo.lua +++ b/bench/micro_tests/test_Failure_pcall_game_Foo.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Failure_xpcall_a_bar.lua b/bench/micro_tests/test_Failure_xpcall_a_bar.lua index e00a3ca6..44534da4 100644 --- a/bench/micro_tests/test_Failure_xpcall_a_bar.lua +++ b/bench/micro_tests/test_Failure_xpcall_a_bar.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Failure_xpcall_game_Foo.lua b/bench/micro_tests/test_Failure_xpcall_game_Foo.lua index 86dadc90..35659598 100644 --- a/bench/micro_tests/test_Failure_xpcall_game_Foo.lua +++ b/bench/micro_tests/test_Failure_xpcall_game_Foo.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_LargeTableCtor_array.lua b/bench/micro_tests/test_LargeTableCtor_array.lua index 016dfd2d..35b6f449 100644 --- a/bench/micro_tests/test_LargeTableCtor_array.lua +++ b/bench/micro_tests/test_LargeTableCtor_array.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_LargeTableCtor_hash.lua b/bench/micro_tests/test_LargeTableCtor_hash.lua index c46a7ab4..e2b11b4b 100644 --- a/bench/micro_tests/test_LargeTableCtor_hash.lua +++ b/bench/micro_tests/test_LargeTableCtor_hash.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_LargeTableSum_loop_index.lua b/bench/micro_tests/test_LargeTableSum_loop_index.lua index 2aae109e..dd64ca00 100644 --- a/bench/micro_tests/test_LargeTableSum_loop_index.lua +++ b/bench/micro_tests/test_LargeTableSum_loop_index.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_LargeTableSum_loop_ipairs.lua b/bench/micro_tests/test_LargeTableSum_loop_ipairs.lua index 29205e26..54ee888d 100644 --- a/bench/micro_tests/test_LargeTableSum_loop_ipairs.lua +++ b/bench/micro_tests/test_LargeTableSum_loop_ipairs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_LargeTableSum_loop_iter.lua b/bench/micro_tests/test_LargeTableSum_loop_iter.lua index ea2b157c..fb69470f 100644 --- a/bench/micro_tests/test_LargeTableSum_loop_iter.lua +++ b/bench/micro_tests/test_LargeTableSum_loop_iter.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_LargeTableSum_loop_pairs.lua b/bench/micro_tests/test_LargeTableSum_loop_pairs.lua index 8d789fcf..ffe19a20 100644 --- a/bench/micro_tests/test_LargeTableSum_loop_pairs.lua +++ b/bench/micro_tests/test_LargeTableSum_loop_pairs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_MethodCalls.lua b/bench/micro_tests/test_MethodCalls.lua index f8b44527..016a4798 100644 --- a/bench/micro_tests/test_MethodCalls.lua +++ b/bench/micro_tests/test_MethodCalls.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_OOP_constructor.lua b/bench/micro_tests/test_OOP_constructor.lua index 9fec3b67..b1c03dfc 100644 --- a/bench/micro_tests/test_OOP_constructor.lua +++ b/bench/micro_tests/test_OOP_constructor.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_OOP_method_call.lua b/bench/micro_tests/test_OOP_method_call.lua index 1e5249c5..09699acb 100644 --- a/bench/micro_tests/test_OOP_method_call.lua +++ b/bench/micro_tests/test_OOP_method_call.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_OOP_virtual_constructor.lua b/bench/micro_tests/test_OOP_virtual_constructor.lua index df99e13b..68dfba61 100644 --- a/bench/micro_tests/test_OOP_virtual_constructor.lua +++ b/bench/micro_tests/test_OOP_virtual_constructor.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Pcall_call_return.lua b/bench/micro_tests/test_Pcall_call_return.lua index 2a612175..45d8ca58 100644 --- a/bench/micro_tests/test_Pcall_call_return.lua +++ b/bench/micro_tests/test_Pcall_call_return.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Pcall_pcall_return.lua b/bench/micro_tests/test_Pcall_pcall_return.lua index 16bdfdd3..09a032df 100644 --- a/bench/micro_tests/test_Pcall_pcall_return.lua +++ b/bench/micro_tests/test_Pcall_pcall_return.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Pcall_pcall_yield.lua b/bench/micro_tests/test_Pcall_pcall_yield.lua index ae0a4b46..2ae0baa6 100644 --- a/bench/micro_tests/test_Pcall_pcall_yield.lua +++ b/bench/micro_tests/test_Pcall_pcall_yield.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Pcall_xpcall_return.lua b/bench/micro_tests/test_Pcall_xpcall_return.lua index 8ac2f0eb..5fb69f1b 100644 --- a/bench/micro_tests/test_Pcall_xpcall_return.lua +++ b/bench/micro_tests/test_Pcall_xpcall_return.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_SqrtSum_exponent.lua b/bench/micro_tests/test_SqrtSum_exponent.lua index bfd6fd72..1bb6a7d2 100644 --- a/bench/micro_tests/test_SqrtSum_exponent.lua +++ b/bench/micro_tests/test_SqrtSum_exponent.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_SqrtSum_math_sqrt.lua b/bench/micro_tests/test_SqrtSum_math_sqrt.lua index 1e1f42c7..7a280460 100644 --- a/bench/micro_tests/test_SqrtSum_math_sqrt.lua +++ b/bench/micro_tests/test_SqrtSum_math_sqrt.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_SqrtSum_sqrt.lua b/bench/micro_tests/test_SqrtSum_sqrt.lua index 96880e7b..ddcddb9d 100644 --- a/bench/micro_tests/test_SqrtSum_sqrt.lua +++ b/bench/micro_tests/test_SqrtSum_sqrt.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_SqrtSum_sqrt_getfenv.lua b/bench/micro_tests/test_SqrtSum_sqrt_getfenv.lua index 55f29e2e..1dd29776 100644 --- a/bench/micro_tests/test_SqrtSum_sqrt_getfenv.lua +++ b/bench/micro_tests/test_SqrtSum_sqrt_getfenv.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_SqrtSum_sqrt_roundabout.lua b/bench/micro_tests/test_SqrtSum_sqrt_roundabout.lua index bbe48a64..0527ea4d 100644 --- a/bench/micro_tests/test_SqrtSum_sqrt_roundabout.lua +++ b/bench/micro_tests/test_SqrtSum_sqrt_roundabout.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_StringInterp.lua b/bench/micro_tests/test_StringInterp.lua index 55430519..d44f5b07 100644 --- a/bench/micro_tests/test_StringInterp.lua +++ b/bench/micro_tests/test_StringInterp.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") bench.runCode(function() diff --git a/bench/micro_tests/test_TableCreate_nil.lua b/bench/micro_tests/test_TableCreate_nil.lua index 707a2750..546e9d6b 100644 --- a/bench/micro_tests/test_TableCreate_nil.lua +++ b/bench/micro_tests/test_TableCreate_nil.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableCreate_number.lua b/bench/micro_tests/test_TableCreate_number.lua index 3e4305bd..fe8437b7 100644 --- a/bench/micro_tests/test_TableCreate_number.lua +++ b/bench/micro_tests/test_TableCreate_number.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableCreate_zerofill.lua b/bench/micro_tests/test_TableCreate_zerofill.lua index fed439b4..e2cfda30 100644 --- a/bench/micro_tests/test_TableCreate_zerofill.lua +++ b/bench/micro_tests/test_TableCreate_zerofill.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableFind_loop_ipairs.lua b/bench/micro_tests/test_TableFind_loop_ipairs.lua index 46560274..ef7f4c81 100644 --- a/bench/micro_tests/test_TableFind_loop_ipairs.lua +++ b/bench/micro_tests/test_TableFind_loop_ipairs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableFind_table_find.lua b/bench/micro_tests/test_TableFind_table_find.lua index 3f22122f..05882c50 100644 --- a/bench/micro_tests/test_TableFind_table_find.lua +++ b/bench/micro_tests/test_TableFind_table_find.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableInsertion_index_cached.lua b/bench/micro_tests/test_TableInsertion_index_cached.lua index 0c34818f..adb40822 100644 --- a/bench/micro_tests/test_TableInsertion_index_cached.lua +++ b/bench/micro_tests/test_TableInsertion_index_cached.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableInsertion_index_len.lua b/bench/micro_tests/test_TableInsertion_index_len.lua index 120a5e28..797dec80 100644 --- a/bench/micro_tests/test_TableInsertion_index_len.lua +++ b/bench/micro_tests/test_TableInsertion_index_len.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableInsertion_table_insert.lua b/bench/micro_tests/test_TableInsertion_table_insert.lua index 1ad3fe22..632e9080 100644 --- a/bench/micro_tests/test_TableInsertion_table_insert.lua +++ b/bench/micro_tests/test_TableInsertion_table_insert.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableInsertion_table_insert_index.lua b/bench/micro_tests/test_TableInsertion_table_insert_index.lua index 41747139..7b35fe39 100644 --- a/bench/micro_tests/test_TableInsertion_table_insert_index.lua +++ b/bench/micro_tests/test_TableInsertion_table_insert_index.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableIteration.lua b/bench/micro_tests/test_TableIteration.lua index 5f78a48b..2c44f43c 100644 --- a/bench/micro_tests/test_TableIteration.lua +++ b/bench/micro_tests/test_TableIteration.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMarshal_select.lua b/bench/micro_tests/test_TableMarshal_select.lua index 9869da60..df5ebf78 100644 --- a/bench/micro_tests/test_TableMarshal_select.lua +++ b/bench/micro_tests/test_TableMarshal_select.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMarshal_table_pack.lua b/bench/micro_tests/test_TableMarshal_table_pack.lua index 3da855f5..3d0190e7 100644 --- a/bench/micro_tests/test_TableMarshal_table_pack.lua +++ b/bench/micro_tests/test_TableMarshal_table_pack.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMarshal_table_unpack_array.lua b/bench/micro_tests/test_TableMarshal_table_unpack_array.lua index 13d1d1c3..32f2eb9a 100644 --- a/bench/micro_tests/test_TableMarshal_table_unpack_array.lua +++ b/bench/micro_tests/test_TableMarshal_table_unpack_array.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMarshal_table_unpack_range.lua b/bench/micro_tests/test_TableMarshal_table_unpack_range.lua index e3aa68be..fa53a31c 100644 --- a/bench/micro_tests/test_TableMarshal_table_unpack_range.lua +++ b/bench/micro_tests/test_TableMarshal_table_unpack_range.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMarshal_varargs.lua b/bench/micro_tests/test_TableMarshal_varargs.lua index 64b41b43..b88d8213 100644 --- a/bench/micro_tests/test_TableMarshal_varargs.lua +++ b/bench/micro_tests/test_TableMarshal_varargs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMove_empty_table.lua b/bench/micro_tests/test_TableMove_empty_table.lua index 39335564..18737f74 100644 --- a/bench/micro_tests/test_TableMove_empty_table.lua +++ b/bench/micro_tests/test_TableMove_empty_table.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMove_same_table.lua b/bench/micro_tests/test_TableMove_same_table.lua index f62022b1..8fc9fa03 100644 --- a/bench/micro_tests/test_TableMove_same_table.lua +++ b/bench/micro_tests/test_TableMove_same_table.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMove_table_create.lua b/bench/micro_tests/test_TableMove_table_create.lua index f03c4de7..3c0cb9e9 100644 --- a/bench/micro_tests/test_TableMove_table_create.lua +++ b/bench/micro_tests/test_TableMove_table_create.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableRemoval_table_remove.lua b/bench/micro_tests/test_TableRemoval_table_remove.lua index 13410116..3ba3e503 100644 --- a/bench/micro_tests/test_TableRemoval_table_remove.lua +++ b/bench/micro_tests/test_TableRemoval_table_remove.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableSort.lua b/bench/micro_tests/test_TableSort.lua index 502cb2a5..e3276845 100644 --- a/bench/micro_tests/test_TableSort.lua +++ b/bench/micro_tests/test_TableSort.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local arr_months = {"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"} diff --git a/bench/micro_tests/test_ToNumberString.lua b/bench/micro_tests/test_ToNumberString.lua index 842b7c22..cda886c0 100644 --- a/bench/micro_tests/test_ToNumberString.lua +++ b/bench/micro_tests/test_ToNumberString.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") bench.runCode(function() diff --git a/bench/micro_tests/test_UpvalueCapture.lua b/bench/micro_tests/test_UpvalueCapture.lua index 4a2608c4..6c2f2616 100644 --- a/bench/micro_tests/test_UpvalueCapture.lua +++ b/bench/micro_tests/test_UpvalueCapture.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_VariadicSelect.lua b/bench/micro_tests/test_VariadicSelect.lua index 5a62f2d8..9710e237 100644 --- a/bench/micro_tests/test_VariadicSelect.lua +++ b/bench/micro_tests/test_VariadicSelect.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_string_lib.lua b/bench/micro_tests/test_string_lib.lua index 041f5b15..5f180151 100644 --- a/bench/micro_tests/test_string_lib.lua +++ b/bench/micro_tests/test_string_lib.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") bench.runCode(function() diff --git a/bench/micro_tests/test_table_concat.lua b/bench/micro_tests/test_table_concat.lua index 590b7d4a..879b63fe 100644 --- a/bench/micro_tests/test_table_concat.lua +++ b/bench/micro_tests/test_table_concat.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") bench.runCode(function() diff --git a/bench/micro_tests/test_vector_lib.lua b/bench/micro_tests/test_vector_lib.lua new file mode 100644 index 00000000..59bddc04 --- /dev/null +++ b/bench/micro_tests/test_vector_lib.lua @@ -0,0 +1,14 @@ +local function prequire(name) local success, result = pcall(require, name); return success and result end +local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") + +bench.runCode(function() + for i=1,1000000 do + vector.create(i, 2, 3) + vector.create(i, 2, 3) + vector.create(i, 2, 3) + vector.create(i, 2, 3) + vector.create(i, 2, 3) + end +end, "vector: create") + +-- TODO: add more tests \ No newline at end of file diff --git a/bench/tests/base64.lua b/bench/tests/base64.lua index e580c595..13bfd070 100644 --- a/bench/tests/base64.lua +++ b/bench/tests/base64.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/chess.lua b/bench/tests/chess.lua index f551139e..7e6c9c0c 100644 --- a/bench/tests/chess.lua +++ b/bench/tests/chess.lua @@ -1,5 +1,5 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local RANKS = "12345678" diff --git a/bench/tests/life.lua b/bench/tests/life.lua index d050b013..a61730aa 100644 --- a/bench/tests/life.lua +++ b/bench/tests/life.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/matrixmult.lua b/bench/tests/matrixmult.lua index af38cb64..fa04b864 100644 --- a/bench/tests/matrixmult.lua +++ b/bench/tests/matrixmult.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local function mmul(matrix1, matrix2) diff --git a/bench/tests/mesh-normal-scalar.lua b/bench/tests/mesh-normal-scalar.lua index 05bef373..509e1e62 100644 --- a/bench/tests/mesh-normal-scalar.lua +++ b/bench/tests/mesh-normal-scalar.lua @@ -1,5 +1,5 @@ --!strict -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/mesh-normal-vector.lua b/bench/tests/mesh-normal-vector.lua index bfc0f1c7..ff4f2b46 100644 --- a/bench/tests/mesh-normal-vector.lua +++ b/bench/tests/mesh-normal-vector.lua @@ -1,5 +1,5 @@ --!strict -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/pcmmix.lua b/bench/tests/pcmmix.lua index c98cee2c..1e8e27a5 100644 --- a/bench/tests/pcmmix.lua +++ b/bench/tests/pcmmix.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local samples = 100_000 diff --git a/bench/tests/qsort.lua b/bench/tests/qsort.lua index 566c1b98..37413fa2 100644 --- a/bench/tests/qsort.lua +++ b/bench/tests/qsort.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/sha256.lua b/bench/tests/sha256.lua index 2ac0ab33..e478e763 100644 --- a/bench/tests/sha256.lua +++ b/bench/tests/sha256.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/shootout/ack.lua b/bench/tests/shootout/ack.lua index f7fd43a8..ca8913ac 100644 --- a/bench/tests/shootout/ack.lua +++ b/bench/tests/shootout/ack.lua @@ -23,7 +23,7 @@ SOFTWARE. ]] -- http://www.bagley.org/~doug/shootout/ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/binary-trees.lua b/bench/tests/shootout/binary-trees.lua index 89c5933c..50d40597 100644 --- a/bench/tests/shootout/binary-trees.lua +++ b/bench/tests/shootout/binary-trees.lua @@ -25,7 +25,7 @@ SOFTWARE. -- http://benchmarksgame.alioth.debian.org/ -- contributed by Mike Pall -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/fannkuch-redux.lua b/bench/tests/shootout/fannkuch-redux.lua index 43bc9e41..60f7c3c0 100644 --- a/bench/tests/shootout/fannkuch-redux.lua +++ b/bench/tests/shootout/fannkuch-redux.lua @@ -25,7 +25,7 @@ SOFTWARE. -- http://benchmarksgame.alioth.debian.org/ -- contributed by Mike Pall -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/fixpoint-fact.lua b/bench/tests/shootout/fixpoint-fact.lua index 112acb4a..226c78a8 100644 --- a/bench/tests/shootout/fixpoint-fact.lua +++ b/bench/tests/shootout/fixpoint-fact.lua @@ -21,7 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/heapsort.lua b/bench/tests/shootout/heapsort.lua index 0daf97ab..69c1b885 100644 --- a/bench/tests/shootout/heapsort.lua +++ b/bench/tests/shootout/heapsort.lua @@ -21,7 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/mandel.lua b/bench/tests/shootout/mandel.lua index a3bbb7e5..547741e6 100644 --- a/bench/tests/shootout/mandel.lua +++ b/bench/tests/shootout/mandel.lua @@ -21,7 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/n-body.lua b/bench/tests/shootout/n-body.lua index e0f9c63c..082b7fa0 100644 --- a/bench/tests/shootout/n-body.lua +++ b/bench/tests/shootout/n-body.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/qt.lua b/bench/tests/shootout/qt.lua index d9b4a517..c15accd0 100644 --- a/bench/tests/shootout/qt.lua +++ b/bench/tests/shootout/qt.lua @@ -23,7 +23,7 @@ SOFTWARE. ]] -- Julia sets via interval cell-mapping (quadtree version) -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/queen.lua b/bench/tests/shootout/queen.lua index c3508d60..8f27e06f 100644 --- a/bench/tests/shootout/queen.lua +++ b/bench/tests/shootout/queen.lua @@ -21,7 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/scimark.lua b/bench/tests/shootout/scimark.lua index 1b66df53..dd7cae53 100644 --- a/bench/tests/shootout/scimark.lua +++ b/bench/tests/shootout/scimark.lua @@ -33,7 +33,7 @@ -- Modification to be compatible with Lua 5.3 ------------------------------------------------------------------------------ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/spectral-norm.lua b/bench/tests/shootout/spectral-norm.lua index b5116612..f1acd34c 100644 --- a/bench/tests/shootout/spectral-norm.lua +++ b/bench/tests/shootout/spectral-norm.lua @@ -25,7 +25,7 @@ SOFTWARE. -- http://benchmarksgame.alioth.debian.org/ -- contributed by Mike Pall -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sieve.lua b/bench/tests/sieve.lua index 1bb45d99..8d8cf82a 100644 --- a/bench/tests/sieve.lua +++ b/bench/tests/sieve.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/sunspider/3d-cube.lua b/bench/tests/sunspider/3d-cube.lua index aac7a156..ea132463 100644 --- a/bench/tests/sunspider/3d-cube.lua +++ b/bench/tests/sunspider/3d-cube.lua @@ -2,7 +2,7 @@ -- http://www.speich.net/computer/moztesting/3d.htm -- Created by Simon Speich -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/3d-morph.lua b/bench/tests/sunspider/3d-morph.lua index 8263f015..0dbf1c63 100644 --- a/bench/tests/sunspider/3d-morph.lua +++ b/bench/tests/sunspider/3d-morph.lua @@ -23,7 +23,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/3d-raytrace.lua b/bench/tests/sunspider/3d-raytrace.lua index 33d464b8..83ca7bd9 100644 --- a/bench/tests/sunspider/3d-raytrace.lua +++ b/bench/tests/sunspider/3d-raytrace.lua @@ -22,7 +22,7 @@ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/controlflow-recursive.lua b/bench/tests/sunspider/controlflow-recursive.lua index 1c78a3c2..67c77293 100644 --- a/bench/tests/sunspider/controlflow-recursive.lua +++ b/bench/tests/sunspider/controlflow-recursive.lua @@ -3,7 +3,7 @@ http://shootout.alioth.debian.org/ contributed by Isaac Gouy ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/crypto-aes.lua b/bench/tests/sunspider/crypto-aes.lua index 9692cf52..6b23719b 100644 --- a/bench/tests/sunspider/crypto-aes.lua +++ b/bench/tests/sunspider/crypto-aes.lua @@ -9,7 +9,7 @@ * returns byte-array encrypted value (16 bytes) */]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") -- Sbox is pre-computed multiplicative inverse in GF(2^8) used in SubBytes and KeyExpansion [§5.1.1] diff --git a/bench/tests/sunspider/fannkuch.lua b/bench/tests/sunspider/fannkuch.lua index 08cdcc24..24098740 100644 --- a/bench/tests/sunspider/fannkuch.lua +++ b/bench/tests/sunspider/fannkuch.lua @@ -3,7 +3,7 @@ http://shootout.alioth.debian.org/ contributed by Isaac Gouy ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/math-cordic.lua b/bench/tests/sunspider/math-cordic.lua index 2b622377..861cc51a 100644 --- a/bench/tests/sunspider/math-cordic.lua +++ b/bench/tests/sunspider/math-cordic.lua @@ -23,7 +23,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ]] - local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end + local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/math-partial-sums.lua b/bench/tests/sunspider/math-partial-sums.lua index f0b4b0b7..21f63295 100644 --- a/bench/tests/sunspider/math-partial-sums.lua +++ b/bench/tests/sunspider/math-partial-sums.lua @@ -3,7 +3,7 @@ http://shootout.alioth.debian.org/ contributed by Isaac Gouy ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/n-body-oop.lua b/bench/tests/sunspider/n-body-oop.lua index e04286c8..469e22c1 100644 --- a/bench/tests/sunspider/n-body-oop.lua +++ b/bench/tests/sunspider/n-body-oop.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") local PI = 3.141592653589793 diff --git a/bench/tests/tictactoe.lua b/bench/tests/tictactoe.lua index 673dcd48..bc3282a0 100644 --- a/bench/tests/tictactoe.lua +++ b/bench/tests/tictactoe.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/trig.lua b/bench/tests/trig.lua index 64bf611c..269fd610 100644 --- a/bench/tests/trig.lua +++ b/bench/tests/trig.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/voxelgen.lua b/bench/tests/voxelgen.lua index b50a4592..813838c1 100644 --- a/bench/tests/voxelgen.lua +++ b/bench/tests/voxelgen.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") -- Based on voxel terrain generator by Stickmasterluke diff --git a/tests/AnyTypeSummary.test.cpp b/tests/AnyTypeSummary.test.cpp index 5c3b4aa3..12e02264 100644 --- a/tests/AnyTypeSummary.test.cpp +++ b/tests/AnyTypeSummary.test.cpp @@ -18,6 +18,10 @@ LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(StudioReportLuauAny2) +LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) +LUAU_FASTFLAG(LuauAlwaysFillInFunctionCallDiscriminantTypes) +LUAU_FASTFLAG(LuauStoreCSTData) +LUAU_FASTFLAG(LuauAstTypeGroup) struct ATSFixture : BuiltinsFixture @@ -71,7 +75,22 @@ export type t8 = t0 &((true | any)->('')) LUAU_ASSERT(module->ats.typeInfo.size() == 1); LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::Alias); - LUAU_ASSERT(module->ats.typeInfo[0].node == "export type t8 = t0 &((true | any)->(''))"); + if (FFlag::LuauStoreCSTData && FFlag::LuauAstTypeGroup) + { + LUAU_ASSERT(module->ats.typeInfo[0].node == "export type t8 = t0& (( true | any)->(''))"); + } + else if (FFlag::LuauStoreCSTData) + { + LUAU_ASSERT(module->ats.typeInfo[0].node == "export type t8 = t0 &(( true | any)->(''))"); + } + else if (FFlag::LuauAstTypeGroup) + { + LUAU_ASSERT(module->ats.typeInfo[0].node == "export type t8 = t0& ((true | any)->(''))"); + } + else + { + LUAU_ASSERT(module->ats.typeInfo[0].node == "export type t8 = t0 &((true | any)->(''))"); + } } TEST_CASE_FIXTURE(ATSFixture, "typepacks") @@ -97,7 +116,10 @@ end LUAU_ASSERT(module->ats.typeInfo.size() == 3); LUAU_ASSERT(module->ats.typeInfo[1].code == Pattern::TypePk); - LUAU_ASSERT(module->ats.typeInfo[0].node == "local function fallible(t: number): ...any\n if t > 0 then\n return true, t\n end\n return false, 'must be positive'\nend"); + LUAU_ASSERT( + module->ats.typeInfo[0].node == + "local function fallible(t: number): ...any\n if t > 0 then\n return true, t\n end\n return false, 'must be positive'\nend" + ); } TEST_CASE_FIXTURE(ATSFixture, "typepacks_no_ret") @@ -111,7 +133,7 @@ TEST_CASE_FIXTURE(ATSFixture, "typepacks_no_ret") -- TODO: if partially typed, we'd want to know too local function fallible(t: number) if t > 0 then - return true, t + return true, t end return false, "must be positive" end @@ -406,6 +428,7 @@ TEST_CASE_FIXTURE(ATSFixture, "CannotExtendTable") ScopedFastFlag sff[] = { {FFlag::LuauSolverV2, true}, {FFlag::StudioReportLuauAny2, true}, + {FFlag::LuauAlwaysFillInFunctionCallDiscriminantTypes, true}, }; fileResolver.source["game/Gui/Modules/A"] = R"( @@ -421,7 +444,7 @@ end )"; CheckResult result1 = frontend.check("game/Gui/Modules/A"); - LUAU_REQUIRE_ERROR_COUNT(3, result1); + LUAU_REQUIRE_ERROR_COUNT(1, result1); ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); @@ -498,6 +521,7 @@ TEST_CASE_FIXTURE(ATSFixture, "racing_collision_2") ScopedFastFlag sff[] = { {FFlag::LuauSolverV2, true}, {FFlag::StudioReportLuauAny2, true}, + {FFlag::LuauAlwaysFillInFunctionCallDiscriminantTypes, true}, }; fileResolver.source["game/Gui/Modules/A"] = R"( @@ -561,19 +585,32 @@ initialize() )"; CheckResult result1 = frontend.check("game/Gui/Modules/A"); - LUAU_REQUIRE_ERROR_COUNT(5, result1); + LUAU_REQUIRE_ERROR_COUNT(3, result1); ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); LUAU_ASSERT(module->ats.typeInfo.size() == 11); LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::FuncArg); - LUAU_ASSERT( - module->ats.typeInfo[0].node == - "local function onCharacterAdded(character: Model)\n\n character.DescendantAdded:Connect(function(descendant)\n if " - "descendant:IsA('BasePart')then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n end)\n\n\n for _, descendant in " - "character:GetDescendants()do\n if descendant:IsA('BasePart')then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n " - "end\nend" - ); + if (FFlag::LuauStoreCSTData) + { + CHECK_EQ( + module->ats.typeInfo[0].node, + "local function onCharacterAdded(character: Model)\n\n character.DescendantAdded:Connect(function(descendant)\n if " + "descendant:IsA('BasePart') then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n end)\n\n\n for _, descendant in " + "character:GetDescendants() do\n if descendant:IsA('BasePart') then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n " + "end\nend" + ); + } + else + { + LUAU_ASSERT( + module->ats.typeInfo[0].node == + "local function onCharacterAdded(character: Model)\n\n character.DescendantAdded:Connect(function(descendant)\n if " + "descendant:IsA('BasePart')then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n end)\n\n\n for _, descendant in " + "character:GetDescendants()do\n if descendant:IsA('BasePart')then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n " + "end\nend" + ); + } } TEST_CASE_FIXTURE(ATSFixture, "racing_spawning_1") @@ -581,6 +618,9 @@ TEST_CASE_FIXTURE(ATSFixture, "racing_spawning_1") ScopedFastFlag sff[] = { {FFlag::LuauSolverV2, true}, {FFlag::StudioReportLuauAny2, true}, + // Previously we'd report an error because number <: 'a is not a + // supertype. + {FFlag::LuauTrackInteriorFreeTypesOnScope, true} }; fileResolver.source["game/Gui/Modules/A"] = R"( @@ -632,7 +672,7 @@ initialize() )"; CheckResult result1 = frontend.check("game/Gui/Modules/A"); - LUAU_REQUIRE_ERROR_COUNT(5, result1); + LUAU_REQUIRE_ERROR_COUNT(4, result1); ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); @@ -911,7 +951,7 @@ TEST_CASE_FIXTURE(ATSFixture, "type_alias_any") fileResolver.source["game/Gui/Modules/A"] = R"( type Clear = any - local z: Clear = "zip" + local z: Clear = "zip" )"; CheckResult result1 = frontend.check("game/Gui/Modules/A"); @@ -938,7 +978,7 @@ TEST_CASE_FIXTURE(ATSFixture, "multi_module_any") fileResolver.source["game/B"] = R"( local MyFunc = require(script.Parent.A) type Clear = any - local z: Clear = "zip" + local z: Clear = "zip" )"; fileResolver.source["game/Gui/Modules/A"] = R"( @@ -972,7 +1012,7 @@ TEST_CASE_FIXTURE(ATSFixture, "cast_on_cyclic_req") fileResolver.source["game/B"] = R"( local MyFunc = require(script.Parent.A) :: any type Clear = any - local z: Clear = "zip" + local z: Clear = "zip" )"; CheckResult result = frontend.check("game/B"); diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 504e40e4..fd1deccf 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -506,6 +506,7 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms") SINGLE_COMPARE(vmaxsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5f, 0xc6); SINGLE_COMPARE(vminsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5d, 0xc6); + SINGLE_COMPARE(vcmpeqsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0xc2, 0xc6, 0x00); SINGLE_COMPARE(vcmpltsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0xc2, 0xc6, 0x01); } diff --git a/tests/AstJsonEncoder.test.cpp b/tests/AstJsonEncoder.test.cpp index e6e67020..de30be04 100644 --- a/tests/AstJsonEncoder.test.cpp +++ b/tests/AstJsonEncoder.test.cpp @@ -11,6 +11,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauAstTypeGroup) + struct JsonEncoderFixture { Allocator allocator; @@ -471,10 +473,17 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_annotation") { AstStat* statement = expectParseStatement("type T = ((number) -> (string | nil)) & ((string) -> ())"); - std::string_view expected = - R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"value":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeFunction","location":"0,10 - 0,36","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","nameLocation":"0,11 - 0,17","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","nameLocation":"0,23 - 0,29","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","nameLocation":"0,32 - 0,35","parameters":[]}]}]}},{"type":"AstTypeFunction","location":"0,41 - 0,55","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","nameLocation":"0,42 - 0,48","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[]}}]},"exported":false})"; - - CHECK(toJson(statement) == expected); + if (FFlag::LuauAstTypeGroup) + { + std::string_view expected = R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"value":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeGroup","location":"0,9 - 0,36","type":{"type":"AstTypeFunction","location":"0,10 - 0,36","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","nameLocation":"0,11 - 0,17","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeGroup","location":"0,22 - 0,36","type":{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","nameLocation":"0,23 - 0,29","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","nameLocation":"0,32 - 0,35","parameters":[]}]}}]}}},{"type":"AstTypeGroup","location":"0,40 - 0,55","type":{"type":"AstTypeFunction","location":"0,41 - 0,55","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","nameLocation":"0,42 - 0,48","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[]}}}]},"exported":false})"; + CHECK(toJson(statement) == expected); + } + else + { + std::string_view expected = + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"value":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeFunction","location":"0,10 - 0,36","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","nameLocation":"0,11 - 0,17","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","nameLocation":"0,23 - 0,29","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","nameLocation":"0,32 - 0,35","parameters":[]}]}]}},{"type":"AstTypeFunction","location":"0,41 - 0,55","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","nameLocation":"0,42 - 0,48","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[]}}]},"exported":false})"; + CHECK(toJson(statement) == expected); + } } TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_type_literal") diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index e730171f..702be46b 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -8,8 +8,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauDocumentationAtPosition) - struct DocumentationSymbolFixture : BuiltinsFixture { std::optional getDocSymbol(const std::string& source, Position position) @@ -167,7 +165,6 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "table_overloaded_function_prop") TEST_CASE_FIXTURE(DocumentationSymbolFixture, "string_metatable_method") { - ScopedFastFlag sff{FFlag::LuauDocumentationAtPosition, true}; std::optional symbol = getDocSymbol( R"( local x: string = "Foo" @@ -181,7 +178,6 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "string_metatable_method") TEST_CASE_FIXTURE(DocumentationSymbolFixture, "parent_class_method") { - ScopedFastFlag sff{FFlag::LuauDocumentationAtPosition, true}; loadDefinition(R"( declare class Foo function bar(self, x: string): number diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 0424e3df..6a8bca05 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2232,7 +2232,7 @@ local ec = e(f@5) TEST_CASE_FIXTURE(ACFixture, "type_correct_suggestion_for_overloads") { if (FFlag::LuauSolverV2) // CLI-116814 Autocomplete needs to populate expected types for function arguments correctly - // (overloads and singletons) + // (overloads and singletons) return; check(R"( local target: ((number) -> string) & ((string) -> number)) @@ -2582,7 +2582,7 @@ end TEST_CASE_FIXTURE(ACFixture, "suggest_table_keys") { if (FFlag::LuauSolverV2) // CLI-116812 AutocompleteTest.suggest_table_keys needs to populate expected types for nested - // tables without an annotation + // tables without an annotation return; check(R"( @@ -3069,7 +3069,7 @@ TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocomplete_on_string_singletons") TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") { if (FFlag::LuauSolverV2) // CLI-116814 Autocomplete needs to populate expected types for function arguments correctly - // (overloads and singletons) + // (overloads and singletons) return; check(R"( @@ -4293,8 +4293,7 @@ end foo(@1) )"); - const std::optional EXPECTED_INSERT = - FFlag::LuauSolverV2 ? "function(...: number): number end" : "function(...): number end"; + const std::optional EXPECTED_INSERT = FFlag::LuauSolverV2 ? "function(...: number): number end" : "function(...): number end"; auto ac = autocomplete('1'); @@ -4305,4 +4304,46 @@ foo(@1) CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_at_end_of_stmt_should_continue_as_part_of_stmt") +{ + check(R"( +local data = { x = 1 } +local var = data.@1 + )"); + auto ac = autocomplete('1'); + CHECK(!ac.entryMap.empty()); + CHECK(ac.entryMap.count("x")); + CHECK_EQ(ac.context, AutocompleteContext::Property); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_after_semicolon_should_complete_a_new_statement") +{ + check(R"( +local data = { x = 1 } +local var = data;@1 + )"); + auto ac = autocomplete('1'); + CHECK(!ac.entryMap.empty()); + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); + CHECK_EQ(ac.context, AutocompleteContext::Statement); +} + +TEST_CASE_FIXTURE(ACBuiltinsFixture, "require_tracing") +{ + fileResolver.source["Module/A"] = R"( +return { x = 0 } + )"; + + fileResolver.source["Module/B"] = R"( +local result = require(script.Parent.A) +local x = 1 + result. + )"; + + auto ac = autocomplete("Module/B", Position{2, 21}); + + CHECK(ac.entryMap.size() == 1); + CHECK(ac.entryMap.count("x")); +} + TEST_SUITE_END(); diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 058a1100..c236b49c 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -3,6 +3,7 @@ #include "Luau/AssemblyBuilderA64.h" #include "Luau/CodeAllocator.h" #include "Luau/CodeBlockUnwind.h" +#include "Luau/CodeGen.h" #include "Luau/UnwindBuilder.h" #include "Luau/UnwindBuilderDwarf2.h" #include "Luau/UnwindBuilderWin.h" diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index b062cbfe..af04fb77 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -3,6 +3,8 @@ #include "Luau/BytecodeBuilder.h" #include "Luau/StringUtils.h" +#include "luacode.h" + #include "ScopedFlags.h" #include "doctest.h" @@ -21,12 +23,47 @@ LUAU_FASTINT(LuauCompileInlineThresholdMaxBoost) LUAU_FASTINT(LuauCompileLoopUnrollThreshold) LUAU_FASTINT(LuauCompileLoopUnrollThresholdMaxBoost) LUAU_FASTINT(LuauRecursionLimit) -LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) -LUAU_FASTFLAG(LuauCompileVectorTypeInfo) -LUAU_FASTFLAG(LuauCompileOptimizeRevArith) +LUAU_FASTFLAG(LuauVector2Constants) using namespace Luau; +static void luauLibraryConstantLookup(const char* library, const char* member, Luau::CompileConstant* constant) +{ + // While 'vector' is built-in, because of LUA_VECTOR_SIZE VM configuration, compiler cannot provide the right default by itself + if (strcmp(library, "vector") == 0) + { + if (strcmp(member, "zero") == 0) + return Luau::setCompileConstantVector(constant, 0.0f, 0.0f, 0.0f, 0.0f); + + if (strcmp(member, "one") == 0) + return Luau::setCompileConstantVector(constant, 1.0f, 1.0f, 1.0f, 0.0f); + } + + if (strcmp(library, "Vector3") == 0) + { + if (strcmp(member, "one") == 0) + return Luau::setCompileConstantVector(constant, 1.0f, 1.0f, 1.0f, 0.0f); + + if (strcmp(member, "xAxis") == 0) + return Luau::setCompileConstantVector(constant, 1.0f, 0.0f, 0.0f, 0.0f); + } + + if (strcmp(library, "test") == 0) + { + if (strcmp(member, "some_nil") == 0) + return Luau::setCompileConstantNil(constant); + + if (strcmp(member, "some_boolean") == 0) + return Luau::setCompileConstantBoolean(constant, true); + + if (strcmp(member, "some_number") == 0) + return Luau::setCompileConstantNumber(constant, 4.75); + + if (strcmp(member, "some_string") == 0) + return Luau::setCompileConstantString(constant, "test", 4); + } +} + static std::string compileFunction(const char* source, uint32_t id, int optimizationLevel = 1, int typeInfoLevel = 0, bool enableVectors = false) { Luau::BytecodeBuilder bcb; @@ -39,6 +76,12 @@ static std::string compileFunction(const char* source, uint32_t id, int optimiza options.vectorLib = "Vector3"; options.vectorCtor = "new"; } + + static const char* kLibrariesWithConstants[] = {"vector", "Vector3", "test", nullptr}; + options.librariesWithKnownMembers = kLibrariesWithConstants; + + options.libraryMemberConstantCb = luauLibraryConstantLookup; + Luau::compileOrThrow(bcb, source, options); return bcb.dumpFunction(id); @@ -1442,6 +1485,125 @@ RETURN R0 1 )"); } +TEST_CASE("ConstantFoldVectorArith") +{ + CHECK_EQ("\n" + compileFunction("local n = 2; local a, b = vector.create(1, 2, 3), vector.create(2, 4, 8); return a + b", 0, 2), R"( +LOADK R0 K0 [3, 6, 11] +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction("local n = 2; local a, b = vector.create(1, 2, 3), vector.create(2, 4, 8); return a - b", 0, 2), R"( +LOADK R0 K0 [-1, -2, -5] +RETURN R0 1 +)"); + + // Multiplication by infinity cannot be folded as it creates a non-zero value in W + CHECK_EQ( + "\n" + compileFunction( + "local n = 2; local a, b = vector.create(1, 2, 3), vector.create(2, 4, 8); return a * n, a * b, n * b, a * math.huge", 0, 2 + ), + R"( +LOADK R0 K0 [2, 4, 6] +LOADK R1 K1 [2, 8, 24] +LOADK R2 K2 [4, 8, 16] +LOADK R4 K4 [1, 2, 3] +MULK R3 R4 K3 [inf] +RETURN R0 4 +)" + ); + + // Divisions creating an infinity in W cannot be constant-folded + CHECK_EQ( + "\n" + compileFunction( + "local n = 2; local a, b = vector.create(1, 2, 3), vector.create(2, 4, 8); return a / n, a / b, n / b, a / math.huge", 0, 2 + ), + R"( +LOADK R0 K0 [0.5, 1, 1.5] +LOADK R2 K1 [1, 2, 3] +LOADK R3 K2 [2, 4, 8] +DIV R1 R2 R3 +LOADK R3 K2 [2, 4, 8] +DIVRK R2 K3 [2] R3 +LOADK R3 K4 [0, 0, 0] +RETURN R0 4 +)" + ); + + // Divisions creating an infinity in W cannot be constant-folded + CHECK_EQ( + "\n" + compileFunction("local n = 2; local a, b = vector.create(1, 2, 3), vector.create(2, 4, 8); return a // n, a // b, n // b", 0, 2), + R"( +LOADK R0 K0 [0, 1, 1] +LOADK R2 K1 [1, 2, 3] +LOADK R3 K2 [2, 4, 8] +IDIV R1 R2 R3 +LOADN R3 2 +LOADK R4 K2 [2, 4, 8] +IDIV R2 R3 R4 +RETURN R0 3 +)" + ); + + CHECK_EQ("\n" + compileFunction("local a = vector.create(1, 2, 3); return -a", 0, 2), R"( +LOADK R0 K0 [-1, -2, -3] +RETURN R0 1 +)"); +} + +TEST_CASE("ConstantFoldVectorArith4Wide") +{ + CHECK_EQ("\n" + compileFunction("local n = 2; local a, b = vector.create(1, 2, 3, 4), vector.create(2, 4, 8, 1); return a + b", 0, 2), R"( +LOADK R0 K0 [3, 6, 11, 5] +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction("local n = 2; local a, b = vector.create(1, 2, 3, 4), vector.create(2, 4, 8, 1); return a - b", 0, 2), R"( +LOADK R0 K0 [-1, -2, -5, 3] +RETURN R0 1 +)"); + + CHECK_EQ( + "\n" + compileFunction( + "local n = 2; local a, b = vector.create(1, 2, 3, 4), vector.create(2, 4, 8, 1); return a * n, a * b, n * b, a * math.huge", 0, 2 + ), + R"( +LOADK R0 K0 [2, 4, 6, 8] +LOADK R1 K1 [2, 8, 24, 4] +LOADK R2 K2 [4, 8, 16, 2] +LOADK R3 K3 [inf, inf, inf, inf] +RETURN R0 4 +)" + ); + + CHECK_EQ( + "\n" + compileFunction( + "local n = 2; local a, b = vector.create(1, 2, 3, 4), vector.create(2, 4, 8, 1); return a / n, a / b, n / b, a / math.huge", 0, 2 + ), + R"( +LOADK R0 K0 [0.5, 1, 1.5, 2] +LOADK R1 K1 [0.5, 0.5, 0.375, 4] +LOADK R2 K2 [1, 0.5, 0.25, 2] +LOADK R3 K3 [0, 0, 0] +RETURN R0 4 +)" + ); + + CHECK_EQ( + "\n" + compileFunction("local n = 2; local a, b = vector.create(1, 2, 3, 4), vector.create(2, 4, 8, 1); return a // n, a // b, n // b", 0, 2), + R"( +LOADK R0 K0 [0, 1, 1, 2] +LOADK R1 K1 [0, 0, 0, 4] +LOADK R2 K2 [1, 0, 0, 2] +RETURN R0 3 +)" + ); + + CHECK_EQ("\n" + compileFunction("local a = vector.create(1, 2, 3, 4); return -a", 0, 2), R"( +LOADK R0 K0 [-1, -2, -3, -4] +RETURN R0 1 +)"); +} + TEST_CASE("ConstantFoldStringLen") { CHECK_EQ("\n" + compileFunction0("return #'string', #'', #'a', #('b')"), R"( @@ -2804,8 +2966,6 @@ TEST_CASE("TypeAliasing") TEST_CASE("TypeFunction") { - ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - Luau::BytecodeBuilder bcb; Luau::CompileOptions options; Luau::ParseOptions parseOptions; @@ -4931,36 +5091,80 @@ L0: RETURN R3 -1 )"); } -TEST_CASE("VectorLiterals") +TEST_CASE("VectorConstants") { - CHECK_EQ("\n" + compileFunction("return Vector3.new(1, 2, 3)", 0, 2, 0, /*enableVectors*/ true), R"( + ScopedFastFlag luauVector2Constants{FFlag::LuauVector2Constants, true}; + + CHECK_EQ("\n" + compileFunction("return vector.create(1, 2)", 0, 2, 0), R"( +LOADK R0 K0 [1, 2, 0] +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction("return vector.create(1, 2, 3)", 0, 2, 0), R"( LOADK R0 K0 [1, 2, 3] RETURN R0 1 )"); - CHECK_EQ("\n" + compileFunction("print(Vector3.new(1, 2, 3))", 0, 2, 0, /*enableVectors*/ true), R"( + CHECK_EQ("\n" + compileFunction("print(vector.create(1, 2, 3))", 0, 2, 0), R"( GETIMPORT R0 1 [print] LOADK R1 K2 [1, 2, 3] CALL R0 1 0 RETURN R0 0 )"); - CHECK_EQ("\n" + compileFunction("print(Vector3.new(1, 2, 3, 4))", 0, 2, 0, /*enableVectors*/ true), R"( + CHECK_EQ("\n" + compileFunction("print(vector.create(1, 2, 3, 4))", 0, 2, 0), R"( GETIMPORT R0 1 [print] LOADK R1 K2 [1, 2, 3, 4] CALL R0 1 0 RETURN R0 0 )"); - CHECK_EQ("\n" + compileFunction("return Vector3.new(0, 0, 0), Vector3.new(-0, 0, 0)", 0, 2, 0, /*enableVectors*/ true), R"( + CHECK_EQ("\n" + compileFunction("return vector.create(0, 0, 0), vector.create(-0, 0, 0)", 0, 2, 0), R"( LOADK R0 K0 [0, 0, 0] LOADK R1 K1 [-0, 0, 0] RETURN R0 2 )"); - CHECK_EQ("\n" + compileFunction("return type(Vector3.new(0, 0, 0))", 0, 2, 0, /*enableVectors*/ true), R"( + CHECK_EQ("\n" + compileFunction("return type(vector.create(0, 0, 0))", 0, 2, 0), R"( LOADK R0 K0 ['vector'] RETURN R0 1 +)"); + + // test legacy constructor + CHECK_EQ("\n" + compileFunction("return Vector3.new(1, 2, 3)", 0, 2, 0, /*enableVectors*/ true), R"( +LOADK R0 K0 [1, 2, 3] +RETURN R0 1 +)"); +} + +TEST_CASE("VectorConstantFields") +{ + CHECK_EQ("\n" + compileFunction("return vector.one, vector.zero", 0, 2), R"( +LOADK R0 K0 [1, 1, 1] +LOADK R1 K1 [0, 0, 0] +RETURN R0 2 +)"); + + CHECK_EQ("\n" + compileFunction("return Vector3.one, Vector3.xAxis", 0, 2, 0, /*enableVectors*/ true), R"( +LOADK R0 K0 [1, 1, 1] +LOADK R1 K1 [1, 0, 0] +RETURN R0 2 +)"); + + CHECK_EQ("\n" + compileFunction("return vector.one == vector.create(1, 1, 1)", 0, 2), R"( +LOADB R0 1 +RETURN R0 1 +)"); +} + +TEST_CASE("CustomConstantFields") +{ + CHECK_EQ("\n" + compileFunction("return test.some_nil, test.some_boolean, test.some_number, test.some_string", 0, 2), R"( +LOADNIL R0 +LOADB R1 1 +LOADK R2 K0 [4.75] +LOADK R3 K1 ['test'] +RETURN R0 4 )"); } @@ -7686,6 +7890,39 @@ RETURN R0 1 ); } +TEST_CASE("BuiltinFoldingProhibitedInOptions") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::CompileOptions options; + options.optimizationLevel = 2; + + // math.floor from the test is excluded in this list on purpose + static const char* kDisabledBuiltins[] = {"tostring", "math.abs", "math.sqrt", nullptr}; + options.disabledBuiltins = kDisabledBuiltins; + + Luau::compileOrThrow(bcb, "return math.abs(-42), math.floor(-1.5), math.sqrt(9), (tostring(2))", options); + + std::string result = bcb.dumpFunction(0); + + CHECK_EQ( + "\n" + result, + R"( +GETIMPORT R0 2 [math.abs] +LOADN R1 -42 +CALL R0 1 1 +LOADN R1 -2 +GETIMPORT R2 4 [math.sqrt] +LOADN R3 9 +CALL R2 1 1 +GETIMPORT R3 6 [tostring] +LOADN R4 2 +CALL R3 1 1 +RETURN R0 4 +)" + ); +} + TEST_CASE("LocalReassign") { // locals can be re-assigned and the register gets reused @@ -8424,8 +8661,6 @@ end TEST_CASE("BuiltinTypeVector") { - ScopedFastFlag luauCompileVectorTypeInfo{FFlag::LuauCompileVectorTypeInfo, true}; - CHECK_EQ( "\n" + compileTypeTable(R"( function myfunc(test: Instance, pos: vector) @@ -8522,6 +8757,23 @@ end ); } +TEST_CASE("TypeGroup") +{ + CHECK_EQ( + "\n" + compileTypeTable(R"( +function myfunc(test: (string), foo: nil) +end + +function myfunc2(test: (string | nil), foo: nil) +end +)"), + R"( +0: function(string, nil) +1: function(string?, nil) +)" + ); +} + TEST_CASE("BuiltinFoldMathK") { // we can fold math.pi at optimization level 2 @@ -8847,8 +9099,6 @@ RETURN R0 1 TEST_CASE("ArithRevK") { - ScopedFastFlag sff(FFlag::LuauCompileOptimizeRevArith, true); - // - and / have special optimized form for reverse constants; in absence of type information, we can't optimize other ops CHECK_EQ( "\n" + compileFunction0(R"( diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 58f2aadf..852bfedd 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -31,16 +31,16 @@ extern int optimizationLevel; void luaC_fullgc(lua_State* L); void luaC_validate(lua_State* L); -LUAU_FASTFLAG(LuauMathMap) +LUAU_FASTFLAG(LuauMathLerp) LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_DYNAMIC_FASTFLAG(LuauStackLimit) -LUAU_FASTFLAG(LuauVectorDefinitions) -LUAU_DYNAMIC_FASTFLAG(LuauDebugInfoInvArgLeftovers) LUAU_FASTFLAG(LuauVectorLibNativeCodegen) LUAU_FASTFLAG(LuauVectorLibNativeDot) -LUAU_FASTFLAG(LuauVectorBuiltins) -LUAU_FASTFLAG(LuauVectorMetatable) +LUAU_FASTFLAG(LuauVector2Constructor) +LUAU_FASTFLAG(LuauBufferBitMethods2) +LUAU_FASTFLAG(LuauCodeGenLimitLiveSlotReuse) +LUAU_FASTFLAG(LuauMathMapDefinition) static lua_CompileOptions defaultOptions() { @@ -654,12 +654,14 @@ TEST_CASE("Basic") TEST_CASE("Buffers") { + ScopedFastFlag luauBufferBitMethods{FFlag::LuauBufferBitMethods2, true}; + runConformance("buffers.luau"); } TEST_CASE("Math") { - ScopedFastFlag LuauMathMap{FFlag::LuauMathMap, true}; + ScopedFastFlag LuauMathLerp{FFlag::LuauMathLerp, true}; runConformance("math.luau"); } @@ -891,10 +893,9 @@ TEST_CASE("Vector") TEST_CASE("VectorLibrary") { - ScopedFastFlag luauVectorBuiltins{FFlag::LuauVectorBuiltins, true}; ScopedFastFlag luauVectorLibNativeCodegen{FFlag::LuauVectorLibNativeCodegen, true}; ScopedFastFlag luauVectorLibNativeDot{FFlag::LuauVectorLibNativeDot, true}; - ScopedFastFlag luauVectorMetatable{FFlag::LuauVectorMetatable, true}; + ScopedFastFlag luauVector2Constructor{FFlag::LuauVector2Constructor, true}; lua_CompileOptions copts = defaultOptions(); @@ -911,7 +912,7 @@ TEST_CASE("VectorLibrary") copts.optimizationLevel = 2; } - runConformance("vector_library.luau", [](lua_State* L) {}, nullptr, nullptr, &copts); + runConformance("vector_library.lua", [](lua_State* L) {}, nullptr, nullptr, &copts); } static void populateRTTI(lua_State* L, Luau::TypeId type) @@ -985,7 +986,9 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) TEST_CASE("Types") { - ScopedFastFlag luauVectorDefinitions{FFlag::LuauVectorDefinitions, true}; + ScopedFastFlag luauVector2Constructor{FFlag::LuauVector2Constructor, true}; + ScopedFastFlag luauMathLerp{FFlag::LuauMathLerp, true}; + ScopedFastFlag luauMathMapDefinition{FFlag::LuauMathMapDefinition, true}; runConformance( "types.luau", @@ -1018,9 +1021,7 @@ TEST_CASE("DateTime") TEST_CASE("Debug") { - ScopedFastFlag luauDebugInfoInvArgLeftovers{DFFlag::LuauDebugInfoInvArgLeftovers, true}; - - runConformance("debug.luau"); + runConformance("debug.lua"); } TEST_CASE("Debugger") @@ -1386,6 +1387,25 @@ TEST_CASE("ApiTables") CHECK(strcmp(lua_tostring(L, -1), "test") == 0); lua_pop(L, 1); + // lua_clonetable + lua_clonetable(L, -1); + + CHECK(lua_getfield(L, -1, "key") == LUA_TNUMBER); + CHECK(lua_tonumber(L, -1) == 123.0); + lua_pop(L, 1); + + // modify clone + lua_pushnumber(L, 456.0); + lua_rawsetfield(L, -2, "key"); + + // remove clone + lua_pop(L, 1); + + // check original + CHECK(lua_getfield(L, -1, "key") == LUA_TNUMBER); + CHECK(lua_tonumber(L, -1) == 123.0); + lua_pop(L, 1); + // lua_cleartable lua_cleartable(L, -1); lua_pushnil(L); @@ -2576,6 +2596,8 @@ TEST_CASE("SafeEnv") TEST_CASE("Native") { + ScopedFastFlag luauCodeGenLimitLiveSlotReuse{FFlag::LuauCodeGenLimitLiveSlotReuse, true}; + // This tests requires code to run natively, otherwise all 'is_native' checks will fail if (!codegen || !luau_codegen_supported()) return; diff --git a/tests/ConstraintGeneratorFixture.cpp b/tests/ConstraintGeneratorFixture.cpp index ef91fdf7..90a8b507 100644 --- a/tests/ConstraintGeneratorFixture.cpp +++ b/tests/ConstraintGeneratorFixture.cpp @@ -22,7 +22,9 @@ ConstraintGeneratorFixture::ConstraintGeneratorFixture() void ConstraintGeneratorFixture::generateConstraints(const std::string& code) { AstStatBlock* root = parse(code); - dfg = std::make_unique(DataFlowGraphBuilder::build(root, NotNull{&ice})); + dfg = std::make_unique( + DataFlowGraphBuilder::build(root, NotNull{&mainModule->defArena}, NotNull{&mainModule->keyArena}, NotNull{&ice}) + ); cg = std::make_unique( mainModule, NotNull{&normalizer}, diff --git a/tests/DataFlowGraph.test.cpp b/tests/DataFlowGraph.test.cpp index 4ea656ee..1b7e243c 100644 --- a/tests/DataFlowGraph.test.cpp +++ b/tests/DataFlowGraph.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/DataFlowGraph.h" #include "Fixture.h" +#include "Luau/Def.h" #include "Luau/Error.h" #include "Luau/Parser.h" @@ -18,6 +19,8 @@ struct DataFlowGraphFixture // Only needed to fix the operator== reflexivity of an empty Symbol. ScopedFastFlag dcr{FFlag::LuauSolverV2, true}; + DefArena defArena; + RefinementKeyArena keyArena; InternalErrorReporter handle; Allocator allocator; @@ -32,7 +35,7 @@ struct DataFlowGraphFixture if (!parseResult.errors.empty()) throw ParseErrors(std::move(parseResult.errors)); module = parseResult.root; - graph = DataFlowGraphBuilder::build(module, NotNull{&handle}); + graph = DataFlowGraphBuilder::build(module, NotNull{&defArena}, NotNull{&keyArena}, NotNull{&handle}); } template diff --git a/tests/EqSatSimplification.test.cpp b/tests/EqSatSimplification.test.cpp index aaaec456..6fe2660f 100644 --- a/tests/EqSatSimplification.test.cpp +++ b/tests/EqSatSimplification.test.cpp @@ -3,6 +3,7 @@ #include "Fixture.h" #include "Luau/EqSatSimplification.h" +#include "Luau/Type.h" using namespace Luau; @@ -23,15 +24,11 @@ struct ESFixture : Fixture TypeId genericT = arena_.addType(GenericType{"T"}); TypeId genericU = arena_.addType(GenericType{"U"}); - TypeId numberToString = arena_.addType(FunctionType{ - arena_.addTypePack({builtinTypes->numberType}), - arena_.addTypePack({builtinTypes->stringType}) - }); + TypeId numberToString = + arena_.addType(FunctionType{arena_.addTypePack({builtinTypes->numberType}), arena_.addTypePack({builtinTypes->stringType})}); - TypeId stringToNumber = arena_.addType(FunctionType{ - arena_.addTypePack({builtinTypes->stringType}), - arena_.addTypePack({builtinTypes->numberType}) - }); + TypeId stringToNumber = + arena_.addType(FunctionType{arena_.addTypePack({builtinTypes->stringType}), arena_.addTypePack({builtinTypes->numberType})}); ESFixture() : simplifier(newSimplifier(arena, builtinTypes)) @@ -80,7 +77,7 @@ TEST_CASE_FIXTURE(ESFixture, "number | string") TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = number | t1") { - TypeId ty = arena->freshType(nullptr); + TypeId ty = arena->freshType(builtinTypes, nullptr); asMutable(ty)->ty.emplace(std::vector{builtinTypes->numberType, ty}); CHECK("number" == simplifyStr(ty)); @@ -163,10 +160,11 @@ TEST_CASE_FIXTURE(ESFixture, "never & string") TEST_CASE_FIXTURE(ESFixture, "string & (unknown | never)") { - CHECK("string" == simplifyStr(arena->addType(IntersectionType{{ - builtinTypes->stringType, - arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->neverType}}) - }}))); + CHECK( + "string" == simplifyStr(arena->addType( + IntersectionType{{builtinTypes->stringType, arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->neverType}})}} + )) + ); } TEST_CASE_FIXTURE(ESFixture, "true | false") @@ -211,112 +209,97 @@ TEST_CASE_FIXTURE(ESFixture, "error | unknown") TEST_CASE_FIXTURE(ESFixture, "\"hello\" | string") { - CHECK("string" == simplifyStr(arena->addType(UnionType{{ - arena->addType(SingletonType{StringSingleton{"hello"}}), builtinTypes->stringType - }}))); + CHECK("string" == simplifyStr(arena->addType(UnionType{{arena->addType(SingletonType{StringSingleton{"hello"}}), builtinTypes->stringType}}))); } TEST_CASE_FIXTURE(ESFixture, "\"hello\" | \"world\" | \"hello\"") { - CHECK("\"hello\" | \"world\"" == simplifyStr(arena->addType(UnionType{{ - arena->addType(SingletonType{StringSingleton{"hello"}}), - arena->addType(SingletonType{StringSingleton{"world"}}), - arena->addType(SingletonType{StringSingleton{"hello"}}), - }}))); + CHECK( + "\"hello\" | \"world\"" == simplifyStr(arena->addType(UnionType{{ + arena->addType(SingletonType{StringSingleton{"hello"}}), + arena->addType(SingletonType{StringSingleton{"world"}}), + arena->addType(SingletonType{StringSingleton{"hello"}}), + }})) + ); } TEST_CASE_FIXTURE(ESFixture, "nil | boolean | number | string | thread | function | table | class | buffer") { - CHECK("unknown" == simplifyStr(arena->addType(UnionType{{ - builtinTypes->nilType, - builtinTypes->booleanType, - builtinTypes->numberType, - builtinTypes->stringType, - builtinTypes->threadType, - builtinTypes->functionType, - builtinTypes->tableType, - builtinTypes->classType, - builtinTypes->bufferType, - }}))); + CHECK( + "unknown" == simplifyStr(arena->addType(UnionType{{ + builtinTypes->nilType, + builtinTypes->booleanType, + builtinTypes->numberType, + builtinTypes->stringType, + builtinTypes->threadType, + builtinTypes->functionType, + builtinTypes->tableType, + builtinTypes->classType, + builtinTypes->bufferType, + }})) + ); } TEST_CASE_FIXTURE(ESFixture, "Parent & number") { - CHECK("never" == simplifyStr(arena->addType(IntersectionType{{ - parentClass, builtinTypes->numberType - }}))); + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{parentClass, builtinTypes->numberType}}))); } TEST_CASE_FIXTURE(ESFixture, "Child & Parent") { - CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{ - childClass, parentClass - }}))); + CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{childClass, parentClass}}))); } TEST_CASE_FIXTURE(ESFixture, "Child & Unrelated") { - CHECK("never" == simplifyStr(arena->addType(IntersectionType{{ - childClass, unrelatedClass - }}))); + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{childClass, unrelatedClass}}))); } TEST_CASE_FIXTURE(ESFixture, "Child | Parent") { - CHECK("Parent" == simplifyStr(arena->addType(UnionType{{ - childClass, parentClass - }}))); + CHECK("Parent" == simplifyStr(arena->addType(UnionType{{childClass, parentClass}}))); } TEST_CASE_FIXTURE(ESFixture, "class | Child") { - CHECK("class" == simplifyStr(arena->addType(UnionType{{ - builtinTypes->classType, childClass - }}))); + CHECK("class" == simplifyStr(arena->addType(UnionType{{builtinTypes->classType, childClass}}))); } TEST_CASE_FIXTURE(ESFixture, "Parent | class | Child") { - CHECK("class" == simplifyStr(arena->addType(UnionType{{ - parentClass, builtinTypes->classType, childClass - }}))); + CHECK("class" == simplifyStr(arena->addType(UnionType{{parentClass, builtinTypes->classType, childClass}}))); } TEST_CASE_FIXTURE(ESFixture, "Parent | Unrelated") { - CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{ - parentClass, unrelatedClass - }}))); + CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{parentClass, unrelatedClass}}))); } TEST_CASE_FIXTURE(ESFixture, "never | Parent | Unrelated") { - CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{ - builtinTypes->neverType, parentClass, unrelatedClass - }}))); + CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{builtinTypes->neverType, parentClass, unrelatedClass}}))); } TEST_CASE_FIXTURE(ESFixture, "never | Parent | (number & string) | Unrelated") { - CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{ - builtinTypes->neverType, parentClass, - arena->addType(IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}), - unrelatedClass - }}))); + CHECK( + "Parent | Unrelated" == simplifyStr(arena->addType(UnionType{ + {builtinTypes->neverType, + parentClass, + arena->addType(IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}), + unrelatedClass} + })) + ); } TEST_CASE_FIXTURE(ESFixture, "T & U") { - CHECK("T & U" == simplifyStr(arena->addType(IntersectionType{{ - genericT, genericU - }}))); + CHECK("T & U" == simplifyStr(arena->addType(IntersectionType{{genericT, genericU}}))); } TEST_CASE_FIXTURE(ESFixture, "boolean & true") { - CHECK("true" == simplifyStr(arena->addType(IntersectionType{{ - builtinTypes->booleanType, builtinTypes->trueType - }}))); + CHECK("true" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->booleanType, builtinTypes->trueType}}))); } TEST_CASE_FIXTURE(ESFixture, "boolean & (true | number | string | thread | function | table | class | buffer)") @@ -332,23 +315,17 @@ TEST_CASE_FIXTURE(ESFixture, "boolean & (true | number | string | thread | funct builtinTypes->bufferType, }}); - CHECK("true" == simplifyStr(arena->addType(IntersectionType{{ - builtinTypes->booleanType, truthy - }}))); + CHECK("true" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->booleanType, truthy}}))); } TEST_CASE_FIXTURE(ESFixture, "boolean & ~(false?)") { - CHECK("true" == simplifyStr(arena->addType(IntersectionType{{ - builtinTypes->booleanType, builtinTypes->truthyType - }}))); + CHECK("true" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->booleanType, builtinTypes->truthyType}}))); } TEST_CASE_FIXTURE(ESFixture, "false & ~(false?)") { - CHECK("never" == simplifyStr(arena->addType(IntersectionType{{ - builtinTypes->falseType, builtinTypes->truthyType - }}))); + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->falseType, builtinTypes->truthyType}}))); } TEST_CASE_FIXTURE(ESFixture, "(number) -> string & (number) -> string") @@ -399,28 +376,25 @@ TEST_CASE_FIXTURE(ESFixture, "(number) -> string | (string) -> number") TEST_CASE_FIXTURE(ESFixture, "add") { - CHECK("number" == simplifyStr(arena->addType( - TypeFunctionInstanceType{builtinTypeFunctions().addFunc, { - builtinTypes->numberType, builtinTypes->numberType - }} - ))); + CHECK( + "number" == + simplifyStr(arena->addType(TypeFunctionInstanceType{builtinTypeFunctions().addFunc, {builtinTypes->numberType, builtinTypes->numberType}})) + ); } TEST_CASE_FIXTURE(ESFixture, "union") { - CHECK("number" == simplifyStr(arena->addType( - TypeFunctionInstanceType{builtinTypeFunctions().unionFunc, { - builtinTypes->numberType, builtinTypes->numberType - }} - ))); + CHECK( + "number" == + simplifyStr(arena->addType(TypeFunctionInstanceType{builtinTypeFunctions().unionFunc, {builtinTypes->numberType, builtinTypes->numberType}})) + ); } TEST_CASE_FIXTURE(ESFixture, "never & ~string") { - CHECK("never" == simplifyStr(arena->addType(IntersectionType{{ - builtinTypes->neverType, - arena->addType(NegationType{builtinTypes->stringType}) - }}))); + CHECK( + "never" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->neverType, arena->addType(NegationType{builtinTypes->stringType})}})) + ); } TEST_CASE_FIXTURE(ESFixture, "blocked & never") @@ -444,7 +418,9 @@ TEST_CASE_FIXTURE(ESFixture, "blocked & ~number & function") TEST_CASE_FIXTURE(ESFixture, "(number | boolean | string | nil | table) & (false | nil)") { - const TypeId t1 = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->booleanType, builtinTypes->stringType, builtinTypes->nilType, builtinTypes->tableType}}); + const TypeId t1 = arena->addType( + UnionType{{builtinTypes->numberType, builtinTypes->booleanType, builtinTypes->stringType, builtinTypes->nilType, builtinTypes->tableType}} + ); CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}}))); } @@ -475,7 +451,7 @@ TEST_CASE_FIXTURE(ESFixture, "(boolean | nil) & (false | nil)") TEST_CASE_FIXTURE(ESFixture, "free & string & number") { Scope scope{builtinTypes->anyTypePack}; - const TypeId freeTy = arena->addType(FreeType{&scope}); + const TypeId freeTy = arena->freshType(builtinTypes, &scope); CHECK("never" == simplifyStr(arena->addType(IntersectionType{{freeTy, builtinTypes->numberType, builtinTypes->stringType}}))); } @@ -493,26 +469,17 @@ TEST_CASE_FIXTURE(ESFixture, "(blocked & number) | (blocked & number)") TEST_CASE_FIXTURE(ESFixture, "{} & unknown") { - CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{ - tbl({}), - builtinTypes->unknownType - }}))); + CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{tbl({}), builtinTypes->unknownType}}))); } TEST_CASE_FIXTURE(ESFixture, "{} & table") { - CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{ - tbl({}), - builtinTypes->tableType - }}))); + CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{tbl({}), builtinTypes->tableType}}))); } TEST_CASE_FIXTURE(ESFixture, "{} & ~(false?)") { - CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{ - tbl({}), - builtinTypes->truthyType - }}))); + CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{tbl({}), builtinTypes->truthyType}}))); } TEST_CASE_FIXTURE(ESFixture, "{x: number?} & {x: number}") @@ -606,10 +573,7 @@ TEST_CASE_FIXTURE(ESFixture, "{ x: number } & ~boolean") { const TypeId tblTy = tbl(TableType::Props{{"x", builtinTypes->numberType}}); - const TypeId ty = arena->addType(IntersectionType{{ - tblTy, - arena->addType(NegationType{builtinTypes->booleanType}) - }}); + const TypeId ty = arena->addType(IntersectionType{{tblTy, arena->addType(NegationType{builtinTypes->booleanType})}}); CHECK("{ x: number }" == simplifyStr(ty)); } @@ -634,44 +598,41 @@ TEST_CASE_FIXTURE(ESFixture, "string & (\"hi\" | \"bye\")") const TypeId hi = arena->addType(SingletonType{StringSingleton{"hi"}}); const TypeId bye = arena->addType(SingletonType{StringSingleton{"bye"}}); - CHECK("\"bye\" | \"hi\"" == simplifyStr(arena->addType(IntersectionType{{ - builtinTypes->stringType, - arena->addType(UnionType{{hi, bye}}) - }}))); + CHECK("\"bye\" | \"hi\"" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, arena->addType(UnionType{{hi, bye}})}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(\"err\" | \"ok\") & ~\"ok\"") +{ + TypeId err = arena->addType(SingletonType{StringSingleton{"err"}}); + TypeId ok1 = arena->addType(SingletonType{StringSingleton{"ok"}}); + TypeId ok2 = arena->addType(SingletonType{StringSingleton{"ok"}}); + + TypeId ty = arena->addType(IntersectionType{{arena->addType(UnionType{{err, ok1}}), arena->addType(NegationType{ok2})}}); + + CHECK("\"err\"" == simplifyStr(ty)); } TEST_CASE_FIXTURE(ESFixture, "(Child | Unrelated) & ~Child") { - const TypeId ty = arena->addType(IntersectionType{{ - arena->addType(UnionType{{childClass, unrelatedClass}}), - arena->addType(NegationType{childClass}) - }}); + const TypeId ty = + arena->addType(IntersectionType{{arena->addType(UnionType{{childClass, unrelatedClass}}), arena->addType(NegationType{childClass})}}); CHECK("Unrelated" == simplifyStr(ty)); } TEST_CASE_FIXTURE(ESFixture, "string & ~Child") { - CHECK("string" == simplifyStr(arena->addType(IntersectionType{{ - builtinTypes->stringType, - arena->addType(NegationType{childClass}) - }}))); + CHECK("string" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, arena->addType(NegationType{childClass})}}))); } TEST_CASE_FIXTURE(ESFixture, "(Child | Unrelated) & Child") { - CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{ - arena->addType(UnionType{{childClass, unrelatedClass}}), - childClass - }}))); + CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{arena->addType(UnionType{{childClass, unrelatedClass}}), childClass}}))); } TEST_CASE_FIXTURE(ESFixture, "(Child | AnotherChild) & ~Child") { - CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{ - arena->addType(UnionType{{childClass, anotherChild}}), - childClass - }}))); + CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{arena->addType(UnionType{{childClass, anotherChild}}), childClass}}))); } TEST_CASE_FIXTURE(ESFixture, "{ tag: \"Part\", x: never }") @@ -692,11 +653,7 @@ TEST_CASE_FIXTURE(ESFixture, "{ tag: \"Part\", x: number? } & { x: string }") TEST_CASE_FIXTURE(ESFixture, "Child & add") { const TypeId u = arena->addType(UnionType{{childClass, anotherChild, builtinTypes->stringType}}); - const TypeId intersectTf = arena->addType(TypeFunctionInstanceType{ - builtinTypeFunctions().addFunc, - {u, parentClass}, - {} - }); + const TypeId intersectTf = arena->addType(TypeFunctionInstanceType{builtinTypeFunctions().addFunc, {u, parentClass}, {}}); const TypeId intersection = arena->addType(IntersectionType{{childClass, intersectTf}}); @@ -706,17 +663,46 @@ TEST_CASE_FIXTURE(ESFixture, "Child & add TEST_CASE_FIXTURE(ESFixture, "Child & intersect") { const TypeId u = arena->addType(UnionType{{childClass, anotherChild, builtinTypes->stringType}}); - const TypeId intersectTf = arena->addType(TypeFunctionInstanceType{ - builtinTypeFunctions().intersectFunc, - {u, parentClass}, - {} - }); + const TypeId intersectTf = arena->addType(TypeFunctionInstanceType{builtinTypeFunctions().intersectFunc, {u, parentClass}, {}}); const TypeId intersection = arena->addType(IntersectionType{{childClass, intersectTf}}); CHECK("Child" == simplifyStr(intersection)); } +TEST_CASE_FIXTURE(ESFixture, "lt == boolean") +{ + std::vector> cases{ + {builtinTypes->numberType, arena->addType(BlockedType{})}, + {builtinTypes->stringType, arena->addType(BlockedType{})}, + {arena->addType(BlockedType{}), builtinTypes->numberType}, + {arena->addType(BlockedType{}), builtinTypes->stringType}, + }; + + for (const auto& [lhs, rhs] : cases) + { + const TypeId tfun = arena->addType(TypeFunctionInstanceType{builtinTypeFunctions().ltFunc, {lhs, rhs}}); + CHECK("boolean" == simplifyStr(tfun)); + } +} + +TEST_CASE_FIXTURE(ESFixture, "unknown & ~string") +{ + CHECK_EQ( + "~string", simplifyStr(arena->addType(IntersectionType{{builtinTypes->unknownType, arena->addType(NegationType{builtinTypes->stringType})}})) + ); +} + +TEST_CASE_FIXTURE(ESFixture, "string & ~\"foo\"") +{ + CHECK_EQ( + "string & ~\"foo\"", + simplifyStr(arena->addType( + IntersectionType{{builtinTypes->stringType, arena->addType(NegationType{arena->addType(SingletonType{StringSingleton{"foo"}})})}} + )) + ); +} + // {someKey: ~any} // // Maybe something we could do here is to try to reduce the key, get the diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 5a2f9319..3a8d2cfc 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -25,6 +25,7 @@ static const char* mainModuleName = "MainModule"; LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAG(LuauVector2Constructor) LUAU_FASTFLAG(DebugLuauLogSolverToJsonFile) LUAU_FASTFLAGVARIABLE(DebugLuauForceAllNewSolverTests); @@ -341,8 +342,11 @@ ParseResult Fixture::matchParseErrorPrefix(const std::string& source, const std: return result; } -ModulePtr Fixture::getMainModule() +ModulePtr Fixture::getMainModule(bool forAutocomplete) { + if (forAutocomplete && !FFlag::LuauSolverV2) + return frontend.moduleResolverForAutocomplete.getModule(fromString(mainModuleName)); + return frontend.moduleResolver.getModule(fromString(mainModuleName)); } @@ -365,9 +369,9 @@ std::optional Fixture::getPrimitiveType(TypeId ty) return std::nullopt; } -std::optional Fixture::getType(const std::string& name) +std::optional Fixture::getType(const std::string& name, bool forAutocomplete) { - ModulePtr module = getMainModule(); + ModulePtr module = getMainModule(forAutocomplete); REQUIRE(module); if (!module->hasModuleScope()) @@ -519,6 +523,9 @@ void Fixture::registerTestTypes() void Fixture::dumpErrors(const CheckResult& cr) { + if (hasDumpedErrors) + return; + hasDumpedErrors = true; std::string error = getErrors(cr); if (!error.empty()) MESSAGE(error); @@ -526,6 +533,9 @@ void Fixture::dumpErrors(const CheckResult& cr) void Fixture::dumpErrors(const ModulePtr& module) { + if (hasDumpedErrors) + return; + hasDumpedErrors = true; std::stringstream ss; dumpErrors(ss, module->errors); if (!ss.str().empty()) @@ -534,6 +544,9 @@ void Fixture::dumpErrors(const ModulePtr& module) void Fixture::dumpErrors(const Module& module) { + if (hasDumpedErrors) + return; + hasDumpedErrors = true; std::stringstream ss; dumpErrors(ss, module.errors); if (!ss.str().empty()) @@ -580,6 +593,8 @@ LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source, bool BuiltinsFixture::BuiltinsFixture(bool prepareAutocomplete) : Fixture(prepareAutocomplete) { + ScopedFastFlag luauVector2Constructor{FFlag::LuauVector2Constructor, true}; + Luau::unfreeze(frontend.globals.globalTypes); Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); diff --git a/tests/Fixture.h b/tests/Fixture.h index ba038403..c202075b 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -27,10 +27,8 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(DebugLuauForceAllNewSolverTests) -LUAU_FASTFLAG(LuauVectorDefinitionsExtra) -#define DOES_NOT_PASS_NEW_SOLVER_GUARD_IMPL(line) \ - ScopedFastFlag sff_##line{FFlag::LuauSolverV2, FFlag::DebugLuauForceAllNewSolverTests}; +#define DOES_NOT_PASS_NEW_SOLVER_GUARD_IMPL(line) ScopedFastFlag sff_##line{FFlag::LuauSolverV2, FFlag::DebugLuauForceAllNewSolverTests}; #define DOES_NOT_PASS_NEW_SOLVER_GUARD() DOES_NOT_PASS_NEW_SOLVER_GUARD_IMPL(__LINE__) @@ -90,11 +88,11 @@ struct Fixture // Verify a parse error occurs and the parse error message has the specified prefix ParseResult matchParseErrorPrefix(const std::string& source, const std::string& prefix); - ModulePtr getMainModule(); + ModulePtr getMainModule(bool forAutocomplete = false); SourceModule* getMainSourceModule(); std::optional getPrimitiveType(TypeId ty); - std::optional getType(const std::string& name); + std::optional getType(const std::string& name, bool forAutocomplete = false); TypeId requireType(const std::string& name); TypeId requireType(const ModuleName& moduleName, const std::string& name); TypeId requireType(const ModulePtr& module, const std::string& name); @@ -114,8 +112,6 @@ struct Fixture // In that case, flag can be forced to 'true' using the example below: // ScopedFastFlag sff_LuauExampleFlagDefinition{FFlag::LuauExampleFlagDefinition, true}; - ScopedFastFlag sff_LuauVectorDefinitionsExtra{FFlag::LuauVectorDefinitionsExtra, true}; - // Arena freezing marks the `TypeArena`'s underlying memory as read-only, raising an access violation whenever you mutate it. // This is useful for tracking down violations of Luau's memory model. ScopedFastFlag sff_DebugLuauFreezeArena{FFlag::DebugLuauFreezeArena, true}; @@ -143,11 +139,14 @@ struct Fixture void registerTestTypes(); LoadDefinitionFileResult loadDefinition(const std::string& source, bool forAutocomplete = false); + +private: + bool hasDumpedErrors = false; }; struct BuiltinsFixture : Fixture { - BuiltinsFixture(bool prepareAutocomplete = false); + explicit BuiltinsFixture(bool prepareAutocomplete = false); }; std::optional pathExprToModuleName(const ModuleName& currentModuleName, const std::vector& segments); @@ -181,6 +180,18 @@ std::optional linearSearchForBinding(Scope* scope, const char* name); void registerHiddenTypes(Frontend* frontend); void createSomeClasses(Frontend* frontend); +template +const E* findError(const CheckResult& result) +{ + for (const auto& e : result.errors) + { + if (auto p = get(e)) + return p; + } + + return nullptr; +} + template struct DifferFixtureGeneric : BaseFixture { @@ -330,3 +341,51 @@ using DifferFixtureWithBuiltins = DifferFixtureGeneric; } \ } \ } while (false) + +#define LUAU_REQUIRE_ERROR(result, Type) \ + do \ + { \ + using T = Type; \ + const auto& res = (result); \ + if (!findError(res)) \ + { \ + dumpErrors(res); \ + REQUIRE_MESSAGE(false, "Expected to find " #Type " error"); \ + } \ + } while (false) + +#define LUAU_CHECK_ERROR(result, Type) \ + do \ + { \ + using T = Type; \ + const auto& res = (result); \ + if (!findError(res)) \ + { \ + dumpErrors(res); \ + CHECK_MESSAGE(false, "Expected to find " #Type " error"); \ + } \ + } while (false) + +#define LUAU_REQUIRE_NO_ERROR(result, Type) \ + do \ + { \ + using T = Type; \ + const auto& res = (result); \ + if (findError(res)) \ + { \ + dumpErrors(res); \ + REQUIRE_MESSAGE(false, "Expected to find no " #Type " error"); \ + } \ + } while (false) + +#define LUAU_CHECK_NO_ERROR(result, Type) \ + do \ + { \ + using T = Type; \ + const auto& res = (result); \ + if (findError(res)) \ + { \ + dumpErrors(res); \ + CHECK_MESSAGE(false, "Expected to find no " #Type " error"); \ + } \ + } while (false) diff --git a/tests/FragmentAutocomplete.test.cpp b/tests/FragmentAutocomplete.test.cpp index 42f2bf09..58bbc16a 100644 --- a/tests/FragmentAutocomplete.test.cpp +++ b/tests/FragmentAutocomplete.test.cpp @@ -9,6 +9,7 @@ #include "Luau/Common.h" #include "Luau/Frontend.h" #include "Luau/AutocompleteTypes.h" +#include "Luau/Type.h" #include #include @@ -24,7 +25,17 @@ LUAU_FASTFLAG(LuauAllowFragmentParsing); LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete) LUAU_FASTFLAG(LuauSymbolEquality); LUAU_FASTFLAG(LuauStoreSolverTypeOnModule); -LUAU_FASTFLAG(LexerResumesFromPosition) +LUAU_FASTFLAG(LexerResumesFromPosition2) +LUAU_FASTFLAG(LuauIncrementalAutocompleteCommentDetection) +LUAU_FASTINT(LuauParseErrorLimit) +LUAU_FASTFLAG(LuauCloneIncrementalModule) + +LUAU_FASTFLAG(LuauIncrementalAutocompleteBugfixes) +LUAU_FASTFLAG(LuauReferenceAllocatorInNewSolver) +LUAU_FASTFLAG(LuauMixedModeDefFinderTraversesTypeOf) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) + +LUAU_FASTFLAG(LuauBetterReverseDependencyTracking) static std::optional nullCallback(std::string tag, std::optional ptr, std::optional contents) { @@ -44,15 +55,25 @@ static FrontendOptions getOptions() return options; } +static ModuleResolver& getModuleResolver(Luau::Frontend& frontend) +{ + return FFlag::LuauSolverV2 ? frontend.moduleResolver : frontend.moduleResolverForAutocomplete; +} + template struct FragmentAutocompleteFixtureImpl : BaseType { - ScopedFastFlag sffs[5] = { + static_assert(std::is_base_of_v, "BaseType must be a descendant of Fixture"); + + ScopedFastFlag sffs[8] = { {FFlag::LuauAllowFragmentParsing, true}, {FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete, true}, {FFlag::LuauStoreSolverTypeOnModule, true}, {FFlag::LuauSymbolEquality, true}, - {FFlag::LexerResumesFromPosition, true} + {FFlag::LexerResumesFromPosition2, true}, + {FFlag::LuauReferenceAllocatorInNewSolver, true}, + {FFlag::LuauIncrementalAutocompleteBugfixes, true}, + {FFlag::LuauBetterReverseDependencyTracking, true}, }; FragmentAutocompleteFixtureImpl() @@ -68,7 +89,7 @@ struct FragmentAutocompleteFixtureImpl : BaseType } - FragmentParseResult parseFragment( + std::optional parseFragment( const std::string& document, const Position& cursorPos, std::optional fragmentEndPosition = std::nullopt @@ -91,7 +112,8 @@ struct FragmentAutocompleteFixtureImpl : BaseType std::optional fragmentEndPosition = std::nullopt ) { - return Luau::typecheckFragment(this->frontend, "MainModule", cursorPos, getOptions(), document, fragmentEndPosition); + auto [_, result] = Luau::typecheckFragment(this->frontend, "MainModule", cursorPos, getOptions(), document, fragmentEndPosition); + return result; } FragmentAutocompleteResult autocompleteFragment( @@ -125,6 +147,26 @@ struct FragmentAutocompleteFixtureImpl : BaseType result = autocompleteFragment(updated, cursorPos, fragmentEndPosition); assertions(result); } + + std::pair typecheckFragmentForModule( + const ModuleName& module, + const std::string& document, + Position cursorPos, + std::optional fragmentEndPosition = std::nullopt + ) + { + return Luau::typecheckFragment(this->frontend, module, cursorPos, getOptions(), document, fragmentEndPosition); + } + + FragmentAutocompleteResult autocompleteFragmentForModule( + const ModuleName& module, + const std::string& document, + Position cursorPos, + std::optional fragmentEndPosition = std::nullopt + ) + { + return Luau::fragmentAutocomplete(this->frontend, document, module, cursorPos, getOptions(), nullCallback, fragmentEndPosition); + } }; struct FragmentAutocompleteFixture : FragmentAutocompleteFixtureImpl @@ -159,9 +201,13 @@ end // 'for autocomplete'. loadDefinition(fakeVecDecl); loadDefinition(fakeVecDecl, /* For Autocomplete Module */ true); + + addGlobalBinding(frontend.globals, "game", Binding{builtinTypes->anyType}); + addGlobalBinding(frontend.globalsForAutocomplete, "game", Binding{builtinTypes->anyType}); } }; +// NOLINTBEGIN(bugprone-unchecked-optional-access) TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests"); TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "just_two_locals") @@ -284,13 +330,23 @@ TEST_SUITE_END(); TEST_SUITE_BEGIN("FragmentAutocompleteParserTests"); +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "thrown_parse_error_leads_to_null_root") +{ + check("type A = "); + ScopedFastInt sfi{FInt::LuauParseErrorLimit, 1}; + auto fragment = parseFragment("type A = <>function<> more garbage here", Position(0, 39)); + CHECK(fragment == std::nullopt); +} + TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_initializer") { ScopedFastFlag sff{FFlag::LuauSolverV2, true}; check("local a ="); auto fragment = parseFragment("local a =", Position(0, 10)); - CHECK_EQ("local a =", fragment.fragmentToParse); - CHECK_EQ(Location{Position{0, 0}, 9}, fragment.root->location); + + REQUIRE(fragment.has_value()); + CHECK_EQ("local a =", fragment->fragmentToParse); + CHECK_EQ(Location{Position{0, 0}, 9}, fragment->root->location); } TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_non_null") @@ -308,11 +364,12 @@ TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_n )", Position(1, 0) ); - CHECK_EQ("\n", fragment.fragmentToParse); - CHECK_EQ(2, fragment.ancestry.size()); - REQUIRE(fragment.root); - CHECK_EQ(0, fragment.root->body.size); - auto statBody = fragment.root->as(); + REQUIRE(fragment.has_value()); + CHECK_EQ("\n", fragment->fragmentToParse); + CHECK_EQ(2, fragment->ancestry.size()); + REQUIRE(fragment->root); + CHECK_EQ(0, fragment->root->body.size); + auto statBody = fragment->root->as(); CHECK(statBody != nullptr); } @@ -337,13 +394,15 @@ local z = x + y Position{3, 15} ); - CHECK_EQ(Location{Position{2, 0}, Position{3, 15}}, fragment.root->location); + REQUIRE(fragment.has_value()); - CHECK_EQ("local y = 5\nlocal z = x + y", fragment.fragmentToParse); - CHECK_EQ(5, fragment.ancestry.size()); - REQUIRE(fragment.root); - CHECK_EQ(2, fragment.root->body.size); - auto stat = fragment.root->body.data[1]->as(); + CHECK_EQ(Location{Position{2, 0}, Position{3, 15}}, fragment->root->location); + + CHECK_EQ("local y = 5\nlocal z = x + y", fragment->fragmentToParse); + CHECK_EQ(5, fragment->ancestry.size()); + REQUIRE(fragment->root); + CHECK_EQ(2, fragment->root->body.size); + auto stat = fragment->root->body.data[1]->as(); REQUIRE(stat); CHECK_EQ(1, stat->vars.size); CHECK_EQ(1, stat->values.size); @@ -382,12 +441,14 @@ local y = 5 Position{2, 15} ); - CHECK_EQ("local z = x + y", fragment.fragmentToParse); - CHECK_EQ(5, fragment.ancestry.size()); - REQUIRE(fragment.root); - CHECK_EQ(Location{Position{2, 0}, Position{2, 15}}, fragment.root->location); - CHECK_EQ(1, fragment.root->body.size); - auto stat = fragment.root->body.data[0]->as(); + REQUIRE(fragment.has_value()); + + CHECK_EQ("local z = x + y", fragment->fragmentToParse); + CHECK_EQ(5, fragment->ancestry.size()); + REQUIRE(fragment->root); + CHECK_EQ(Location{Position{2, 0}, Position{2, 15}}, fragment->root->location); + CHECK_EQ(1, fragment->root->body.size); + auto stat = fragment->root->body.data[0]->as(); REQUIRE(stat); CHECK_EQ(1, stat->vars.size); CHECK_EQ(1, stat->values.size); @@ -427,7 +488,9 @@ TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_in_correct_scope") Position{6, 0} ); - CHECK_EQ("\n ", fragment.fragmentToParse); + REQUIRE(fragment.has_value()); + + CHECK_EQ("\n ", fragment->fragmentToParse); } TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_single_line_fragment_override") @@ -446,17 +509,19 @@ abc("bar") Position{1, 10} ); - CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", callFragment.fragmentToParse); - CHECK(callFragment.nearestStatement->is()); + REQUIRE(callFragment.has_value()); - CHECK_GE(callFragment.ancestry.size(), 2); + CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", callFragment->fragmentToParse); + CHECK(callFragment->nearestStatement->is()); - AstNode* back = callFragment.ancestry.back(); + CHECK_GE(callFragment->ancestry.size(), 2); + + AstNode* back = callFragment->ancestry.back(); CHECK(back->is()); CHECK_EQ(Position{1, 4}, back->location.begin); CHECK_EQ(Position{1, 9}, back->location.end); - AstNode* parent = callFragment.ancestry.rbegin()[1]; + AstNode* parent = callFragment->ancestry.rbegin()[1]; CHECK(parent->is()); CHECK_EQ(Position{1, 0}, parent->location.begin); CHECK_EQ(Position{1, 10}, parent->location.end); @@ -471,12 +536,14 @@ abc("bar") Position{1, 9} ); - CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", stringFragment.fragmentToParse); - CHECK(stringFragment.nearestStatement->is()); + REQUIRE(stringFragment.has_value()); - CHECK_GE(stringFragment.ancestry.size(), 1); + CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", stringFragment->fragmentToParse); + CHECK(stringFragment->nearestStatement->is()); - back = stringFragment.ancestry.back(); + CHECK_GE(stringFragment->ancestry.size(), 1); + + back = stringFragment->ancestry.back(); auto asString = back->as(); CHECK(asString); @@ -506,17 +573,19 @@ abc("bar") Position{3, 1} ); - CHECK_EQ("function abc(foo: string) end\nabc(\n\"foo\"\n)", fragment.fragmentToParse); - CHECK(fragment.nearestStatement->is()); + REQUIRE(fragment.has_value()); - CHECK_GE(fragment.ancestry.size(), 2); + CHECK_EQ("function abc(foo: string) end\nabc(\n\"foo\"\n)", fragment->fragmentToParse); + CHECK(fragment->nearestStatement->is()); - AstNode* back = fragment.ancestry.back(); + CHECK_GE(fragment->ancestry.size(), 2); + + AstNode* back = fragment->ancestry.back(); CHECK(back->is()); CHECK_EQ(Position{2, 0}, back->location.begin); CHECK_EQ(Position{2, 5}, back->location.end); - AstNode* parent = fragment.ancestry.rbegin()[1]; + AstNode* parent = fragment->ancestry.rbegin()[1]; CHECK(parent->is()); CHECK_EQ(Position{1, 0}, parent->location.begin); CHECK_EQ(Position{3, 1}, parent->location.end); @@ -547,6 +616,7 @@ t } TEST_SUITE_END(); +// NOLINTEND(bugprone-unchecked-optional-access) TEST_SUITE_BEGIN("FragmentAutocompleteTypeCheckerTests"); @@ -682,10 +752,132 @@ tbl. CHECK_EQ(AutocompleteContext::Property, fragment.acResults.context); } +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "typecheck_fragment_handles_stale_module") +{ + const std::string sourceName = "MainModule"; + fileResolver.source[sourceName] = "local x = 5"; + + CheckResult checkResult = frontend.check(sourceName, getOptions()); + LUAU_REQUIRE_NO_ERRORS(checkResult); + + auto [result, _] = typecheckFragmentForModule(sourceName, fileResolver.source[sourceName], Luau::Position(0, 0)); + CHECK_EQ(result, FragmentTypeCheckStatus::Success); + + frontend.markDirty(sourceName); + frontend.parse(sourceName); + + CHECK_NE(frontend.getSourceModule(sourceName), nullptr); + + auto [result2, __] = typecheckFragmentForModule(sourceName, fileResolver.source[sourceName], Luau::Position(0, 0)); + CHECK_EQ(result2, FragmentTypeCheckStatus::SkipAutocomplete); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "typecheck_fragment_handles_unusable_module") +{ + const std::string sourceA = "MainModule"; + fileResolver.source[sourceA] = R"( +local Modules = game:GetService('Gui').Modules +local B = require(Modules.B) +return { hello = B } +)"; + + const std::string sourceB = "game/Gui/Modules/B"; + fileResolver.source[sourceB] = R"(return {hello = "hello"})"; + + CheckResult result = frontend.check(sourceA, getOptions()); + CHECK(!frontend.isDirty(sourceA, getOptions().forAutocomplete)); + + std::weak_ptr weakModule = getModuleResolver(frontend).getModule(sourceB); + REQUIRE(!weakModule.expired()); + + frontend.markDirty(sourceB); + CHECK(frontend.isDirty(sourceA, getOptions().forAutocomplete)); + + frontend.check(sourceB, getOptions()); + CHECK(weakModule.expired()); + + auto [status, _] = typecheckFragmentForModule(sourceA, fileResolver.source[sourceA], Luau::Position(0, 0)); + CHECK_EQ(status, FragmentTypeCheckStatus::SkipAutocomplete); + + auto [status2, _2] = typecheckFragmentForModule(sourceB, fileResolver.source[sourceB], Luau::Position(3, 20)); + CHECK_EQ(status2, FragmentTypeCheckStatus::Success); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("FragmentAutocompleteTests"); +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "multiple_fragment_autocomplete") +{ + ToStringOptions opt; + opt.exhaustive = true; + opt.exhaustive = true; + opt.functionTypeArguments = true; + opt.maxTableLength = 0; + opt.maxTypeLength = 0; + + auto checkAndExamine = [&](const std::string& src, const std::string& idName, const std::string& idString) + { + check(src, getOptions()); + auto id = getType(idName, true); + LUAU_ASSERT(id); + CHECK_EQ(Luau::toString(*id, opt), idString); + }; + + auto getTypeFromModule = [](ModulePtr module, const std::string& name) -> std::optional + { + if (!module->hasModuleScope()) + return std::nullopt; + return lookupName(module->getModuleScope(), name); + }; + + auto fragmentACAndCheck = [&](const std::string& updated, + const Position& pos, + const std::string& idName, + const std::string& srcIdString, + const std::string& fragIdString) + { + FragmentAutocompleteResult result = autocompleteFragment(updated, pos, std::nullopt); + auto fragId = getTypeFromModule(result.incrementalModule, idName); + LUAU_ASSERT(fragId); + CHECK_EQ(Luau::toString(*fragId, opt), fragIdString); + + auto srcId = getType(idName, true); + LUAU_ASSERT(srcId); + CHECK_EQ(Luau::toString(*srcId, opt), srcIdString); + }; + + const std::string source = R"(local module = {} +f +return module)"; + + const std::string updated1 = R"(local module = {} +function module.a +return module)"; + + const std::string updated2 = R"(local module = {} +function module.ab +return module)"; + + { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + checkAndExamine(source, "module", "{ }"); + // [TODO] CLI-140762 we shouldn't mutate stale module in autocompleteFragment + // early return since the following checking will fail, which it shouldn't! + // fragmentACAndCheck(updated1, Position{1, 17}, "module", "{ }", "{ a: (%error-id%: unknown) -> () }"); + // fragmentACAndCheck(updated2, Position{1, 18}, "module", "{ }", "{ ab: (%error-id%: unknown) -> () }"); + } + { + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + checkAndExamine(source, "module", "{ }"); + // [TODO] CLI-140762 we shouldn't mutate stale module in autocompleteFragment + // early return since the following checking will fail, which it shouldn't! + return; + fragmentACAndCheck(updated1, Position{1, 17}, "module", "{ }", "{ a: (%error-id%: unknown) -> () }"); + fragmentACAndCheck(updated2, Position{1, 18}, "module", "{ }", "{ ab: (%error-id%: unknown) -> () }"); + } +} + TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_autocomplete_simple_property_access") { @@ -1383,4 +1575,382 @@ t ); } +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "no_recs_for_comments_simple") +{ + const std::string source = R"( +-- sel +-- retur +-- fo +-- if +-- end +-- the +)"; + ScopedFastFlag sff{FFlag::LuauIncrementalAutocompleteCommentDetection, true}; + autocompleteFragmentInBothSolvers( + source, + source, + Position{4, 6}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "no_recs_for_comments_blocks") +{ + const std::string source = R"( +--[[ +comment 1 +]] local +-- [[ comment 2]] +-- +-- sdfsdfsdf +--[[comment 3]] +--[[ +foo +bar +baz +]] +)"; + ScopedFastFlag sff{FFlag::LuauIncrementalAutocompleteCommentDetection, true}; + autocompleteFragmentInBothSolvers( + source, + source, + Position{3, 0}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); + + autocompleteFragmentInBothSolvers( + source, + source, + Position{3, 2}, + [](FragmentAutocompleteResult& result) + { + CHECK(!result.acResults.entryMap.empty()); + } + ); + + autocompleteFragmentInBothSolvers( + source, + source, + Position{8, 6}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); + + autocompleteFragmentInBothSolvers( + source, + source, + Position{10, 0}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "no_recs_for_comments") +{ + const std::string source = R"( +-- sel +-- retur +-- fo +--[[ sel ]] +local -- hello +)"; + ScopedFastFlag sff{FFlag::LuauIncrementalAutocompleteCommentDetection, true}; + autocompleteFragmentInBothSolvers( + source, + source, + Position{1, 7}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); + + autocompleteFragmentInBothSolvers( + source, + source, + Position{2, 9}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); + + autocompleteFragmentInBothSolvers( + source, + source, + Position{3, 6}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); + + autocompleteFragmentInBothSolvers( + source, + source, + Position{4, 9}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); + + autocompleteFragmentInBothSolvers( + source, + source, + Position{5, 6}, + [](FragmentAutocompleteResult& result) + { + CHECK(!result.acResults.entryMap.empty()); + } + ); + + autocompleteFragmentInBothSolvers( + source, + source, + Position{5, 14}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "no_recs_for_comments_in_incremental_fragment") +{ + const std::string source = R"( +local x = 5 +if x == 5 +)"; + const std::string updated = R"( +local x = 5 +if x == 5 then -- a comment +)"; + ScopedFastFlag sff{FFlag::LuauIncrementalAutocompleteCommentDetection, true}; + autocompleteFragmentInBothSolvers( + source, + updated, + Position{2, 28}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "fragment_autocomplete_handles_parse_errors") +{ + + ScopedFastInt sfi{FInt::LuauParseErrorLimit, 1}; + const std::string source = R"( + +)"; + const std::string updated = R"( +type A = <>random non code text here +)"; + + autocompleteFragmentInBothSolvers( + source, + updated, + Position{1, 38}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "fragment_autocomplete_handles_stale_module") +{ + const std::string sourceName = "MainModule"; + fileResolver.source[sourceName] = "local x = 5"; + + frontend.check(sourceName, getOptions()); + frontend.markDirty(sourceName); + frontend.parse(sourceName); + + FragmentAutocompleteResult result = autocompleteFragmentForModule(sourceName, fileResolver.source[sourceName], Luau::Position(0, 0)); + CHECK(result.acResults.entryMap.empty()); + CHECK_EQ(result.incrementalModule, nullptr); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "require_tracing") +{ + fileResolver.source["MainModule/A"] = R"( +return { x = 0 } + )"; + + fileResolver.source["MainModule"] = R"( +local result = require(script.A) +local x = 1 + result. + )"; + + autocompleteFragmentInBothSolvers( + fileResolver.source["MainModule"], + fileResolver.source["MainModule"], + Position{2, 21}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.size() == 1); + CHECK(result.acResults.entryMap.count("x")); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "fragment_ac_must_traverse_typeof_and_not_ice") +{ + // This test ensures that we traverse typeof expressions for defs that are being referred to in the fragment + // In this case, we want to ensure we populate the incremental environment with the reference to `m` + // Without this, we would ice as we will refer to the local `m` before it's declaration + ScopedFastFlag sff{FFlag::LuauMixedModeDefFinderTraversesTypeOf, true}; + const std::string source = R"( +--!strict +local m = {} +-- and here +function m:m1() end +type nt = typeof(m) + +return m +)"; + const std::string updated = R"( +--!strict +local m = {} +-- and here +function m:m1() end +type nt = typeof(m) +l +return m +)"; + + autocompleteFragmentInBothSolvers(source, updated, Position{6, 2}, [](auto& _) {}); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "generalization_crash_when_old_solver_freetypes_have_no_bounds_set") +{ + ScopedFastFlag sff{FFlag::LuauFreeTypesMustHaveBounds, true}; + const std::string source = R"( +local UserInputService = game:GetService("UserInputService"); + +local Camera = workspace.CurrentCamera; + +UserInputService.InputBegan:Connect(function(Input) + if (Input.KeyCode == Enum.KeyCode.One) then + local Up = Input.Foo + local Vector = -(Up:Unit) + end +end) +)"; + + const std::string dest = R"( +local UserInputService = game:GetService("UserInputService"); + +local Camera = workspace.CurrentCamera; + +UserInputService.InputBegan:Connect(function(Input) + if (Input.KeyCode == Enum.KeyCode.One) then + local Up = Input.Foo + local Vector = -(Up:Unit()) + end +end) +)"; + + autocompleteFragmentInBothSolvers(source, dest, Position{8, 36}, [](auto& _) {}); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "fragment_autocomplete_ensures_memory_isolation") +{ + ScopedFastFlag sff{FFlag::LuauCloneIncrementalModule, true}; + ToStringOptions opt; + opt.exhaustive = true; + opt.exhaustive = true; + opt.functionTypeArguments = true; + opt.maxTableLength = 0; + opt.maxTypeLength = 0; + + auto checkAndExamine = [&](const std::string& src, const std::string& idName, const std::string& idString) + { + check(src, getOptions()); + auto id = getType(idName, true); + LUAU_ASSERT(id); + CHECK_EQ(Luau::toString(*id, opt), idString); + }; + + auto getTypeFromModule = [](ModulePtr module, const std::string& name) -> std::optional + { + if (!module->hasModuleScope()) + return std::nullopt; + return lookupName(module->getModuleScope(), name); + }; + + auto fragmentACAndCheck = [&](const std::string& updated, const Position& pos, const std::string& idName) + { + FragmentAutocompleteResult result = autocompleteFragment(updated, pos, std::nullopt); + auto fragId = getTypeFromModule(result.incrementalModule, idName); + LUAU_ASSERT(fragId); + + auto srcId = getType(idName, true); + LUAU_ASSERT(srcId); + + CHECK((*fragId)->owningArena != (*srcId)->owningArena); + CHECK(&(result.incrementalModule->internalTypes) == (*fragId)->owningArena); + }; + + const std::string source = R"(local module = {} +f +return module)"; + + const std::string updated1 = R"(local module = {} +function module.a +return module)"; + + const std::string updated2 = R"(local module = {} +function module.ab +return module)"; + + { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + checkAndExamine(source, "module", "{ }"); + // [TODO] CLI-140762 we shouldn't mutate stale module in autocompleteFragment + // early return since the following checking will fail, which it shouldn't! + fragmentACAndCheck(updated1, Position{1, 17}, "module"); + fragmentACAndCheck(updated2, Position{1, 18}, "module"); + } + + { + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + checkAndExamine(source, "module", "{ }"); + // [TODO] CLI-140762 we shouldn't mutate stale module in autocompleteFragment + // early return since the following checking will fail, which it shouldn't! + fragmentACAndCheck(updated1, Position{1, 17}, "module"); + fragmentACAndCheck(updated2, Position{1, 18}, "module"); + } +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "fragment_autocomplete_shouldnt_crash_on_cross_module_mutation") +{ + ScopedFastFlag sff{FFlag::LuauCloneIncrementalModule, true}; + const std::string source = R"(local module = {} +function module. +return module +)"; + + const std::string updated = R"(local module = {} +function module.f +return module +)"; + + autocompleteFragmentInBothSolvers(source, updated, Position{1, 18}, [](FragmentAutocompleteResult& result) {}); +} + + TEST_SUITE_END(); diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index bfa69fe4..9491e28a 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" +#include "Luau/DenseHash.h" #include "Luau/Frontend.h" #include "Luau/RequireTracer.h" @@ -15,6 +16,9 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(DebugLuauFreezeArena); LUAU_FASTFLAG(DebugLuauMagicTypes); +LUAU_FASTFLAG(LuauReferenceAllocatorInNewSolver); +LUAU_FASTFLAG(LuauSelectivelyRetainDFGArena) +LUAU_FASTFLAG(LuauBetterReverseDependencyTracking); namespace { @@ -1522,4 +1526,255 @@ TEST_CASE_FIXTURE(FrontendFixture, "get_required_scripts_dirty") CHECK(requiredScripts[0] == "game/workspace/MyModuleScript"); } +TEST_CASE_FIXTURE(FrontendFixture, "check_module_references_allocator") +{ + ScopedFastFlag sff{FFlag::LuauReferenceAllocatorInNewSolver, true}; + fileResolver.source["game/workspace/MyScript"] = R"( + print("Hello World") + )"; + + frontend.check("game/workspace/MyScript"); + + ModulePtr module = frontend.moduleResolver.getModule("game/workspace/MyScript"); + SourceModule* source = frontend.getSourceModule("game/workspace/MyScript"); + CHECK(module); + CHECK(source); + + CHECK_EQ(module->allocator.get(), source->allocator.get()); + CHECK_EQ(module->names.get(), source->names.get()); +} + +TEST_CASE_FIXTURE(FrontendFixture, "dfg_data_cleared_on_retain_type_graphs_unset") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauSelectivelyRetainDFGArena, true} + }; + fileResolver.source["game/A"] = R"( +local a = 1 +local b = 2 +local c = 3 +return {x = a, y = b, z = c} +)"; + + frontend.options.retainFullTypeGraphs = true; + frontend.check("game/A"); + + auto mod = frontend.moduleResolver.getModule("game/A"); + CHECK(!mod->defArena.allocator.empty()); + CHECK(!mod->keyArena.allocator.empty()); + + // We should check that the dfg arena is empty once retainFullTypeGraphs is unset + frontend.options.retainFullTypeGraphs = false; + frontend.markDirty("game/A"); + frontend.check("game/A"); + + mod = frontend.moduleResolver.getModule("game/A"); + CHECK(mod->defArena.allocator.empty()); + CHECK(mod->keyArena.allocator.empty()); +} + +TEST_CASE_FIXTURE(FrontendFixture, "test_traverse_dependents") +{ + ScopedFastFlag dependencyTracking{FFlag::LuauBetterReverseDependencyTracking, true}; + + fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; + fileResolver.source["game/Gui/Modules/B"] = R"( + return require(game:GetService('Gui').Modules.A) + )"; + fileResolver.source["game/Gui/Modules/C"] = R"( + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {c_value = B.hello} + )"; + fileResolver.source["game/Gui/Modules/D"] = R"( + local Modules = game:GetService('Gui').Modules + local C = require(Modules.C) + return {d_value = C.c_value} + )"; + + frontend.check("game/Gui/Modules/D"); + + std::vector visited; + frontend.traverseDependents( + "game/Gui/Modules/B", + [&visited](SourceNode& node) + { + visited.push_back(node.name); + return true; + } + ); + + CHECK_EQ(std::vector{"game/Gui/Modules/B", "game/Gui/Modules/C", "game/Gui/Modules/D"}, visited); +} + +TEST_CASE_FIXTURE(FrontendFixture, "test_traverse_dependents_early_exit") +{ + ScopedFastFlag dependencyTracking{FFlag::LuauBetterReverseDependencyTracking, true}; + + fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; + fileResolver.source["game/Gui/Modules/B"] = R"( + return require(game:GetService('Gui').Modules.A) + )"; + fileResolver.source["game/Gui/Modules/C"] = R"( + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {c_value = B.hello} + )"; + + frontend.check("game/Gui/Modules/C"); + + std::vector visited; + frontend.traverseDependents( + "game/Gui/Modules/A", + [&visited](SourceNode& node) + { + visited.push_back(node.name); + return node.name != "game/Gui/Modules/B"; + } + ); + + CHECK_EQ(std::vector{"game/Gui/Modules/A", "game/Gui/Modules/B"}, visited); +} + +TEST_CASE_FIXTURE(FrontendFixture, "test_dependents_stored_on_node_as_graph_updates") +{ + ScopedFastFlag dependencyTracking{FFlag::LuauBetterReverseDependencyTracking, true}; + + auto updateSource = [&](const std::string& name, const std::string& source) + { + fileResolver.source[name] = source; + frontend.markDirty(name); + }; + + auto validateMatchesRequireLists = [&](const std::string& message) + { + DenseHashMap> dependents{{}}; + for (const auto& module : frontend.sourceNodes) + { + for (const auto& dep : module.second->requireSet) + dependents[dep].push_back(module.first); + } + + for (const auto& module : frontend.sourceNodes) + { + Set& dependentsForModule = module.second->dependents; + for (const auto& dep : dependents[module.first]) + CHECK_MESSAGE(1 == dependentsForModule.count(dep), "Mismatch in dependents for " << module.first << ": " << message); + } + }; + + auto validateSecondDependsOnFirst = [&](const std::string& from, const std::string& to, bool expected) + { + SourceNode& fromNode = *frontend.sourceNodes[from]; + CHECK_MESSAGE( + fromNode.dependents.count(to) == int(expected), + "Expected " << from << " to " << (expected ? std::string() : std::string("not ")) << "have a reverse dependency on " << to + ); + }; + + // C -> B -> A + { + updateSource("game/Gui/Modules/A", "return {hello=5, world=true}"); + updateSource("game/Gui/Modules/B", R"( + return require(game:GetService('Gui').Modules.A) + )"); + updateSource("game/Gui/Modules/C", R"( + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {c_value = B} + )"); + frontend.check("game/Gui/Modules/C"); + + validateMatchesRequireLists("Initial check"); + + validateSecondDependsOnFirst("game/Gui/Modules/A", "game/Gui/Modules/B", true); + validateSecondDependsOnFirst("game/Gui/Modules/B", "game/Gui/Modules/C", true); + validateSecondDependsOnFirst("game/Gui/Modules/C", "game/Gui/Modules/A", false); + } + + // C -> B, A + { + updateSource("game/Gui/Modules/B", R"( + return 1 + )"); + frontend.check("game/Gui/Modules/C"); + + validateMatchesRequireLists("Removing dependency B->A"); + validateSecondDependsOnFirst("game/Gui/Modules/A", "game/Gui/Modules/B", false); + } + + // C -> B -> A + { + updateSource("game/Gui/Modules/B", R"( + return require(game:GetService('Gui').Modules.A) + )"); + frontend.check("game/Gui/Modules/C"); + + validateMatchesRequireLists("Adding back B->A"); + validateSecondDependsOnFirst("game/Gui/Modules/A", "game/Gui/Modules/B", true); + } + + // C -> B -> A, D -> (C,B,A) + { + updateSource("game/Gui/Modules/D", R"( + local C = require(game:GetService('Gui').Modules.C) + local B = require(game:GetService('Gui').Modules.B) + local A = require(game:GetService('Gui').Modules.A) + return {d_value = C.c_value} + )"); + frontend.check("game/Gui/Modules/D"); + + validateMatchesRequireLists("Adding D->C, D->B, D->A"); + validateSecondDependsOnFirst("game/Gui/Modules/A", "game/Gui/Modules/D", true); + validateSecondDependsOnFirst("game/Gui/Modules/B", "game/Gui/Modules/D", true); + validateSecondDependsOnFirst("game/Gui/Modules/C", "game/Gui/Modules/D", true); + } + + // B -> A, C <-> D + { + updateSource("game/Gui/Modules/D", "return require(game:GetService('Gui').Modules.C)"); + updateSource("game/Gui/Modules/C", "return require(game:GetService('Gui').Modules.D)"); + frontend.check("game/Gui/Modules/D"); + + validateMatchesRequireLists("Adding cycle D->C, C->D"); + validateSecondDependsOnFirst("game/Gui/Modules/C", "game/Gui/Modules/D", true); + validateSecondDependsOnFirst("game/Gui/Modules/D", "game/Gui/Modules/C", true); + } + + // B -> A, C -> D, D -> error + { + updateSource("game/Gui/Modules/D", "return require(game:GetService('Gui').Modules.C.)"); + frontend.check("game/Gui/Modules/D"); + + validateMatchesRequireLists("Adding error dependency D->C."); + validateSecondDependsOnFirst("game/Gui/Modules/D", "game/Gui/Modules/C", true); + validateSecondDependsOnFirst("game/Gui/Modules/C", "game/Gui/Modules/D", false); + } +} + +TEST_CASE_FIXTURE(FrontendFixture, "test_invalid_dependency_tracking_per_module_resolver") +{ + ScopedFastFlag dependencyTracking{FFlag::LuauBetterReverseDependencyTracking, true}; + ScopedFastFlag newSolver{FFlag::LuauSolverV2, false}; + + fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; + fileResolver.source["game/Gui/Modules/B"] = "return require(game:GetService('Gui').Modules.A)"; + + FrontendOptions opts; + opts.forAutocomplete = false; + + frontend.check("game/Gui/Modules/B", opts); + CHECK(frontend.allModuleDependenciesValid("game/Gui/Modules/B", opts.forAutocomplete)); + CHECK(!frontend.allModuleDependenciesValid("game/Gui/Modules/B", !opts.forAutocomplete)); + + opts.forAutocomplete = true; + frontend.check("game/Gui/Modules/A", opts); + + CHECK(!frontend.allModuleDependenciesValid("game/Gui/Modules/B", opts.forAutocomplete)); + CHECK(frontend.allModuleDependenciesValid("game/Gui/Modules/B", !opts.forAutocomplete)); + CHECK(frontend.allModuleDependenciesValid("game/Gui/Modules/A", !opts.forAutocomplete)); + CHECK(frontend.allModuleDependenciesValid("game/Gui/Modules/A", opts.forAutocomplete)); +} + TEST_SUITE_END(); diff --git a/tests/Generalization.test.cpp b/tests/Generalization.test.cpp index 1388b900..b9e4eaf1 100644 --- a/tests/Generalization.test.cpp +++ b/tests/Generalization.test.cpp @@ -179,9 +179,9 @@ TEST_CASE_FIXTURE(GeneralizationFixture, "functions_containing_cyclic_tables_can TEST_CASE_FIXTURE(GeneralizationFixture, "union_type_traversal_doesnt_crash") { // t1 where t1 = ('h <: (t1 <: 'i)) | ('j <: (t1 <: 'i)) - TypeId i = arena.addType(FreeType{NotNull{globalScope.get()}}); - TypeId h = arena.addType(FreeType{NotNull{globalScope.get()}}); - TypeId j = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId i = arena.freshType(NotNull{&builtinTypes}, globalScope.get()); + TypeId h = arena.freshType(NotNull{&builtinTypes}, globalScope.get()); + TypeId j = arena.freshType(NotNull{&builtinTypes}, globalScope.get()); TypeId unionType = arena.addType(UnionType{{h, j}}); getMutable(h)->upperBound = i; getMutable(h)->lowerBound = builtinTypes.neverType; @@ -196,9 +196,9 @@ TEST_CASE_FIXTURE(GeneralizationFixture, "union_type_traversal_doesnt_crash") TEST_CASE_FIXTURE(GeneralizationFixture, "intersection_type_traversal_doesnt_crash") { // t1 where t1 = ('h <: (t1 <: 'i)) & ('j <: (t1 <: 'i)) - TypeId i = arena.addType(FreeType{NotNull{globalScope.get()}}); - TypeId h = arena.addType(FreeType{NotNull{globalScope.get()}}); - TypeId j = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId i = arena.freshType(NotNull{&builtinTypes}, globalScope.get()); + TypeId h = arena.freshType(NotNull{&builtinTypes}, globalScope.get()); + TypeId j = arena.freshType(NotNull{&builtinTypes}, globalScope.get()); TypeId intersectionType = arena.addType(IntersectionType{{h, j}}); getMutable(h)->upperBound = i; diff --git a/tests/Instantiation2.test.cpp b/tests/Instantiation2.test.cpp index fff98e60..fcd136fb 100644 --- a/tests/Instantiation2.test.cpp +++ b/tests/Instantiation2.test.cpp @@ -4,6 +4,7 @@ #include "Fixture.h" #include "ClassFixture.h" +#include "Luau/Type.h" #include "ScopedFlags.h" #include "doctest.h" @@ -29,7 +30,7 @@ TEST_CASE_FIXTURE(Fixture, "weird_cyclic_instantiation") DenseHashMap genericSubstitutions{nullptr}; DenseHashMap genericPackSubstitutions{nullptr}; - TypeId freeTy = arena.freshType(&scope); + TypeId freeTy = arena.freshType(builtinTypes, &scope); FreeType* ft = getMutable(freeTy); REQUIRE(ft); ft->lowerBound = idTy; diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index ba4e7f04..629c3696 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -13,8 +13,6 @@ #include LUAU_FASTFLAG(DebugLuauAbortingChecks) -LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim) -LUAU_FASTFLAG(LuauCodeGenArithOpt) using namespace Luau::CodeGen; @@ -1725,8 +1723,6 @@ bb_fallback_1: TEST_CASE_FIXTURE(IrBuilderFixture, "NumericSimplifications") { - ScopedFastFlag luauCodeGenArithOpt{FFlag::LuauCodeGenArithOpt, true}; - IrOp block = build.block(IrBlockKind::Internal); build.beginBlock(block); @@ -4472,8 +4468,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "VectorOverNumber") { - ScopedFastFlag luauCodeGenVectorDeadStoreElim{FFlag::LuauCodeGenVectorDeadStoreElim, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -4497,8 +4491,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "VectorOverVector") { - ScopedFastFlag luauCodeGenVectorDeadStoreElim{FFlag::LuauCodeGenVectorDeadStoreElim, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -4522,8 +4514,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "NumberOverVector") { - ScopedFastFlag luauCodeGenVectorDeadStoreElim{FFlag::LuauCodeGenVectorDeadStoreElim, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -4547,8 +4537,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "NumberOverNil") { - ScopedFastFlag luauCodeGenVectorDeadStoreElim{FFlag::LuauCodeGenVectorDeadStoreElim, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -4571,8 +4559,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "VectorOverNil") { - ScopedFastFlag luauCodeGenVectorDeadStoreElim{FFlag::LuauCodeGenVectorDeadStoreElim, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -4595,8 +4581,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "NumberOverCombinedVector") { - ScopedFastFlag luauCodeGenVectorDeadStoreElim{FFlag::LuauCodeGenVectorDeadStoreElim, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -4622,8 +4606,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "VectorOverCombinedVector") { - ScopedFastFlag luauCodeGenVectorDeadStoreElim{FFlag::LuauCodeGenVectorDeadStoreElim, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); @@ -4649,8 +4631,6 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "VectorOverCombinedNumber") { - ScopedFastFlag luauCodeGenVectorDeadStoreElim{FFlag::LuauCodeGenVectorDeadStoreElim, true}; - IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index 27376777..2a5c23fd 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "lua.h" #include "lualib.h" +#include "luacode.h" #include "Luau/BytecodeBuilder.h" #include "Luau/CodeGen.h" @@ -15,10 +16,65 @@ #include #include -static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1) +static void luauLibraryConstantLookup(const char* library, const char* member, Luau::CompileConstant* constant) { - Luau::CodeGen::AssemblyOptions options; + // While 'vector' library constants are a Luau built-in, their constant value depends on the embedder LUA_VECTOR_SIZE value + if (strcmp(library, "vector") == 0) + { + if (strcmp(member, "zero") == 0) + return Luau::setCompileConstantVector(constant, 0.0f, 0.0f, 0.0f, 0.0f); + if (strcmp(member, "one") == 0) + return Luau::setCompileConstantVector(constant, 1.0f, 1.0f, 1.0f, 0.0f); + } + + if (strcmp(library, "Vector3") == 0) + { + if (strcmp(member, "xAxis") == 0) + return Luau::setCompileConstantVector(constant, 1.0f, 0.0f, 0.0f, 0.0f); + + if (strcmp(member, "yAxis") == 0) + return Luau::setCompileConstantVector(constant, 0.0f, 1.0f, 0.0f, 0.0f); + } +} + +static void luauLibraryConstantLookupC(const char* library, const char* member, lua_CompileConstant* constant) +{ + if (strcmp(library, "test") == 0) + { + if (strcmp(member, "some_nil") == 0) + return luau_set_compile_constant_nil(constant); + + if (strcmp(member, "some_boolean") == 0) + return luau_set_compile_constant_boolean(constant, 1); + + if (strcmp(member, "some_number") == 0) + return luau_set_compile_constant_number(constant, 4.75); + + if (strcmp(member, "some_vector") == 0) + return luau_set_compile_constant_vector(constant, 1.0f, 2.0f, 4.0f, 8.0f); + + if (strcmp(member, "some_string") == 0) + return luau_set_compile_constant_string(constant, "test", 4); + } +} + +static int luauLibraryTypeLookup(const char* library, const char* member) +{ + if (strcmp(library, "Vector3") == 0) + { + if (strcmp(member, "xAxis") == 0) + return LuauBytecodeType::LBC_TYPE_VECTOR; + + if (strcmp(member, "yAxis") == 0) + return LuauBytecodeType::LBC_TYPE_VECTOR; + } + + return LuauBytecodeType::LBC_TYPE_ANY; +} + +static void setupAssemblyOptions(Luau::CodeGen::AssemblyOptions& options, bool includeIrTypes) +{ options.compilationOptions.hooks.vectorAccessBytecodeType = vectorAccessBytecodeType; options.compilationOptions.hooks.vectorNamecallBytecodeType = vectorNamecallBytecodeType; options.compilationOptions.hooks.vectorAccess = vectorAccess; @@ -44,35 +100,10 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes = options.includeUseInfo = Luau::CodeGen::IncludeUseInfo::No; options.includeCfgInfo = Luau::CodeGen::IncludeCfgInfo::No; options.includeRegFlowInfo = Luau::CodeGen::IncludeRegFlowInfo::No; +} - Luau::Allocator allocator; - Luau::AstNameTable names(allocator); - Luau::ParseResult result = Luau::Parser::parse(source, strlen(source), names, allocator); - - if (!result.errors.empty()) - throw Luau::ParseErrors(result.errors); - - Luau::CompileOptions copts = {}; - - copts.optimizationLevel = 2; - copts.debugLevel = debugLevel; - copts.typeInfoLevel = 1; - copts.vectorCtor = "vector"; - copts.vectorType = "vector"; - - static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr}; - copts.userdataTypes = kUserdataCompileTypes; - - Luau::BytecodeBuilder bcb; - Luau::compileOrThrow(bcb, result, names, copts); - - std::string bytecode = bcb.getBytecode(); - std::unique_ptr globalState(luaL_newstate(), lua_close); - lua_State* L = globalState.get(); - - // Runtime mapping is specifically created to NOT match the compilation mapping - options.compilationOptions.userdataTypes = kUserdataRunTypes; - +static void initializeCodegen(lua_State* L) +{ if (Luau::CodeGen::isSupported()) { // Type remapper requires the codegen runtime @@ -101,9 +132,95 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes = } ); } +} + +static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1) +{ + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + Luau::ParseResult result = Luau::Parser::parse(source, strlen(source), names, allocator); + + if (!result.errors.empty()) + throw Luau::ParseErrors(result.errors); + + Luau::CompileOptions copts = {}; + + copts.optimizationLevel = 2; + copts.debugLevel = debugLevel; + copts.typeInfoLevel = 1; + copts.vectorCtor = "vector"; + copts.vectorType = "vector"; + + static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr}; + copts.userdataTypes = kUserdataCompileTypes; + + static const char* kLibrariesWithConstants[] = {"vector", "Vector3", nullptr}; + copts.librariesWithKnownMembers = kLibrariesWithConstants; + + copts.libraryMemberTypeCb = luauLibraryTypeLookup; + copts.libraryMemberConstantCb = luauLibraryConstantLookup; + + Luau::BytecodeBuilder bcb; + Luau::compileOrThrow(bcb, result, names, copts); + + std::string bytecode = bcb.getBytecode(); + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + initializeCodegen(L); if (luau_load(L, "name", bytecode.data(), bytecode.size(), 0) == 0) + { + Luau::CodeGen::AssemblyOptions options; + setupAssemblyOptions(options, includeIrTypes); + + // Runtime mapping is specifically created to NOT match the compilation mapping + options.compilationOptions.userdataTypes = kUserdataRunTypes; + return Luau::CodeGen::getAssembly(L, -1, options, nullptr); + } + + FAIL("Failed to load bytecode"); + return ""; +} + +static std::string getCodegenAssemblyUsingCApi(const char* source, bool includeIrTypes = false, int debugLevel = 1) +{ + lua_CompileOptions copts = {}; + + copts.optimizationLevel = 2; + copts.debugLevel = debugLevel; + copts.typeInfoLevel = 1; + + static const char* kLibrariesWithConstants[] = {"test", nullptr}; + copts.librariesWithKnownMembers = kLibrariesWithConstants; + + copts.libraryMemberTypeCb = luauLibraryTypeLookup; + copts.libraryMemberConstantCb = luauLibraryConstantLookupC; + + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source, strlen(source), &copts, &bytecodeSize); + REQUIRE(bytecode); + + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + initializeCodegen(L); + + if (luau_load(L, "name", bytecode, bytecodeSize, 0) == 0) + { + free(bytecode); + + Luau::CodeGen::AssemblyOptions options; + setupAssemblyOptions(options, includeIrTypes); + + // Runtime mapping is specifically created to NOT match the compilation mapping + options.compilationOptions.userdataTypes = kUserdataRunTypes; + + return Luau::CodeGen::getAssembly(L, -1, options, nullptr); + } + + free(bytecode); FAIL("Failed to load bytecode"); return ""; @@ -401,9 +518,9 @@ bb_bytecode_0: JUMP bb_2 bb_2: CHECK_SAFE_ENV exit(3) - JUMP_EQ_TAG K1, tnil, bb_fallback_4, bb_3 + JUMP_EQ_TAG K1 (nil), tnil, bb_fallback_4, bb_3 bb_3: - %9 = LOAD_TVALUE K1 + %9 = LOAD_TVALUE K1 (nil) STORE_TVALUE R1, %9 JUMP bb_5 bb_5: @@ -456,7 +573,7 @@ bb_0: bb_2: JUMP bb_bytecode_1 bb_bytecode_1: - %4 = LOAD_TVALUE K0, 0i, tvector + %4 = LOAD_TVALUE K0 (1, 2, 3), 0i, tvector %11 = LOAD_TVALUE R0 %12 = ADD_VEC %4, %11 %13 = TAG_VECTOR %12 @@ -483,7 +600,7 @@ bb_0: bb_2: JUMP bb_bytecode_1 bb_bytecode_1: - FALLBACK_NAMECALL 0u, R1, R0, K0 + FALLBACK_NAMECALL 0u, R1, R0, K0 ('Abs') INTERRUPT 2u SET_SAVEDPC 3u CALL R1, 1i, -1i @@ -509,8 +626,8 @@ bb_0: bb_2: JUMP bb_bytecode_1 bb_bytecode_1: - FALLBACK_GETTABLEKS 0u, R3, R0, K0 - FALLBACK_GETTABLEKS 2u, R4, R0, K1 + FALLBACK_GETTABLEKS 0u, R3, R0, K0 ('XX') + FALLBACK_GETTABLEKS 2u, R4, R0, K1 ('YY') CHECK_TAG R3, tnumber, bb_fallback_3 CHECK_TAG R4, tnumber, bb_fallback_3 %14 = LOAD_DOUBLE R3 @@ -520,7 +637,7 @@ bb_bytecode_1: JUMP bb_4 bb_4: CHECK_TAG R0, tvector, exit(5) - FALLBACK_GETTABLEKS 5u, R3, R0, K2 + FALLBACK_GETTABLEKS 5u, R3, R0, K2 ('ZZ') CHECK_TAG R2, tnumber, bb_fallback_5 CHECK_TAG R3, tnumber, bb_fallback_5 %30 = LOAD_DOUBLE R2 @@ -738,8 +855,8 @@ bb_2: JUMP bb_bytecode_1 bb_bytecode_1: %8 = LOAD_POINTER R0 - %9 = GET_SLOT_NODE_ADDR %8, 0u, K1 - CHECK_SLOT_MATCH %9, K1, bb_fallback_3 + %9 = GET_SLOT_NODE_ADDR %8, 0u, K1 ('n') + CHECK_SLOT_MATCH %9, K1 ('n'), bb_fallback_3 %11 = LOAD_TVALUE %9, 0i STORE_TVALUE R3, %11 JUMP bb_4 @@ -766,8 +883,8 @@ bb_4: STORE_VECTOR R3, %30, %33, %36 CHECK_TAG R0, ttable, exit(6) %41 = LOAD_POINTER R0 - %42 = GET_SLOT_NODE_ADDR %41, 6u, K3 - CHECK_SLOT_MATCH %42, K3, bb_fallback_5 + %42 = GET_SLOT_NODE_ADDR %41, 6u, K3 ('b') + CHECK_SLOT_MATCH %42, K3 ('b'), bb_fallback_5 %44 = LOAD_TVALUE %42, 0i STORE_TVALUE R5, %44 JUMP bb_6 @@ -810,8 +927,8 @@ bb_0: bb_2: JUMP bb_bytecode_1 bb_bytecode_1: - FALLBACK_GETTABLEKS 0u, R2, R0, K0 - FALLBACK_GETTABLEKS 2u, R3, R0, K1 + FALLBACK_GETTABLEKS 0u, R2, R0, K0 ('x') + FALLBACK_GETTABLEKS 2u, R3, R0, K1 ('y') CHECK_TAG R2, tnumber, bb_fallback_3 CHECK_TAG R3, tnumber, bb_fallback_3 %14 = LOAD_DOUBLE R2 @@ -845,9 +962,9 @@ bb_2: bb_bytecode_1: STORE_DOUBLE R1, 3 STORE_TAG R1, tnumber - FALLBACK_SETTABLEKS 1u, R1, R0, K0 + FALLBACK_SETTABLEKS 1u, R1, R0, K0 ('x') STORE_DOUBLE R1, 4 - FALLBACK_SETTABLEKS 4u, R1, R0, K1 + FALLBACK_SETTABLEKS 4u, R1, R0, K1 ('y') INTERRUPT 6u RETURN R0, 0i )" @@ -870,11 +987,11 @@ bb_0: bb_2: JUMP bb_bytecode_1 bb_bytecode_1: - FALLBACK_NAMECALL 0u, R2, R0, K0 + FALLBACK_NAMECALL 0u, R2, R0, K0 ('GetX') INTERRUPT 2u SET_SAVEDPC 3u CALL R2, 1i, 1i - FALLBACK_NAMECALL 3u, R3, R0, K1 + FALLBACK_NAMECALL 3u, R3, R0, K1 ('GetY') INTERRUPT 5u SET_SAVEDPC 6u CALL R3, 1i, 1i @@ -1248,8 +1365,8 @@ bb_bytecode_1: bb_4: CHECK_TAG R2, ttable, exit(1) %23 = LOAD_POINTER R2 - %24 = GET_SLOT_NODE_ADDR %23, 1u, K0 - CHECK_SLOT_MATCH %24, K0, bb_fallback_5 + %24 = GET_SLOT_NODE_ADDR %23, 1u, K0 ('pos') + CHECK_SLOT_MATCH %24, K0 ('pos'), bb_fallback_5 %26 = LOAD_TVALUE %24, 0i STORE_TVALUE R4, %26 JUMP bb_6 @@ -1357,13 +1474,13 @@ bb_bytecode_1: bb_4: CHECK_TAG R3, ttable, bb_fallback_5 %23 = LOAD_POINTER R3 - %24 = GET_SLOT_NODE_ADDR %23, 1u, K0 - CHECK_SLOT_MATCH %24, K0, bb_fallback_5 + %24 = GET_SLOT_NODE_ADDR %23, 1u, K0 ('normal') + CHECK_SLOT_MATCH %24, K0 ('normal'), bb_fallback_5 %26 = LOAD_TVALUE %24, 0i STORE_TVALUE R2, %26 JUMP bb_6 bb_6: - %31 = LOAD_TVALUE K1, 0i, tvector + %31 = LOAD_TVALUE K1 (0.707000017, 0, 0.707000017), 0i, tvector STORE_TVALUE R4, %31 CHECK_TAG R2, tvector, exit(4) %37 = LOAD_FLOAT R2, 0i @@ -1484,9 +1601,9 @@ bb_bytecode_1: STORE_DOUBLE R1, 0 STORE_TAG R1, tnumber CHECK_SAFE_ENV exit(1) - JUMP_EQ_TAG K1, tnil, bb_fallback_6, bb_5 + JUMP_EQ_TAG K1 (nil), tnil, bb_fallback_6, bb_5 bb_5: - %9 = LOAD_TVALUE K1 + %9 = LOAD_TVALUE K1 (nil) STORE_TVALUE R2, %9 JUMP bb_7 bb_7: @@ -1508,8 +1625,8 @@ bb_9: bb_bytecode_2: CHECK_TAG R6, ttable, exit(6) %35 = LOAD_POINTER R6 - %36 = GET_SLOT_NODE_ADDR %35, 6u, K2 - CHECK_SLOT_MATCH %36, K2, bb_fallback_10 + %36 = GET_SLOT_NODE_ADDR %35, 6u, K2 ('pos') + CHECK_SLOT_MATCH %36, K2 ('pos'), bb_fallback_10 %38 = LOAD_TVALUE %36, 0i STORE_TVALUE R8, %38 JUMP bb_11 @@ -1710,8 +1827,8 @@ bb_0: bb_2: JUMP bb_bytecode_1 bb_bytecode_1: - FALLBACK_GETTABLEKS 0u, R2, R0, K0 - FALLBACK_GETTABLEKS 2u, R3, R0, K1 + FALLBACK_GETTABLEKS 0u, R2, R0, K0 ('Row1') + FALLBACK_GETTABLEKS 2u, R3, R0, K1 ('Row2') CHECK_TAG R2, tvector, exit(4) CHECK_TAG R3, tvector, exit(4) %14 = LOAD_TVALUE R2 @@ -1994,4 +2111,103 @@ bb_bytecode_1: ); } +TEST_CASE("LibraryFieldTypesAndConstants") +{ + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function foo(a: vector) + return Vector3.xAxis * a + Vector3.yAxis +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo($arg0) line 2 +; R0: vector [argument] +; R2: vector from 3 to 4 +; R3: vector from 1 to 2 +; R3: vector from 3 to 4 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %4 = LOAD_TVALUE K0 (1, 0, 0), 0i, tvector + %11 = LOAD_TVALUE R0 + %12 = MUL_VEC %4, %11 + %15 = LOAD_TVALUE K1 (0, 1, 0), 0i, tvector + %23 = ADD_VEC %12, %15 + %24 = TAG_VECTOR %23 + STORE_TVALUE R1, %24 + INTERRUPT 4u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("LibraryFieldTypesAndConstants") +{ + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function foo(a: vector) + local x = vector.zero + x += a + return x +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo($arg0) line 2 +; R0: vector [argument] +; R1: vector from 0 to 3 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %4 = LOAD_TVALUE K0 (0, 0, 0), 0i, tvector + %11 = LOAD_TVALUE R0 + %12 = ADD_VEC %4, %11 + %13 = TAG_VECTOR %12 + STORE_TVALUE R1, %13 + INTERRUPT 2u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("LibraryFieldTypesAndConstantsCApi") +{ + CHECK_EQ( + "\n" + getCodegenAssemblyUsingCApi( + R"( +local function foo() + return test.some_nil, test.some_boolean, test.some_number, test.some_vector, test.some_string +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo() line 2 +bb_bytecode_0: + STORE_TAG R0, tnil + STORE_INT R1, 1i + STORE_TAG R1, tboolean + STORE_DOUBLE R2, 4.75 + STORE_TAG R2, tnumber + %5 = LOAD_TVALUE K1 (1, 2, 4), 0i, tvector + STORE_TVALUE R3, %5 + %7 = LOAD_TVALUE K2 ('test'), 0i, tstring + STORE_TVALUE R4, %7 + INTERRUPT 5u + RETURN R0, 5i +)" + ); +} + TEST_SUITE_END(); diff --git a/tests/Lexer.test.cpp b/tests/Lexer.test.cpp index e0716e4c..6133305d 100644 --- a/tests/Lexer.test.cpp +++ b/tests/Lexer.test.cpp @@ -8,6 +8,8 @@ using namespace Luau; +LUAU_FASTFLAG(LexerFixInterpStringStart) + TEST_SUITE_BEGIN("LexerTests"); TEST_CASE("broken_string_works") @@ -153,6 +155,8 @@ TEST_CASE("string_interpolation_basic") Lexeme interpEnd = lexer.next(); CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd); + // The InterpStringEnd should start with }, not `. + CHECK_EQ(interpEnd.location.begin.column, FFlag::LexerFixInterpStringStart ? 11 : 12); } TEST_CASE("string_interpolation_full") @@ -173,6 +177,7 @@ TEST_CASE("string_interpolation_full") Lexeme interpMid = lexer.next(); CHECK_EQ(interpMid.type, Lexeme::InterpStringMid); CHECK_EQ(interpMid.toString(), "} {"); + CHECK_EQ(interpMid.location.begin.column, FFlag::LexerFixInterpStringStart ? 11 : 12); Lexeme quote2 = lexer.next(); CHECK_EQ(quote2.type, Lexeme::QuotedString); @@ -181,6 +186,7 @@ TEST_CASE("string_interpolation_full") Lexeme interpEnd = lexer.next(); CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd); CHECK_EQ(interpEnd.toString(), "} end`"); + CHECK_EQ(interpEnd.location.begin.column, FFlag::LexerFixInterpStringStart ? 19 : 20); } TEST_CASE("string_interpolation_double_brace") @@ -242,4 +248,185 @@ TEST_CASE("string_interpolation_with_unicode_escape") CHECK_EQ(lexer.next().type, Lexeme::Eof); } +TEST_CASE("single_quoted_string") +{ + const std::string testInput = "'test'"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + CHECK_EQ(lexeme.type, Lexeme::QuotedString); + CHECK_EQ(lexeme.getQuoteStyle(), Lexeme::QuoteStyle::Single); +} + +TEST_CASE("double_quoted_string") +{ + const std::string testInput = R"("test")"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + CHECK_EQ(lexeme.type, Lexeme::QuotedString); + CHECK_EQ(lexeme.getQuoteStyle(), Lexeme::QuoteStyle::Double); +} + +TEST_CASE("lexer_determines_string_block_depth_0") +{ + const std::string testInput = "[[ test ]]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 0); +} + +TEST_CASE("lexer_determines_string_block_depth_0_multiline_1") +{ + const std::string testInput = R"([[ test + ]])"; + + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 0); +} + +TEST_CASE("lexer_determines_string_block_depth_0_multiline_2") +{ + const std::string testInput = R"([[ + test + ]])"; + + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 0); +} + +TEST_CASE("lexer_determines_string_block_depth_0_multiline_3") +{ + const std::string testInput = R"([[ + test ]])"; + + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 0); +} + +TEST_CASE("lexer_determines_string_block_depth_1") +{ + const std::string testInput = "[=[[%s]]=]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 1); +} + +TEST_CASE("lexer_determines_string_block_depth_2") +{ + const std::string testInput = "[==[ test ]==]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 2); +} + +TEST_CASE("lexer_determines_string_block_depth_2_multiline_1") +{ + const std::string testInput = R"([==[ test + ]==])"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 2); +} + +TEST_CASE("lexer_determines_string_block_depth_2_multiline_2") +{ + const std::string testInput = R"([==[ + test + ]==])"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 2); +} + +TEST_CASE("lexer_determines_string_block_depth_2_multiline_3") +{ + const std::string testInput = R"([==[ + + test ]==])"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 2); +} + + +TEST_CASE("lexer_determines_comment_block_depth_0") +{ + const std::string testInput = "--[[ test ]]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::BlockComment); + CHECK_EQ(lexeme.getBlockDepth(), 0); +} + +TEST_CASE("lexer_determines_string_block_depth_1") +{ + const std::string testInput = "--[=[ μέλλον ]=]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::BlockComment); + CHECK_EQ(lexeme.getBlockDepth(), 1); +} + +TEST_CASE("lexer_determines_string_block_depth_2") +{ + const std::string testInput = "--[==[ test ]==]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::BlockComment); + CHECK_EQ(lexeme.getBlockDepth(), 2); +} + TEST_SUITE_END(); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 025fa7fd..08b0bb0d 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -14,7 +14,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(DebugLuauFreezeArena); LUAU_FASTINT(LuauTypeCloneIterationLimit); - +LUAU_FASTFLAG(LuauOldSolverCreatesChildScopePointers) TEST_SUITE_BEGIN("ModuleTests"); TEST_CASE_FIXTURE(Fixture, "is_within_comment") @@ -540,4 +540,28 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "clone_a_bound_typepack_to_a_persistent_typep REQUIRE(res == follow(boundTo)); } +TEST_CASE_FIXTURE(Fixture, "old_solver_correctly_populates_child_scopes") +{ + ScopedFastFlag sff{FFlag::LuauOldSolverCreatesChildScopePointers, true}; + check(R"( +--!strict +if true then +end + +if false then +end + +if true then +else +end + +local x = {} +for i,v in x do +end +)"); + + auto& module = frontend.moduleResolver.getModule("MainModule"); + CHECK(module->getModuleScope()->children.size() == 7); +} + TEST_SUITE_END(); diff --git a/tests/NonStrictTypeChecker.test.cpp b/tests/NonStrictTypeChecker.test.cpp index 8d13ebde..f613e750 100644 --- a/tests/NonStrictTypeChecker.test.cpp +++ b/tests/NonStrictTypeChecker.test.cpp @@ -15,6 +15,7 @@ #include LUAU_FASTFLAG(LuauCountSelfCallsNonstrict) +LUAU_FASTFLAG(LuauVector2Constructor) using namespace Luau; @@ -581,7 +582,8 @@ buffer.readi8(b, 0) TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "nonstrict_method_calls") { - ScopedFastFlag sff{FFlag::LuauCountSelfCallsNonstrict, true}; + ScopedFastFlag luauCountSelfCallsNonstrict{FFlag::LuauCountSelfCallsNonstrict, true}; + ScopedFastFlag luauVector2Constructor{FFlag::LuauVector2Constructor, true}; Luau::unfreeze(frontend.globals.globalTypes); Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 24186c0a..0e026edf 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -12,7 +12,7 @@ LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTINT(LuauTypeInferRecursionLimit) -LUAU_FASTFLAG(LuauNormalizationTracksCyclicPairsThroughInhabitance) +LUAU_FASTFLAG(LuauFixNormalizedIntersectionOfNegatedClass) using namespace Luau; namespace @@ -27,7 +27,9 @@ struct IsSubtypeFixture : Fixture if (!module->hasModuleScope()) FAIL("isSubtype: module scope data is not available"); - return ::Luau::isSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, ice); + SimplifierPtr simplifier = newSimplifier(NotNull{&module->internalTypes}, builtinTypes); + + return ::Luau::isSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, NotNull{simplifier.get()}, ice); } }; } // namespace @@ -849,17 +851,17 @@ TEST_CASE_FIXTURE(NormalizeFixture, "crazy_metatable") TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_classes") { + ScopedFastFlag _{FFlag::LuauFixNormalizedIntersectionOfNegatedClass, true}; createSomeClasses(&frontend); CHECK("(Parent & ~Child) | Unrelated" == toString(normal("(Parent & Not) | Unrelated"))); CHECK("((class & ~Child) | boolean | buffer | function | number | string | table | thread)?" == toString(normal("Not"))); - CHECK("Child" == toString(normal("Not & Child"))); + CHECK("never" == toString(normal("Not & Child"))); CHECK("((class & ~Parent) | Child | boolean | buffer | function | number | string | table | thread)?" == toString(normal("Not | Child"))); CHECK("(boolean | buffer | function | number | string | table | thread)?" == toString(normal("Not"))); CHECK( "(Parent | Unrelated | boolean | buffer | function | number | string | table | thread)?" == toString(normal("Not & Not & Not>")) ); - CHECK("Child" == toString(normal("(Child | Unrelated) & Not"))); } @@ -960,7 +962,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "final_types_are_cached") TEST_CASE_FIXTURE(NormalizeFixture, "non_final_types_can_be_normalized_but_are_not_cached") { - TypeId a = arena.freshType(&globalScope); + TypeId a = arena.freshType(builtinTypes, &globalScope); std::shared_ptr na1 = normalizer.normalize(a); std::shared_ptr na2 = normalizer.normalize(a); @@ -1032,7 +1034,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "normalizer_should_be_able_to_detect_cyclic_t if (!FFlag::LuauSolverV2) return; ScopedFastInt sfi{FInt::LuauTypeInferRecursionLimit, 0}; - ScopedFastFlag sff{FFlag::LuauNormalizationTracksCyclicPairsThroughInhabitance, true}; + CheckResult result = check(R"( --!strict diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index b35466cb..2395efb6 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -16,11 +16,13 @@ LUAU_FASTINT(LuauRecursionLimit) LUAU_FASTINT(LuauTypeLengthLimit) LUAU_FASTINT(LuauParseErrorLimit) LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) -LUAU_FASTFLAG(LuauUserDefinedTypeFunParseExport) LUAU_FASTFLAG(LuauAllowComplexTypesInGenericParams) LUAU_FASTFLAG(LuauErrorRecoveryForTableTypes) LUAU_FASTFLAG(LuauErrorRecoveryForClassNames) +LUAU_FASTFLAG(LuauFixFunctionNameStartPosition) +LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon) +LUAU_FASTFLAG(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) +LUAU_FASTFLAG(LuauAstTypeGroup) namespace { @@ -369,7 +371,10 @@ TEST_CASE_FIXTURE(Fixture, "return_type_is_an_intersection_type_if_led_with_one_ AstTypeIntersection* returnAnnotation = annotation->returnTypes.types.data[0]->as(); REQUIRE(returnAnnotation != nullptr); - CHECK(returnAnnotation->types.data[0]->as()); + if (FFlag::LuauAstTypeGroup) + CHECK(returnAnnotation->types.data[0]->as()); + else + CHECK(returnAnnotation->types.data[0]->as()); CHECK(returnAnnotation->types.data[1]->as()); } @@ -448,38 +453,62 @@ TEST_CASE_FIXTURE(Fixture, "type_alias_span_is_correct") TEST_CASE_FIXTURE(Fixture, "parse_error_messages") { - matchParseError(R"( + matchParseError( + R"( local a: (number, number) -> (string - )", "Expected ')' (to close '(' at line 2), got "); + )", + "Expected ')' (to close '(' at line 2), got " + ); - matchParseError(R"( + matchParseError( + R"( local a: (number, number) -> ( string - )", "Expected ')' (to close '(' at line 2), got "); + )", + "Expected ')' (to close '(' at line 2), got " + ); - matchParseError(R"( + matchParseError( + R"( local a: (number, number) - )", "Expected '->' when parsing function type, got "); + )", + "Expected '->' when parsing function type, got " + ); - matchParseError(R"( + matchParseError( + R"( local a: (number, number - )", "Expected ')' (to close '(' at line 2), got "); + )", + "Expected ')' (to close '(' at line 2), got " + ); - matchParseError(R"( + matchParseError( + R"( local a: {foo: string, - )", "Expected identifier when parsing table field, got "); + )", + "Expected identifier when parsing table field, got " + ); - matchParseError(R"( + matchParseError( + R"( local a: {foo: string - )", "Expected '}' (to close '{' at line 2), got "); + )", + "Expected '}' (to close '{' at line 2), got " + ); - matchParseError(R"( + matchParseError( + R"( local a: { [string]: number, [number]: string } - )", "Cannot have more than one table indexer"); + )", + "Cannot have more than one table indexer" + ); - matchParseError(R"( + matchParseError( + R"( type T = foo - )", "Expected '(' when parsing function parameters, got 'foo'"); + )", + "Expected '(' when parsing function parameters, got 'foo'" + ); } TEST_CASE_FIXTURE(Fixture, "mixed_intersection_and_union_not_allowed") @@ -614,9 +643,12 @@ TEST_CASE_FIXTURE(Fixture, "vertical_space") TEST_CASE_FIXTURE(Fixture, "parse_error_type_name") { - matchParseError(R"( + matchParseError( + R"( local a: Foo.= - )", "Expected identifier when parsing field name, got '='"); + )", + "Expected identifier when parsing field name, got '='" + ); } TEST_CASE_FIXTURE(Fixture, "parse_numbers_decimal") @@ -678,9 +710,12 @@ TEST_CASE_FIXTURE(Fixture, "break_return_not_last_error") TEST_CASE_FIXTURE(Fixture, "error_on_unicode") { - matchParseError(R"( + matchParseError( + R"( local ☃ = 10 - )", "Expected identifier when parsing variable name, got Unicode character U+2603"); + )", + "Expected identifier when parsing variable name, got Unicode character U+2603" + ); } TEST_CASE_FIXTURE(Fixture, "allow_unicode_in_string") @@ -691,9 +726,12 @@ TEST_CASE_FIXTURE(Fixture, "allow_unicode_in_string") TEST_CASE_FIXTURE(Fixture, "error_on_confusable") { - matchParseError(R"( + matchParseError( + R"( local pi = 3․13 - )", "Expected identifier when parsing expression, got Unicode character U+2024 (did you mean '.'?)"); + )", + "Expected identifier when parsing expression, got Unicode character U+2024 (did you mean '.'?)" + ); } TEST_CASE_FIXTURE(Fixture, "error_on_non_utf8_sequence") @@ -2342,9 +2380,6 @@ TEST_CASE_FIXTURE(Fixture, "invalid_type_forms") TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions") { - ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag sff2{FFlag::LuauUserDefinedTypeFunParseExport, true}; - AstStat* stat = parse(R"( type function foo() return types.number @@ -2363,8 +2398,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions") TEST_CASE_FIXTURE(Fixture, "parse_nested_type_function") { - ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - AstStat* stat = parse(R"( local v1 = 1 type function foo() @@ -2386,12 +2419,95 @@ TEST_CASE_FIXTURE(Fixture, "parse_nested_type_function") TEST_CASE_FIXTURE(Fixture, "invalid_user_defined_type_functions") { - ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - matchParseError("local foo = 1; type function bar() print(foo) end", "Type function cannot reference outer local 'foo'"); matchParseError("type function foo() local v1 = 1; type function bar() print(v1) end end", "Type function cannot reference outer local 'v1'"); } +TEST_CASE_FIXTURE(Fixture, "leading_union_intersection_with_single_type_preserves_the_union_intersection_ast_node") +{ + ScopedFastFlag _{FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType, true}; + AstStatBlock* block = parse(R"( + type Foo = | string + type Bar = & number + )"); + + REQUIRE_EQ(2, block->body.size); + + const auto alias1 = block->body.data[0]->as(); + REQUIRE(alias1); + + const auto unionType = alias1->type->as(); + REQUIRE(unionType); + CHECK_EQ(1, unionType->types.size); + + const auto alias2 = block->body.data[1]->as(); + REQUIRE(alias2); + + const auto intersectionType = alias2->type->as(); + REQUIRE(intersectionType); + CHECK_EQ(1, intersectionType->types.size); +} + +TEST_CASE_FIXTURE(Fixture, "parse_simple_ast_type_group") +{ + ScopedFastFlag _{FFlag::LuauAstTypeGroup, true}; + + AstStatBlock* stat = parse(R"( + type Foo = (string) + )"); + REQUIRE(stat); + REQUIRE_EQ(1, stat->body.size); + + auto alias1 = stat->body.data[0]->as(); + REQUIRE(alias1); + + auto group1 = alias1->type->as(); + REQUIRE(group1); + CHECK(group1->type->is()); +} + +TEST_CASE_FIXTURE(Fixture, "parse_nested_ast_type_group") +{ + ScopedFastFlag _{FFlag::LuauAstTypeGroup, true}; + + AstStatBlock* stat = parse(R"( + type Foo = ((string)) + )"); + REQUIRE(stat); + REQUIRE_EQ(1, stat->body.size); + + auto alias1 = stat->body.data[0]->as(); + REQUIRE(alias1); + + auto group1 = alias1->type->as(); + REQUIRE(group1); + + auto group2 = group1->type->as(); + REQUIRE(group2); + CHECK(group2->type->is()); +} + +TEST_CASE_FIXTURE(Fixture, "parse_return_type_ast_type_group") +{ + ScopedFastFlag _{FFlag::LuauAstTypeGroup, true}; + + AstStatBlock* stat = parse(R"( + type Foo = () -> (string) + )"); + REQUIRE(stat); + REQUIRE_EQ(1, stat->body.size); + + auto alias1 = stat->body.data[0]->as(); + REQUIRE(alias1); + + auto funcType = alias1->type->as(); + REQUIRE(funcType); + + REQUIRE_EQ(1, funcType->returnTypes.types.size); + REQUIRE(!funcType->returnTypes.tailType); + CHECK(funcType->returnTypes.types.data[0]->is()); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("ParseErrorRecovery"); @@ -3662,7 +3778,14 @@ TEST_CASE_FIXTURE(Fixture, "grouped_function_type") auto unionTy = paramTy.type->as(); LUAU_ASSERT(unionTy); CHECK_EQ(unionTy->types.size, 2); - CHECK(unionTy->types.data[0]->is()); // () -> () + if (FFlag::LuauAstTypeGroup) + { + auto groupTy = unionTy->types.data[0]->as(); // (() -> ()) + REQUIRE(groupTy); + CHECK(groupTy->type->is()); // () -> () + } + else + CHECK(unionTy->types.data[0]->is()); // () -> () CHECK(unionTy->types.data[1]->is()); // nil } @@ -3708,13 +3831,65 @@ TEST_CASE_FIXTURE(Fixture, "recover_from_bad_table_type") ScopedFastFlag _{FFlag::LuauErrorRecoveryForTableTypes, true}; ParseOptions opts; opts.allowDeclarationSyntax = true; - const auto result = tryParse(R"( + const auto result = tryParse( + R"( declare class Widget state: {string: function(string, Widget)} end - )", opts); + )", + opts + ); CHECK_EQ(result.errors.size(), 2); } +TEST_CASE_FIXTURE(Fixture, "function_name_has_correct_start_location") +{ + ScopedFastFlag _{FFlag::LuauFixFunctionNameStartPosition, true}; + AstStatBlock* block = parse(R"( + function simple() + end + + function T:complex() + end + )"); + + REQUIRE_EQ(2, block->body.size); + + const auto function1 = block->body.data[0]->as(); + LUAU_ASSERT(function1); + CHECK_EQ(Position{1, 17}, function1->name->location.begin); + + const auto function2 = block->body.data[1]->as(); + LUAU_ASSERT(function2); + CHECK_EQ(Position{4, 17}, function2->name->location.begin); +} + +TEST_CASE_FIXTURE(Fixture, "stat_end_includes_semicolon_position") +{ + ScopedFastFlag _{FFlag::LuauExtendStatEndPosWithSemicolon, true}; + AstStatBlock* block = parse(R"( + local x = 1 + local y = 2; + local z = 3 ; + )"); + + REQUIRE_EQ(3, block->body.size); + + const auto stat1 = block->body.data[0]; + LUAU_ASSERT(stat1); + CHECK_FALSE(stat1->hasSemicolon); + CHECK_EQ(Position{1, 19}, stat1->location.end); + + const auto stat2 = block->body.data[1]; + LUAU_ASSERT(stat2); + CHECK(stat2->hasSemicolon); + CHECK_EQ(Position{2, 20}, stat2->location.end); + + const auto stat3 = block->body.data[2]; + LUAU_ASSERT(stat3); + CHECK(stat3->hasSemicolon); + CHECK_EQ(Position{3, 22}, stat3->location.end); +} + TEST_SUITE_END(); diff --git a/tests/Repl.test.cpp b/tests/Repl.test.cpp index 71a46878..85d53390 100644 --- a/tests/Repl.test.cpp +++ b/tests/Repl.test.cpp @@ -2,7 +2,7 @@ #include "lua.h" #include "lualib.h" -#include "Repl.h" +#include "Luau/Repl.h" #include "ScopedFlags.h" #include "doctest.h" @@ -13,8 +13,6 @@ #include #include -LUAU_FASTFLAG(LuauMathMap) - struct Completion { std::string completion; @@ -175,7 +173,7 @@ TEST_CASE_FIXTURE(ReplFixture, "CompleteGlobalVariables") CHECK(checkCompletion(completions, prefix, "myvariable1")); CHECK(checkCompletion(completions, prefix, "myvariable2")); } - if (FFlag::LuauMathMap) + { // Try completing some builtin functions CompletionSet completions = getCompletionSet("math.m"); diff --git a/tests/RequireByString.test.cpp b/tests/RequireByString.test.cpp index bc1161b0..59a1af3b 100644 --- a/tests/RequireByString.test.cpp +++ b/tests/RequireByString.test.cpp @@ -1,17 +1,25 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Common.h" +#include "Luau/Config.h" + #include "ScopedFlags.h" #include "lua.h" #include "lualib.h" -#include "Repl.h" -#include "FileUtils.h" +#include "Luau/Repl.h" +#include "Luau/FileUtils.h" #include "doctest.h" #include +#include #include #include +#include +#include +#include +#include +#include #if __APPLE__ #include @@ -112,7 +120,7 @@ public: for (int i = 0; i < 20; ++i) { bool engineTestDir = isDirectory(luauDirAbs + "/Client/Luau/tests"); - bool luauTestDir = isDirectory(luauDirAbs + "/luau/tests/require"); + bool luauTestDir = isDirectory(luauDirAbs + "/tests/require"); if (engineTestDir || luauTestDir) { @@ -121,12 +129,6 @@ public: luauDirRel += "/Client/Luau"; luauDirAbs += "/Client/Luau"; } - else - { - luauDirRel += "/luau"; - luauDirAbs += "/luau"; - } - if (type == PathType::Relative) return luauDirRel; @@ -217,21 +219,43 @@ TEST_CASE("PathResolution") std::string prefix = "/"; #endif - CHECK(resolvePath(prefix + "Users/modules/module.luau", "") == prefix + "Users/modules/module.luau"); - CHECK(resolvePath(prefix + "Users/modules/module.luau", "a/string/that/should/be/ignored") == prefix + "Users/modules/module.luau"); - CHECK(resolvePath(prefix + "Users/modules/module.luau", "./a/string/that/should/be/ignored") == prefix + "Users/modules/module.luau"); - CHECK(resolvePath(prefix + "Users/modules/module.luau", "/a/string/that/should/be/ignored") == prefix + "Users/modules/module.luau"); - CHECK(resolvePath(prefix + "Users/modules/module.luau", "/Users/modules") == prefix + "Users/modules/module.luau"); + // tuple format: {inputPath, inputBaseFilePath, expected} + std::vector> tests = { + // 1. Basic path resolution + // a. Relative to a relative path that begins with './' + {"./dep", "./src/modules/module.luau", "./src/modules/dep"}, + {"../dep", "./src/modules/module.luau", "./src/dep"}, + {"../../dep", "./src/modules/module.luau", "./dep"}, + {"../../", "./src/modules/module.luau", "./"}, - CHECK(resolvePath("../module", "") == "../module"); - CHECK(resolvePath("../../module", "") == "../../module"); - CHECK(resolvePath("../module/..", "") == "../"); - CHECK(resolvePath("../module/../..", "") == "../../"); + // b. Relative to a relative path that begins with '../' + {"./dep", "../src/modules/module.luau", "../src/modules/dep"}, + {"../dep", "../src/modules/module.luau", "../src/dep"}, + {"../../dep", "../src/modules/module.luau", "../dep"}, + {"../../", "../src/modules/module.luau", "../"}, - CHECK(resolvePath("../dependency", prefix + "Users/modules/module.luau") == prefix + "Users/dependency"); - CHECK(resolvePath("../dependency/", prefix + "Users/modules/module.luau") == prefix + "Users/dependency"); - CHECK(resolvePath("../../../../../Users/dependency", prefix + "Users/modules/module.luau") == prefix + "Users/dependency"); - CHECK(resolvePath("../..", prefix + "Users/modules/module.luau") == prefix); + // c. Relative to an absolute path + {"./dep", prefix + "src/modules/module.luau", prefix + "src/modules/dep"}, + {"../dep", prefix + "src/modules/module.luau", prefix + "src/dep"}, + {"../../dep", prefix + "src/modules/module.luau", prefix + "dep"}, + {"../../", prefix + "src/modules/module.luau", prefix}, + + + // 2. Check behavior for extraneous ".." + // a. Relative paths retain '..' and append if needed + {"../../../", "./src/modules/module.luau", "../"}, + {"../../../", "../src/modules/module.luau", "../../"}, + + // b. Absolute paths ignore '..' if already at root + {"../../../", prefix + "src/modules/module.luau", prefix}, + }; + + for (const auto& [inputPath, inputBaseFilePath, expected] : tests) + { + std::optional resolved = resolvePath(inputPath, inputBaseFilePath); + CHECK(resolved); + CHECK_EQ(resolved, expected); + } } TEST_CASE("PathNormalization") @@ -242,34 +266,57 @@ TEST_CASE("PathNormalization") std::string prefix = "/"; #endif - // Relative path - std::optional result = normalizePath("../../modules/module"); - CHECK(result); - std::string normalized = *result; - std::vector variants = { - "./.././.././modules/./module/", "placeholder/../../../modules/module", "../placeholder/placeholder2/../../../modules/module" - }; - for (const std::string& variant : variants) - { - result = normalizePath(variant); - CHECK(result); - CHECK(normalized == *result); - } + // pair format: {input, expected} + std::vector> tests = { + // 1. Basic formatting checks + {"", "./"}, + {".", "./"}, + {"..", "../"}, + {"a/relative/path", "./a/relative/path"}, - // Absolute path - result = normalizePath(prefix + "Users/modules/module"); - CHECK(result); - normalized = *result; - variants = { - "Users/Users/Users/.././.././modules/./module/", - "placeholder/../Users/..//Users/modules/module", - "Users/../placeholder/placeholder2/../../Users/modules/module" + + // 2. Paths containing extraneous '.' and '/' symbols + {"./remove/extraneous/symbols/", "./remove/extraneous/symbols"}, + {"./remove/extraneous//symbols", "./remove/extraneous/symbols"}, + {"./remove/extraneous/symbols/.", "./remove/extraneous/symbols"}, + {"./remove/extraneous/./symbols", "./remove/extraneous/symbols"}, + + {"../remove/extraneous/symbols/", "../remove/extraneous/symbols"}, + {"../remove/extraneous//symbols", "../remove/extraneous/symbols"}, + {"../remove/extraneous/symbols/.", "../remove/extraneous/symbols"}, + {"../remove/extraneous/./symbols", "../remove/extraneous/symbols"}, + + {prefix + "remove/extraneous/symbols/", prefix + "remove/extraneous/symbols"}, + {prefix + "remove/extraneous//symbols", prefix + "remove/extraneous/symbols"}, + {prefix + "remove/extraneous/symbols/.", prefix + "remove/extraneous/symbols"}, + {prefix + "remove/extraneous/./symbols", prefix + "remove/extraneous/symbols"}, + + + // 3. Paths containing '..' + // a. '..' removes the erasable component before it + {"./remove/me/..", "./remove"}, + {"./remove/me/../", "./remove"}, + + {"../remove/me/..", "../remove"}, + {"../remove/me/../", "../remove"}, + + {prefix + "remove/me/..", prefix + "remove"}, + {prefix + "remove/me/../", prefix + "remove"}, + + // b. '..' stays if path is relative and component is non-erasable + {"./..", "../"}, + {"./../", "../"}, + + {"../..", "../../"}, + {"../../", "../../"}, + + // c. '..' disappears if path is absolute and component is non-erasable + {prefix + "..", prefix}, }; - for (const std::string& variant : variants) + + for (const auto& [input, expected] : tests) { - result = normalizePath(prefix + variant); - CHECK(result); - CHECK(normalized == *result); + CHECK_EQ(normalizePath(input), expected); } } @@ -491,4 +538,63 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "AliasHasIllegalFormat") assertOutputContainsAll({"false", " is not a valid alias"}); } +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireFromLuauBinary") +{ + char executable[] = "luau"; + std::vector paths = { + getLuauDirectory(PathType::Relative) + "/tests/require/without_config/dependency.luau", + getLuauDirectory(PathType::Absolute) + "/tests/require/without_config/dependency.luau" + }; + + for (const std::string& path : paths) + { + std::vector pathStr(path.size() + 1); + strncpy(pathStr.data(), path.c_str(), path.size()); + pathStr[path.size()] = '\0'; + + char* argv[2] = {executable, pathStr.data()}; + CHECK_EQ(replMain(2, argv), 0); + } +} + +TEST_CASE("ParseAliases") +{ + std::string configJson = R"({ + "aliases": { + "MyAlias": "/my/alias/path", + } +})"; + + Luau::Config config; + + Luau::ConfigOptions::AliasOptions aliasOptions; + aliasOptions.configLocation = "/default/location"; + aliasOptions.overwriteAliases = true; + + Luau::ConfigOptions options{false, aliasOptions}; + + std::optional error = Luau::parseConfig(configJson, config, options); + REQUIRE(!error); + + auto checkContents = [](Luau::Config& config) -> void + { + CHECK(config.aliases.size() == 1); + REQUIRE(config.aliases.contains("myalias")); + + Luau::Config::AliasInfo& aliasInfo = config.aliases["myalias"]; + CHECK(aliasInfo.value == "/my/alias/path"); + CHECK(aliasInfo.originalCase == "MyAlias"); + }; + + checkContents(config); + + // Ensure that copied Configs retain the same information + Luau::Config copyConstructedConfig = config; + checkContents(copyConstructedConfig); + + Luau::Config copyAssignedConfig; + copyAssignedConfig = config; + checkContents(copyAssignedConfig); +} + TEST_SUITE_END(); diff --git a/tests/Subtyping.test.cpp b/tests/Subtyping.test.cpp index 27b2f6e7..76efc835 100644 --- a/tests/Subtyping.test.cpp +++ b/tests/Subtyping.test.cpp @@ -65,6 +65,7 @@ struct SubtypeFixture : Fixture TypeArena arena; InternalErrorReporter iceReporter; UnifierSharedState sharedState{&ice}; + SimplifierPtr simplifier = newSimplifier(NotNull{&arena}, builtinTypes); Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; TypeCheckLimits limits; TypeFunctionRuntime typeFunctionRuntime{NotNull{&iceReporter}, NotNull{&limits}}; @@ -79,7 +80,9 @@ struct SubtypeFixture : Fixture Subtyping mkSubtyping() { - return Subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}}; + return Subtyping{ + builtinTypes, NotNull{&arena}, NotNull{simplifier.get()}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter} + }; } TypePackId pack(std::initializer_list tys) diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index c9eb3450..11027e6f 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -324,8 +324,7 @@ TEST_CASE_FIXTURE(Fixture, "free") { DOES_NOT_PASS_NEW_SOLVER_GUARD(); - Type type{TypeVariant{FreeType{TypeLevel{0, 0}}}}; - + Type type{TypeVariant{FreeType{TypeLevel{0, 0}, builtinTypes->neverType, builtinTypes->unknownType}}}; ToDotOptions opts; opts.showPointers = false; CHECK_EQ( diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index db018eda..536a4081 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -211,8 +211,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "exhaustive_toString_of_cyclic_table") CHECK( "t2 where " "t1 = { __index: t1, __mul: ((t2, number) -> t2) & ((t2, t2) -> t2), new: () -> t2 } ; " - "t2 = { @metatable t1, { x: number, y: number, z: number } }" == - a + "t2 = { @metatable t1, { x: number, y: number, z: number } }" == a ); } else diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 188d9682..3505f96d 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -12,7 +12,10 @@ using namespace Luau; -LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) +LUAU_FASTFLAG(LuauStoreCSTData) +LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon) +LUAU_FASTFLAG(LuauAstTypeGroup); +LUAU_FASTFLAG(LexerFixInterpStringStart) TEST_SUITE_BEGIN("TranspilerTests"); @@ -44,6 +47,37 @@ TEST_CASE("string_literals_containing_utf8") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("if_stmt_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( if This then Once() end)"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( if This then Once() end)"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( if This then Once() end)"; + CHECK_EQ(three, transpile(three).code); + + const std::string four = R"( if This then Once() end)"; + CHECK_EQ(four, transpile(four).code); + + const std::string five = R"( if This then Once() else Other() end)"; + CHECK_EQ(five, transpile(five).code); + + const std::string six = R"( if This then Once() else Other() end)"; + CHECK_EQ(six, transpile(six).code); + + const std::string seven = R"( if This then Once() elseif true then Other() end)"; + CHECK_EQ(seven, transpile(seven).code); + + const std::string eight = R"( if This then Once() elseif true then Other() end)"; + CHECK_EQ(eight, transpile(eight).code); + + const std::string nine = R"( if This then Once() elseif true then Other() end)"; + CHECK_EQ(nine, transpile(nine).code); +} + TEST_CASE("elseif_chains_indent_sensibly") { const std::string code = R"( @@ -64,17 +98,31 @@ TEST_CASE("elseif_chains_indent_sensibly") TEST_CASE("strips_type_annotations") { const std::string code = R"( local s: string= 'hello there' )"; - const std::string expected = R"( local s = 'hello there' )"; - - CHECK_EQ(expected, transpile(code).code); + if (FFlag::LuauStoreCSTData) + { + const std::string expected = R"( local s = 'hello there' )"; + CHECK_EQ(expected, transpile(code).code); + } + else + { + const std::string expected = R"( local s = 'hello there' )"; + CHECK_EQ(expected, transpile(code).code); + } } TEST_CASE("strips_type_assertion_expressions") { const std::string code = R"( local s= some_function() :: any+ something_else() :: number )"; - const std::string expected = R"( local s= some_function() + something_else() )"; - - CHECK_EQ(expected, transpile(code).code); + if (FFlag::LuauStoreCSTData) + { + const std::string expected = R"( local s= some_function() + something_else() )"; + CHECK_EQ(expected, transpile(code).code); + } + else + { + const std::string expected = R"( local s= some_function() + something_else() )"; + CHECK_EQ(expected, transpile(code).code); + } } TEST_CASE("function_taking_ellipsis") @@ -99,24 +147,89 @@ TEST_CASE("for_loop") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("for_loop_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( for index = 1, 10 do call(index) end )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( for index = 1 , 10 do call(index) end )"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( for index = 1, 10 , 3 do call(index) end )"; + CHECK_EQ(three, transpile(three).code); + + const std::string four = R"( for index = 1, 10 do call(index) end )"; + CHECK_EQ(four, transpile(four).code); + + const std::string five = R"( for index = 1, 10 do call(index) end )"; + CHECK_EQ(five, transpile(five).code); +} + TEST_CASE("for_in_loop") { const std::string code = R"( for k, v in ipairs(x)do end )"; CHECK_EQ(code, transpile(code).code); } +TEST_CASE("for_in_loop_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( for k, v in ipairs(x) do end )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( for k, v in ipairs(x) do end )"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( for k , v in ipairs(x) do end )"; + CHECK_EQ(three, transpile(three).code); + + const std::string four = R"( for k, v in next , t do end )"; + CHECK_EQ(four, transpile(four).code); + + const std::string five = R"( for k, v in ipairs(x) do end )"; + CHECK_EQ(five, transpile(five).code); +} + TEST_CASE("while_loop") { const std::string code = R"( while f(x)do print() end )"; CHECK_EQ(code, transpile(code).code); } +TEST_CASE("while_loop_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( while f(x) do print() end )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( while f(x) do print() end )"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( while f(x) do print() end )"; + CHECK_EQ(three, transpile(three).code); + + const std::string four = R"( while f(x) do print() end )"; + CHECK_EQ(four, transpile(four).code); +} + TEST_CASE("repeat_until_loop") { const std::string code = R"( repeat print() until f(x) )"; CHECK_EQ(code, transpile(code).code); } +TEST_CASE("repeat_until_loop_condition_on_new_line") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + repeat + print() + until + f(x) )"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("lambda") { const std::string one = R"( local p=function(o, m, g) return 77 end )"; @@ -126,6 +239,43 @@ TEST_CASE("lambda") CHECK_EQ(two, transpile(two).code); } +TEST_CASE("local_assignment") +{ + const std::string one = R"( local x = 1 )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( local x, y, z = 1, 2, 3 )"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( local x )"; + CHECK_EQ(three, transpile(three).code); +} + +TEST_CASE("local_assignment_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( local x = 1 )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( local x = 1 )"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( local x = 1 )"; + CHECK_EQ(three, transpile(three).code); + + const std::string four = R"( local x , y = 1, 2 )"; + CHECK_EQ(four, transpile(four).code); + + const std::string five = R"( local x, y = 1, 2 )"; + CHECK_EQ(five, transpile(five).code); + + const std::string six = R"( local x, y = 1 , 2 )"; + CHECK_EQ(six, transpile(six).code); + + const std::string seven = R"( local x, y = 1, 2 )"; + CHECK_EQ(seven, transpile(seven).code); +} + TEST_CASE("local_function") { const std::string one = R"( local function p(o, m, g) return 77 end )"; @@ -135,6 +285,16 @@ TEST_CASE("local_function") CHECK_EQ(two, transpile(two).code); } +TEST_CASE("local_function_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( local function p(o, m, ...) end )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( local function p(o, m, ...) end )"; + CHECK_EQ(two, transpile(two).code); +} + TEST_CASE("function") { const std::string one = R"( function p(o, m, g) return 77 end )"; @@ -144,6 +304,19 @@ TEST_CASE("function") CHECK_EQ(two, transpile(two).code); } +TEST_CASE("returns_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( return 1 )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( return 1 , 2 )"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( return 1, 2 )"; + CHECK_EQ(three, transpile(three).code); +} + TEST_CASE("table_literals") { const std::string code = R"( local t={1, 2, 3, foo='bar', baz=99,[5.5]='five point five', 'end'} )"; @@ -186,6 +359,59 @@ TEST_CASE("table_literal_closing_brace_at_correct_position") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("table_literal_with_semicolon_separators") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + local t = { x = 1; y = 2 } + )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("table_literal_with_trailing_separators") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + local t = { x = 1, y = 2, } + )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("table_literal_with_spaces_around_separator") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + local t = { x = 1 , y = 2 } + )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("table_literal_with_spaces_around_equals") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + local t = { x = 1 } + )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("table_literal_multiline_with_indexers") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + local t = { + ["my first value"] = "x"; + ["my second value"] = "y"; + } + )"; + + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("method_calls") { const std::string code = R"( foo.bar.baz:quux() )"; @@ -203,8 +429,15 @@ TEST_CASE("spaces_between_keywords_even_if_it_pushes_the_line_estimation_off") // Luau::Parser doesn't exactly preserve the string representation of numbers in Lua, so we can find ourselves // falling out of sync with the original code. We need to push keywords out so that there's at least one space between them. const std::string code = R"( if math.abs(raySlope) < .01 then return 0 end )"; - const std::string expected = R"( if math.abs(raySlope) < 0.01 then return 0 end)"; - CHECK_EQ(expected, transpile(code).code); + if (FFlag::LuauStoreCSTData) + { + CHECK_EQ(code, transpile(code).code); + } + else + { + const std::string expected = R"( if math.abs(raySlope) < 0.01 then return 0 end)"; + CHECK_EQ(expected, transpile(code).code); + } } TEST_CASE("numbers") @@ -216,8 +449,70 @@ TEST_CASE("numbers") TEST_CASE("infinity") { const std::string code = R"( local a = 1e500 local b = 1e400 )"; - const std::string expected = R"( local a = 1e500 local b = 1e500 )"; - CHECK_EQ(expected, transpile(code).code); + if (FFlag::LuauStoreCSTData) + { + CHECK_EQ(code, transpile(code).code); + } + else + { + const std::string expected = R"( local a = 1e500 local b = 1e500 )"; + CHECK_EQ(expected, transpile(code).code); + } +} + +TEST_CASE("numbers_with_separators") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = 123_456_789 )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("hexadecimal_numbers") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = 0xFFFF )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("binary_numbers") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = 0b0101 )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("single_quoted_strings") +{ + const std::string code = R"( local a = 'hello world' )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("double_quoted_strings") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = "hello world" )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("simple_interp_string") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = `hello world` )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("raw_strings") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = [[ hello world ]] )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("raw_strings_with_blocks") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = [==[ hello world ]==] )"; + CHECK_EQ(code, transpile(code).code); } TEST_CASE("escaped_strings") @@ -232,6 +527,33 @@ TEST_CASE("escaped_strings_2") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("escaped_strings_newline") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + print("foo \ + bar") + )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("escaped_strings_raw") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local x = [=[\v<((do|load)file|require)\s*\(?['"]\zs[^'"]+\ze['"]]=] )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("position_correctly_updated_when_writing_multiline_string") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + call([[ + testing + ]]) )"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("need_a_space_between_number_literals_and_dots") { const std::string code = R"( return point and math.ceil(point* 100000* 100)/ 100000 .. '%'or '' )"; @@ -244,6 +566,86 @@ TEST_CASE("binary_keywords") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("function_call_parentheses_no_args") +{ + const std::string code = R"( call() )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_parentheses_one_arg") +{ + const std::string code = R"( call(arg) )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_parentheses_multiple_args") +{ + const std::string code = R"( call(arg1, arg3, arg3) )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_parentheses_multiple_args_no_space") +{ + const std::string code = R"( call(arg1,arg3,arg3) )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_parentheses_multiple_args_space_before_commas") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call(arg1 ,arg3 ,arg3) )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_spaces_before_parentheses") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call () )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_spaces_within_parentheses") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call( ) )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_string_double_quotes") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call "string" )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_string_single_quotes") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call 'string' )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_string_no_space") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call'string' )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_table_literal") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call { x = 1 } )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_table_literal_no_space") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call{x=1} )"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("do_blocks") { const std::string code = R"( @@ -260,6 +662,19 @@ TEST_CASE("do_blocks") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("nested_do_block") +{ + const std::string code = R"( + do + do + local x = 1 + end + end + )"; + + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("emit_a_do_block_in_cases_of_potentially_ambiguous_syntax") { const std::string code = R"( @@ -269,6 +684,106 @@ TEST_CASE("emit_a_do_block_in_cases_of_potentially_ambiguous_syntax") CHECK_EQ(code, transpile(code).code); } +TEST_CASE_FIXTURE(Fixture, "parentheses_multiline") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( +local test = ( + x +) + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "stmt_semicolon") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauExtendStatEndPosWithSemicolon, true}, + }; + std::string code = R"( local test = 1; )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local test = 1 ; )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "do_block_ending_with_semicolon") +{ + std::string code = R"( + do + return; + end; + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "if_stmt_semicolon") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauExtendStatEndPosWithSemicolon, true}, + }; + std::string code = R"( + if init then + x = string.sub(x, utf8.offset(x, init)); + end; + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "if_stmt_semicolon_2") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauExtendStatEndPosWithSemicolon, true}, + }; + std::string code = R"( + if (t < 1) then return c/2*t*t + b end; + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "for_loop_stmt_semicolon") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauExtendStatEndPosWithSemicolon, true}, + }; + std::string code = R"( + for i,v in ... do + end; + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "while_do_semicolon") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauExtendStatEndPosWithSemicolon, true}, + }; + std::string code = R"( + while true do + end; + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "function_definition_semicolon") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauExtendStatEndPosWithSemicolon, true}, + }; + std::string code = R"( + function foo() + end; + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_CASE("roundtrip_types") { const std::string code = R"( @@ -339,9 +854,16 @@ TEST_CASE("a_table_key_can_be_the_empty_string") TEST_CASE("always_emit_a_space_after_local_keyword") { std::string code = "do local aZZZZ = Workspace.P1.Shape local bZZZZ = Enum.PartType.Cylinder end"; - std::string expected = "do local aZZZZ = Workspace.P1 .Shape local bZZZZ= Enum.PartType.Cylinder end"; - CHECK_EQ(expected, transpile(code).code); + if (FFlag::LuauStoreCSTData) + { + CHECK_EQ(code, transpile(code).code); + } + else + { + std::string expected = "do local aZZZZ = Workspace.P1 .Shape local bZZZZ= Enum.PartType.Cylinder end"; + CHECK_EQ(expected, transpile(code).code); + } } TEST_CASE_FIXTURE(Fixture, "types_should_not_be_considered_cyclic_if_they_are_not_recursive") @@ -429,6 +951,80 @@ TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else") CHECK_EQ(code, transpile(code).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else_multiple_conditions") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else_multiple_conditions_2") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( + local x = if yes + then nil + else if no + then if this + then that + else other + else nil + )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE_FIXTURE(Fixture, "if_then_else_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = "local a = if 1 then 2 else 3"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 else 3"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 else 3"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 else 3"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 else 3"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE_FIXTURE(Fixture, "if_then_else_spaces_between_else_if") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( + return + if a then "was a" else + if b then "was b" else + if c then "was c" else + "was nothing!" + )"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_type_reference_import") { fileResolver.source["game/A"] = R"( @@ -444,6 +1040,34 @@ local a: Import.Type CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_type_reference_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( local _: Foo.Type )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Foo .Type )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Foo. Type )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Type <> )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Type< > )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Type< number> )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Type )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Type )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_type_packs") { std::string code = R"( @@ -473,7 +1097,10 @@ TEST_CASE_FIXTURE(Fixture, "transpile_union_type_nested_3") { std::string code = "local a: nil | (string & number)"; - CHECK_EQ("local a: ( string & number)?", transpile(code, {}, true).code); + if (FFlag::LuauAstTypeGroup) + CHECK_EQ("local a: (string & number)?", transpile(code, {}, true).code); + else + CHECK_EQ("local a: ( string & number)?", transpile(code, {}, true).code); } TEST_CASE_FIXTURE(Fixture, "transpile_intersection_type_nested") @@ -497,6 +1124,26 @@ TEST_CASE_FIXTURE(Fixture, "transpile_varargs") CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "index_name_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string one = "local _ = a.name"; + CHECK_EQ(one, transpile(one, {}, true).code); + + std::string two = "local _ = a .name"; + CHECK_EQ(two, transpile(two, {}, true).code); + + std::string three = "local _ = a. name"; + CHECK_EQ(three, transpile(three, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "index_name_ends_with_digit") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = "sparkles.Color = Color3.new()"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_index_expr") { std::string code = "local a = {1, 2, 3} local b = a[2]"; @@ -504,6 +1151,22 @@ TEST_CASE_FIXTURE(Fixture, "transpile_index_expr") CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "index_expr_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string one = "local _ = a[2]"; + CHECK_EQ(one, transpile(one, {}, true).code); + + std::string two = "local _ = a [2]"; + CHECK_EQ(two, transpile(two, {}, true).code); + + std::string three = "local _ = a[ 2]"; + CHECK_EQ(three, transpile(three, {}, true).code); + + std::string four = "local _ = a[2 ]"; + CHECK_EQ(four, transpile(four, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_unary") { std::string code = R"( @@ -518,6 +1181,32 @@ local d = #e CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "unary_spaces_around_tokens") +{ + std::string code = R"( +local _ = -1 +local _ = - 1 +local _ = not true +local _ = not true +local _ = #e +local _ = # e + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "binary_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( +local _ = 1+1 +local _ = 1 +1 +local _ = 1+ 1 + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_break_continue") { std::string code = R"( @@ -548,6 +1237,16 @@ a ..= ' - result' CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "compound_assignment_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string one = R"( a += 1 )"; + CHECK_EQ(one, transpile(one, {}, true).code); + + std::string two = R"( a += 1 )"; + CHECK_EQ(two, transpile(two, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_assign_multiple") { std::string code = "a, b, c = 1, 2, 3"; @@ -555,6 +1254,31 @@ TEST_CASE_FIXTURE(Fixture, "transpile_assign_multiple") CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_assign_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string one = "a = 1"; + CHECK_EQ(one, transpile(one).code); + + std::string two = "a = 1"; + CHECK_EQ(two, transpile(two).code); + + std::string three = "a = 1"; + CHECK_EQ(three, transpile(three).code); + + std::string four = "a , b = 1, 2"; + CHECK_EQ(four, transpile(four).code); + + std::string five = "a, b = 1, 2"; + CHECK_EQ(five, transpile(five).code); + + std::string six = "a, b = 1 , 2"; + CHECK_EQ(six, transpile(six).code); + + std::string seven = "a, b = 1, 2"; + CHECK_EQ(seven, transpile(seven).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_generic_function") { std::string code = R"( @@ -684,13 +1408,58 @@ TEST_CASE_FIXTURE(Fixture, "transpile_for_in_multiple_types") TEST_CASE_FIXTURE(Fixture, "transpile_string_interp") { + ScopedFastFlag fflags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LexerFixInterpStringStart, true}, + }; std::string code = R"( local _ = `hello {name}` )"; CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_string_interp_multiline") +{ + ScopedFastFlag fflags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LexerFixInterpStringStart, true}, + }; + std::string code = R"( local _ = `hello { + name + }!` )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_string_interp_on_new_line") +{ + ScopedFastFlag fflags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LexerFixInterpStringStart, true}, + }; + std::string code = R"( + error( + `a {b} c` + ) + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_string_interp_multiline_escape") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( local _ = `hello \ + world!` )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_string_literal_escape") { + ScopedFastFlag fflags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LexerFixInterpStringStart, true}, + }; std::string code = R"( local _ = ` bracket = \{, backtick = \` = {'ok'} ` )"; CHECK_EQ(code, transpile(code, {}, true).code); @@ -698,11 +1467,191 @@ TEST_CASE_FIXTURE(Fixture, "transpile_string_literal_escape") TEST_CASE_FIXTURE(Fixture, "transpile_type_functions") { - ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - std::string code = R"( type function foo(arg1, arg2) if arg1 == arg2 then return arg1 end return arg2 end )"; CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_typeof_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( type X = typeof(x) )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type X = typeof(x) )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type X = typeof (x) )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type X = typeof( x) )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type X = typeof(x ) )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_single_quoted_string_types") +{ + const std::string code = R"( type a = 'hello world' )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_double_quoted_string_types") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( type a = "hello world" )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_raw_string_types") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( type a = [[ hello world ]] )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type a = [==[ hello world ]==] )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_escaped_string_types") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( type a = "\\b\\t\\n\\\\" )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_type_table_semicolon_separators") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + type Foo = { + bar: number; + baz: number; + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_type_table_access_modifiers") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( + type Foo = { + read bar: number, + write baz: number, + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { read string } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { + read [string]: number, + read ["property"]: number + } )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_type_table_spaces_between_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( type Foo = { bar: number, } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { bar: number, } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { bar : number, } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { bar: number, } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { bar: number , } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { bar: number, } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { bar: number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { [string]: number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { [string]: number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { [ string]: number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { [string ]: number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { [string] : number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { [string]: number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_type_table_preserve_original_indexer_style") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( + type Foo = { + [number]: string + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( + type Foo = { { number } } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_type_table_preserve_indexer_location") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( + type Foo = { + [number]: string, + property: number, + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( + type Foo = { + property: number, + [number]: string, + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( + type Foo = { + property: number, + [number]: string, + property2: number, + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_type_table_preserve_property_definition_style") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( + type Foo = { + ["$$typeof1"]: string, + ['$$typeof2']: string, + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_SUITE_END(); diff --git a/tests/TypeFunction.test.cpp b/tests/TypeFunction.test.cpp index 4aa6b680..096b3876 100644 --- a/tests/TypeFunction.test.cpp +++ b/tests/TypeFunction.test.cpp @@ -14,6 +14,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) LUAU_DYNAMIC_FASTINT(LuauTypeFamilyApplicationCartesianProductLimit) +LUAU_FASTFLAG(LuauMetatableTypeFunctions) struct TypeFunctionFixture : Fixture { @@ -33,20 +34,20 @@ struct TypeFunctionFixture : Fixture if (isString(param)) { - return TypeFunctionReductionResult{ctx->builtins->numberType, false, {}, {}}; + return TypeFunctionReductionResult{ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; } else if (isNumber(param)) { - return TypeFunctionReductionResult{ctx->builtins->stringType, false, {}, {}}; + return TypeFunctionReductionResult{ctx->builtins->stringType, Reduction::MaybeOk, {}, {}}; } else if (is(param) || is(param) || is(param) || (ctx->solver && ctx->solver->hasUnresolvedConstraints(param))) { - return TypeFunctionReductionResult{std::nullopt, false, {param}, {}}; + return TypeFunctionReductionResult{std::nullopt, Reduction::MaybeOk, {param}, {}}; } else { - return TypeFunctionReductionResult{std::nullopt, true, {}, {}}; + return TypeFunctionReductionResult{std::nullopt, Reduction::Erroneous, {}, {}}; } } }; @@ -1262,4 +1263,195 @@ TEST_CASE_FIXTURE(Fixture, "fuzz_len_type_function_follow") )"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_type_function_assigns_correct_metatable") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Identity = setmetatable<{}, { __index: {} }> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId id = requireTypeAlias("Identity"); + CHECK_EQ(toString(id, {true}), "{ @metatable { __index: { } }, { } }"); + const MetatableType* mt = get(id); + REQUIRE(mt); + CHECK_EQ(toString(mt->metatable), "{ __index: { } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_type_function_assigns_correct_metatable_2") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Identity = setmetatable<{}, { __index: {} }> + type FooBar = setmetatable<{}, Identity> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId id = requireTypeAlias("Identity"); + CHECK_EQ(toString(id, {true}), "{ @metatable { __index: { } }, { } }"); + const MetatableType* mt = get(id); + REQUIRE(mt); + CHECK_EQ(toString(mt->metatable), "{ __index: { } }"); + + TypeId foobar = requireTypeAlias("FooBar"); + const MetatableType* mt2 = get(foobar); + REQUIRE(mt2); + CHECK_EQ(toString(mt2->metatable, {true}), "{ @metatable { __index: { } }, { } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_type_function_errors_on_metatable_with_metatable_metamethod") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Identity = setmetatable<{}, { __metatable: "blocked" }> + type Bad = setmetatable + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeId id = requireTypeAlias("Identity"); + CHECK_EQ(toString(id, {true}), "{ @metatable { __metatable: \"blocked\" }, { } }"); + const MetatableType* mt = get(id); + REQUIRE(mt); + CHECK_EQ(toString(mt->metatable), "{ __metatable: \"blocked\" }"); +} + + +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_type_function_errors_on_invalid_set") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Identity = setmetatable + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_type_function_errors_on_nontable_metatable") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Identity = setmetatable<{}, string> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_type_function_returns_nil_if_no_metatable") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type TableWithNoMetatable = getmetatable<{}> + type NumberWithNoMetatable = getmetatable + type BooleanWithNoMetatable = getmetatable + type BooleanLiteralWithNoMetatable = getmetatable + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto tableResult = requireTypeAlias("TableWithNoMetatable"); + CHECK_EQ(toString(tableResult), "nil"); + + auto numberResult = requireTypeAlias("NumberWithNoMetatable"); + CHECK_EQ(toString(numberResult), "nil"); + + auto booleanResult = requireTypeAlias("BooleanWithNoMetatable"); + CHECK_EQ(toString(booleanResult), "nil"); + + auto booleanLiteralResult = requireTypeAlias("BooleanLiteralWithNoMetatable"); + CHECK_EQ(toString(booleanLiteralResult), "nil"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_returns_correct_metatable") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + local metatable = { __index = { w = 4 } } + local obj = setmetatable({x = 1, y = 2, z = 3}, metatable) + type Metatable = getmetatable + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireTypeAlias("Metatable"), {true}), "{ __index: { w: number } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_returns_correct_metatable_for_union") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Identity = setmetatable<{}, {}> + type Metatable = getmetatable + type IntersectMetatable = getmetatable + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const PrimitiveType* stringType = get(builtinTypes->stringType); + REQUIRE(stringType->metatable); + + TypeArena arena = TypeArena{}; + + std::string expected1 = toString(arena.addType(UnionType{{*stringType->metatable, builtinTypes->emptyTableType}}), {true}); + CHECK_EQ(toString(requireTypeAlias("Metatable"), {true}), expected1); + + std::string expected2 = toString(arena.addType(IntersectionType{{*stringType->metatable, builtinTypes->emptyTableType}}), {true}); + CHECK_EQ(toString(requireTypeAlias("IntersectMetatable"), {true}), expected2); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_returns_correct_metatable_for_string") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Metatable = getmetatable + type Metatable2 = getmetatable<"foo"> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const PrimitiveType* stringType = get(builtinTypes->stringType); + REQUIRE(stringType->metatable); + + std::string expected = toString(*stringType->metatable, {true}); + + CHECK_EQ(toString(requireTypeAlias("Metatable"), {true}), expected); + CHECK_EQ(toString(requireTypeAlias("Metatable2"), {true}), expected); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_respects_metatable_metamethod") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + local metatable = { __metatable = "Test" } + local obj = setmetatable({x = 1, y = 2, z = 3}, metatable) + type Metatable = getmetatable + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireTypeAlias("Metatable")), "string"); +} + TEST_SUITE_END(); diff --git a/tests/TypeFunction.user.test.cpp b/tests/TypeFunction.user.test.cpp index 145772fd..bdde63f5 100644 --- a/tests/TypeFunction.user.test.cpp +++ b/tests/TypeFunction.user.test.cpp @@ -8,22 +8,16 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) -LUAU_FASTFLAG(LuauUserTypeFunFixRegister) -LUAU_FASTFLAG(LuauUserTypeFunFixNoReadWrite) -LUAU_FASTFLAG(LuauUserTypeFunFixMetatable) -LUAU_FASTFLAG(LuauUserDefinedTypeFunctionResetState) -LUAU_FASTFLAG(LuauUserTypeFunNonstrict) -LUAU_FASTFLAG(LuauUserTypeFunExportedAndLocal) -LUAU_FASTFLAG(LuauUserDefinedTypeFunParseExport) -LUAU_FASTFLAG(LuauUserTypeFunThreadBuffer) +LUAU_FASTFLAG(LuauUserTypeFunFixInner) +LUAU_FASTFLAG(LuauUserTypeFunGenerics) +LUAU_FASTFLAG(LuauUserTypeFunCloneTail) +LUAU_FASTFLAG(DebugLuauEqSatSimplification) TEST_SUITE_BEGIN("UserDefinedTypeFunctionTests"); TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_nil_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_nil(arg) @@ -39,7 +33,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_nil_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_nil_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getnil() @@ -59,7 +52,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_nil_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_unknown_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_unknown(arg) @@ -75,7 +67,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_unknown_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_unknown_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getunknown() @@ -95,7 +86,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_unknown_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_never_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_never(arg) @@ -111,7 +101,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_never_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_never_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getnever() @@ -131,7 +120,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_never_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_any_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_any(arg) @@ -147,7 +135,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_any_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_any_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getany() @@ -167,7 +154,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_any_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolean_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_bool(arg) @@ -183,7 +169,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolean_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolean_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getboolean() @@ -203,7 +188,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolean_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_number_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_num(arg) @@ -219,7 +203,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_number_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_number_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getnumber() @@ -239,9 +222,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_number_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "thread_and_buffer_types") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; - ScopedFastFlag luauUserTypeFunThreadBuffer{FFlag::LuauUserTypeFunThreadBuffer, true}; LUAU_REQUIRE_NO_ERRORS(check(R"( type function work_with_thread(x) @@ -269,7 +249,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "thread_and_buffer_types") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_string_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_str(arg) @@ -285,7 +264,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_string_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_string_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getstring() @@ -305,7 +283,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_string_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolsingleton_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_boolsingleton(arg) @@ -321,7 +298,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolsingleton_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolsingleton_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getboolsingleton() @@ -341,7 +317,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolsingleton_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strsingleton_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_strsingleton(arg) @@ -357,7 +332,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strsingleton_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strsingleton_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getstrsingleton() @@ -377,7 +351,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strsingleton_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_union_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_union(arg) @@ -397,7 +370,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_union_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_union_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getunion() @@ -426,7 +398,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_union_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_intersection_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_intersection(arg) @@ -446,7 +417,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_intersection_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_intersection_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getintersection() @@ -481,7 +451,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_intersection_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_negation_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getnegation() @@ -503,10 +472,31 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_negation_methods_work") CHECK(toString(tpm->givenTp) == "~string"); } +TEST_CASE_FIXTURE(ClassFixture, "udtf_negation_inner") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunFixInner{FFlag::LuauUserTypeFunFixInner, true}; + + CheckResult result = check(R"( +type function pass(t) + return types.negationof(t):inner() +end + +type function fail(t) + return t:inner() +end + +local function ok(idx: pass): number return idx end +local function notok(idx: fail): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"('fail' type function errored at runtime: [string "fail"]:7: type.inner: cannot call inner method on non-negation type: `number` type)"); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_table_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_table(arg) @@ -526,7 +516,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_table_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_table_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function gettable() @@ -565,7 +554,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_table_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_metatable_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getmetatable() @@ -598,7 +586,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_metatable_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_func(arg) @@ -614,7 +601,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getfunction() @@ -644,7 +630,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_methods_work") TEST_CASE_FIXTURE(ClassFixture, "udtf_class_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_class(arg) @@ -659,7 +644,6 @@ TEST_CASE_FIXTURE(ClassFixture, "udtf_class_serialization_works") TEST_CASE_FIXTURE(ClassFixture, "udtf_class_methods_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( @@ -682,9 +666,6 @@ TEST_CASE_FIXTURE(ClassFixture, "udtf_class_methods_works") TEST_CASE_FIXTURE(ClassFixture, "write_of_readonly_is_nil") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag udtfRwFix{FFlag::LuauUserTypeFunFixNoReadWrite, true}; - CheckResult result = check(R"( type function getclass(arg) @@ -711,7 +692,6 @@ TEST_CASE_FIXTURE(ClassFixture, "write_of_readonly_is_nil") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_check_mutability") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function checkmut() @@ -743,7 +723,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_check_mutability") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_copy_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getcopy() @@ -776,7 +755,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_copy_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_simple_cyclic_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_cycle(arg) @@ -797,7 +775,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_simple_cyclic_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_createtable_bad_metatable") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function badmetatable() @@ -806,7 +783,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_createtable_bad_metatable") local function bad(arg: badmetatable<>) end )"); - LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + LUAU_REQUIRE_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error UserDefinedTypeFunctionError* e = get(result.errors[0]); REQUIRE(e); CHECK( @@ -818,7 +795,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_createtable_bad_metatable") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_complex_cyclic_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_cycle2(arg) @@ -847,7 +823,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_complex_cyclic_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_user_error_is_reported") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function errors_if_string(arg) @@ -860,7 +835,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_user_error_is_reported") local function ok(idx: errors_if_string): nil return idx end )"); - LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + LUAU_REQUIRE_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error UserDefinedTypeFunctionError* e = get(result.errors[0]); REQUIRE(e); CHECK(e->message == "'errors_if_string' type function errored at runtime: [string \"errors_if_string\"]:5: We are in a math class! not english"); @@ -869,7 +844,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_user_error_is_reported") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_type_overrides_call_metamethod") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function hello(arg) @@ -878,7 +852,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_type_overrides_call_metamethod") local function ok(idx: hello): nil return idx end )"); - LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + LUAU_REQUIRE_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error UserDefinedTypeFunctionError* e = get(result.errors[0]); REQUIRE(e); CHECK(e->message == "'hello' type function errored at runtime: [string \"hello\"]:3: userdata"); @@ -887,7 +861,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_type_overrides_call_metamethod") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_type_overrides_eq_metamethod") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function hello() @@ -912,7 +885,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_type_overrides_eq_metamethod") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_type_cant_call_get_props") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function hello(arg) @@ -921,7 +893,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_type_cant_call_get_props") local function ok(idx: hello<() -> ()>): nil return idx end )"); - LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + LUAU_REQUIRE_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error UserDefinedTypeFunctionError* e = get(result.errors[0]); REQUIRE(e); CHECK( @@ -933,7 +905,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_type_cant_call_get_props") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_each_other") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function foo() @@ -945,16 +916,66 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_each_other") local function ok(idx: bar<>): nil return idx end )"); - LUAU_CHECK_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); TypePackMismatch* tpm = get(result.errors[0]); REQUIRE(tpm); CHECK(toString(tpm->givenTp) == "\"hi\""); } +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_each_other_2") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function first(arg) + return arg + end + type function second(arg) + return types.singleton(first(arg)) + end + type function third() + return second("hi") + end + local function ok(idx: third<>): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "\"hi\""); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_each_other_3") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + -- this function should not see 'fourth' function when invoked from 'third' that sees it + type function first(arg) + return fourth(arg) + end + type function second(arg) + return types.singleton(first(arg)) + end + + do + type function fourth(arg) + return arg + end + type function third() + return second("hi") + end + local function ok(idx: third<>): nil return idx end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"('third' type function errored at runtime: [string "first"]:4: attempt to call a nil value)"); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_no_shared_state") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function foo() @@ -974,7 +995,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_no_shared_state") )"); // We are only checking first errors, others are mostly duplicates - LUAU_CHECK_ERROR_COUNT(8, result); + LUAU_REQUIRE_ERROR_COUNT(8, result); CHECK(toString(result.errors[0]) == R"('bar' type function errored at runtime: [string "foo"]:4: attempt to modify a readonly table)"); CHECK(toString(result.errors[1]) == R"(Type function instance bar<"x"> is uninhabited)"); } @@ -982,8 +1003,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_no_shared_state") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_math_reset") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserDefinedTypeFunctionResetState{FFlag::LuauUserDefinedTypeFunctionResetState, true}; CheckResult result = check(R"( type function foo(x) @@ -998,7 +1017,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_math_reset") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_optionify") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function optionify(tbl) @@ -1018,7 +1036,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_optionify") local function ok(idx: optionify): nil return idx end )"); - LUAU_CHECK_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); TypePackMismatch* tpm = get(result.errors[0]); REQUIRE(tpm); CHECK(toString(tpm->givenTp) == "{ age: number?, alive: boolean?, name: string? }"); @@ -1027,7 +1045,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_optionify") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_illegal_global") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function illegal(arg) @@ -1039,7 +1056,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_illegal_global") local function ok(idx: illegal): nil return idx end )"); - LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + LUAU_REQUIRE_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error UserDefinedTypeFunctionError* e = get(result.errors[0]); REQUIRE(e); CHECK(e->message == "'illegal' type function errored at runtime: [string \"illegal\"]:3: this function is not supported in type functions"); @@ -1048,7 +1065,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_illegal_global") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_recursion_and_gc") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function foo(tbl) @@ -1065,7 +1081,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_recursion_and_gc") local function ok(idx: foo): nil return idx end )"); - LUAU_CHECK_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); TypePackMismatch* tpm = get(result.errors[0]); REQUIRE(tpm); } @@ -1073,7 +1089,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_recursion_and_gc") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_recovery_no_upvalues") { ScopedFastFlag solverV2{FFlag::LuauSolverV2, true}; - ScopedFastFlag userDefinedTypeFunctionsSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( local var @@ -1089,14 +1104,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_recovery_no_upvalues") end )"); - LUAU_CHECK_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK(toString(result.errors[0]) == R"(Type function cannot reference outer local 'var')"); } TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_follow") { ScopedFastFlag solverV2{FFlag::LuauSolverV2, true}; - ScopedFastFlag userDefinedTypeFunctionsSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type t0 = any @@ -1105,14 +1119,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_follow") end )"); - LUAU_CHECK_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK(toString(result.errors[0]) == R"(Redefinition of type 't0', previously defined at line 2)"); } TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strip_indexer") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function stripindexer(tbl) @@ -1137,8 +1150,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strip_indexer") TEST_CASE_FIXTURE(BuiltinsFixture, "no_type_methods_on_types") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; CheckResult result = check(R"( type function test(x) @@ -1154,8 +1165,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "no_type_methods_on_types") TEST_CASE_FIXTURE(BuiltinsFixture, "no_types_functions_on_type") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; CheckResult result = check(R"( type function test(x) @@ -1171,8 +1180,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "no_types_functions_on_type") TEST_CASE_FIXTURE(BuiltinsFixture, "no_metatable_writes") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; CheckResult result = check(R"( type function test(x) @@ -1190,8 +1197,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "no_metatable_writes") TEST_CASE_FIXTURE(BuiltinsFixture, "no_eq_field") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; CheckResult result = check(R"( type function test(x) @@ -1207,8 +1212,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "no_eq_field") TEST_CASE_FIXTURE(BuiltinsFixture, "tag_field") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; CheckResult result = check(R"( type function test(x) @@ -1229,9 +1232,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tag_field") TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_serialization") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; - ScopedFastFlag luauUserTypeFunFixMetatable{FFlag::LuauUserTypeFunFixMetatable, true}; CheckResult result = check(R"( type function makemttbl() @@ -1260,9 +1260,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_serialization") TEST_CASE_FIXTURE(BuiltinsFixture, "nonstrict_mode") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; - ScopedFastFlag luauUserTypeFunNonstrict{FFlag::LuauUserTypeFunNonstrict, true}; CheckResult result = check(R"( --!nonstrict @@ -1275,9 +1272,6 @@ local a: foo<> = "a" TEST_CASE_FIXTURE(BuiltinsFixture, "implicit_export") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; - ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true}; fileResolver.source["game/A"] = R"( type function concat(a, b) @@ -1305,9 +1299,6 @@ local b: Test.Concat<'third', 'fourth'> TEST_CASE_FIXTURE(BuiltinsFixture, "local_scope") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; - ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true}; CheckResult result = check(R"( type function foo() @@ -1330,10 +1321,6 @@ local a = test() TEST_CASE_FIXTURE(BuiltinsFixture, "explicit_export") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; - ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true}; - ScopedFastFlag luauUserDefinedTypeFunParseExport{FFlag::LuauUserDefinedTypeFunParseExport, true}; fileResolver.source["game/A"] = R"( export type function concat(a, b) @@ -1357,4 +1344,554 @@ local b: Test.concat<'third', 'fourth'> CHECK(toString(requireType("b")) == R"("thirdfourth")"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "print_to_error") +{ + ScopedFastFlag solverV2{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function t0(a) + print("Where does this go") + print(a.tag) + return types.any + end + local a: t0 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == R"(Where does this go)"); + CHECK(toString(result.errors[1]) == R"(string)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "print_to_error_plus_error") +{ + ScopedFastFlag solverV2{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function t0(a) + print("Where does this go") + print(a.tag) + error("test") + end + local a: t0 + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"(Where does this go)"); + CHECK(toString(result.errors[1]) == R"(string)"); + CHECK(toString(result.errors[2]) == R"('t0' type function errored at runtime: [string "t0"]:5: test)"); + CHECK(toString(result.errors[3]) == R"(Type function instance t0 is uninhabited)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "print_to_error_plus_no_result") +{ + ScopedFastFlag solverV2{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function t0(a) + print("Where does this go") + print(a.tag) + end + local a: t0 + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"(Where does this go)"); + CHECK(toString(result.errors[1]) == R"(string)"); + CHECK(toString(result.errors[2]) == R"('t0' type function: returned a non-type value)"); + CHECK(toString(result.errors[3]) == R"(Type function instance t0 is uninhabited)"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_serialization_1") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function pass(arg) + return arg +end + +type test = (T, { x: (y: T) -> (), y: U }, U) -> () + +local function ok(idx: pass): test return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_serialization_2") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function pass(arg) + return arg +end + +type test = (T) -> (T, U...) + +local function ok(idx: pass): test return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_serialization_3") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function pass(arg) + return arg +end + +local function m(a, b) + return {x = a, y = b} +end + +type test = typeof(m) + +local function ok(idx: pass): test return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_cloning_1") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function pass(arg) + return types.copy(arg) +end + +type test = (T, { x: (y: T) -> (), y: U }, U) -> () + +local function ok(idx: pass): test return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_cloning_2") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + ScopedFastFlag luauUserTypeFunCloneTail{FFlag::LuauUserTypeFunCloneTail, true}; + + CheckResult result = check(R"( +type function pass(arg) + return types.copy(arg) +end + +type test = (T) -> (T, U...) + +local function ok(idx: pass): test return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_equality") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function pass(arg) + return types.singleton(types.copy(arg) == arg) +end + +type test = (T) -> (T, U...) + +local function ok(idx: pass): true return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_1") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function pass(arg) + local generics = arg:generics() + local T = generics[1] + return types.newfunction({ head = {T} }, { head = {T} }, {T}) +end + +type test = (T, { x: (y: T) -> (), y: U }, U) -> () + +local function ok(idx: pass): (T) -> (T) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_2") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function pass(arg) + local generics = arg:generics() + local T = generics[1] + local f = types.newfunction() + f:setparameters({T, T}); + f:setreturns({T}); + f:setgenerics({T}); + return f +end + +type test = (T, { x: (y: T) -> (), y: U }, U) -> () + +local function ok(idx: pass): (T, T) -> (T) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_3") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function pass() + local T = types.generic("T") + assert(T.tag == "generic") + assert(T:name() == "T") + assert(T:ispack() == false) + + local Us, Vs = types.generic("U", true), types.generic("V", true) + assert(Us.tag == "generic") + assert(Us:name() == "U") + assert(Us:ispack() == true) + + local f = types.newfunction() + f:setparameters({T}, Us); + f:setreturns({T}, Vs); + f:setgenerics({T, Us, Vs}); + return f +end + +local function ok(idx: pass<>): (T, U...) -> (T, V...) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_4") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function pass() + local T, U = types.generic("T"), types.generic("U") + + -- (T) -> () + local func = types.newfunction({ head = {T} }, {}, {T}); + + -- { x: (T) -> (), y: U } + local tbl = types.newtable({ [types.singleton("x")] = func, [types.singleton("y")] = U }) + + -- (T, { x: (T) -> (), y: U }, U) -> () + return types.newfunction({ head = {T, tbl, U } }, {}, {T, U}) +end + +type test = (T, { x: (y: T) -> (), y: U }, U) -> () + +local function ok(idx: pass<>): test return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_5") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function pass() + local T = types.generic("T") + return types.newfunction({ head = {T} }, {}, {types.copy(T)}) +end + +local function ok(idx: pass<>): (T) -> () return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_6") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function pass(arg) + local generics = arg:generics() + local T, U = generics[1], generics[2] + local f = types.newfunction() + f:setparameters({T}); + f:setreturns({U}); + f:setgenerics({T, U}); + return f +end + +local function m(a, b) + return {x = a, y = b} +end + +type test = typeof(m) + +local function ok(idx: pass): (T) -> (U) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_7") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function pass(arg) + local p, r = arg:parameters(), arg:returns() + local f = types.newfunction() + f:setparameters(p.head, p.tail); + f:setreturns(r.head, r.tail); + f:setgenerics(arg:generics()); + return f +end + +type test = (T, U...) -> (T, U...) + +local function ok(idx: pass): (T, U...) -> (T, U...) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_8") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function pass(arg) + local p, r = arg:parameters(), arg:returns() + local f = types.newfunction() + f:setparameters(p.head, p.tail); + f:setreturns(r.head, r.tail); + f:setgenerics(arg:generics()); + return f +end + +type test = (U...) -> (U...) + +local function ok(idx: pass): (T, T) -> (T) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_equality_2") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function get() + local T, Us = types.generic("T"), types.generic("U", true) + + local tbl1 = types.newtable({ [types.singleton("x")] = T }) + local tbl2 = types.newtable({ [types.singleton("x")] = Us }) -- it is possible to have invalid types in-flight + + return types.singleton(tbl1 == tbl2) +end + +local function ok(idx: get<>): false return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_1") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function get() + local T, Us = types.generic("T"), types.generic("U", true) + return types.newfunction({}, {}, {Us, T}) +end +local function ok(idx: get<>): false return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK( + toString(result.errors[0]) == + R"('get' type function errored at runtime: [string "get"]:4: types.newfunction: generic type cannot follow a generic pack)" + ); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_2") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function get() + local T, Us = types.generic("T"), types.generic("U", true) + return types.newfunction({ head = {T} }, {}, {}) +end +local function ok(idx: get<>): false return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"(Generic type 'T' is not in a scope of the active generic function)"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_3") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function get() + local T, U = types.generic("T"), types.generic("U") + + -- (U) -> () + local func = types.newfunction({ head = {U} }, {}, {U}); + + -- broken: (T, (U) -> (), U) -> () + return types.newfunction({ head = {T, func, U } }, {}, {T}) +end +local function ok(idx: get<>): false return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"(Generic type 'U' is not in a scope of the active generic function)"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_4") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function get() + local T, Us = types.generic("T"), types.generic("U", true) + return types.newfunction({ head = {T} }, { tail = Us }, {T, T}) +end +local function ok(idx: get<>): false return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"(Duplicate type parameter 'T')"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_5") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function get() + local T, Ts = types.generic("T"), types.generic("T", true) + return types.newfunction({ head = {T} }, { tail = Ts }, {T, Ts}) +end +local function ok(idx: get<>): false return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"(Duplicate type parameter 'T')"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_6") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function get() + local T, Us = types.generic("T"), types.generic("U", true) + return types.newfunction({ head = {Us} }, {}, {T, Us}) +end +local function ok(idx: get<>): false return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"(Generic type pack 'U...' cannot be placed in a type position)"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_7") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function get() + local T, Us = types.generic("T"), types.generic("U", true) + return types.newfunction({ tail = Us }, {}, {T}) +end +local function ok(idx: get<>): false return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"(Generic type pack 'U...' is not in a scope of the active generic function)"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_variadic_api") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; + + CheckResult result = check(R"( +type function pass(arg) + local p, r = arg:parameters(), arg:returns() + local f = types.newfunction() + f:setparameters({p.tail}, p.head[1]); + f:setreturns({r.tail}, r.head[1]); + return f +end + +type test = (string, ...number) -> (number, ...string) + +local function ok(idx: pass): (number, ...string) -> (string, ...number) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_eqsat_opaque") +{ + ScopedFastFlag sffs[] = {{FFlag::LuauSolverV2, true}, {FFlag::LuauUserTypeFunGenerics, true}, {FFlag::DebugLuauEqSatSimplification, true}}; + + CheckResult _ = check(R"( + type function t0(a) + error("test") + end + local v: t0 + )"); + TypeArena arena; + auto ty = requireType("v"); + auto simplifier = EqSatSimplification::newSimplifier(NotNull{&arena}, frontend.builtinTypes); + auto simplified = eqSatSimplify(NotNull{simplifier.get()}, ty); + REQUIRE(simplified); + CHECK_EQ("t0", toString(simplified->result)); // NOLINT(bugprone-unchecked-optional-access) +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 2b5e64be..3972fd6b 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -2,6 +2,7 @@ #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/AstQuery.h" @@ -9,7 +10,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) +LUAU_FASTFLAG(LuauFixInfiniteRecursionInNormalization) TEST_SUITE_BEGIN("TypeAliases"); @@ -1179,4 +1180,33 @@ TEST_CASE_FIXTURE(Fixture, "bound_type_in_alias_segfault") )")); } +TEST_CASE_FIXTURE(BuiltinsFixture, "gh1632_no_infinite_recursion_in_normalization") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauFixInfiniteRecursionInNormalization, true}, + }; + + CheckResult result = check(R"( + type Node = { + value: T, + next: Node?, + -- remove `prev`, solves issue + prev: Node?, + }; + + type List = { + head: Node? + } + + local function IsFront(list: List, nodeB: Node) + -- remove if statement below, solves issue + if (list.head == nodeB) then + end + end + )"); + + LUAU_CHECK_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index d9c4c13e..96443aeb 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -10,8 +10,9 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauTypestateBuiltins2) -LUAU_FASTFLAG(LuauStringFormatArityFix) +LUAU_FASTFLAG(LuauTableCloneClonesType3) +LUAU_FASTFLAG(LuauStringFormatErrorSuppression) +LUAU_FASTFLAG(LuauFreezeIgnorePersistent) TEST_SUITE_BEGIN("BuiltinTests"); @@ -807,8 +808,6 @@ TEST_CASE_FIXTURE(Fixture, "string_format_as_method") TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_trivial_arity") { - ScopedFastFlag sff{FFlag::LuauStringFormatArityFix, true}; - CheckResult result = check(R"( string.format() )"); @@ -1132,15 +1131,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins2) + if (FFlag::LuauSolverV2) CHECK("Key 'b' not found in table '{ read a: number }'" == toString(result.errors[0])); - else if (FFlag::LuauSolverV2) - CHECK("Key 'b' not found in table '{ a: number }'" == toString(result.errors[0])); else CHECK_EQ("Key 'b' not found in table '{| a: number |}'", toString(result.errors[0])); CHECK(Location({13, 18}, {13, 23}) == result.errors[0].location); - if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins2) + if (FFlag::LuauSolverV2) { CHECK_EQ("{ read a: number }", toString(requireTypeAtPosition({15, 19}))); CHECK_EQ("{ read b: string }", toString(requireTypeAtPosition({16, 19}))); @@ -1176,8 +1173,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_does_not_retroactively_block_mu LUAU_REQUIRE_NO_ERRORS(result); - - if (FFlag::LuauTypestateBuiltins2) + if (FFlag::LuauSolverV2) { CHECK_EQ("{ a: number, q: string } | { read a: number, read q: string }", toString(requireType("t1"), {/*exhaustive */ true})); // before the assignment, it's `t1` @@ -1207,8 +1203,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_no_generic_table") end )"); - if (FFlag::LuauTypestateBuiltins2) - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_on_metatable") @@ -1235,13 +1230,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_errors_on_no_args") table.freeze() )"); - // this does not error in the new solver without the typestate builtins functionality. - if (FFlag::LuauSolverV2 && !FFlag::LuauTypestateBuiltins2) - { - LUAU_REQUIRE_NO_ERRORS(result); - return; - } - LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK(get(result.errors[0])); @@ -1254,25 +1242,40 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_errors_on_non_tables") table.freeze(42) )"); - // this does not error in the new solver without the typestate builtins functionality. - if (FFlag::LuauSolverV2 && !FFlag::LuauTypestateBuiltins2) - { - LUAU_REQUIRE_NO_ERRORS(result); - return; - } - LUAU_REQUIRE_ERROR_COUNT(1, result); TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins2) + if (FFlag::LuauSolverV2) CHECK_EQ(toString(tm->wantedType), "table"); else CHECK_EQ(toString(tm->wantedType), "{- -}"); CHECK_EQ(toString(tm->givenType), "number"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_persistent_skip") +{ + ScopedFastFlag luauFreezeIgnorePersistent{FFlag::LuauFreezeIgnorePersistent, true}; + + CheckResult result = check(R"( + table.freeze(table) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_clone_persistent_skip") +{ + ScopedFastFlag luauFreezeIgnorePersistent{FFlag::LuauFreezeIgnorePersistent, true}; + + CheckResult result = check(R"( + table.clone(table) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") { // In the new solver, nil can certainly be used where a generic is required, so all generic parameters are optional. @@ -1586,4 +1589,77 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "string_find_should_not_crash") )")); } +TEST_CASE_FIXTURE(BuiltinsFixture, "table_dot_clone_type_states") +{ + CheckResult result = check(R"( + local t1 = {} + t1.x = 5 + local t2 = table.clone(t1) + t2.y = 6 + t1.z = 3 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauTableCloneClonesType3) + { + CHECK_EQ(toString(requireType("t1"), {true}), "{ x: number, z: number }"); + CHECK_EQ(toString(requireType("t2"), {true}), "{ x: number, y: number }"); + } + else + { + CHECK_EQ(toString(requireType("t1"), {true}), "{ x: number, y: number, z: number }"); + CHECK_EQ(toString(requireType("t2"), {true}), "{ x: number, y: number, z: number }"); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_clone_should_not_break") +{ + CheckResult result = check(R"( + local Immutable = {} + + function Immutable.Set(dictionary, key, value) + local new = table.clone(dictionary) + + new[key] = value + + return new + end + + return Immutable + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_clone_should_not_break_2") +{ + CheckResult result = check(R"( + function set(dictionary, key, value) + local new = table.clone(dictionary) + + new[key] = value + + return new + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_should_support_any") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + local x: any = "world" + print(string.format("Hello, %s!", x)) + )"); + + if (FFlag::LuauStringFormatErrorSuppression) + LUAU_REQUIRE_NO_ERRORS(result); + else + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index b81ac010..53f1396d 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -665,12 +665,11 @@ TEST_CASE_FIXTURE(ClassFixture, "indexable_classes") )"); if (FFlag::LuauSolverV2) - CHECK( - "Type 'boolean' could not be converted into 'number | string'" == toString(result.errors.at(0)) - ); + CHECK("Type 'boolean' could not be converted into 'number | string'" == toString(result.errors.at(0))); else CHECK_EQ( - toString(result.errors.at(0)), "Type 'boolean' could not be converted into 'number | string'; none of the union options are compatible" + toString(result.errors.at(0)), + "Type 'boolean' could not be converted into 'number | string'; none of the union options are compatible" ); } { @@ -680,12 +679,11 @@ TEST_CASE_FIXTURE(ClassFixture, "indexable_classes") )"); if (FFlag::LuauSolverV2) - CHECK( - "Type 'boolean' could not be converted into 'number | string'" == toString(result.errors.at(0)) - ); + CHECK("Type 'boolean' could not be converted into 'number | string'" == toString(result.errors.at(0))); else CHECK_EQ( - toString(result.errors.at(0)), "Type 'boolean' could not be converted into 'number | string'; none of the union options are compatible" + toString(result.errors.at(0)), + "Type 'boolean' could not be converted into 'number | string'; none of the union options are compatible" ); } diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 2ab90ab5..ce1cef29 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -10,6 +10,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauNewSolverPrePopulateClasses) +LUAU_FASTFLAG(LuauClipNestedAndRecursiveUnion) TEST_SUITE_BEGIN("DefinitionTests"); @@ -541,4 +542,20 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_has_source_module_name_set") CHECK_EQ(ctv->definitionModuleName, "@test"); } +TEST_CASE_FIXTURE(Fixture, "recursive_redefinition_reduces_rightfully") +{ + ScopedFastFlag _{FFlag::LuauClipNestedAndRecursiveUnion, true}; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + local t: {[string]: string} = {} + + local function f() + t = t + end + + t = t + )")); +} + + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 3686f2d4..942ef6a7 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -19,8 +19,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTINT(LuauTarjanChildLimit) -LUAU_FASTFLAG(LuauRetrySubtypingWithoutHiddenPack) -LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions) +LUAU_FASTFLAG(DebugLuauEqSatSimplification) TEST_SUITE_BEGIN("TypeInferFunctions"); @@ -2565,10 +2564,7 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_return_type") { - ScopedFastFlag sffs[] = { - {FFlag::LuauSolverV2, true}, - {FFlag::LuauDontRefCountTypesInTypeFunctions, true} - }; + ScopedFastFlag _{FFlag::LuauSolverV2, true}; // CLI-114134: This test: // a) Has a kind of weird result (suggesting `number | false` is not great); @@ -2880,8 +2876,6 @@ TEST_CASE_FIXTURE(Fixture, "fuzzer_missing_follow_in_ast_stat_fun") TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types") { - ScopedFastFlag _{FFlag::LuauDontRefCountTypesInTypeFunctions, true}; - CheckResult result = check(R"( function foo(player) local success,result = player:thing() @@ -3017,9 +3011,6 @@ local u,v = id(3), id(id(44)) TEST_CASE_FIXTURE(Fixture, "hidden_variadics_should_not_break_subtyping") { - // Only applies to new solver. - ScopedFastFlag sff{FFlag::LuauRetrySubtypingWithoutHiddenPack, true}; - CheckResult result = check(R"( --!strict type FooType = { diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index a9109e1d..e5fdbdd3 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -12,7 +12,6 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauTypestateBuiltins2) LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations) using namespace Luau; @@ -185,10 +184,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cross_module_table_freeze") ModulePtr b = frontend.moduleResolver.getModule("game/B"); REQUIRE(b != nullptr); // confirm that no cross-module mutation happened here! - if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins2) + if (FFlag::LuauSolverV2) CHECK(toString(b->returnType) == "{ read a: number }"); - else if (FFlag::LuauSolverV2) - CHECK(toString(b->returnType) == "{ a: number }"); else CHECK(toString(b->returnType) == "{| a: number |}"); } diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 5f4b730e..7460434e 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -17,7 +17,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauMetatableFollow) +LUAU_FASTFLAG(LuauDoNotGeneralizeInTypeFunctions) TEST_SUITE_BEGIN("TypeInferOperators"); @@ -801,7 +801,10 @@ TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") "Operator '+' could not be applied to operands of types unknown and unknown; there is no corresponding overload for __add", toString(result.errors[0]) ); - CHECK_EQ("Operator '-' could not be applied to operands of types unknown and unknown; there is no corresponding overload for __sub", toString(result.errors[1])); + CHECK_EQ( + "Operator '-' could not be applied to operands of types unknown and unknown; there is no corresponding overload for __sub", + toString(result.errors[1]) + ); } else { @@ -812,19 +815,19 @@ TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") TEST_CASE_FIXTURE(BuiltinsFixture, "and_binexps_dont_unify") { - CheckResult result = check(R"( - --!strict - local t = {} - while true and t[1] do - print(t[1].test) - end - )"); + ScopedFastFlag _{FFlag::LuauDoNotGeneralizeInTypeFunctions, true}; - // This infers a type for `t` of `{unknown}`, and so it makes sense that `t[1].test` would error. - if (FFlag::LuauSolverV2) - LUAU_REQUIRE_ERROR_COUNT(1, result); - else - LUAU_REQUIRE_NO_ERRORS(result); + // `t` will be inferred to be of type `{ { test: unknown } }` which is + // reasonable, in that it's empty with no bounds on its members. Optimally + // we might emit an error here that the `print(...)` expression is + // unreachable. + LUAU_REQUIRE_NO_ERRORS(check(R"( + --!strict + local t = {} + while true and t[1] do + print(t[1].test) + end + )")); } TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators") @@ -1614,8 +1617,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "compound_operator_on_upvalue") TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_operator_follow") { - ScopedFastFlag luauMetatableFollow{FFlag::LuauMetatableFollow, true}; - CheckResult result = check(R"( local t1 = {} local t2 = {} diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index 0c14a448..2c76f123 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -12,8 +12,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauVectorDefinitions) - using namespace Luau; TEST_SUITE_BEGIN("TypeInferPrimitives"); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index b2cc6713..ee7d713d 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -1,4 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + #include "Luau/TypeInfer.h" #include "Luau/RecursionCounter.h" @@ -11,6 +12,8 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAG(DebugLuauEqSatSimplification); +LUAU_FASTFLAG(LuauStoreCSTData); LUAU_FASTINT(LuauNormalizeCacheLimit); LUAU_FASTINT(LuauTarjanChildLimit); LUAU_FASTINT(LuauTypeInferIterationLimit); @@ -45,7 +48,16 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - const std::string expected = R"( + const std::string expected = FFlag::LuauStoreCSTData ? R"( + function f(a:{fn:()->(a,b...)}): () + if type(a) == 'boolean' then + local a1:boolean=a + elseif a.fn() then + local a2:{fn:()->(a,b...)}=a + end + end + )" + : R"( function f(a:{fn:()->(a,b...)}): () if type(a) == 'boolean'then local a1:boolean=a @@ -55,7 +67,16 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - const std::string expectedWithNewSolver = R"( + const std::string expectedWithNewSolver = FFlag::LuauStoreCSTData ? R"( + function f(a:{fn:()->(unknown,...unknown)}): () + if type(a) == 'boolean' then + local a1:{fn:()->(unknown,...unknown)}&boolean=a + elseif a.fn() then + local a2:{fn:()->(unknown,...unknown)}&(class|function|nil|number|string|thread|buffer|table)=a + end + end + )" + : R"( function f(a:{fn:()->(unknown,...unknown)}): () if type(a) == 'boolean'then local a1:{fn:()->(unknown,...unknown)}&boolean=a @@ -65,8 +86,29 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - if (FFlag::LuauSolverV2) + const std::string expectedWithEqSat = FFlag::LuauStoreCSTData ? R"( + function f(a:{fn:()->(unknown,...unknown)}): () + if type(a) == 'boolean' then + local a1:{fn:()->(unknown,...unknown)}&boolean=a + elseif a.fn() then + local a2:{fn:()->(unknown,...unknown)}&negate=a + end + end + )" + : R"( + function f(a:{fn:()->(unknown,...unknown)}): () + if type(a) == 'boolean'then + local a1:{fn:()->(unknown,...unknown)}&boolean=a + elseif a.fn()then + local a2:{fn:()->(unknown,...unknown)}&negate=a + end + end + )"; + + if (FFlag::LuauSolverV2 && !FFlag::DebugLuauEqSatSimplification) CHECK_EQ(expectedWithNewSolver, decorateWithTypes(code)); + else if (FFlag::LuauSolverV2 && FFlag::DebugLuauEqSatSimplification) + CHECK_EQ(expectedWithEqSat, decorateWithTypes(code)); else CHECK_EQ(expected, decorateWithTypes(code)); } @@ -522,10 +564,10 @@ TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") std::unique_ptr scope = std::make_unique(builtinTypes->anyTypePack); - TypeId free1 = arena.addType(FreeType{scope.get()}); + TypeId free1 = arena.freshType(builtinTypes, scope.get()); TypeId option1 = arena.addType(UnionType{{nilType, free1}}); - TypeId free2 = arena.addType(FreeType{scope.get()}); + TypeId free2 = arena.freshType(builtinTypes, scope.get()); TypeId option2 = arena.addType(UnionType{{nilType, free2}}); InternalErrorReporter iceHandler; @@ -653,13 +695,15 @@ struct IsSubtypeFixture : Fixture { bool isSubtype(TypeId a, TypeId b) { + SimplifierPtr simplifier = newSimplifier(NotNull{&getMainModule()->internalTypes}, builtinTypes); + ModulePtr module = getMainModule(); REQUIRE(module); if (!module->hasModuleScope()) FAIL("isSubtype: module scope data is not available"); - return ::Luau::isSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, ice); + return ::Luau::isSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, NotNull{simplifier.get()}, ice); } }; } // namespace @@ -950,10 +994,10 @@ TEST_CASE_FIXTURE(Fixture, "free_options_can_be_unified_together") std::unique_ptr scope = std::make_unique(builtinTypes->anyTypePack); - TypeId free1 = arena.addType(FreeType{scope.get()}); + TypeId free1 = arena.freshType(builtinTypes, scope.get()); TypeId option1 = arena.addType(UnionType{{nilType, free1}}); - TypeId free2 = arena.addType(FreeType{scope.get()}); + TypeId free2 = arena.freshType(builtinTypes, scope.get()); TypeId option2 = arena.addType(UnionType{{nilType, free2}}); InternalErrorReporter iceHandler; @@ -1269,7 +1313,7 @@ TEST_CASE_FIXTURE(Fixture, "table_containing_non_final_type_is_erroneously_cache TableType* table = getMutable(tableTy); REQUIRE(table); - TypeId freeTy = arena.freshType(&globalScope); + TypeId freeTy = arena.freshType(builtinTypes, &globalScope); table->props["foo"] = Property::rw(freeTy); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index dcbc712e..535b9961 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -8,57 +8,71 @@ #include "doctest.h" LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(DebugLuauEqSatSimplification) +LUAU_FASTFLAG(InferGlobalTypes) +LUAU_FASTFLAG(LuauGeneralizationRemoveRecursiveUpperBound) using namespace Luau; namespace { -std::optional> magicFunctionInstanceIsA( - TypeChecker& typeChecker, - const ScopePtr& scope, - const AstExprCall& expr, - WithPredicate withPredicate -) + +struct MagicInstanceIsA final : MagicFunction { - if (expr.args.size != 1) - return std::nullopt; + std::optional> handleOldSolver( + TypeChecker& typeChecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate + ) override + { + if (expr.args.size != 1) + return std::nullopt; - auto index = expr.func->as(); - auto str = expr.args.data[0]->as(); - if (!index || !str) - return std::nullopt; + auto index = expr.func->as(); + auto str = expr.args.data[0]->as(); + if (!index || !str) + return std::nullopt; - std::optional lvalue = tryGetLValue(*index->expr); - std::optional tfun = scope->lookupType(std::string(str->value.data, str->value.size)); - if (!lvalue || !tfun) - return std::nullopt; + std::optional lvalue = tryGetLValue(*index->expr); + std::optional tfun = scope->lookupType(std::string(str->value.data, str->value.size)); + if (!lvalue || !tfun) + return std::nullopt; - ModulePtr module = typeChecker.currentModule; - TypePackId booleanPack = module->internalTypes.addTypePack({typeChecker.booleanType}); - return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; -} + ModulePtr module = typeChecker.currentModule; + TypePackId booleanPack = module->internalTypes.addTypePack({typeChecker.booleanType}); + return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; + } -void dcrMagicRefinementInstanceIsA(const MagicRefinementContext& ctx) -{ - if (ctx.callSite->args.size != 1 || ctx.discriminantTypes.empty()) - return; + bool infer(const MagicFunctionCallContext&) override + { + return false; + } - auto index = ctx.callSite->func->as(); - auto str = ctx.callSite->args.data[0]->as(); - if (!index || !str) - return; + void refine(const MagicRefinementContext& ctx) override + { + if (ctx.callSite->args.size != 1 || ctx.discriminantTypes.empty()) + return; - std::optional discriminantTy = ctx.discriminantTypes[0]; - if (!discriminantTy) - return; + auto index = ctx.callSite->func->as(); + auto str = ctx.callSite->args.data[0]->as(); + if (!index || !str) + return; + + std::optional discriminantTy = ctx.discriminantTypes[0]; + if (!discriminantTy) + return; + + std::optional tfun = ctx.scope->lookupType(std::string(str->value.data, str->value.size)); + if (!tfun) + return; + + LUAU_ASSERT(get(*discriminantTy)); + asMutable(*discriminantTy)->ty.emplace(tfun->type); + } +}; - std::optional tfun = ctx.scope->lookupType(std::string(str->value.data, str->value.size)); - if (!tfun) - return; - LUAU_ASSERT(get(*discriminantTy)); - asMutable(*discriminantTy)->ty.emplace(tfun->type); -} struct RefinementClassFixture : BuiltinsFixture { @@ -82,8 +96,7 @@ struct RefinementClassFixture : BuiltinsFixture TypePackId isAParams = arena.addTypePack({inst, builtinTypes->stringType}); TypePackId isARets = arena.addTypePack({builtinTypes->booleanType}); TypeId isA = arena.addType(FunctionType{isAParams, isARets}); - getMutable(isA)->magicFunction = magicFunctionInstanceIsA; - getMutable(isA)->dcrMagicRefinement = dcrMagicRefinementInstanceIsA; + getMutable(isA)->magic = std::make_shared(); getMutable(inst)->props = { {"Name", Property{builtinTypes->stringType}}, @@ -448,10 +461,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "call_an_incompatible_function_after_using_ty LUAU_REQUIRE_ERROR_COUNT(2, result); CHECK("Type 'string' could not be converted into 'number'" == toString(result.errors[0])); - CHECK(Location{{ 7, 18}, {7, 19}} == result.errors[0].location); + CHECK(Location{{7, 18}, {7, 19}} == result.errors[0].location); CHECK("Type 'string' could not be converted into 'number'" == toString(result.errors[1])); - CHECK(Location{{ 13, 18}, {13, 19}} == result.errors[1].location); + CHECK(Location{{13, 18}, {13, 19}} == result.errors[1].location); } TEST_CASE_FIXTURE(BuiltinsFixture, "impossible_type_narrow_is_not_an_error") @@ -488,8 +501,15 @@ TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties") if (FFlag::LuauSolverV2) { - // CLI-115281 - Types produced by refinements don't always get simplified - CHECK("{ x: number? } & { x: ~(false?) }" == toString(requireTypeAtPosition({4, 23}))); + if (FFlag::DebugLuauEqSatSimplification) + { + CHECK("{ x: number }" == toString(requireTypeAtPosition({4, 23}))); + } + else + { + // CLI-115281 - Types produced by refinements don't always get simplified + CHECK("{ x: number? } & { x: ~(false?) }" == toString(requireTypeAtPosition({4, 23}))); + } CHECK("number" == toString(requireTypeAtPosition({5, 26}))); } @@ -732,11 +752,15 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "nonoptional_type_can_narrow_to_nil_if_sense_ if (FFlag::LuauSolverV2) { // CLI-115281 Types produced by refinements do not consistently get simplified - CHECK_EQ("(nil & string)?", toString(requireTypeAtPosition({4, 24}))); // type(v) == "nil" - CHECK_EQ("(boolean | buffer | class | function | number | string | table | thread) & string", toString(requireTypeAtPosition({6, 24}))); // type(v) ~= "nil" + CHECK_EQ("(nil & string)?", toString(requireTypeAtPosition({4, 24}))); // type(v) == "nil" + CHECK_EQ( + "(boolean | buffer | class | function | number | string | table | thread) & string", toString(requireTypeAtPosition({6, 24})) + ); // type(v) ~= "nil" - CHECK_EQ("(nil & string)?", toString(requireTypeAtPosition({10, 24}))); // equivalent to type(v) == "nil" - CHECK_EQ("(boolean | buffer | class | function | number | string | table | thread) & string", toString(requireTypeAtPosition({12, 24}))); // equivalent to type(v) ~= "nil" + CHECK_EQ("(nil & string)?", toString(requireTypeAtPosition({10, 24}))); // equivalent to type(v) == "nil" + CHECK_EQ( + "(boolean | buffer | class | function | number | string | table | thread) & string", toString(requireTypeAtPosition({12, 24})) + ); // equivalent to type(v) ~= "nil" } else { @@ -1857,6 +1881,8 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "refine_a_param_that_got_resolved_duri TEST_CASE_FIXTURE(Fixture, "refine_a_property_of_some_global") { + ScopedFastFlag sff{FFlag::InferGlobalTypes, true}; + CheckResult result = check(R"( foo = { bar = 5 :: number? } @@ -1867,9 +1893,8 @@ TEST_CASE_FIXTURE(Fixture, "refine_a_property_of_some_global") if (FFlag::LuauSolverV2) { - LUAU_REQUIRE_ERROR_COUNT(3, result); - - CHECK_EQ("*error-type* | buffer | class | function | number | string | table | thread | true", toString(requireTypeAtPosition({4, 30}))); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("number", toString(requireTypeAtPosition({4, 30}))); } } @@ -2263,37 +2288,37 @@ TEST_CASE_FIXTURE(Fixture, "more_complex_long_disjunction_of_refinements_shouldn { CHECK_NOTHROW(check(R"( script:connect(function(obj) - if script.Parent.SeatNumber.Value == "1D" or - script.Parent.SeatNumber.Value == "2D" or - script.Parent.SeatNumber.Value == "3D" or - script.Parent.SeatNumber.Value == "4D" or - script.Parent.SeatNumber.Value == "5D" or - script.Parent.SeatNumber.Value == "6D" or - script.Parent.SeatNumber.Value == "7D" or - script.Parent.SeatNumber.Value == "8D" or - script.Parent.SeatNumber.Value == "9D" or - script.Parent.SeatNumber.Value == "10D" or - script.Parent.SeatNumber.Value == "11D" or - script.Parent.SeatNumber.Value == "12D" or - script.Parent.SeatNumber.Value == "13D" or - script.Parent.SeatNumber.Value == "14D" or - script.Parent.SeatNumber.Value == "15D" or - script.Parent.SeatNumber.Value == "16D" or - script.Parent.SeatNumber.Value == "1C" or - script.Parent.SeatNumber.Value == "2C" or - script.Parent.SeatNumber.Value == "3C" or - script.Parent.SeatNumber.Value == "4C" or - script.Parent.SeatNumber.Value == "5C" or - script.Parent.SeatNumber.Value == "6C" or - script.Parent.SeatNumber.Value == "7C" or - script.Parent.SeatNumber.Value == "8C" or - script.Parent.SeatNumber.Value == "9C" or - script.Parent.SeatNumber.Value == "10C" or - script.Parent.SeatNumber.Value == "11C" or - script.Parent.SeatNumber.Value == "12C" or - script.Parent.SeatNumber.Value == "13C" or - script.Parent.SeatNumber.Value == "14C" or - script.Parent.SeatNumber.Value == "15C" or + if script.Parent.SeatNumber.Value == "1D" or + script.Parent.SeatNumber.Value == "2D" or + script.Parent.SeatNumber.Value == "3D" or + script.Parent.SeatNumber.Value == "4D" or + script.Parent.SeatNumber.Value == "5D" or + script.Parent.SeatNumber.Value == "6D" or + script.Parent.SeatNumber.Value == "7D" or + script.Parent.SeatNumber.Value == "8D" or + script.Parent.SeatNumber.Value == "9D" or + script.Parent.SeatNumber.Value == "10D" or + script.Parent.SeatNumber.Value == "11D" or + script.Parent.SeatNumber.Value == "12D" or + script.Parent.SeatNumber.Value == "13D" or + script.Parent.SeatNumber.Value == "14D" or + script.Parent.SeatNumber.Value == "15D" or + script.Parent.SeatNumber.Value == "16D" or + script.Parent.SeatNumber.Value == "1C" or + script.Parent.SeatNumber.Value == "2C" or + script.Parent.SeatNumber.Value == "3C" or + script.Parent.SeatNumber.Value == "4C" or + script.Parent.SeatNumber.Value == "5C" or + script.Parent.SeatNumber.Value == "6C" or + script.Parent.SeatNumber.Value == "7C" or + script.Parent.SeatNumber.Value == "8C" or + script.Parent.SeatNumber.Value == "9C" or + script.Parent.SeatNumber.Value == "10C" or + script.Parent.SeatNumber.Value == "11C" or + script.Parent.SeatNumber.Value == "12C" or + script.Parent.SeatNumber.Value == "13C" or + script.Parent.SeatNumber.Value == "14C" or + script.Parent.SeatNumber.Value == "15C" or script.Parent.SeatNumber.Value == "16C" then end) )")); @@ -2418,4 +2443,23 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeof_instance_isa_refinement") CHECK_EQ("string", toString(requireTypeAtPosition({8, 28}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "remove_recursive_upper_bound_when_generalizing") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::DebugLuauEqSatSimplification, true}, + {FFlag::LuauGeneralizationRemoveRecursiveUpperBound, true}, + }; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + local t = {"hello"} + local v = t[2] + if type(v) == "nil" then + local foo = v + end + )")); + + CHECK_EQ("(nil & string)?", toString(requireTypeAtPosition({4, 24}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index ee561e2f..ea893528 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -18,9 +18,11 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauFixIndexerSubtypingOrdering) -LUAU_FASTFLAG(LuauRetrySubtypingWithoutHiddenPack) LUAU_FASTFLAG(LuauTableKeysAreRValues) LUAU_FASTFLAG(LuauAllowNilAssignmentToIndexer) +LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) +LUAU_FASTFLAG(LuauTrackInteriorFreeTablesOnScope) +LUAU_FASTFLAG(LuauDontInPlaceMutateTableType) TEST_SUITE_BEGIN("TableTests"); @@ -2375,7 +2377,7 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table local c : string = t.m("hi") )"); - if (FFlag::LuauSolverV2 && FFlag::LuauRetrySubtypingWithoutHiddenPack) + if (FFlag::LuauSolverV2) { LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -2384,15 +2386,6 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table // This is not actually the expected behavior, but the typemismatch we were seeing before was for the wrong reason. // The behavior of this test is just regressed generally in the new solver, and will need to be consciously addressed. } - else if (FFlag::LuauSolverV2) - { - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CHECK(get(result.errors[0])); - CHECK(Location{{6, 45}, {6, 46}} == result.errors[0].location); - - CHECK(get(result.errors[1])); - } // TODO: test behavior is wrong with LuauInstantiateInSubtyping until we can re-enable the covariant requirement for instantiation in subtyping else if (FFlag::LuauInstantiateInSubtyping) @@ -3815,6 +3808,11 @@ TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compati TEST_CASE_FIXTURE(Fixture, "a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible") { + ScopedFastFlag sffs[] = { + {FFlag::LuauTrackInteriorFreeTypesOnScope, true}, + {FFlag::LuauTrackInteriorFreeTablesOnScope, true}, + }; + CheckResult result = check(R"( local function f(s): string local foo = s:absolutely_no_scalar_has_this_method() @@ -3824,17 +3822,14 @@ TEST_CASE_FIXTURE(Fixture, "a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_ if (FFlag::LuauSolverV2) { - LUAU_REQUIRE_ERROR_COUNT(4, result); + LUAU_REQUIRE_ERROR_COUNT(3, result); CHECK(toString(result.errors[0]) == "Parameter 's' has been reduced to never. This function is not callable with any possible value."); - // FIXME: These free types should have been generalized by now. CHECK( toString(result.errors[1]) == - "Parameter 's' is required to be a subtype of '{- read absolutely_no_scalar_has_this_method: ('a <: (never) -> ('b, c...)) -}' here." + "Parameter 's' is required to be a subtype of '{ read absolutely_no_scalar_has_this_method: (never) -> (unknown, ...unknown) }' here." ); CHECK(toString(result.errors[2]) == "Parameter 's' is required to be a subtype of 'string' here."); - CHECK(get(result.errors[3])); - CHECK_EQ("(never) -> string", toString(requireType("f"))); } else @@ -5002,7 +4997,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_union_type") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ( - "Cannot add indexer to table '{ @metatable t1, (nil & ~(false?)) | { } } where t1 = { new: (a) -> { @metatable t1, (a & ~(false?)) | { } } }'", + "Cannot add indexer to table '{ @metatable t1, (nil & ~(false?)) | { } } where t1 = { new: (a) -> { @metatable t1, (a & ~(false?)) | { " + "} } }'", toString(result.errors[0]) ); } @@ -5066,4 +5062,33 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "read_only_property_reads") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "multiple_fields_in_literal") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauDontInPlaceMutateTableType, true}, + }; + + auto result = check(R"( + type Foo = { + [string]: { + Min: number, + Max: number + } + } + local Foos: Foo = { + ["Foo"] = { + Min = -1, + Max = 1 + }, + ["Foo"] = { + Min = -1, + Max = 1 + } + } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 80dddc67..2ff97a25 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -23,8 +23,8 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTINT(LuauNormalizeCacheLimit); LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauTypeInferRecursionLimit); -LUAU_FASTFLAG(LuauNewSolverVisitErrorExprLvalues) -LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions) +LUAU_FASTFLAG(InferGlobalTypes) +LUAU_FASTFLAG(LuauAstTypeGroup) using namespace Luau; @@ -877,7 +877,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions1") CheckResult result = check(R"(local a = if true then "true" else "false")"); LUAU_REQUIRE_NO_ERRORS(result); TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveType::String); + CHECK("string" == toString(aType)); } TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions2") @@ -888,7 +888,7 @@ local a = if false then "a" elseif false then "b" else "c" )"); LUAU_REQUIRE_NO_ERRORS(result); TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveType::String); + CHECK("string" == toString(aType)); } TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_type_union") @@ -1197,13 +1197,26 @@ TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_normalizer") validateErrors(result.errors); REQUIRE_MESSAGE(!result.errors.empty(), getErrors(result)); - CHECK(1 == result.errors.size()); - if (FFlag::LuauSolverV2) - CHECK(Location{{3, 22}, {3, 42}} == result.errors[0].location); + { + CHECK(3 == result.errors.size()); + CHECK(Location{{2, 22}, {2, 41}} == result.errors[0].location); + CHECK(Location{{3, 22}, {3, 42}} == result.errors[1].location); + if (FFlag::LuauAstTypeGroup) + CHECK(Location{{3, 22}, {3, 40}} == result.errors[2].location); + else + CHECK(Location{{3, 23}, {3, 40}} == result.errors[2].location); + CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0])); + CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[1])); + CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[2])); + } else + { + CHECK(1 == result.errors.size()); + CHECK(Location{{3, 12}, {3, 46}} == result.errors[0].location); - CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0])); + CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "type_infer_cache_limit_normalizer") @@ -1708,10 +1721,7 @@ TEST_CASE_FIXTURE(Fixture, "react_lua_follow_free_type_ub") TEST_CASE_FIXTURE(Fixture, "visit_error_nodes_in_lvalue") { - ScopedFastFlag sffs[] = { - {FFlag::LuauSolverV2, true}, - {FFlag::LuauNewSolverVisitErrorExprLvalues, true} - }; + ScopedFastFlag _{FFlag::LuauSolverV2, true}; // This should always fail to parse, but shouldn't assert. Previously this // would assert as we end up _roughly_ parsing this (with a lot of error @@ -1727,16 +1737,13 @@ TEST_CASE_FIXTURE(Fixture, "visit_error_nodes_in_lvalue") // in lvalue positions. LUAU_REQUIRE_ERRORS(check(R"( --!strict - (::, + (::, )")); } TEST_CASE_FIXTURE(Fixture, "avoid_blocking_type_function") { - ScopedFastFlag sffs[] = { - {FFlag::LuauSolverV2, true}, - {FFlag::LuauDontRefCountTypesInTypeFunctions, true} - }; + ScopedFastFlag _{FFlag::LuauSolverV2, true}; LUAU_CHECK_NO_ERRORS(check(R"( --!strict @@ -1749,10 +1756,7 @@ TEST_CASE_FIXTURE(Fixture, "avoid_blocking_type_function") TEST_CASE_FIXTURE(Fixture, "avoid_double_reference_to_free_type") { - ScopedFastFlag sffs[] = { - {FFlag::LuauSolverV2, true}, - {FFlag::LuauDontRefCountTypesInTypeFunctions, true} - }; + ScopedFastFlag _{FFlag::LuauSolverV2, true}; LUAU_CHECK_NO_ERRORS(check(R"( --!strict @@ -1763,4 +1767,21 @@ TEST_CASE_FIXTURE(Fixture, "avoid_double_reference_to_free_type") )")); } +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_types_of_globals") +{ + ScopedFastFlag sff_LuauSolverV2{FFlag::LuauSolverV2, true}; + ScopedFastFlag sff_InferGlobalTypes{FFlag::InferGlobalTypes, true}; + + CheckResult result = check(R"( + --!strict + foo = 5 + print(foo) + )"); + + CHECK_EQ("number", toString(requireTypeAtPosition({3, 14}))); + + REQUIRE_EQ(1, result.errors.size()); + CHECK_EQ("Unknown global 'foo'", toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index ccfa6923..48f5d3ea 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -42,12 +42,13 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") { - Type functionOne{TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType})) + Type functionOne{TypeVariant{ + FunctionType(arena.addTypePack({arena.freshType(builtinTypes, globalScope->level)}), arena.addTypePack({builtinTypes->numberType})) }}; - Type functionTwo{ - TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({arena.freshType(globalScope->level)}))} - }; + Type functionTwo{TypeVariant{FunctionType( + arena.addTypePack({arena.freshType(builtinTypes, globalScope->level)}), arena.addTypePack({arena.freshType(builtinTypes, globalScope->level)}) + )}}; state.tryUnify(&functionTwo, &functionOne); CHECK(!state.failure); @@ -60,14 +61,16 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") { - TypePackVar argPackOne{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; - Type functionOne{TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType})) + TypePackVar argPackOne{TypePack{{arena.freshType(builtinTypes, globalScope->level)}, std::nullopt}}; + Type functionOne{TypeVariant{ + FunctionType(arena.addTypePack({arena.freshType(builtinTypes, globalScope->level)}), arena.addTypePack({builtinTypes->numberType})) }}; Type functionOneSaved = functionOne.clone(); - TypePackVar argPackTwo{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; - Type functionTwo{TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->stringType})) + TypePackVar argPackTwo{TypePack{{arena.freshType(builtinTypes, globalScope->level)}, std::nullopt}}; + Type functionTwo{TypeVariant{ + FunctionType(arena.addTypePack({arena.freshType(builtinTypes, globalScope->level)}), arena.addTypePack({builtinTypes->stringType})) }}; Type functionTwoSaved = functionTwo.clone(); @@ -83,11 +86,11 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") { Type tableOne{TypeVariant{ - TableType{{{"foo", {arena.freshType(globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, + TableType{{{"foo", {arena.freshType(builtinTypes, globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; Type tableTwo{TypeVariant{ - TableType{{{"foo", {arena.freshType(globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, + TableType{{{"foo", {arena.freshType(builtinTypes, globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; CHECK_NE(*getMutable(&tableOne)->props["foo"].type(), *getMutable(&tableTwo)->props["foo"].type()); @@ -106,7 +109,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") { Type tableOne{TypeVariant{ TableType{ - {{"foo", {arena.freshType(globalScope->level)}}, {"bar", {builtinTypes->numberType}}}, + {{"foo", {arena.freshType(builtinTypes, globalScope->level)}}, {"bar", {builtinTypes->numberType}}}, std::nullopt, globalScope->level, TableState::Unsealed @@ -115,7 +118,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") Type tableTwo{TypeVariant{ TableType{ - {{"foo", {arena.freshType(globalScope->level)}}, {"bar", {builtinTypes->stringType}}}, + {{"foo", {arena.freshType(builtinTypes, globalScope->level)}}, {"bar", {builtinTypes->stringType}}}, std::nullopt, globalScope->level, TableState::Unsealed @@ -295,7 +298,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") { - Type redirect{FreeType{TypeLevel{}}}; + Type redirect{FreeType{TypeLevel{}, builtinTypes->neverType, builtinTypes->unknownType}}; Type table{TableType{}}; Type metatable{MetatableType{&redirect, &table}}; redirect = BoundType{&metatable}; // Now we have a metatable that is recursive on the table type @@ -318,7 +321,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification") TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") { - TypeId a = arena.addType(Type{FreeType{TypeLevel{}}}); + TypeId a = arena.freshType(builtinTypes, TypeLevel{}); TypeId b = builtinTypes->numberType; state.tryUnify(a, b); @@ -381,7 +384,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "fuzz_tail_unification_issue") TypePackVar packTmp{TypePack{{builtinTypes->anyType}, &variadicAny}}; TypePackVar packSub{TypePack{{builtinTypes->anyType, builtinTypes->anyType}, &packTmp}}; - Type freeTy{FreeType{TypeLevel{}}}; + Type freeTy{FreeType{TypeLevel{}, builtinTypes->neverType, builtinTypes->unknownType}}; TypePackVar freeTp{FreeTypePack{TypeLevel{}}}; TypePackVar packSuper{TypePack{{&freeTy}, &freeTp}}; @@ -438,10 +441,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_two_unions_under_dcr_does_not_creat const std::shared_ptr scope = globalScope; const std::shared_ptr nestedScope = std::make_shared(scope); - const TypeId outerType = arena.freshType(scope.get()); - const TypeId outerType2 = arena.freshType(scope.get()); + const TypeId outerType = arena.freshType(builtinTypes, scope.get()); + const TypeId outerType2 = arena.freshType(builtinTypes, scope.get()); - const TypeId innerType = arena.freshType(nestedScope.get()); + const TypeId innerType = arena.freshType(builtinTypes, nestedScope.get()); state.enableNewSolver(); diff --git a/tests/TypeInfer.typePacks.test.cpp b/tests/TypeInfer.typePacks.test.cpp index 9cf8f153..858c3052 100644 --- a/tests/TypeInfer.typePacks.test.cpp +++ b/tests/TypeInfer.typePacks.test.cpp @@ -953,11 +953,10 @@ a = b if (FFlag::LuauSolverV2) { - const std::string expected = - "Type\n" - " '() -> (number, ...boolean)'\n" - "could not be converted into\n" - " '() -> (number, ...string)'; at returns().tail().variadic(), boolean is not a subtype of string"; + const std::string expected = "Type\n" + " '() -> (number, ...boolean)'\n" + "could not be converted into\n" + " '() -> (number, ...string)'; at returns().tail().variadic(), boolean is not a subtype of string"; CHECK(expected == toString(result.errors[0])); } diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 49ec61e7..247894d1 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -621,7 +621,7 @@ TEST_CASE_FIXTURE(Fixture, "indexing_into_a_cyclic_union_doesnt_crash") TypeArena& arena = frontend.globals.globalTypes; unfreeze(arena); - TypeId badCyclicUnionTy = arena.freshType(frontend.globals.globalScope.get()); + TypeId badCyclicUnionTy = arena.freshType(builtinTypes, frontend.globals.globalScope.get()); UnionType u; u.options.push_back(badCyclicUnionTy); @@ -881,7 +881,9 @@ TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_union_types") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("(({ read x: unknown } & { x: number }) | ({ read x: unknown } & { x: string })) -> { x: number } | { x: string }", toString(requireType("f"))); + CHECK_EQ( + "(({ read x: unknown } & { x: number }) | ({ read x: unknown } & { x: string })) -> { x: number } | { x: string }", toString(requireType("f")) + ); } TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_union_types_2") diff --git a/tests/TypePath.test.cpp b/tests/TypePath.test.cpp index bf831621..b281dcab 100644 --- a/tests/TypePath.test.cpp +++ b/tests/TypePath.test.cpp @@ -17,6 +17,7 @@ using namespace Luau::TypePath; LUAU_FASTFLAG(LuauSolverV2); LUAU_DYNAMIC_FASTINT(LuauTypePathMaximumTraverseSteps); +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds); struct TypePathFixture : Fixture { @@ -277,7 +278,7 @@ TEST_CASE_FIXTURE(TypePathFixture, "bounds") TypeArena& arena = frontend.globals.globalTypes; unfreeze(arena); - TypeId ty = arena.freshType(frontend.globals.globalScope.get()); + TypeId ty = arena.freshType(frontend.builtinTypes, frontend.globals.globalScope.get()); FreeType* ft = getMutable(ty); SUBCASE("upper") diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 9e21b1e0..1e5fdaf1 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -219,7 +219,7 @@ TEST_CASE_FIXTURE(Fixture, "UnionTypeIterator_with_only_cyclic_union") */ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") { - Type ftv11{FreeType{TypeLevel{}}}; + Type ftv11{FreeType{TypeLevel{}, builtinTypes->neverType, builtinTypes->unknownType}}; TypePackVar tp24{TypePack{{&ftv11}}}; TypePackVar tp17{TypePack{}}; @@ -469,8 +469,8 @@ TEST_CASE("content_reassignment") myAny.documentationSymbol = "@global/any"; TypeArena arena; - - TypeId futureAny = arena.addType(FreeType{TypeLevel{}}); + BuiltinTypes builtinTypes; + TypeId futureAny = arena.freshType(NotNull{&builtinTypes}, TypeLevel{}); asMutable(futureAny)->reassign(myAny); CHECK(get(futureAny) != nullptr); diff --git a/tests/VisitType.test.cpp b/tests/VisitType.test.cpp index 186afaa5..86063ae8 100644 --- a/tests/VisitType.test.cpp +++ b/tests/VisitType.test.cpp @@ -4,6 +4,7 @@ #include "Luau/RecursionCounter.h" +#include "Luau/Type.h" #include "doctest.h" using namespace Luau; @@ -54,7 +55,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_throw_when_limit_is_high_enough") TEST_CASE_FIXTURE(Fixture, "some_free_types_do_not_have_bounds") { - Type t{FreeType{TypeLevel{}}}; + Type t{FreeType{TypeLevel{}, builtinTypes->neverType, builtinTypes->unknownType}}; (void)toString(&t); } diff --git a/tests/conformance/buffers.luau b/tests/conformance/buffers.luau index 5da2a688..370fb8a8 100644 --- a/tests/conformance/buffers.luau +++ b/tests/conformance/buffers.luau @@ -599,6 +599,90 @@ end misc(table.create(16, 0)) +local function bitops(size, base) + local b = buffer.create(size) + + buffer.writeu32(b, base / 8, 0x12345678) + + assert(buffer.readbits(b, base, 8) == buffer.readu8(b, base / 8)) + assert(buffer.readbits(b, base, 16) == buffer.readu16(b, base / 8)) + assert(buffer.readbits(b, base, 32) == buffer.readu32(b, base / 8)) + + buffer.writebits(b, base, 32, 0) + + buffer.writebits(b, base, 1, 1) + assert(buffer.readi8(b, base / 8) == 1) + + buffer.writebits(b, base + 1, 1, 1) + assert(buffer.readi8(b, base / 8) == 3) + + -- construct 00000010 00000000_01000000_00010000_00001000 00001000_00010000_01000010_00100101 + buffer.writebits(b, base + 0, 1, 0b1) + buffer.writebits(b, base + 1, 2, 0b10) + buffer.writebits(b, base + 3, 3, 0b100) + buffer.writebits(b, base + 6, 4, 0b1000) + buffer.writebits(b, base + 10, 5, 0b10000) + buffer.writebits(b, base + 15, 6, 0b100000) + buffer.writebits(b, base + 21, 7, 0b1000000) + buffer.writebits(b, base + 28, 8, 0b10000000) + buffer.writebits(b, base + 36, 9, 0b100000000) + buffer.writebits(b, base + 45, 10, 0b1000000000) + buffer.writebits(b, base + 55, 11, 0b10000000000) + + assert(buffer.readbits(b, base + 0, 32) == 0b00001000_00010000_01000010_00100101) + assert(buffer.readbits(b, base + 32, 32) == 0b00000000_01000000_00010000_00001000) + + assert(buffer.readu32(b, base / 8 + 0) == 0b00001000_00010000_01000010_00100101) + assert(buffer.readu32(b, base / 8 + 4) == 0b00000000_01000000_00010000_00001000) + + -- slide the window to touch 5 bytes + assert(buffer.readbits(b, base + 1, 32) == 0b00000100000010000010000100010010) + assert(buffer.readbits(b, base + 2, 32) == 0b00000010000001000001000010001001) + assert(buffer.readbits(b, base + 3, 32) == 0b00000001000000100000100001000100) + assert(buffer.readbits(b, base + 4, 32) == 0b10000000100000010000010000100010) + assert(buffer.readbits(b, base + 5, 32) == 0b01000000010000001000001000010001) + assert(buffer.readbits(b, base + 6, 32) == 0b00100000001000000100000100001000) + assert(buffer.readbits(b, base + 7, 32) == 0b00010000000100000010000010000100) + assert(buffer.readbits(b, base + 8, 32) == 0b00001000000010000001000001000010) + + assert(buffer.readbits(b, base + 1, 15) == 0b010000100010010) + assert(buffer.readbits(b, base + 2, 15) == 0b001000010001001) + assert(buffer.readbits(b, base + 3, 15) == 0b000100001000100) + assert(buffer.readbits(b, base + 4, 15) == 0b000010000100010) + assert(buffer.readbits(b, base + 5, 15) == 0b000001000010001) + assert(buffer.readbits(b, base + 6, 15) == 0b100000100001000) + assert(buffer.readbits(b, base + 7, 15) == 0b010000010000100) + assert(buffer.readbits(b, base + 8, 15) == 0b001000001000010) + + -- zero bit + buffer.writebits(b, base, 0, 0b1) + assert(buffer.readbits(b, base, 32) == 0b00001000_00010000_01000010_00100101) + assert(buffer.readbits(b, base, 0) == 0) + assert(buffer.readbits(b, size * 8, 0) == 0) + + -- bounds + assert(ecall(function() buffer.readbits(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readbits(b, size * 8, 1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readbits(b, size * 8 - 1, 2) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readbits(b, 0, 64) end) == "bit count is out of range of [0; 32]") + + assert(ecall(function() buffer.writebits(b, -1, 0, 1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writebits(b, size * 8, 1, 1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writebits(b, size * 8 - 1, 2, 1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writebits(b, 0, 64, 1) end) == "bit count is out of range of [0; 32]") + + + return b +end + +do + bitops(16, 0) + bitops(17, 8) + + -- a very large buffer and bit offsets can now be over 32 bits + bitops(1024 * 1024 * 1024, 6 * 1024 * 1024 * 1024) +end + local function testslowcalls() getfenv() @@ -619,6 +703,7 @@ local function testslowcalls() fromtostring() fill() misc(table.create(16, 0)) + bitops(16, 0) end testslowcalls() diff --git a/tests/conformance/calls.luau b/tests/conformance/calls.luau index 6555f93e..63ad81e1 100644 --- a/tests/conformance/calls.luau +++ b/tests/conformance/calls.luau @@ -237,7 +237,7 @@ if not limitedstack then end -- testing deep nested calls with a large thread stack -do +if not limitedstack then function recurse(n, ...) return n <= 1 and (1 + #{...}) or recurse(n-1, table.unpack(table.create(4000, 1))) + 1 end local ok, msg = pcall(recurse, 19000) diff --git a/tests/conformance/math.luau b/tests/conformance/math.luau index 97c44462..586023ed 100644 --- a/tests/conformance/math.luau +++ b/tests/conformance/math.luau @@ -402,6 +402,23 @@ assert(math.map(4, 4, 1, 2, 0) == 2) assert(math.map(-8, 0, 4, 0, 2) == -4) assert(math.map(16, 0, 4, 0, 2) == 8) +-- lerp basics +assert(math.lerp(1, 5, 0) == 1) +assert(math.lerp(1, 5, 1) == 5) +assert(math.lerp(1, 5, 0.5) == 3) +assert(math.lerp(1, 5, 1.5) == 7) +assert(math.lerp(1, 5, -0.5) == -1) +assert(math.lerp(1, 5, noinline(0.5)) == 3) + +-- lerp properties +local sq2, sq3 = math.sqrt(2), math.sqrt(3) +assert(math.lerp(sq2, sq3, 0) == sq2) -- exact at 0 +assert(math.lerp(sq2, sq3, 1) == sq3) -- exact at 1 +assert(math.lerp(-sq3, sq2, 1) == sq2) -- exact at 1 (fails for a + t*(b-a)) +assert(math.lerp(sq2, sq2, sq2 / 2) <= math.lerp(sq2, sq2, 1)) -- monotonic (fails for a*t + b*(1-t)) +assert(math.lerp(-sq3, sq2, 1) <= math.sqrt(2)) -- bounded (fails for a + t*(b-a)) +assert(math.lerp(sq2, sq2, sq2 / 2) == sq2) -- consistent (fails for a*t + b*(1-t)) + assert(tostring(math.pow(-2, 0.5)) == "nan") -- test that fastcalls return correct number of results @@ -464,5 +481,6 @@ assert(math.sign("2") == 1) assert(math.sign("-2") == -1) assert(math.sign("0") == 0) assert(math.round("1.8") == 2) +assert(math.lerp("1", "5", 0.5) == 3) return('OK') diff --git a/tests/conformance/native.luau b/tests/conformance/native.luau index 03845013..16172bab 100644 --- a/tests/conformance/native.luau +++ b/tests/conformance/native.luau @@ -513,4 +513,68 @@ end assert(extramath3(2) == "number") assert(extramath3("2") == "number") +local function slotcachelimit1() + local tbl = { + f1 = function() return 1 end, + f2 = function() return 2 end, + f3 = function() return 3 end, + f4 = function() return 4 end, + f5 = function() return 5 end, + f6 = function() return 6 end, + f7 = function() return 7 end, + f8 = function() return 8 end, + f9 = function() return 9 end, + f10 = function() return 10 end, + f11 = function() return 11 end, + f12 = function() return 12 end, + f13 = function() return 13 end, + f14 = function() return 14 end, + f15 = function() return 15 end, + f16 = function() return 16 end, + } + + local lookup = { + [tbl.f1] = 1, + [tbl.f2] = 2, + [tbl.f3] = 3, + [tbl.f4] = 4, + [tbl.f5] = 5, + [tbl.f6] = 6, + [tbl.f7] = 7, + [tbl.f8] = 8, + [tbl.f9] = 9, + [tbl.f10] = 10, + [tbl.f11] = 11, + [tbl.f12] = 12, + [tbl.f13] = 13, + [tbl.f14] = 14, + [tbl.f15] = 15, + [tbl.f16] = 16, + } + + assert(is_native()) + + return lookup +end + +slotcachelimit1() + +local function slotcachelimit2(foo, size) + local c1 = foo(vector.create(size.X, size.Y, size.Z)) + local c2 = foo(vector.create(-size.X, size.Y, size.Z)) + local c3 = foo(vector.create(-size.X, -size.Y, size.Z)) + local c4 = foo(vector.create(-size.X, -size.Y, -size.Z)) + local c5 = foo(vector.create(size.X, -size.Y, -size.Z)) + local c6 = foo(vector.create(size.X, size.Y, -size.Z)) + local c7 = foo(vector.create(size.X, -size.Y, size.Z)) + local c8 = foo(vector.create(-size.X, size.Y, -size.Z)) + local max = vector.create(math.max(c1.X, c2.X, c3.X, c4.X, c5.X, c6.X, c7.X, c8.X), math.max(c1.Y, c2.Y, c3.Y, c4.Y, c5.Y, c6.Y, c7.Y, c8.Y), math.max(c1.Z, c2.Z, c3.Z, c4.Z, c5.Z, c6.Z, c7.Z, c8.Z)) + local min = vector.create(math.min(c1.X, c2.X, c3.X, c4.X, c5.X, c6.X, c7.X, c8.X), math.min(c1.Y, c2.Y, c3.Y, c4.Y, c5.Y, c6.Y, c7.Y, c8.Y), math.min(c1.Z, c2.Z, c3.Z, c4.Z, c5.Z, c6.Z, c7.Z, c8.Z)) + + assert(is_native()) + return max - min +end + +slotcachelimit2(function(a) return -a end, vector.create(1, 2, 3)) + return('OK') diff --git a/tests/conformance/vector_library.luau b/tests/conformance/vector_library.luau index 3f30d900..dd5f2d1b 100644 --- a/tests/conformance/vector_library.luau +++ b/tests/conformance/vector_library.luau @@ -11,8 +11,15 @@ function ecall(fn, ...) end -- make sure we cover both builtin and C impl +assert(vector.create(1, 2) == vector.create("1", "2")) assert(vector.create(1, 2, 4) == vector.create("1", "2", "4")) +-- 'create' +local v12 = vector.create(1, 2) +local v123 = vector.create(1, 2, 3) +assert(v12.x == 1 and v12.y == 2 and v12.z == 0) +assert(v123.x == 1 and v123.y == 2 and v123.z == 3) + -- testing 'dot' with error handling and different call kinds to mostly check details in the codegen assert(vector.dot(vector.create(1, 2, 4), vector.create(5, 6, 7)) == 45) assert(ecall(function() vector.dot(vector.create(1, 2, 4)) end) == "missing argument #2 to 'dot' (vector expected)") diff --git a/tests/main.cpp b/tests/main.cpp index bd5a0517..005a3e61 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Common.h" +#include "Luau/CodeGenCommon.h" + #define DOCTEST_CONFIG_IMPLEMENT // Our calls to parseOption/parseFlag don't provide a prefix so set the prefix to the empty string. #define DOCTEST_CONFIG_OPTIONS_PREFIX "" diff --git a/tools/natvis/VM.natvis b/tools/natvis/VM.natvis index 59bc43c4..adf603eb 100644 --- a/tools/natvis/VM.natvis +++ b/tools/natvis/VM.natvis @@ -77,7 +77,7 @@ --- - + table metatable