diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index db2f6712..7dc38835 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -65,10 +65,7 @@ TypeId makeFunction( // Polymorphic bool checked = false ); -void attachMagicFunction(TypeId ty, MagicFunction fn); -void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn); -void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn); -void attachDcrMagicFunctionTypeCheck(TypeId ty, DcrMagicFunctionTypeCheck fn); +void attachMagicFunction(TypeId ty, std::shared_ptr fn); Property makeProperty(TypeId ty, std::optional documentationSymbol = std::nullopt); void assignPropDocumentationSymbols(TableType::Props& props, const std::string& baseName); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index bb358abb..b8eaac56 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -166,7 +166,7 @@ struct ConstraintSolver **/ void finalizeTypeFunctions(); - bool isDone(); + bool isDone() const; private: /** @@ -298,10 +298,10 @@ public: // FIXME: This use of a boolean for the return result is an appalling // interface. bool blockOnPendingTypes(TypeId target, NotNull constraint); - bool blockOnPendingTypes(TypePackId target, NotNull constraint); + bool blockOnPendingTypes(TypePackId targetPack, NotNull constraint); void unblock(NotNull progressed); - void unblock(TypeId progressed, Location location); + void unblock(TypeId ty, Location location); void unblock(TypePackId progressed, Location location); void unblock(const std::vector& types, Location location); void unblock(const std::vector& packs, Location location); @@ -336,7 +336,7 @@ public: * @param location the location where the require is taking place; used for * error locations. **/ - TypeId resolveModule(const ModuleInfo& module, const Location& location); + TypeId resolveModule(const ModuleInfo& info, const Location& location); void reportError(TypeErrorData&& data, const Location& location); void reportError(TypeError e); diff --git a/Analysis/include/Luau/DataFlowGraph.h b/Analysis/include/Luau/DataFlowGraph.h index 83dfa4b7..7c0e81ac 100644 --- a/Analysis/include/Luau/DataFlowGraph.h +++ b/Analysis/include/Luau/DataFlowGraph.h @@ -6,6 +6,7 @@ #include "Luau/ControlFlow.h" #include "Luau/DenseHash.h" #include "Luau/Def.h" +#include "Luau/NotNull.h" #include "Luau/Symbol.h" #include "Luau/TypedAllocator.h" @@ -48,13 +49,13 @@ struct DataFlowGraph const RefinementKey* getRefinementKey(const AstExpr* expr) const; private: - DataFlowGraph() = default; + DataFlowGraph(NotNull defArena, NotNull keyArena); DataFlowGraph(const DataFlowGraph&) = delete; DataFlowGraph& operator=(const DataFlowGraph&) = delete; - DefArena defArena; - RefinementKeyArena keyArena; + NotNull defArena; + NotNull keyArena; DenseHashMap astDefs{nullptr}; @@ -110,30 +111,22 @@ using ScopeStack = std::vector; struct DataFlowGraphBuilder { - static DataFlowGraph build(AstStatBlock* root, NotNull handle); - - /** - * This method is identical to the build method above, but returns a pair of dfg, scopes as the data flow graph - * here is intended to live on the module between runs of typechecking. Before, the DFG only needed to live as - * long as the typecheck, but in a world with incremental typechecking, we need the information on the dfg to incrementally - * typecheck small fragments of code. - * @param block - pointer to the ast to build the dfg for - * @param handle - for raising internal errors while building the dfg - */ - static std::pair, std::vector>> buildShared( + static DataFlowGraph build( AstStatBlock* block, - NotNull handle + NotNull defArena, + NotNull keyArena, + NotNull handle ); private: - DataFlowGraphBuilder() = default; + DataFlowGraphBuilder(NotNull defArena, NotNull keyArena); DataFlowGraphBuilder(const DataFlowGraphBuilder&) = delete; DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete; DataFlowGraph graph; - NotNull defArena{&graph.defArena}; - NotNull keyArena{&graph.keyArena}; + NotNull defArena; + NotNull keyArena; struct InternalErrorReporter* handle = nullptr; diff --git a/Analysis/include/Luau/EqSatSimplificationImpl.h b/Analysis/include/Luau/EqSatSimplificationImpl.h index 2e704e98..c9f49d52 100644 --- a/Analysis/include/Luau/EqSatSimplificationImpl.h +++ b/Analysis/include/Luau/EqSatSimplificationImpl.h @@ -105,6 +105,9 @@ private: std::vector storage; }; +template +using Node = EqSat::Node; + using EType = EqSat::Language< TNil, TBoolean, @@ -171,6 +174,9 @@ struct Subst Id eclass; Id newClass; + // The node into eclass which is boring, if any + std::optional boringIndex; + std::string desc; Subst(Id eclass, Id newClass, std::string desc = ""); @@ -211,6 +217,7 @@ struct Simplifier void subst(Id from, Id to); void subst(Id from, Id to, const std::string& ruleName); void subst(Id from, Id to, const std::string& ruleName, const std::unordered_map& forceNodes); + void subst(Id from, size_t boringIndex, Id to, const std::string& ruleName, const std::unordered_map& forceNodes); void unionClasses(std::vector& hereParts, Id there); @@ -295,13 +302,13 @@ QueryIterator::QueryIterator(EGraph* egraph_, Id eclass) for (const auto& enode : ecl.nodes) { - if (enode.index() < idx) + if (enode.node.index() < idx) ++index; else break; } - if (index >= ecl.nodes.size() || ecl.nodes[index].index() != idx) + if (index >= ecl.nodes.size() || ecl.nodes[index].node.index() != idx) { egraph = nullptr; index = 0; @@ -331,7 +338,7 @@ std::pair QueryIterator::operator*() const EGraph::EClassT& ecl = (*egraph)[eclass]; LUAU_ASSERT(index < ecl.nodes.size()); - auto& enode = ecl.nodes[index]; + auto& enode = ecl.nodes[index].node; Tag* result = enode.template get(); LUAU_ASSERT(result); return {result, index}; @@ -343,12 +350,16 @@ QueryIterator& QueryIterator::operator++() { const auto& ecl = (*egraph)[eclass]; - ++index; - if (index >= ecl.nodes.size() || ecl.nodes[index].index() != EType::VariantTy::getTypeId()) + do { - egraph = nullptr; - index = 0; - } + ++index; + if (index >= ecl.nodes.size() || ecl.nodes[index].node.index() != EType::VariantTy::getTypeId()) + { + egraph = nullptr; + index = 0; + break; + } + } while (ecl.nodes[index].boring); return *this; } diff --git a/Analysis/include/Luau/FragmentAutocomplete.h b/Analysis/include/Luau/FragmentAutocomplete.h index 2125cc41..bf67b8b6 100644 --- a/Analysis/include/Luau/FragmentAutocomplete.h +++ b/Analysis/include/Luau/FragmentAutocomplete.h @@ -17,8 +17,8 @@ struct FrontendOptions; enum class FragmentTypeCheckStatus { - Success, SkipAutocomplete, + Success, }; struct FragmentAutocompleteAncestryResult @@ -56,7 +56,7 @@ struct FragmentAutocompleteResult FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos); -FragmentParseResult parseFragment( +std::optional parseFragment( const SourceModule& srcModule, std::string_view src, const Position& cursorPos, diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 3f3e69f1..7346a422 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -139,6 +139,11 @@ struct Module TypePackId returnType = nullptr; std::unordered_map exportedTypeBindings; + // Arenas related to the DFG must persist after the DFG no longer exists, as + // Module objects maintain raw pointers to objects in these arenas. + DefArena defArena; + RefinementKeyArena keyArena; + bool hasModuleScope() const; ScopePtr getModuleScope() const; diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 701fe051..558a5110 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -131,14 +131,14 @@ struct BlockedType BlockedType(); int index; - Constraint* getOwner() const; - void setOwner(Constraint* newOwner); - void replaceOwner(Constraint* newOwner); + const Constraint* getOwner() const; + void setOwner(const Constraint* newOwner); + void replaceOwner(const Constraint* newOwner); private: // The constraint that is intended to unblock this type. Other constraints // should block on this constraint if present. - Constraint* owner = nullptr; + const Constraint* owner = nullptr; }; struct PrimitiveType @@ -279,9 +279,6 @@ struct WithPredicate } }; -using MagicFunction = std::function>(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate)>; - struct MagicFunctionCallContext { NotNull solver; @@ -291,7 +288,6 @@ struct MagicFunctionCallContext TypePackId result; }; -using DcrMagicFunction = std::function; struct MagicRefinementContext { NotNull scope; @@ -308,8 +304,29 @@ struct MagicFunctionTypeCheckContext NotNull checkScope; }; -using DcrMagicRefinement = void (*)(const MagicRefinementContext&); -using DcrMagicFunctionTypeCheck = std::function; +struct MagicFunction +{ + virtual std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) = 0; + + // Callback to allow custom typechecking of builtin function calls whose argument types + // will only be resolved after constraint solving. For example, the arguments to string.format + // have types that can only be decided after parsing the format string and unifying + // with the passed in values, but the correctness of the call can only be decided after + // all the types have been finalized. + virtual bool infer(const MagicFunctionCallContext&) = 0; + virtual void refine(const MagicRefinementContext&) {} + + // If a magic function needs to do its own special typechecking, do it here. + // Returns true if magic typechecking was performed. Return false if the + // default typechecking logic should run. + virtual bool typeCheck(const MagicFunctionTypeCheckContext&) + { + return false; + } + + virtual ~MagicFunction() {} +}; + struct FunctionType { // Global monomorphic function @@ -367,16 +384,7 @@ struct FunctionType Scope* scope = nullptr; TypePackId argTypes; TypePackId retTypes; - MagicFunction magicFunction = nullptr; - DcrMagicFunction dcrMagicFunction = nullptr; - DcrMagicRefinement dcrMagicRefinement = nullptr; - - // Callback to allow custom typechecking of builtin function calls whose argument types - // will only be resolved after constraint solving. For example, the arguments to string.format - // have types that can only be decided after parsing the format string and unifying - // with the passed in values, but the correctness of the call can only be decided after - // all the types have been finalized. - DcrMagicFunctionTypeCheck dcrMagicTypeCheck = nullptr; + std::shared_ptr magic = nullptr; bool hasSelf; // `hasNoFreeOrGenericTypes` should be true if and only if the type does not have any free or generic types present inside it. diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index 0e5475a7..a9685462 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -85,6 +85,8 @@ struct GenericTypeVisitor { } + virtual ~GenericTypeVisitor() {} + virtual void cycle(TypeId) {} virtual void cycle(TypePackId) {} diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 96c4ea10..815164d8 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -13,6 +13,8 @@ LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon) + namespace Luau { @@ -41,11 +43,26 @@ struct AutocompleteNodeFinder : public AstVisitor bool visit(AstStat* stat) override { - if (stat->location.begin < pos && pos <= stat->location.end) + if (FFlag::LuauExtendStatEndPosWithSemicolon) { - ancestry.push_back(stat); - return true; + // Consider 'local myLocal = 4;|' and 'local myLocal = 4', where '|' is the cursor position. In both cases, the cursor position is equal + // to `AstStatLocal.location.end`. However, in the first case (semicolon), we are starting a new statement, whilst in the second case + // (no semicolon) we are still part of the AstStatLocal, hence the different comparison check. + if (stat->location.begin < pos && (stat->hasSemicolon ? pos < stat->location.end : pos <= stat->location.end)) + { + ancestry.push_back(stat); + return true; + } } + else + { + if (stat->location.begin < pos && pos <= stat->location.end) + { + ancestry.push_back(stat); + return true; + } + } + return false; } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index fba3c964..78503fed 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -34,46 +34,78 @@ LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix) LUAU_FASTFLAGVARIABLE(LuauStringFormatErrorSuppression) LUAU_FASTFLAG(AutocompleteRequirePathSuggestions2) LUAU_FASTFLAG(LuauVectorDefinitionsExtra) +LUAU_FASTFLAGVARIABLE(LuauTableCloneClonesType) +LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) namespace Luau { -static std::optional> magicFunctionSelect( - TypeChecker& typechecker, - const ScopePtr& scope, - const AstExprCall& expr, - WithPredicate withPredicate -); -static std::optional> magicFunctionSetMetaTable( - TypeChecker& typechecker, - const ScopePtr& scope, - const AstExprCall& expr, - WithPredicate withPredicate -); -static std::optional> magicFunctionAssert( - TypeChecker& typechecker, - const ScopePtr& scope, - const AstExprCall& expr, - WithPredicate withPredicate -); -static std::optional> magicFunctionPack( - TypeChecker& typechecker, - const ScopePtr& scope, - const AstExprCall& expr, - WithPredicate withPredicate -); -static std::optional> magicFunctionRequire( - TypeChecker& typechecker, - const ScopePtr& scope, - const AstExprCall& expr, - WithPredicate withPredicate -); +struct MagicSelect final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; +struct MagicSetMetatable final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; -static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); -static bool dcrMagicFunctionRequire(MagicFunctionCallContext context); -static bool dcrMagicFunctionPack(MagicFunctionCallContext context); -static bool dcrMagicFunctionFreeze(MagicFunctionCallContext context); +struct MagicAssert final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicPack final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicRequire final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicClone final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicFreeze final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicFormat final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; + bool typeCheck(const MagicFunctionTypeCheckContext& ctx) override; +}; + +struct MagicMatch final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicGmatch final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicFind final : MagicFunction +{ + std::optional> handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; TypeId makeUnion(TypeArena& arena, std::vector&& types) { @@ -168,34 +200,10 @@ TypeId makeFunction( return arena.addType(std::move(ftv)); } -void attachMagicFunction(TypeId ty, MagicFunction fn) +void attachMagicFunction(TypeId ty, std::shared_ptr magic) { if (auto ftv = getMutable(ty)) - ftv->magicFunction = fn; - else - LUAU_ASSERT(!"Got a non functional type"); -} - -void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn) -{ - if (auto ftv = getMutable(ty)) - ftv->dcrMagicFunction = fn; - else - LUAU_ASSERT(!"Got a non functional type"); -} - -void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn) -{ - if (auto ftv = getMutable(ty)) - ftv->dcrMagicRefinement = fn; - else - LUAU_ASSERT(!"Got a non functional type"); -} - -void attachDcrMagicFunctionTypeCheck(TypeId ty, DcrMagicFunctionTypeCheck fn) -{ - if (auto ftv = getMutable(ty)) - ftv->dcrMagicTypeCheck = fn; + ftv->magic = std::move(magic); else LUAU_ASSERT(!"Got a non functional type"); } @@ -396,7 +404,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC } } - attachMagicFunction(getGlobalBinding(globals, "assert"), magicFunctionAssert); + attachMagicFunction(getGlobalBinding(globals, "assert"), std::make_shared()); if (FFlag::LuauSolverV2) { @@ -412,9 +420,8 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC addGlobalBinding(globals, "assert", assertTy, "@luau"); } - attachMagicFunction(getGlobalBinding(globals, "setmetatable"), magicFunctionSetMetaTable); - attachMagicFunction(getGlobalBinding(globals, "select"), magicFunctionSelect); - attachDcrMagicFunction(getGlobalBinding(globals, "select"), dcrMagicFunctionSelect); + attachMagicFunction(getGlobalBinding(globals, "setmetatable"), std::make_shared()); + attachMagicFunction(getGlobalBinding(globals, "select"), std::make_shared()); if (TableType* ttv = getMutable(getGlobalBinding(globals, "table"))) { @@ -445,23 +452,22 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC ttv->props["foreach"].deprecated = true; ttv->props["foreachi"].deprecated = true; - attachMagicFunction(ttv->props["pack"].type(), magicFunctionPack); - attachDcrMagicFunction(ttv->props["pack"].type(), dcrMagicFunctionPack); + attachMagicFunction(ttv->props["pack"].type(), std::make_shared()); + if (FFlag::LuauTableCloneClonesType) + attachMagicFunction(ttv->props["clone"].type(), std::make_shared()); if (FFlag::LuauTypestateBuiltins2) - attachDcrMagicFunction(ttv->props["freeze"].type(), dcrMagicFunctionFreeze); + attachMagicFunction(ttv->props["freeze"].type(), std::make_shared()); } if (FFlag::AutocompleteRequirePathSuggestions2) { TypeId requireTy = getGlobalBinding(globals, "require"); attachTag(requireTy, kRequireTagName); - attachMagicFunction(requireTy, magicFunctionRequire); - attachDcrMagicFunction(requireTy, dcrMagicFunctionRequire); + attachMagicFunction(requireTy, std::make_shared()); } else { - attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); - attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); + attachMagicFunction(getGlobalBinding(globals, "require"), std::make_shared()); } } @@ -501,7 +507,7 @@ static std::vector parseFormatString(NotNull builtinTypes, return result; } -std::optional> magicFunctionFormat( +std::optional> MagicFormat::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -551,7 +557,7 @@ std::optional> magicFunctionFormat( return WithPredicate{arena.addTypePack({typechecker.stringType})}; } -static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) +bool MagicFormat::infer(const MagicFunctionCallContext& context) { TypeArena* arena = context.solver->arena; @@ -595,7 +601,7 @@ static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) return true; } -static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext context) +bool MagicFormat::typeCheck(const MagicFunctionTypeCheckContext& context) { AstExprConstantString* fmt = nullptr; if (auto index = context.callSite->func->as(); index && context.callSite->self) @@ -615,7 +621,7 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex context.typechecker->reportError( CountMismatch{1, std::nullopt, 0, CountMismatch::Arg, true, "string.format"}, context.callSite->location ); - return; + return true; } std::vector expected = parseFormatString(context.builtinTypes, fmt->value.data, fmt->value.size); @@ -657,6 +663,8 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex } } } + + return true; } static std::vector parsePatternString(NotNull builtinTypes, const char* data, size_t size) @@ -719,7 +727,7 @@ static std::vector parsePatternString(NotNull builtinTypes return result; } -static std::optional> magicFunctionGmatch( +std::optional> MagicGmatch::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -755,7 +763,7 @@ static std::optional> magicFunctionGmatch( return WithPredicate{arena.addTypePack({iteratorType})}; } -static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) +bool MagicGmatch::infer(const MagicFunctionCallContext& context) { const auto& [params, tail] = flatten(context.arguments); @@ -788,7 +796,7 @@ static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) return true; } -static std::optional> magicFunctionMatch( +std::optional> MagicMatch::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -828,7 +836,7 @@ static std::optional> magicFunctionMatch( return WithPredicate{returnList}; } -static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) +bool MagicMatch::infer(const MagicFunctionCallContext& context) { const auto& [params, tail] = flatten(context.arguments); @@ -864,7 +872,7 @@ static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) return true; } -static std::optional> magicFunctionFind( +std::optional> MagicFind::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -922,7 +930,7 @@ static std::optional> magicFunctionFind( return WithPredicate{returnList}; } -static bool dcrMagicFunctionFind(MagicFunctionCallContext context) +bool MagicFind::infer(const MagicFunctionCallContext& context) { const auto& [params, tail] = flatten(context.arguments); @@ -999,11 +1007,9 @@ TypeId makeStringMetatable(NotNull builtinTypes) FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack}; - formatFTV.magicFunction = &magicFunctionFormat; formatFTV.isCheckedFunction = true; const TypeId formatFn = arena->addType(formatFTV); - attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); - attachDcrMagicFunctionTypeCheck(formatFn, dcrMagicFunctionTypeCheckFormat); + attachMagicFunction(formatFn, std::make_shared()); const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true); @@ -1017,16 +1023,14 @@ TypeId makeStringMetatable(NotNull builtinTypes) makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false); const TypeId gmatchFunc = makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true); - attachMagicFunction(gmatchFunc, magicFunctionGmatch); - attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); + attachMagicFunction(gmatchFunc, std::make_shared()); FunctionType matchFuncTy{ arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}) }; matchFuncTy.isCheckedFunction = true; const TypeId matchFunc = arena->addType(matchFuncTy); - attachMagicFunction(matchFunc, magicFunctionMatch); - attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); + attachMagicFunction(matchFunc, std::make_shared()); FunctionType findFuncTy{ arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), @@ -1034,8 +1038,7 @@ TypeId makeStringMetatable(NotNull builtinTypes) }; findFuncTy.isCheckedFunction = true; const TypeId findFunc = arena->addType(findFuncTy); - attachMagicFunction(findFunc, magicFunctionFind); - attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); + attachMagicFunction(findFunc, std::make_shared()); // string.byte : string -> number? -> number? -> ...number FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList}; @@ -1096,7 +1099,7 @@ TypeId makeStringMetatable(NotNull builtinTypes) return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } -static std::optional> magicFunctionSelect( +std::optional> MagicSelect::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -1141,7 +1144,7 @@ static std::optional> magicFunctionSelect( return std::nullopt; } -static bool dcrMagicFunctionSelect(MagicFunctionCallContext context) +bool MagicSelect::infer(const MagicFunctionCallContext& context) { if (context.callSite->args.size <= 0) { @@ -1186,7 +1189,7 @@ static bool dcrMagicFunctionSelect(MagicFunctionCallContext context) return false; } -static std::optional> magicFunctionSetMetaTable( +std::optional> MagicSetMetatable::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -1268,7 +1271,12 @@ static std::optional> magicFunctionSetMetaTable( return WithPredicate{arena.addTypePack({target})}; } -static std::optional> magicFunctionAssert( +bool MagicSetMetatable::infer(const MagicFunctionCallContext&) +{ + return false; +} + +std::optional> MagicAssert::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -1302,7 +1310,12 @@ static std::optional> magicFunctionAssert( return WithPredicate{arena.addTypePack(TypePack{std::move(head), tail})}; } -static std::optional> magicFunctionPack( +bool MagicAssert::infer(const MagicFunctionCallContext&) +{ + return false; +} + +std::optional> MagicPack::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -1345,7 +1358,7 @@ static std::optional> magicFunctionPack( return WithPredicate{arena.addTypePack({packedTable})}; } -static bool dcrMagicFunctionPack(MagicFunctionCallContext context) +bool MagicPack::infer(const MagicFunctionCallContext& context) { TypeArena* arena = context.solver->arena; @@ -1385,7 +1398,68 @@ static bool dcrMagicFunctionPack(MagicFunctionCallContext context) return true; } -static std::optional freezeTable(TypeId inputType, MagicFunctionCallContext& context) +std::optional> MagicClone::handleOldSolver( + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) +{ + LUAU_ASSERT(FFlag::LuauTableCloneClonesType); + + auto [paramPack, _predicates] = withPredicate; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + const auto& [paramTypes, paramTail] = flatten(paramPack); + if (paramTypes.empty() || expr.args.size == 0) + { + typechecker.reportError(expr.argLocation, CountMismatch{1, std::nullopt, 0}); + return std::nullopt; + } + + TypeId inputType = follow(paramTypes[0]); + + CloneState cloneState{typechecker.builtinTypes}; + TypeId resultType = shallowClone(inputType, arena, cloneState); + + TypePackId clonedTypePack = arena.addTypePack({resultType}); + return WithPredicate{clonedTypePack}; +} + +bool MagicClone::infer(const MagicFunctionCallContext& context) +{ + LUAU_ASSERT(FFlag::LuauTableCloneClonesType); + + TypeArena* arena = context.solver->arena; + + const auto& [paramTypes, paramTail] = flatten(context.arguments); + if (paramTypes.empty() || context.callSite->args.size == 0) + { + context.solver->reportError(CountMismatch{1, std::nullopt, 0}, context.callSite->argLocation); + return false; + } + + TypeId inputType = follow(paramTypes[0]); + + CloneState cloneState{context.solver->builtinTypes}; + TypeId resultType = shallowClone(inputType, *arena, cloneState); + + if (auto tableType = getMutable(resultType)) + { + tableType->scope = context.constraint->scope.get(); + } + + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + trackInteriorFreeType(context.constraint->scope.get(), resultType); + + TypePackId clonedTypePack = arena->addTypePack({resultType}); + asMutable(context.result)->ty.emplace(clonedTypePack); + + return true; +} + +static std::optional freezeTable(TypeId inputType, const MagicFunctionCallContext& context) { TypeArena* arena = context.solver->arena; @@ -1430,7 +1504,12 @@ static std::optional freezeTable(TypeId inputType, MagicFunctionCallCont return std::nullopt; } -static bool dcrMagicFunctionFreeze(MagicFunctionCallContext context) +std::optional> MagicFreeze::handleOldSolver(struct TypeChecker &, const std::shared_ptr &, const class AstExprCall &, WithPredicate) +{ + return std::nullopt; +} + +bool MagicFreeze::infer(const MagicFunctionCallContext& context) { LUAU_ASSERT(FFlag::LuauTypestateBuiltins2); @@ -1491,7 +1570,7 @@ static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) return good; } -static std::optional> magicFunctionRequire( +std::optional> MagicRequire::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -1537,7 +1616,7 @@ static bool checkRequirePathDcr(NotNull solver, AstExpr* expr) return good; } -static bool dcrMagicFunctionRequire(MagicFunctionCallContext context) +bool MagicRequire::infer(const MagicFunctionCallContext& context) { if (context.callSite->args.size != 1) { diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index cde566d8..b0f7c432 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -3,8 +3,6 @@ #include "Luau/Constraint.h" #include "Luau/VisitType.h" -LUAU_FASTFLAGVARIABLE(LuauDontRefCountTypesInTypeFunctions) - namespace Luau { @@ -60,7 +58,7 @@ struct ReferenceCountInitializer : TypeOnceVisitor // // The default behavior here is `true` for "visit the child types" // of this type, hence: - return !FFlag::LuauDontRefCountTypesInTypeFunctions; + return false; } }; diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index e6e54916..1d0dc41a 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -75,7 +75,7 @@ size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const { if (auto blocked = get(ty)) { - Constraint* owner = blocked->getOwner(); + const Constraint* owner = blocked->getOwner(); LUAU_ASSERT(owner); return owner == constraint; } @@ -446,7 +446,7 @@ void ConstraintSolver::run() if (success) { unblock(c); - unsolvedConstraints.erase(unsolvedConstraints.begin() + i); + unsolvedConstraints.erase(unsolvedConstraints.begin() + ptrdiff_t(i)); // decrement the referenced free types for this constraint if we dispatched successfully! for (auto ty : c->getMaybeMutatedFreeTypes()) @@ -553,7 +553,7 @@ void ConstraintSolver::finalizeTypeFunctions() } } -bool ConstraintSolver::isDone() +bool ConstraintSolver::isDone() const { return unsolvedConstraints.empty(); } @@ -1293,11 +1293,11 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulldcrMagicFunction) - usedMagic = ftv->dcrMagicFunction(MagicFunctionCallContext{NotNull{this}, constraint, c.callSite, c.argsPack, result}); - - if (ftv->dcrMagicRefinement) - ftv->dcrMagicRefinement(MagicRefinementContext{constraint->scope, c.callSite, c.discriminantTypes}); + if (ftv->magic) + { + usedMagic = ftv->magic->infer(MagicFunctionCallContext{NotNull{this}, constraint, c.callSite, c.argsPack, result}); + ftv->magic->refine(MagicRefinementContext{constraint->scope, c.callSite, c.discriminantTypes}); + } } if (!usedMagic) @@ -1702,7 +1702,7 @@ bool ConstraintSolver::tryDispatchHasIndexer( for (TypeId part : parts) { TypeId r = arena->addType(BlockedType{}); - getMutable(r)->setOwner(const_cast(constraint.get())); + getMutable(r)->setOwner(constraint.get()); bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen); // If we've cut a recursive loop short, skip it. @@ -1734,7 +1734,7 @@ bool ConstraintSolver::tryDispatchHasIndexer( for (TypeId part : parts) { TypeId r = arena->addType(BlockedType{}); - getMutable(r)->setOwner(const_cast(constraint.get())); + getMutable(r)->setOwner(constraint.get()); bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen); // If we've cut a recursive loop short, skip it. @@ -2874,10 +2874,10 @@ bool ConstraintSolver::blockOnPendingTypes(TypeId target, NotNull constraint) +bool ConstraintSolver::blockOnPendingTypes(TypePackId targetPack, NotNull constraint) { Blocker blocker{NotNull{this}, constraint}; - blocker.traverse(pack); + blocker.traverse(targetPack); return !blocker.blocked; } diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp index 9925f29c..3f724f2c 100644 --- a/Analysis/src/DataFlowGraph.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -62,6 +62,12 @@ const RefinementKey* RefinementKeyArena::node(const RefinementKey* parent, DefId return allocator.allocate(RefinementKey{parent, def, propName}); } +DataFlowGraph::DataFlowGraph(NotNull defArena, NotNull keyArena) + : defArena{defArena} + , keyArena{keyArena} +{ +} + DefId DataFlowGraph::getDef(const AstExpr* expr) const { auto def = astDefs.find(expr); @@ -178,11 +184,23 @@ bool DfgScope::canUpdateDefinition(DefId def, const std::string& key) const return true; } -DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull handle) +DataFlowGraphBuilder::DataFlowGraphBuilder(NotNull defArena, NotNull keyArena) + : graph{defArena, keyArena} + , defArena{defArena} + , keyArena{keyArena} +{ +} + +DataFlowGraph DataFlowGraphBuilder::build( + AstStatBlock* block, + NotNull defArena, + NotNull keyArena, + NotNull handle +) { LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking"); - DataFlowGraphBuilder builder; + DataFlowGraphBuilder builder(defArena, keyArena); builder.handle = handle; DfgScope* moduleScope = builder.makeChildScope(); PushScope ps{builder.scopeStack, moduleScope}; @@ -198,30 +216,6 @@ DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull, std::vector>> DataFlowGraphBuilder::buildShared( - AstStatBlock* block, - NotNull handle -) -{ - - LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking"); - - DataFlowGraphBuilder builder; - builder.handle = handle; - DfgScope* moduleScope = builder.makeChildScope(); - PushScope ps{builder.scopeStack, moduleScope}; - builder.visitBlockWithoutChildScope(block); - builder.resolveCaptures(); - - if (FFlag::DebugLuauFreezeArena) - { - builder.defArena->allocator.freeze(); - builder.keyArena->allocator.freeze(); - } - - return {std::make_shared(std::move(builder.graph)), std::move(builder.scopes)}; -} - void DataFlowGraphBuilder::resolveCaptures() { for (const auto& [_, capture] : captures) diff --git a/Analysis/src/EqSatSimplification.cpp b/Analysis/src/EqSatSimplification.cpp index 71a5d2a7..709dafb5 100644 --- a/Analysis/src/EqSatSimplification.cpp +++ b/Analysis/src/EqSatSimplification.cpp @@ -206,7 +206,7 @@ static bool isTerminal(const EGraph& egraph, Id eclass) nodes.end(), [](auto& a) { - return isTerminal(a); + return isTerminal(a.node); } ); } @@ -464,7 +464,7 @@ static size_t computeCost(std::unordered_map& bestNodes, const EGrap if (auto it = costs.find(id); it != costs.end()) return it->second; - const std::vector& nodes = egraph[id].nodes; + const std::vector>& nodes = egraph[id].nodes; size_t minCost = std::numeric_limits::max(); size_t bestNode = std::numeric_limits::max(); @@ -481,7 +481,7 @@ static size_t computeCost(std::unordered_map& bestNodes, const EGrap // First, quickly scan for a terminal type. If we can find one, it is obviously the best. for (size_t index = 0; index < nodes.size(); ++index) { - if (isTerminal(nodes[index])) + if (isTerminal(nodes[index].node)) { minCost = 1; bestNode = index; @@ -533,44 +533,44 @@ static size_t computeCost(std::unordered_map& bestNodes, const EGrap { const auto& node = nodes[index]; - if (node.get()) + if (node.node.get()) updateCost(BOUND_PENALTY, index); // TODO: This could probably be an assert now that we don't need rewrite rules to handle TBound. - else if (node.get()) + else if (node.node.get()) { minCost = 1; bestNode = index; } - else if (auto tbl = node.get()) + else if (auto tbl = node.node.get()) { // TODO: We could make the penalty a parameter to computeChildren. std::optional maybeCost = computeChildren(tbl->operands(), minCost); if (maybeCost) updateCost(TABLE_TYPE_PENALTY + *maybeCost, index); } - else if (node.get()) + else if (node.node.get()) { minCost = IMPORTED_TABLE_PENALTY; bestNode = index; } - else if (auto u = node.get()) + else if (auto u = node.node.get()) { std::optional maybeCost = computeChildren(u->operands(), minCost); if (maybeCost) updateCost(SET_TYPE_PENALTY + *maybeCost, index); } - else if (auto i = node.get()) + else if (auto i = node.node.get()) { std::optional maybeCost = computeChildren(i->operands(), minCost); if (maybeCost) updateCost(SET_TYPE_PENALTY + *maybeCost, index); } - else if (auto negation = node.get()) + else if (auto negation = node.node.get()) { std::optional maybeCost = computeChildren(negation->operands(), minCost); if (maybeCost) updateCost(NEGATION_PENALTY + *maybeCost, index); } - else if (auto tfun = node.get()) + else if (auto tfun = node.node.get()) { std::optional maybeCost = computeChildren(tfun->operands(), minCost); if (maybeCost) @@ -643,7 +643,7 @@ TypeId flattenTableNode( for (size_t i = 0; i < eclass.nodes.size(); ++i) { - if (eclass.nodes[i].get()) + if (eclass.nodes[i].node.get()) { found = true; index = i; @@ -660,13 +660,13 @@ TypeId flattenTableNode( } const auto& node = eclass.nodes[index]; - if (const TTable* ttable = node.get()) + if (const TTable* ttable = node.node.get()) { stack.push_back(ttable); id = ttable->getBasis(); continue; } - else if (const TImportedTable* ti = node.get()) + else if (const TImportedTable* ti = node.node.get()) { importedTable = ti; break; @@ -718,7 +718,7 @@ TypeId fromId( size_t index = bestNodes.at(rootId); LUAU_ASSERT(index <= egraph[rootId].nodes.size()); - const EType& node = egraph[rootId].nodes[index]; + const EType& node = egraph[rootId].nodes[index].node; if (node.get()) return builtinTypes->nilType; @@ -1025,8 +1025,9 @@ std::string toDot(const StringCache& strings, const EGraph& egraph) for (const auto& [id, eclass] : egraph.getAllClasses()) { - for (const auto& node : eclass.nodes) + for (const auto& n : eclass.nodes) { + const EType& node = n.node; if (!node.operands().empty()) populated.insert(id); for (Id op : node.operands()) @@ -1047,7 +1048,7 @@ std::string toDot(const StringCache& strings, const EGraph& egraph) for (size_t index = 0; index < eclass.nodes.size(); ++index) { - const auto& node = eclass.nodes[index]; + const auto& node = eclass.nodes[index].node; const std::string label = getNodeName(strings, node); const std::string nodeName = "n" + std::to_string(uint32_t(id)) + "_" + std::to_string(index); @@ -1062,7 +1063,7 @@ std::string toDot(const StringCache& strings, const EGraph& egraph) { for (size_t index = 0; index < eclass.nodes.size(); ++index) { - const auto& node = eclass.nodes[index]; + const auto& node = eclass.nodes[index].node; const std::string label = getNodeName(strings, node); const std::string nodeName = "n" + std::to_string(uint32_t(egraph.find(id))) + "_" + std::to_string(index); @@ -1098,7 +1099,7 @@ static Tag const* isTag(const EGraph& egraph, Id id) { for (const auto& node : egraph[id].nodes) { - if (auto n = isTag(node)) + if (auto n = isTag(node.node)) return n; } return nullptr; @@ -1134,7 +1135,7 @@ protected: { for (const auto& node : (*egraph)[id].nodes) { - if (auto n = node.get()) + if (auto n = node.node.get()) return n; } return nullptr; @@ -1322,8 +1323,10 @@ const EType* findSubtractableClass(const EGraph& egraph, std::unordered_set& const EType* bestUnion = nullptr; std::optional unionSize; - for (const auto& node : egraph[id].nodes) + for (const auto& n : egraph[id].nodes) { + const EType& node = n.node; + if (isTerminal(node)) return &node; @@ -1439,14 +1442,14 @@ bool subtract(EGraph& egraph, CanonicalizedType& ct, Id part) return true; } -Id fromCanonicalized(EGraph& egraph, CanonicalizedType& ct) +static std::pair fromCanonicalized(EGraph& egraph, CanonicalizedType& ct) { if (ct.isUnknown()) { if (ct.errorPart) - return egraph.add(TAny{}); + return {egraph.add(TAny{}), 1}; else - return egraph.add(TUnknown{}); + return {egraph.add(TUnknown{}), 1}; } std::vector parts; @@ -1484,7 +1487,12 @@ Id fromCanonicalized(EGraph& egraph, CanonicalizedType& ct) parts.insert(parts.end(), ct.functionParts.begin(), ct.functionParts.end()); parts.insert(parts.end(), ct.otherParts.begin(), ct.otherParts.end()); - return mkUnion(egraph, std::move(parts)); + std::sort(parts.begin(), parts.end()); + auto it = std::unique(parts.begin(), parts.end()); + parts.erase(it, parts.end()); + + const size_t size = parts.size(); + return {mkUnion(egraph, std::move(parts)), size}; } void addChildren(const EGraph& egraph, const EType* enode, VecDeque& worklist) @@ -1530,7 +1538,7 @@ const Tag* Simplifier::isTag(Id id) const { for (const auto& node : get(id).nodes) { - if (const Tag* ty = node.get()) + if (const Tag* ty = node.node.get()) return ty; } @@ -1564,6 +1572,16 @@ void Simplifier::subst(Id from, Id to, const std::string& ruleName, const std::u substs.emplace_back(from, to, desc); } +void Simplifier::subst(Id from, size_t boringIndex, Id to, const std::string& ruleName, const std::unordered_map& forceNodes) +{ + std::string desc; + if (FFlag::DebugLuauLogSimplification) + desc = mkDesc(egraph, stringCache, arena, builtinTypes, from, to, forceNodes, ruleName); + + egraph.markBoring(from, boringIndex); + substs.emplace_back(from, to, desc); +} + void Simplifier::unionClasses(std::vector& hereParts, Id there) { if (1 == hereParts.size() && isTag(hereParts[0])) @@ -1614,9 +1632,12 @@ void Simplifier::simplifyUnion(Id id) for (Id part : u->operands()) unionWithType(egraph, canonicalized, find(part)); - Id resultId = fromCanonicalized(egraph, canonicalized); + const auto [resultId, newSize] = fromCanonicalized(egraph, canonicalized); - subst(id, resultId, "simplifyUnion", {{id, unionIndex}}); + if (newSize < u->operands().size()) + subst(id, unionIndex, resultId, "simplifyUnion", {{id, unionIndex}}); + else + subst(id, resultId, "simplifyUnion", {{id, unionIndex}}); } } @@ -1824,7 +1845,7 @@ void Simplifier::uninhabitedIntersection(Id id) const auto& partNodes = egraph[partId].nodes; for (size_t partIndex = 0; partIndex < partNodes.size(); ++partIndex) { - const EType& N = partNodes[partIndex]; + const EType& N = partNodes[partIndex].node; if (std::optional intersection = intersectOne(egraph, accumulator, &accumulatorNode, partId, &N)) { if (isTag(*intersection)) @@ -1847,9 +1868,14 @@ void Simplifier::uninhabitedIntersection(Id id) if ((unsimplified.empty() || !isTag(accumulator)) && find(accumulator) != id) unsimplified.push_back(accumulator); + const bool isSmaller = unsimplified.size() < parts.size(); + const Id result = mkIntersection(egraph, std::move(unsimplified)); - subst(id, result, "uninhabitedIntersection", {{id, index}}); + if (isSmaller) + subst(id, index, result, "uninhabitedIntersection", {{id, index}}); + else + subst(id, result, "uninhabitedIntersection", {{id, index}}); } } @@ -1880,7 +1906,7 @@ void Simplifier::intersectWithNegatedClass(Id id) const auto& iNodes = egraph[iId].nodes; for (size_t iIndex = 0; iIndex < iNodes.size(); ++iIndex) { - const EType& iNode = iNodes[iIndex]; + const EType& iNode = iNodes[iIndex].node; if (isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode) || // isTag(iNode) || // I'm not sure about this one. @@ -1923,7 +1949,7 @@ void Simplifier::intersectWithNegatedClass(Id id) newParts.push_back(part); } - Id substId = egraph.add(Intersection{newParts.begin(), newParts.end()}); + Id substId = mkIntersection(egraph, newParts); subst( id, substId, @@ -1965,7 +1991,7 @@ void Simplifier::intersectWithNegatedAtom(Id id) { for (size_t negationOperandIndex = 0; negationOperandIndex < egraph[negation->operands()[0]].nodes.size(); ++negationOperandIndex) { - const EType* negationOperand = &egraph[negation->operands()[0]].nodes[negationOperandIndex]; + const EType* negationOperand = &egraph[negation->operands()[0]].nodes[negationOperandIndex].node; if (!isTerminal(*negationOperand) || negationOperand->get()) continue; @@ -1976,7 +2002,7 @@ void Simplifier::intersectWithNegatedAtom(Id id) for (size_t jNodeIndex = 0; jNodeIndex < egraph[intersectionOperands[j]].nodes.size(); ++jNodeIndex) { - const EType* jNode = &egraph[intersectionOperands[j]].nodes[jNodeIndex]; + const EType* jNode = &egraph[intersectionOperands[j]].nodes[jNodeIndex].node; if (!isTerminal(*jNode) || jNode->get()) continue; @@ -2001,7 +2027,7 @@ void Simplifier::intersectWithNegatedAtom(Id id) subst( id, - egraph.add(Intersection{newOperands}), + mkIntersection(egraph, std::move(newOperands)), "intersectWithNegatedAtom", {{id, intersectionIndex}, {intersectionOperands[i], negationIndex}, {intersectionOperands[j], jNodeIndex}} ); @@ -2178,7 +2204,7 @@ void Simplifier::expandNegation(Id id) if (!ok) continue; - subst(id, fromCanonicalized(egraph, canonicalized), "expandNegation", {{id, index}}); + subst(id, fromCanonicalized(egraph, canonicalized).first, "expandNegation", {{id, index}}); } } @@ -2576,9 +2602,9 @@ std::optional eqSatSimplify(NotNull simpl // try to run any rules on it. bool shouldAbort = false; - for (const EType& enode : egraph[id].nodes) + for (const auto& enode : egraph[id].nodes) { - if (isTerminal(enode)) + if (isTerminal(enode.node)) { shouldAbort = true; break; @@ -2588,8 +2614,8 @@ std::optional eqSatSimplify(NotNull simpl if (shouldAbort) continue; - for (const EType& enode : egraph[id].nodes) - addChildren(egraph, &enode, worklist); + for (const auto& enode : egraph[id].nodes) + addChildren(egraph, &enode.node, worklist); for (Simplifier::RewriteRuleFn rule : rules) (simplifier.get()->*rule)(id); diff --git a/Analysis/src/FragmentAutocomplete.cpp b/Analysis/src/FragmentAutocomplete.cpp index 7687847b..cde8125a 100644 --- a/Analysis/src/FragmentAutocomplete.cpp +++ b/Analysis/src/FragmentAutocomplete.cpp @@ -200,7 +200,7 @@ ScopePtr findClosestScope(const ModulePtr& module, const AstStat* nearestStateme return closest; } -FragmentParseResult parseFragment( +std::optional parseFragment( const SourceModule& srcModule, std::string_view src, const Position& cursorPos, @@ -245,6 +245,10 @@ FragmentParseResult parseFragment( opts.captureComments = true; opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack), startPos}; ParseResult p = Luau::Parser::parse(srcStart, parseLength, *nameTbl, *fragmentResult.alloc.get(), opts); + // This means we threw a ParseError and we should decline to offer autocomplete here. + if (p.root == nullptr) + return std::nullopt; + std::vector fabricatedAncestry = std::move(result.ancestry); // Get the ancestry for the fragment at the offset cursor position. @@ -366,7 +370,8 @@ FragmentTypeCheckResult typecheckFragment_( TypeFunctionRuntime typeFunctionRuntime(iceHandler, NotNull{&limits}); /// Create a DataFlowGraph just for the surrounding context - auto dfg = DataFlowGraphBuilder::build(root, iceHandler); + DataFlowGraph dfg = DataFlowGraphBuilder::build(root, NotNull{&incrementalModule->defArena}, NotNull{&incrementalModule->keyArena}, iceHandler); + SimplifierPtr simplifier = newSimplifier(NotNull{&incrementalModule->internalTypes}, frontend.builtinTypes); FrontendModuleResolver& resolver = getModuleResolver(frontend, opts); @@ -468,7 +473,13 @@ std::pair typecheckFragment( return {}; } - FragmentParseResult parseResult = parseFragment(*sourceModule, src, cursorPos, fragmentEndPosition); + auto tryParse = parseFragment(*sourceModule, src, cursorPos, fragmentEndPosition); + + if (!tryParse) + return {FragmentTypeCheckStatus::SkipAutocomplete, {}}; + + FragmentParseResult& parseResult = *tryParse; + if (isWithinComment(parseResult.commentLocations, fragmentEndPosition.value_or(cursorPos))) return {FragmentTypeCheckStatus::SkipAutocomplete, {}}; diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index f7164256..14ff1f5e 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -13,6 +13,7 @@ #include "Luau/EqSatSimplification.h" #include "Luau/FileResolver.h" #include "Luau/NonStrictTypeChecker.h" +#include "Luau/NotNull.h" #include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/StringUtils.h" @@ -1338,7 +1339,7 @@ ModulePtr check( } } - DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler); + DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, NotNull{&result->defArena}, NotNull{&result->keyArena}, iceHandler); UnifierSharedState unifierState{iceHandler}; unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 4b6d1115..d9d0d3b0 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -61,9 +61,7 @@ TypeId Instantiation::clean(TypeId ty) LUAU_ASSERT(ftv); FunctionType clone = FunctionType{level, scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; - clone.magicFunction = ftv->magicFunction; - clone.dcrMagicFunction = ftv->dcrMagicFunction; - clone.dcrMagicRefinement = ftv->dcrMagicRefinement; + clone.magic = ftv->magic; clone.tags = ftv->tags; clone.argNames = ftv->argNames; TypeId result = addType(std::move(clone)); diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index e8357f48..e00f0d3d 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -98,9 +98,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a FunctionType clone = FunctionType{a.level, a.scope, a.argTypes, a.retTypes, a.definition, a.hasSelf}; clone.generics = a.generics; clone.genericPacks = a.genericPacks; - clone.magicFunction = a.magicFunction; - clone.dcrMagicFunction = a.dcrMagicFunction; - clone.dcrMagicRefinement = a.dcrMagicRefinement; + clone.magic = a.magic; clone.tags = a.tags; clone.argNames = a.argNames; clone.isCheckedFunction = a.isCheckedFunction; diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index a5298ee5..aee91ec3 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -554,12 +554,12 @@ BlockedType::BlockedType() { } -Constraint* BlockedType::getOwner() const +const Constraint* BlockedType::getOwner() const { return owner; } -void BlockedType::setOwner(Constraint* newOwner) +void BlockedType::setOwner(const Constraint* newOwner) { LUAU_ASSERT(owner == nullptr); @@ -569,7 +569,7 @@ void BlockedType::setOwner(Constraint* newOwner) owner = newOwner; } -void BlockedType::replaceOwner(Constraint* newOwner) +void BlockedType::replaceOwner(const Constraint* newOwner) { owner = newOwner; } diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 3019bf01..1e78acf7 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -1454,10 +1454,11 @@ void TypeChecker2::visitCall(AstExprCall* call) TypePackId argsTp = module->internalTypes.addTypePack(args); if (auto ftv = get(follow(*originalCallTy))) { - if (ftv->dcrMagicTypeCheck) + if (ftv->magic) { - ftv->dcrMagicTypeCheck(MagicFunctionTypeCheckContext{NotNull{this}, builtinTypes, call, argsTp, scope}); - return; + bool usedMagic = ftv->magic->typeCheck(MagicFunctionTypeCheckContext{NotNull{this}, builtinTypes, call, argsTp, scope}); + if (usedMagic) + return; } } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index addd3445..4a243856 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -4506,10 +4506,10 @@ std::unique_ptr> TypeChecker::checkCallOverload( // When this function type has magic functions and did return something, we select that overload instead. // TODO: pass in a Unifier object to the magic functions? This will allow the magic functions to cooperate with overload resolution. - if (ftv->magicFunction) + if (ftv->magic) { // TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458 - if (std::optional> ret = ftv->magicFunction(*this, scope, expr, argListResult)) + if (std::optional> ret = ftv->magic->handleOldSolver(*this, scope, expr, argListResult)) return std::make_unique>(std::move(*ret)); } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index f618bc06..0d91f5d5 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -23,6 +23,7 @@ LUAU_FASTFLAGVARIABLE(LuauAllowComplexTypesInGenericParams) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForTableTypes) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForClassNames) LUAU_FASTFLAGVARIABLE(LuauFixFunctionNameStartPosition) +LUAU_FASTFLAGVARIABLE(LuauExtendStatEndPosWithSemicolon) namespace Luau { @@ -288,6 +289,10 @@ AstStatBlock* Parser::parseBlockNoScope() { nextLexeme(); stat->hasSemicolon = true; + if (FFlag::LuauExtendStatEndPosWithSemicolon) + { + stat->location.end = lexer.previousLocation().end; + } } body.push_back(stat); diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index 93fbd8d7..62a0d77d 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -14,6 +14,7 @@ inline bool isFlagExperimental(const char* flag) "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code "LuauFixIndexerSubtypingOrdering", // requires some small fixes to lua-apps code since this fixes a false negative "StudioReportLuauAny2", // takes telemetry data for usage of any types + "LuauTableCloneClonesType", // requires fixes in lua-apps code, terrifyingly "LuauSolverV2", // makes sure we always have at least one entry nullptr, diff --git a/EqSat/include/Luau/EGraph.h b/EqSat/include/Luau/EGraph.h index 6af79b77..2703ad9d 100644 --- a/EqSat/include/Luau/EGraph.h +++ b/EqSat/include/Luau/EGraph.h @@ -51,13 +51,70 @@ struct Analysis final } }; +template +struct Node +{ + L node; + bool boring = false; + + struct Hash + { + size_t operator()(const Node& node) const + { + return typename L::Hash{}(node.node); + } + }; +}; + +template +struct NodeIterator +{ +private: + using iterator = std::vector>; + iterator iter; + +public: + L& operator*() + { + return iter->node; + } + + const L& operator*() const + { + return iter->node; + } + + iterator& operator++() + { + ++iter; + return *this; + } + + iterator operator++(int) + { + iterator copy = *this; + ++*this; + return copy; + } + + bool operator==(const iterator& rhs) const + { + return iter == rhs.iter; + } + + bool operator!=(const iterator& rhs) const + { + return iter != rhs.iter; + } +}; + /// Each e-class is a set of e-nodes representing equivalent terms from a given language, /// and an e-node is a function symbol paired with a list of children e-classes. template struct EClass final { Id id; - std::vector nodes; + std::vector> nodes; D data; std::vector> parents; }; @@ -125,9 +182,9 @@ struct EGraph final std::sort( eclass1.nodes.begin(), eclass1.nodes.end(), - [](const L& left, const L& right) + [](const Node& left, const Node& right) { - return left.index() < right.index(); + return left.node.index() < right.node.index(); } ); @@ -177,6 +234,11 @@ struct EGraph final return classes; } + void markBoring(Id id, size_t index) + { + get(id).nodes[index].boring = true; + } + private: Analysis analysis; @@ -225,7 +287,7 @@ private: id, EClassT{ id, - {enode}, + {Node{enode, false}}, analysis.make(*this, enode), {}, } @@ -264,18 +326,18 @@ private: std::vector> parents = get(id).parents; for (auto& pair : parents) { - L& enode = pair.first; - Id id = pair.second; + L& parentNode = pair.first; + Id parentId = pair.second; // By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id. - hashcons.erase(enode); - canonicalize(enode); - hashcons.insert_or_assign(enode, find(id)); + hashcons.erase(parentNode); + canonicalize(parentNode); + hashcons.insert_or_assign(parentNode, find(parentId)); - if (auto it = newParents.find(enode); it != newParents.end()) - merge(id, it->second); + if (auto it = newParents.find(parentNode); it != newParents.end()) + merge(parentId, it->second); - newParents.insert_or_assign(enode, find(id)); + newParents.insert_or_assign(parentNode, find(parentId)); } // We reacquire the pointer because the prior loop potentially merges @@ -287,22 +349,30 @@ private: for (const auto& [node, id] : newParents) eclass->parents.emplace_back(std::move(node), std::move(id)); - std::unordered_set newNodes; - for (L node : eclass->nodes) + std::unordered_map newNodes; + for (Node node : eclass->nodes) { - canonicalize(node); - newNodes.insert(std::move(node)); + canonicalize(node.node); + + bool& b = newNodes[std::move(node.node)]; + b = b || node.boring; } - eclass->nodes.assign(newNodes.begin(), newNodes.end()); + eclass->nodes.clear(); + + while (!newNodes.empty()) + { + auto n = newNodes.extract(newNodes.begin()); + eclass->nodes.push_back(Node{n.key(), n.mapped()}); + } // FIXME: Extract into sortByTag() std::sort( eclass->nodes.begin(), eclass->nodes.end(), - [](const L& left, const L& right) + [](const Node& left, const Node& right) { - return left.index() < right.index(); + return left.node.index() < right.node.index(); } ); } diff --git a/VM/include/lua.h b/VM/include/lua.h index c4f5f714..2cdedd7d 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -154,6 +154,7 @@ LUA_API const float* lua_tovector(lua_State* L, int idx); LUA_API int lua_toboolean(lua_State* L, int idx); LUA_API const char* lua_tolstring(lua_State* L, int idx, size_t* len); LUA_API const char* lua_tostringatom(lua_State* L, int idx, int* atom); +LUA_API const char* lua_tolstringatom(lua_State* L, int idx, size_t* len, int* atom); LUA_API const char* lua_namecallatom(lua_State* L, int* atom); LUA_API int lua_objlen(lua_State* L, int idx); LUA_API lua_CFunction lua_tocfunction(lua_State* L, int idx); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 1a8af74d..98afca7b 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -454,6 +454,29 @@ const char* lua_tostringatom(lua_State* L, int idx, int* atom) return getstr(s); } +const char* lua_tolstringatom(lua_State* L, int idx, size_t* len, int* atom) +{ + StkId o = index2addr(L, idx); + + if (!ttisstring(o)) + { + if (len) + *len = 0; + return NULL; + } + + TString* s = tsvalue(o); + if (len) + *len = s->len; + if (atom) + { + updateatom(L, s); + *atom = s->atom; + } + + return getstr(s); +} + const char* lua_namecallatom(lua_State* L, int* atom) { TString* s = L->namecall; diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 07cc117e..44da57c2 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -422,6 +422,20 @@ int luaG_isnative(lua_State* L, int level) return (ci->flags & LUA_CALLINFO_NATIVE) != 0 ? 1 : 0; } +int luaG_hasnative(lua_State* L, int level) +{ + if (unsigned(level) >= unsigned(L->ci - L->base_ci)) + return 0; + + CallInfo* ci = L->ci - level; + + Proto* proto = getluaproto(ci); + if (proto == nullptr) + return 0; + + return (proto->execdata != nullptr); +} + void lua_singlestep(lua_State* L, int enabled) { L->singlestep = bool(enabled); diff --git a/VM/src/ldebug.h b/VM/src/ldebug.h index 49b1ca88..f215e815 100644 --- a/VM/src/ldebug.h +++ b/VM/src/ldebug.h @@ -31,3 +31,4 @@ LUAI_FUNC bool luaG_onbreak(lua_State* L); LUAI_FUNC int luaG_getline(Proto* p, int pc); LUAI_FUNC int luaG_isnative(lua_State* L, int level); +LUAI_FUNC int luaG_hasnative(lua_State* L, int level); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 5e65a9e1..dd02671f 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -4304,4 +4304,29 @@ foo(@1) CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_at_end_of_stmt_should_continue_as_part_of_stmt") +{ + check(R"( +local data = { x = 1 } +local var = data.@1 + )"); + auto ac = autocomplete('1'); + CHECK(!ac.entryMap.empty()); + CHECK(ac.entryMap.count("x")); + CHECK_EQ(ac.context, AutocompleteContext::Property); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_after_semicolon_should_complete_a_new_statement") +{ + check(R"( +local data = { x = 1 } +local var = data;@1 + )"); + auto ac = autocomplete('1'); + CHECK(!ac.entryMap.empty()); + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); + CHECK_EQ(ac.context, AutocompleteContext::Statement); +} + TEST_SUITE_END(); diff --git a/tests/ConstraintGeneratorFixture.cpp b/tests/ConstraintGeneratorFixture.cpp index ef91fdf7..90a8b507 100644 --- a/tests/ConstraintGeneratorFixture.cpp +++ b/tests/ConstraintGeneratorFixture.cpp @@ -22,7 +22,9 @@ ConstraintGeneratorFixture::ConstraintGeneratorFixture() void ConstraintGeneratorFixture::generateConstraints(const std::string& code) { AstStatBlock* root = parse(code); - dfg = std::make_unique(DataFlowGraphBuilder::build(root, NotNull{&ice})); + dfg = std::make_unique( + DataFlowGraphBuilder::build(root, NotNull{&mainModule->defArena}, NotNull{&mainModule->keyArena}, NotNull{&ice}) + ); cg = std::make_unique( mainModule, NotNull{&normalizer}, diff --git a/tests/DataFlowGraph.test.cpp b/tests/DataFlowGraph.test.cpp index 4ea656ee..1b7e243c 100644 --- a/tests/DataFlowGraph.test.cpp +++ b/tests/DataFlowGraph.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/DataFlowGraph.h" #include "Fixture.h" +#include "Luau/Def.h" #include "Luau/Error.h" #include "Luau/Parser.h" @@ -18,6 +19,8 @@ struct DataFlowGraphFixture // Only needed to fix the operator== reflexivity of an empty Symbol. ScopedFastFlag dcr{FFlag::LuauSolverV2, true}; + DefArena defArena; + RefinementKeyArena keyArena; InternalErrorReporter handle; Allocator allocator; @@ -32,7 +35,7 @@ struct DataFlowGraphFixture if (!parseResult.errors.empty()) throw ParseErrors(std::move(parseResult.errors)); module = parseResult.root; - graph = DataFlowGraphBuilder::build(module, NotNull{&handle}); + graph = DataFlowGraphBuilder::build(module, NotNull{&defArena}, NotNull{&keyArena}, NotNull{&handle}); } template diff --git a/tests/FragmentAutocomplete.test.cpp b/tests/FragmentAutocomplete.test.cpp index 8a30abf5..3ddd6a30 100644 --- a/tests/FragmentAutocomplete.test.cpp +++ b/tests/FragmentAutocomplete.test.cpp @@ -26,6 +26,7 @@ LUAU_FASTFLAG(LuauSymbolEquality); LUAU_FASTFLAG(LuauStoreSolverTypeOnModule); LUAU_FASTFLAG(LexerResumesFromPosition2) LUAU_FASTFLAG(LuauIncrementalAutocompleteCommentDetection) +LUAU_FASTINT(LuauParseErrorLimit) static std::optional nullCallback(std::string tag, std::optional ptr, std::optional contents) { @@ -69,7 +70,7 @@ struct FragmentAutocompleteFixtureImpl : BaseType } - FragmentParseResult parseFragment( + std::optional parseFragment( const std::string& document, const Position& cursorPos, std::optional fragmentEndPosition = std::nullopt @@ -164,6 +165,7 @@ end } }; +//NOLINTBEGIN(bugprone-unchecked-optional-access) TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests"); TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "just_two_locals") @@ -286,13 +288,23 @@ TEST_SUITE_END(); TEST_SUITE_BEGIN("FragmentAutocompleteParserTests"); +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "thrown_parse_error_leads_to_null_root") +{ + check("type A = "); + ScopedFastInt sfi{FInt::LuauParseErrorLimit, 1}; + auto fragment = parseFragment("type A = <>function<> more garbage here", Position(0, 39)); + CHECK(fragment == std::nullopt); +} + TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_initializer") { ScopedFastFlag sff{FFlag::LuauSolverV2, true}; check("local a ="); auto fragment = parseFragment("local a =", Position(0, 10)); - CHECK_EQ("local a =", fragment.fragmentToParse); - CHECK_EQ(Location{Position{0, 0}, 9}, fragment.root->location); + + REQUIRE(fragment.has_value()); + CHECK_EQ("local a =", fragment->fragmentToParse); + CHECK_EQ(Location{Position{0, 0}, 9}, fragment->root->location); } TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_non_null") @@ -310,11 +322,12 @@ TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_n )", Position(1, 0) ); - CHECK_EQ("\n", fragment.fragmentToParse); - CHECK_EQ(2, fragment.ancestry.size()); - REQUIRE(fragment.root); - CHECK_EQ(0, fragment.root->body.size); - auto statBody = fragment.root->as(); + REQUIRE(fragment.has_value()); + CHECK_EQ("\n", fragment->fragmentToParse); + CHECK_EQ(2, fragment->ancestry.size()); + REQUIRE(fragment->root); + CHECK_EQ(0, fragment->root->body.size); + auto statBody = fragment->root->as(); CHECK(statBody != nullptr); } @@ -339,13 +352,15 @@ local z = x + y Position{3, 15} ); - CHECK_EQ(Location{Position{2, 0}, Position{3, 15}}, fragment.root->location); + REQUIRE(fragment.has_value()); - CHECK_EQ("local y = 5\nlocal z = x + y", fragment.fragmentToParse); - CHECK_EQ(5, fragment.ancestry.size()); - REQUIRE(fragment.root); - CHECK_EQ(2, fragment.root->body.size); - auto stat = fragment.root->body.data[1]->as(); + CHECK_EQ(Location{Position{2, 0}, Position{3, 15}}, fragment->root->location); + + CHECK_EQ("local y = 5\nlocal z = x + y", fragment->fragmentToParse); + CHECK_EQ(5, fragment->ancestry.size()); + REQUIRE(fragment->root); + CHECK_EQ(2, fragment->root->body.size); + auto stat = fragment->root->body.data[1]->as(); REQUIRE(stat); CHECK_EQ(1, stat->vars.size); CHECK_EQ(1, stat->values.size); @@ -384,12 +399,14 @@ local y = 5 Position{2, 15} ); - CHECK_EQ("local z = x + y", fragment.fragmentToParse); - CHECK_EQ(5, fragment.ancestry.size()); - REQUIRE(fragment.root); - CHECK_EQ(Location{Position{2, 0}, Position{2, 15}}, fragment.root->location); - CHECK_EQ(1, fragment.root->body.size); - auto stat = fragment.root->body.data[0]->as(); + REQUIRE(fragment.has_value()); + + CHECK_EQ("local z = x + y", fragment->fragmentToParse); + CHECK_EQ(5, fragment->ancestry.size()); + REQUIRE(fragment->root); + CHECK_EQ(Location{Position{2, 0}, Position{2, 15}}, fragment->root->location); + CHECK_EQ(1, fragment->root->body.size); + auto stat = fragment->root->body.data[0]->as(); REQUIRE(stat); CHECK_EQ(1, stat->vars.size); CHECK_EQ(1, stat->values.size); @@ -429,7 +446,9 @@ TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_in_correct_scope") Position{6, 0} ); - CHECK_EQ("\n ", fragment.fragmentToParse); + REQUIRE(fragment.has_value()); + + CHECK_EQ("\n ", fragment->fragmentToParse); } TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_single_line_fragment_override") @@ -448,17 +467,19 @@ abc("bar") Position{1, 10} ); - CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", callFragment.fragmentToParse); - CHECK(callFragment.nearestStatement->is()); + REQUIRE(callFragment.has_value()); - CHECK_GE(callFragment.ancestry.size(), 2); + CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", callFragment->fragmentToParse); + CHECK(callFragment->nearestStatement->is()); - AstNode* back = callFragment.ancestry.back(); + CHECK_GE(callFragment->ancestry.size(), 2); + + AstNode* back = callFragment->ancestry.back(); CHECK(back->is()); CHECK_EQ(Position{1, 4}, back->location.begin); CHECK_EQ(Position{1, 9}, back->location.end); - AstNode* parent = callFragment.ancestry.rbegin()[1]; + AstNode* parent = callFragment->ancestry.rbegin()[1]; CHECK(parent->is()); CHECK_EQ(Position{1, 0}, parent->location.begin); CHECK_EQ(Position{1, 10}, parent->location.end); @@ -473,12 +494,14 @@ abc("bar") Position{1, 9} ); - CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", stringFragment.fragmentToParse); - CHECK(stringFragment.nearestStatement->is()); + REQUIRE(stringFragment.has_value()); - CHECK_GE(stringFragment.ancestry.size(), 1); + CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", stringFragment->fragmentToParse); + CHECK(stringFragment->nearestStatement->is()); - back = stringFragment.ancestry.back(); + CHECK_GE(stringFragment->ancestry.size(), 1); + + back = stringFragment->ancestry.back(); auto asString = back->as(); CHECK(asString); @@ -508,17 +531,19 @@ abc("bar") Position{3, 1} ); - CHECK_EQ("function abc(foo: string) end\nabc(\n\"foo\"\n)", fragment.fragmentToParse); - CHECK(fragment.nearestStatement->is()); + REQUIRE(fragment.has_value()); - CHECK_GE(fragment.ancestry.size(), 2); + CHECK_EQ("function abc(foo: string) end\nabc(\n\"foo\"\n)", fragment->fragmentToParse); + CHECK(fragment->nearestStatement->is()); - AstNode* back = fragment.ancestry.back(); + CHECK_GE(fragment->ancestry.size(), 2); + + AstNode* back = fragment->ancestry.back(); CHECK(back->is()); CHECK_EQ(Position{2, 0}, back->location.begin); CHECK_EQ(Position{2, 5}, back->location.end); - AstNode* parent = fragment.ancestry.rbegin()[1]; + AstNode* parent = fragment->ancestry.rbegin()[1]; CHECK(parent->is()); CHECK_EQ(Position{1, 0}, parent->location.begin); CHECK_EQ(Position{3, 1}, parent->location.end); @@ -549,6 +574,7 @@ t } TEST_SUITE_END(); +//NOLINTEND(bugprone-unchecked-optional-access) TEST_SUITE_BEGIN("FragmentAutocompleteTypeCheckerTests"); @@ -1558,4 +1584,26 @@ if x == 5 then -- a comment ); } +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "fragment_autocomplete_handles_parse_errors") +{ + + ScopedFastInt sfi{FInt::LuauParseErrorLimit, 1}; + const std::string source = R"( + +)"; + const std::string updated = R"( +type A = <>random non code text here +)"; + + autocompleteFragmentInBothSolvers( + source, + updated, + Position{1, 38}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); +} + TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 434701b2..32476e86 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -20,6 +20,7 @@ LUAU_FASTFLAG(LuauAllowComplexTypesInGenericParams) LUAU_FASTFLAG(LuauErrorRecoveryForTableTypes) LUAU_FASTFLAG(LuauErrorRecoveryForClassNames) LUAU_FASTFLAG(LuauFixFunctionNameStartPosition) +LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon) namespace { @@ -3766,5 +3767,32 @@ TEST_CASE_FIXTURE(Fixture, "function_name_has_correct_start_location") CHECK_EQ(Position{4, 17}, function2->name->location.begin); } +TEST_CASE_FIXTURE(Fixture, "stat_end_includes_semicolon_position") +{ + ScopedFastFlag _{FFlag::LuauExtendStatEndPosWithSemicolon, true}; + AstStatBlock* block = parse(R"( + local x = 1 + local y = 2; + local z = 3 ; + )"); + + REQUIRE_EQ(3, block->body.size); + + const auto stat1 = block->body.data[0]; + LUAU_ASSERT(stat1); + CHECK_FALSE(stat1->hasSemicolon); + CHECK_EQ(Position{1, 19}, stat1->location.end); + + const auto stat2 = block->body.data[1]; + LUAU_ASSERT(stat2); + CHECK(stat2->hasSemicolon); + CHECK_EQ(Position{2, 20}, stat2->location.end); + + const auto stat3 = block->body.data[2]; + LUAU_ASSERT(stat3); + CHECK(stat3->hasSemicolon); + CHECK_EQ(Position{3, 22}, stat3->location.end); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 34e430ea..750f066c 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -12,6 +12,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauTypestateBuiltins2) LUAU_FASTFLAG(LuauStringFormatArityFix) +LUAU_FASTFLAG(LuauTableCloneClonesType) LUAU_FASTFLAG(LuauStringFormatErrorSuppression) TEST_SUITE_BEGIN("BuiltinTests"); @@ -1587,6 +1588,30 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "string_find_should_not_crash") )")); } +TEST_CASE_FIXTURE(BuiltinsFixture, "table_dot_clone_type_states") +{ + CheckResult result = check(R"( + local t1 = {} + t1.x = 5 + local t2 = table.clone(t1) + t2.y = 6 + t1.z = 3 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauTableCloneClonesType) + { + CHECK_EQ(toString(requireType("t1"), {true}), "{ x: number, z: number }"); + CHECK_EQ(toString(requireType("t2"), {true}), "{ x: number, y: number }"); + } + else + { + CHECK_EQ(toString(requireType("t1"), {true}), "{ x: number, y: number, z: number }"); + CHECK_EQ(toString(requireType("t2"), {true}), "{ x: number, y: number, z: number }"); + } +} + TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_should_support_any") { ScopedFastFlag _{FFlag::LuauSolverV2, true}; diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index bd43410c..bc1d55dd 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -20,7 +20,6 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauRetrySubtypingWithoutHiddenPack) -LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions) LUAU_FASTFLAG(DebugLuauEqSatSimplification) TEST_SUITE_BEGIN("TypeInferFunctions"); @@ -2566,7 +2565,7 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_return_type") { - ScopedFastFlag sffs[] = {{FFlag::LuauSolverV2, true}, {FFlag::LuauDontRefCountTypesInTypeFunctions, true}}; + ScopedFastFlag _{FFlag::LuauSolverV2, true}; // CLI-114134: This test: // a) Has a kind of weird result (suggesting `number | false` is not great); @@ -2878,8 +2877,6 @@ TEST_CASE_FIXTURE(Fixture, "fuzzer_missing_follow_in_ast_stat_fun") TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types") { - ScopedFastFlag _{FFlag::LuauDontRefCountTypesInTypeFunctions, true}; - CheckResult result = check(R"( function foo(player) local success,result = player:thing() diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 0cd23c10..535b9961 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -16,52 +16,63 @@ using namespace Luau; namespace { -std::optional> magicFunctionInstanceIsA( - TypeChecker& typeChecker, - const ScopePtr& scope, - const AstExprCall& expr, - WithPredicate withPredicate -) + +struct MagicInstanceIsA final : MagicFunction { - if (expr.args.size != 1) - return std::nullopt; + std::optional> handleOldSolver( + TypeChecker& typeChecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate + ) override + { + if (expr.args.size != 1) + return std::nullopt; - auto index = expr.func->as(); - auto str = expr.args.data[0]->as(); - if (!index || !str) - return std::nullopt; + auto index = expr.func->as(); + auto str = expr.args.data[0]->as(); + if (!index || !str) + return std::nullopt; - std::optional lvalue = tryGetLValue(*index->expr); - std::optional tfun = scope->lookupType(std::string(str->value.data, str->value.size)); - if (!lvalue || !tfun) - return std::nullopt; + std::optional lvalue = tryGetLValue(*index->expr); + std::optional tfun = scope->lookupType(std::string(str->value.data, str->value.size)); + if (!lvalue || !tfun) + return std::nullopt; - ModulePtr module = typeChecker.currentModule; - TypePackId booleanPack = module->internalTypes.addTypePack({typeChecker.booleanType}); - return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; -} + ModulePtr module = typeChecker.currentModule; + TypePackId booleanPack = module->internalTypes.addTypePack({typeChecker.booleanType}); + return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; + } -void dcrMagicRefinementInstanceIsA(const MagicRefinementContext& ctx) -{ - if (ctx.callSite->args.size != 1 || ctx.discriminantTypes.empty()) - return; + bool infer(const MagicFunctionCallContext&) override + { + return false; + } - auto index = ctx.callSite->func->as(); - auto str = ctx.callSite->args.data[0]->as(); - if (!index || !str) - return; + void refine(const MagicRefinementContext& ctx) override + { + if (ctx.callSite->args.size != 1 || ctx.discriminantTypes.empty()) + return; - std::optional discriminantTy = ctx.discriminantTypes[0]; - if (!discriminantTy) - return; + auto index = ctx.callSite->func->as(); + auto str = ctx.callSite->args.data[0]->as(); + if (!index || !str) + return; + + std::optional discriminantTy = ctx.discriminantTypes[0]; + if (!discriminantTy) + return; + + std::optional tfun = ctx.scope->lookupType(std::string(str->value.data, str->value.size)); + if (!tfun) + return; + + LUAU_ASSERT(get(*discriminantTy)); + asMutable(*discriminantTy)->ty.emplace(tfun->type); + } +}; - std::optional tfun = ctx.scope->lookupType(std::string(str->value.data, str->value.size)); - if (!tfun) - return; - LUAU_ASSERT(get(*discriminantTy)); - asMutable(*discriminantTy)->ty.emplace(tfun->type); -} struct RefinementClassFixture : BuiltinsFixture { @@ -85,8 +96,7 @@ struct RefinementClassFixture : BuiltinsFixture TypePackId isAParams = arena.addTypePack({inst, builtinTypes->stringType}); TypePackId isARets = arena.addTypePack({builtinTypes->booleanType}); TypeId isA = arena.addType(FunctionType{isAParams, isARets}); - getMutable(isA)->magicFunction = magicFunctionInstanceIsA; - getMutable(isA)->dcrMagicRefinement = dcrMagicRefinementInstanceIsA; + getMutable(isA)->magic = std::make_shared(); getMutable(inst)->props = { {"Name", Property{builtinTypes->stringType}}, diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index cea3fc6d..4264c777 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -24,7 +24,6 @@ LUAU_FASTINT(LuauNormalizeCacheLimit); LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTFLAG(LuauNewSolverVisitErrorExprLvalues) -LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions) LUAU_FASTFLAG(InferGlobalTypes) using namespace Luau; @@ -1731,7 +1730,7 @@ TEST_CASE_FIXTURE(Fixture, "visit_error_nodes_in_lvalue") TEST_CASE_FIXTURE(Fixture, "avoid_blocking_type_function") { - ScopedFastFlag sffs[] = {{FFlag::LuauSolverV2, true}, {FFlag::LuauDontRefCountTypesInTypeFunctions, true}}; + ScopedFastFlag _{FFlag::LuauSolverV2, true}; LUAU_CHECK_NO_ERRORS(check(R"( --!strict @@ -1744,7 +1743,7 @@ TEST_CASE_FIXTURE(Fixture, "avoid_blocking_type_function") TEST_CASE_FIXTURE(Fixture, "avoid_double_reference_to_free_type") { - ScopedFastFlag sffs[] = {{FFlag::LuauSolverV2, true}, {FFlag::LuauDontRefCountTypesInTypeFunctions, true}}; + ScopedFastFlag _{FFlag::LuauSolverV2, true}; LUAU_CHECK_NO_ERRORS(check(R"( --!strict