diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index 0a012327..c3a531d1 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -11,15 +11,14 @@ #include "Luau/ModuleResolver.h" #include "Luau/Normalize.h" #include "Luau/NotNull.h" +#include "Luau/Polarity.h" #include "Luau/Refinement.h" #include "Luau/Symbol.h" #include "Luau/TypeFwd.h" #include "Luau/TypeUtils.h" -#include "Luau/Variant.h" #include #include -#include namespace Luau { @@ -162,19 +161,26 @@ struct ConstraintGenerator void visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block); private: - std::vector> interiorTypes; + struct InteriorFreeTypes + { + std::vector types; + std::vector typePacks; + }; + + std::vector> DEPRECATED_interiorTypes; + std::vector interiorFreeTypes; /** * Fabricates a new free type belonging to a given scope. * @param scope the scope the free type belongs to. */ - TypeId freshType(const ScopePtr& scope); + TypeId freshType(const ScopePtr& scope, Polarity polarity = Polarity::Unknown); /** * Fabricates a new free type pack belonging to a given scope. * @param scope the scope the free type pack belongs to. */ - TypePackId freshTypePack(const ScopePtr& scope); + TypePackId freshTypePack(const ScopePtr& scope, Polarity polarity = Polarity::Unknown); /** * Allocate a new TypePack with the given head and tail. @@ -295,7 +301,7 @@ private: ); Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton); - Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional expectedType, bool forceSingleton); + Inference check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional expectedType, bool forceSingleton); Inference check(const ScopePtr& scope, AstExprLocal* local); Inference check(const ScopePtr& scope, AstExprGlobal* global); Inference checkIndexName(const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation); @@ -371,6 +377,11 @@ private: **/ TypeId resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments, bool replaceErrorWithFresh = false); + // resolveType() is recursive, but we only want to invoke + // inferGenericPolarities() once at the very end. We thus isolate the + // recursive part of the algorithm to this internal helper. + TypeId resolveType_(const ScopePtr& scope, AstType* ty, bool inTypeArguments, bool replaceErrorWithFresh = false); + /** * Resolves a type pack from its AST annotation. * @param scope the scope that the type annotation appears within. @@ -380,6 +391,9 @@ private: **/ TypePackId resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArguments, bool replaceErrorWithFresh = false); + // Inner hepler for resolveTypePack + TypePackId resolveTypePack_(const ScopePtr& scope, AstTypePack* tp, bool inTypeArguments, bool replaceErrorWithFresh = false); + /** * Resolves a type pack from its AST annotation. * @param scope the scope that the type annotation appears within. @@ -418,7 +432,7 @@ private: **/ std::vector> createGenericPacks( const ScopePtr& scope, - AstArray packs, + AstArray generics, bool useCache = false, bool addTypes = true ); diff --git a/Analysis/include/Luau/DataFlowGraph.h b/Analysis/include/Luau/DataFlowGraph.h index 8a70332e..90ab6419 100644 --- a/Analysis/include/Luau/DataFlowGraph.h +++ b/Analysis/include/Luau/DataFlowGraph.h @@ -129,8 +129,8 @@ private: /// A stack of scopes used by the visitor to see where we are. ScopeStack scopeStack; - - DfgScope* currentScope(); + NotNull currentScope(); + DfgScope* currentScope_DEPRECATED(); struct FunctionCapture { @@ -148,8 +148,8 @@ private: void joinBindings(DfgScope* p, const DfgScope& a, const DfgScope& b); void joinProps(DfgScope* p, const DfgScope& a, const DfgScope& b); - DefId lookup(Symbol symbol); - DefId lookup(DefId def, const std::string& key); + DefId lookup(Symbol symbol, Location location); + DefId lookup(DefId def, const std::string& key, Location location); ControlFlow visit(AstStatBlock* b); ControlFlow visitBlockWithoutChildScope(AstStatBlock* b); diff --git a/Analysis/include/Luau/Def.h b/Analysis/include/Luau/Def.h index 9627f998..1d712d8a 100644 --- a/Analysis/include/Luau/Def.h +++ b/Analysis/include/Luau/Def.h @@ -4,7 +4,8 @@ #include "Luau/NotNull.h" #include "Luau/TypedAllocator.h" #include "Luau/Variant.h" - +#include "Luau/Location.h" +#include "Luau/Symbol.h" #include #include @@ -13,6 +14,7 @@ namespace Luau struct Def; using DefId = NotNull; +struct AstLocal; /** * A cell is a "single-object" value. @@ -64,6 +66,8 @@ struct Def using V = Variant; V v; + Symbol name; + Location location; }; template @@ -79,7 +83,7 @@ struct DefArena { TypedAllocator allocator; - DefId freshCell(bool subscripted = false); + DefId freshCell(Symbol sym, Location location, bool subscripted = false); DefId phi(DefId a, DefId b); DefId phi(const std::vector& defs); }; diff --git a/Analysis/include/Luau/FragmentAutocomplete.h b/Analysis/include/Luau/FragmentAutocomplete.h index d073ea58..6dfd1378 100644 --- a/Analysis/include/Luau/FragmentAutocomplete.h +++ b/Analysis/include/Luau/FragmentAutocomplete.h @@ -75,7 +75,7 @@ struct FragmentAutocompleteResult { ModulePtr incrementalModule; Scope* freshScope; - TypeArena arenaForAutocomplete; + TypeArena arenaForAutocomplete_DEPRECATED; AutocompleteResult acResults; }; diff --git a/Analysis/include/Luau/Generalization.h b/Analysis/include/Luau/Generalization.h index 7f20ea2e..b2b89c07 100644 --- a/Analysis/include/Luau/Generalization.h +++ b/Analysis/include/Luau/Generalization.h @@ -8,6 +8,34 @@ namespace Luau { +template +struct GeneralizationParams +{ + bool foundOutsideFunctions = false; + size_t useCount = 0; + Polarity polarity = Polarity::None; +}; + +// Replace a single free type by its bounds according to the polarity provided. +std::optional generalizeType( + NotNull arena, + NotNull builtinTypes, + NotNull scope, + TypeId freeTy, + const GeneralizationParams& params +); + +// Generalize one type pack +std::optional generalizeTypePack( + NotNull arena, + NotNull builtinTypes, + NotNull scope, + TypePackId tp, + const GeneralizationParams& params +); + +void sealTable(NotNull scope, TypeId ty); + std::optional generalize( NotNull arena, NotNull builtinTypes, @@ -15,5 +43,4 @@ std::optional generalize( NotNull> cachedTypes, TypeId ty ); - } diff --git a/Analysis/include/Luau/InferPolarity.h b/Analysis/include/Luau/InferPolarity.h new file mode 100644 index 00000000..7f96327f --- /dev/null +++ b/Analysis/include/Luau/InferPolarity.h @@ -0,0 +1,16 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/NotNull.h" +#include "Luau/TypeFwd.h" + +namespace Luau +{ + +struct Scope; +struct TypeArena; + +void inferGenericPolarities(NotNull arena, NotNull scope, TypeId ty); +void inferGenericPolarities(NotNull arena, NotNull scope, TypePackId tp); + +} // namespace Luau diff --git a/Analysis/include/Luau/InsertionOrderedMap.h b/Analysis/include/Luau/InsertionOrderedMap.h index 2937dcda..e99bad06 100644 --- a/Analysis/include/Luau/InsertionOrderedMap.h +++ b/Analysis/include/Luau/InsertionOrderedMap.h @@ -67,6 +67,19 @@ public: return &pairs.at(it->second).second; } + V& operator[](const K& k) + { + auto it = indices.find(k); + if (it == indices.end()) + { + pairs.push_back(std::make_pair(k, V())); + indices[k] = pairs.size() - 1; + return pairs.back().second; + } + else + return pairs.at(it->second).second; + } + const_iterator begin() const { return pairs.begin(); diff --git a/Analysis/include/Luau/Polarity.h b/Analysis/include/Luau/Polarity.h new file mode 100644 index 00000000..b2117709 --- /dev/null +++ b/Analysis/include/Luau/Polarity.h @@ -0,0 +1,68 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +namespace Luau +{ + +enum struct Polarity : uint8_t +{ + None = 0b000, + Positive = 0b001, + Negative = 0b010, + Mixed = 0b011, + Unknown = 0b100, +}; + +inline Polarity operator|(Polarity lhs, Polarity rhs) +{ + return Polarity(uint8_t(lhs) | uint8_t(rhs)); +} + +inline Polarity& operator|=(Polarity& lhs, Polarity rhs) +{ + lhs = lhs | rhs; + return lhs; +} + +inline Polarity operator&(Polarity lhs, Polarity rhs) +{ + return Polarity(uint8_t(lhs) & uint8_t(rhs)); +} + +inline Polarity& operator&=(Polarity& lhs, Polarity rhs) +{ + lhs = lhs & rhs; + return lhs; +} + +inline bool isPositive(Polarity p) +{ + return bool(p & Polarity::Positive); +} + +inline bool isNegative(Polarity p) +{ + return bool(p & Polarity::Negative); +} + +inline bool isKnown(Polarity p) +{ + return p != Polarity::Unknown; +} + +inline Polarity invert(Polarity p) +{ + switch (p) + { + case Polarity::Positive: + return Polarity::Negative; + case Polarity::Negative: + return Polarity::Positive; + default: + return p; + } +} + +} // namespace Luau diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 6c3e15df..7d253d24 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -40,7 +40,7 @@ struct Scope // All the children of this scope. std::vector> children; std::unordered_map bindings; - TypePackId returnType; + TypePackId returnType = nullptr; std::optional varargPack; TypeLevel level; @@ -100,6 +100,7 @@ struct Scope std::unordered_map typeAliasTypePackParameters; std::optional> interiorFreeTypes; + std::optional> interiorFreeTypePacks; }; // Returns true iff the left scope encloses the right scope. A Scope* equal to diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index ebeeaa1a..3af8fa6c 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -5,10 +5,11 @@ #include "Luau/Ast.h" #include "Luau/Common.h" -#include "Luau/Refinement.h" #include "Luau/DenseHash.h" #include "Luau/NotNull.h" +#include "Luau/Polarity.h" #include "Luau/Predicate.h" +#include "Luau/Refinement.h" #include "Luau/Unifiable.h" #include "Luau/Variant.h" #include "Luau/VecDeque.h" @@ -37,15 +38,6 @@ struct Constraint; struct Subtyping; struct TypeChecker2; -enum struct Polarity : uint8_t -{ - None = 0b000, - Positive = 0b001, - Negative = 0b010, - Mixed = 0b011, - Unknown = 0b100, -}; - /** * There are three kinds of type variables: * - `Free` variables are metavariables, which stand for unconstrained types. @@ -80,7 +72,7 @@ struct FreeType // New constructors explicit FreeType(TypeLevel level, TypeId lowerBound, TypeId upperBound); // This one got promoted to explicit - explicit FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound); + explicit FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound, Polarity polarity = Polarity::Unknown); explicit FreeType(Scope* scope, TypeLevel level, TypeId lowerBound, TypeId upperBound); // Old constructors explicit FreeType(TypeLevel level); @@ -99,6 +91,8 @@ struct FreeType // Only used under local type inference TypeId lowerBound = nullptr; TypeId upperBound = nullptr; + + Polarity polarity = Polarity::Unknown; }; struct GenericType @@ -107,8 +101,8 @@ struct GenericType GenericType(); explicit GenericType(TypeLevel level); - explicit GenericType(const Name& name); - explicit GenericType(Scope* scope); + explicit GenericType(const Name& name, Polarity polarity = Polarity::Unknown); + explicit GenericType(Scope* scope, Polarity polarity = Polarity::Unknown); GenericType(TypeLevel level, const Name& name); GenericType(Scope* scope, const Name& name); @@ -118,6 +112,8 @@ struct GenericType Scope* scope = nullptr; Name name; bool explicitName = false; + + Polarity polarity = Polarity::Unknown; }; // When an equality constraint is found, it is then "bound" to that type, @@ -1206,7 +1202,7 @@ private: } }; -TypeId freshType(NotNull arena, NotNull builtinTypes, Scope* scope); +TypeId freshType(NotNull arena, NotNull builtinTypes, Scope* scope, Polarity polarity = Polarity::Unknown); using TypeIdPredicate = std::function(TypeId)>; std::vector filterMap(TypeId type, TypeIdPredicate predicate); diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h index ebefa41f..dde1a5be 100644 --- a/Analysis/include/Luau/TypeArena.h +++ b/Analysis/include/Luau/TypeArena.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Polarity.h" #include "Luau/TypedAllocator.h" #include "Luau/Type.h" #include "Luau/TypePack.h" @@ -40,7 +41,7 @@ struct TypeArena TypeId freshType_DEPRECATED(Scope* scope); TypeId freshType_DEPRECATED(Scope* scope, TypeLevel level); - TypePackId freshTypePack(Scope* scope); + TypePackId freshTypePack(Scope* scope, Polarity polarity = Polarity::Unknown); TypePackId addTypePack(std::initializer_list types); TypePackId addTypePack(std::vector types, std::optional tail = {}); diff --git a/Analysis/include/Luau/TypeFunction.h b/Analysis/include/Luau/TypeFunction.h index 88f95507..396fc3c1 100644 --- a/Analysis/include/Luau/TypeFunction.h +++ b/Analysis/include/Luau/TypeFunction.h @@ -248,6 +248,8 @@ struct BuiltinTypeFunctions TypeFunction setmetatableFunc; TypeFunction getmetatableFunc; + TypeFunction weakoptionalFunc; + void addToScope(NotNull arena, NotNull scope) const; }; diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 8509da03..1266f27c 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -1,11 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Common.h" +#include "Luau/NotNull.h" +#include "Luau/Polarity.h" +#include "Luau/TypeFwd.h" #include "Luau/Unifiable.h" #include "Luau/Variant.h" -#include "Luau/TypeFwd.h" -#include "Luau/NotNull.h" -#include "Luau/Common.h" #include #include @@ -26,12 +27,14 @@ struct TypeFunctionInstanceTypePack; struct FreeTypePack { explicit FreeTypePack(TypeLevel level); - explicit FreeTypePack(Scope* scope); + explicit FreeTypePack(Scope* scope, Polarity polarity = Polarity::Unknown); FreeTypePack(Scope* scope, TypeLevel level); int index; TypeLevel level; Scope* scope = nullptr; + + Polarity polarity = Polarity::Unknown; }; struct GenericTypePack @@ -40,7 +43,7 @@ struct GenericTypePack GenericTypePack(); explicit GenericTypePack(TypeLevel level); explicit GenericTypePack(const Name& name); - explicit GenericTypePack(Scope* scope); + explicit GenericTypePack(Scope* scope, Polarity polarity = Polarity::Unknown); GenericTypePack(TypeLevel level, const Name& name); GenericTypePack(Scope* scope, const Name& name); @@ -49,6 +52,8 @@ struct GenericTypePack Scope* scope = nullptr; Name name; bool explicitName = false; + + Polarity polarity = Polarity::Unknown; }; using BoundTypePack = Unifiable::Bound; @@ -100,9 +105,9 @@ struct TypeFunctionInstanceTypePack struct TypePackVar { - explicit TypePackVar(const TypePackVariant& ty); - explicit TypePackVar(TypePackVariant&& ty); - TypePackVar(TypePackVariant&& ty, bool persistent); + explicit TypePackVar(const TypePackVariant& tp); + explicit TypePackVar(TypePackVariant&& tp); + TypePackVar(TypePackVariant&& tp, bool persistent); bool operator==(const TypePackVar& rhs) const; @@ -169,6 +174,7 @@ struct TypePackIterator private: TypePackId currentTypePack = nullptr; + TypePackId tailCycleCheck = nullptr; const TypePack* tp = nullptr; size_t currentIndex = 0; diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index c3bed421..04516396 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -289,4 +289,6 @@ std::vector findBlockedArgTypesIn(AstExprCall* expr, NotNull genericPackSubstitutions{nullptr}; + // Unification sometimes results in the creation of new free types. + // We collect them here so that other systems can perform necessary + // bookkeeping. + std::vector newFreshTypes; + std::vector newFreshTypePacks; + int recursionCount = 0; int recursionLimit = 0; @@ -113,6 +119,9 @@ private: // Returns true if needle occurs within haystack already. ie if we bound // needle to haystack, would a cyclic TypePack result? OccursCheckResult occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack); + + TypeId freshType(NotNull scope, Polarity polarity); + TypePackId freshTypePack(NotNull scope, Polarity polarity); }; } // namespace Luau diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 815164d8..02d6444b 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -13,8 +13,6 @@ LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon) - namespace Luau { @@ -43,24 +41,13 @@ struct AutocompleteNodeFinder : public AstVisitor bool visit(AstStat* stat) override { - if (FFlag::LuauExtendStatEndPosWithSemicolon) + // Consider 'local myLocal = 4;|' and 'local myLocal = 4', where '|' is the cursor position. In both cases, the cursor position is equal + // to `AstStatLocal.location.end`. However, in the first case (semicolon), we are starting a new statement, whilst in the second case + // (no semicolon) we are still part of the AstStatLocal, hence the different comparison check. + if (stat->location.begin < pos && (stat->hasSemicolon ? pos < stat->location.end : pos <= stat->location.end)) { - // Consider 'local myLocal = 4;|' and 'local myLocal = 4', where '|' is the cursor position. In both cases, the cursor position is equal - // to `AstStatLocal.location.end`. However, in the first case (semicolon), we are starting a new statement, whilst in the second case - // (no semicolon) we are still part of the AstStatLocal, hence the different comparison check. - if (stat->location.begin < pos && (stat->hasSemicolon ? pos < stat->location.end : pos <= stat->location.end)) - { - ancestry.push_back(stat); - return true; - } - } - else - { - if (stat->location.begin < pos && pos <= stat->location.end) - { - ancestry.push_back(stat); - return true; - } + ancestry.push_back(stat); + return true; } return false; diff --git a/Analysis/src/AutocompleteCore.cpp b/Analysis/src/AutocompleteCore.cpp index 03e5c31e..5d15f751 100644 --- a/Analysis/src/AutocompleteCore.cpp +++ b/Analysis/src/AutocompleteCore.cpp @@ -493,7 +493,7 @@ static void autocompleteProps( // Then we are on a one way journey to a stack overflow. if (FFlag::LuauAutocompleteUnionCopyPreviousSeen) { - for (auto ty: seen) + for (auto ty : seen) { if (is(ty)) innerSeen.insert(ty); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index ff532c43..7c5d8e76 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -3,22 +3,23 @@ #include "Luau/Ast.h" #include "Luau/Clone.h" +#include "Luau/Common.h" +#include "Luau/ConstraintGenerator.h" +#include "Luau/ConstraintSolver.h" #include "Luau/DenseHash.h" #include "Luau/Error.h" #include "Luau/Frontend.h" -#include "Luau/Symbol.h" -#include "Luau/Common.h" -#include "Luau/ToString.h" -#include "Luau/ConstraintSolver.h" -#include "Luau/ConstraintGenerator.h" +#include "Luau/InferPolarity.h" #include "Luau/NotNull.h" -#include "Luau/TypeInfer.h" +#include "Luau/Subtyping.h" +#include "Luau/Symbol.h" +#include "Luau/ToString.h" +#include "Luau/Type.h" #include "Luau/TypeChecker2.h" #include "Luau/TypeFunction.h" +#include "Luau/TypeInfer.h" #include "Luau/TypePack.h" -#include "Luau/Type.h" #include "Luau/TypeUtils.h" -#include "Luau/Subtyping.h" #include @@ -29,10 +30,12 @@ */ LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauNonReentrantGeneralization) LUAU_FASTFLAGVARIABLE(LuauTableCloneClonesType3) LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) LUAU_FASTFLAGVARIABLE(LuauFollowTableFreeze) LUAU_FASTFLAGVARIABLE(LuauUserTypeFunTypecheck) +LUAU_FASTFLAGVARIABLE(LuauMagicFreezeCheckBlocked) namespace Luau { @@ -246,6 +249,7 @@ void addGlobalBinding(GlobalTypes& globals, const ScopePtr& scope, const std::st void addGlobalBinding(GlobalTypes& globals, const ScopePtr& scope, const std::string& name, Binding binding) { + inferGenericPolarities(NotNull{&globals.globalTypes}, NotNull{scope.get()}, binding.typeId); scope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = binding; } @@ -310,6 +314,9 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC TypeArena& arena = globals.globalTypes; NotNull builtinTypes = globals.builtinTypes; + Scope* globalScope = nullptr; // NotNull when removing FFlag::LuauNonReentrantGeneralization + if (FFlag::LuauNonReentrantGeneralization) + globalScope = globals.globalScope.get(); if (FFlag::LuauSolverV2) builtinTypeFunctions().addToScope(NotNull{&arena}, NotNull{globals.globalScope.get()}); @@ -319,8 +326,8 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC ); LUAU_ASSERT(loadResult.success); - TypeId genericK = arena.addType(GenericType{"K"}); - TypeId genericV = arena.addType(GenericType{"V"}); + TypeId genericK = arena.addType(GenericType{globalScope, "K"}); + TypeId genericV = arena.addType(GenericType{globalScope, "V"}); TypeId mapOfKtoV = arena.addType(TableType{{}, TableIndexer(genericK, genericV), globals.globalScope->level, TableState::Generic}); std::optional stringMetatableTy = getMetatable(builtinTypes->stringType, builtinTypes); @@ -368,7 +375,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) addGlobalBinding(globals, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); - TypeId genericMT = arena.addType(GenericType{"MT"}); + TypeId genericMT = arena.addType(GenericType{globalScope, "MT"}); TableType tab{TableState::Generic, globals.globalScope->level}; TypeId tabTy = arena.addType(tab); @@ -380,7 +387,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC if (FFlag::LuauSolverV2) { - TypeId genericT = arena.addType(GenericType{"T"}); + TypeId genericT = arena.addType(GenericType{globalScope, "T"}); TypeId tMetaMT = arena.addType(MetatableType{genericT, genericMT}); // clang-format off @@ -437,7 +444,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC if (FFlag::LuauSolverV2) { // declare function assert(value: T, errorMessage: string?): intersect - TypeId genericT = arena.addType(GenericType{"T"}); + TypeId genericT = arena.addType(GenericType{globalScope, "T"}); TypeId refinedTy = arena.addType(TypeFunctionInstanceType{ NotNull{&builtinTypeFunctions().intersectFunc}, {genericT, arena.addType(NegationType{builtinTypes->falsyType})}, {} }); @@ -460,12 +467,16 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC // the top table type. We do the best we can by modelling these // functions using unconstrained generics. It's not quite right, // but it'll be ok for now. - TypeId genericTy = arena.addType(GenericType{"T"}); + TypeId genericTy = arena.addType(GenericType{globalScope, "T"}); TypePackId thePack = arena.addTypePack({genericTy}); TypeId idTyWithMagic = arena.addType(FunctionType{{genericTy}, {}, thePack, thePack}); ttv->props["freeze"] = makeProperty(idTyWithMagic, "@luau/global/table.freeze"); + if (globalScope) + inferGenericPolarities(NotNull{&globals.globalTypes}, NotNull{globalScope}, idTyWithMagic); TypeId idTy = arena.addType(FunctionType{{genericTy}, {}, thePack, thePack}); + if (globalScope) + inferGenericPolarities(NotNull{&globals.globalTypes}, NotNull{globalScope}, idTy); ttv->props["clone"] = makeProperty(idTy, "@luau/global/table.clone"); } else @@ -713,15 +724,15 @@ bool MagicFormat::typeCheck(const MagicFunctionTypeCheckContext& context) { switch (shouldSuppressErrors(NotNull{&context.typechecker->normalizer}, actualTy)) { - case ErrorSuppression::Suppress: - break; - case ErrorSuppression::NormalizationFailed: - break; - case ErrorSuppression::DoNotSuppress: - Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result); + case ErrorSuppression::Suppress: + break; + case ErrorSuppression::NormalizationFailed: + break; + case ErrorSuppression::DoNotSuppress: + Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result); - if (!reasonings.suppressed) - context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location); + if (!reasonings.suppressed) + context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location); } } } @@ -1598,6 +1609,17 @@ bool MagicFreeze::infer(const MagicFunctionCallContext& context) std::optional resultDef = dfg->getDefOptional(targetExpr); std::optional resultTy = resultDef ? scope->lookup(*resultDef) : std::nullopt; + if (FFlag::LuauMagicFreezeCheckBlocked) + { + if (resultTy && !get(resultTy)) + { + // If there's an existing result type but it's _not_ blocked, then + // we aren't type stating this builtin and should fall back to + // regular inference. + return false; + } + } + std::optional frozenType = freezeTable(inputType, context); if (!frozenType) diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 1be0bd5b..ef4f4690 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -9,6 +9,7 @@ #include "Luau/DcrLogger.h" #include "Luau/Def.h" #include "Luau/DenseHash.h" +#include "Luau/InferPolarity.h" #include "Luau/ModuleResolver.h" #include "Luau/NotNull.h" #include "Luau/RecursionCounter.h" @@ -32,6 +33,7 @@ LUAU_FASTINT(LuauCheckRecursionLimit) LUAU_FASTFLAG(DebugLuauLogSolverToJson) LUAU_FASTFLAG(DebugLuauMagicTypes) +LUAU_FASTFLAG(LuauNonReentrantGeneralization) LUAU_FASTFLAG(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) LUAU_FASTFLAGVARIABLE(LuauPropagateExpectedTypesForCalls) LUAU_FASTFLAG(DebugLuauGreedyGeneralization) @@ -50,6 +52,7 @@ LUAU_FASTFLAGVARIABLE(LuauRetainDefinitionAliasLocations) LUAU_FASTFLAG(LuauDeprecatedAttribute) LUAU_FASTFLAGVARIABLE(LuauCacheInferencePerAstExpr) LUAU_FASTFLAGVARIABLE(LuauAlwaysResolveAstTypes) +LUAU_FASTFLAGVARIABLE(LuauWeakNilRefinementType) namespace Luau { @@ -227,7 +230,10 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) rootScope->location = block->location; module->astScopes[block] = NotNull{scope.get()}; - rootScope->returnType = freshTypePack(scope); + if (FFlag::LuauNonReentrantGeneralization) + interiorFreeTypes.emplace_back(); + else + DEPRECATED_interiorTypes.emplace_back(); if (FFlag::LuauUserTypeFunTypecheck) { @@ -237,8 +243,8 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) typeFunctionRuntime->rootScope = localTypeFunctionScope; } + rootScope->returnType = freshTypePack(scope, Polarity::Positive); TypeId moduleFnTy = arena->addType(FunctionType{TypeLevel{}, builtinTypes->anyTypePack, rootScope->returnType}); - interiorTypes.emplace_back(); prepopulateGlobalScope(scope, block); @@ -255,12 +261,20 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) scope, block->location, GeneralizationConstraint{ - result, moduleFnTy, FFlag::LuauTrackInteriorFreeTypesOnScope ? std::vector{} : std::move(interiorTypes.back()) + result, + moduleFnTy, + (FFlag::LuauNonReentrantGeneralization || FFlag::LuauTrackInteriorFreeTypesOnScope) ? std::vector{} + : std::move(DEPRECATED_interiorTypes.back()) } ); - if (FFlag::LuauTrackInteriorFreeTypesOnScope) - scope->interiorFreeTypes = std::move(interiorTypes.back()); + if (FFlag::LuauNonReentrantGeneralization) + { + scope->interiorFreeTypes = std::move(interiorFreeTypes.back().types); + scope->interiorFreeTypePacks = std::move(interiorFreeTypes.back().typePacks); + } + else if (FFlag::LuauTrackInteriorFreeTypesOnScope) + scope->interiorFreeTypes = std::move(DEPRECATED_interiorTypes.back()); getMutable(result)->setOwner(genConstraint); forEachConstraint( @@ -273,7 +287,10 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) } ); - interiorTypes.pop_back(); + if (FFlag::LuauNonReentrantGeneralization) + interiorFreeTypes.pop_back(); + else + DEPRECATED_interiorTypes.pop_back(); fillInInferredBindings(scope, block); @@ -302,11 +319,16 @@ void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStat // We prepopulate global data in the resumeScope to avoid writing data into the old modules scopes prepopulateGlobalScopeForFragmentTypecheck(globalScope, resumeScope, block); // Pre - // We need to pop the interior types, - interiorTypes.emplace_back(); + if (FFlag::LuauNonReentrantGeneralization) + interiorFreeTypes.emplace_back(); + else + DEPRECATED_interiorTypes.emplace_back(); visitBlockWithoutChildScope(resumeScope, block); // Post - interiorTypes.pop_back(); + if (FFlag::LuauNonReentrantGeneralization) + interiorFreeTypes.pop_back(); + else + DEPRECATED_interiorTypes.pop_back(); fillInInferredBindings(resumeScope, block); @@ -331,12 +353,18 @@ void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStat } -TypeId ConstraintGenerator::freshType(const ScopePtr& scope) +TypeId ConstraintGenerator::freshType(const ScopePtr& scope, Polarity polarity) { - if (FFlag::LuauTrackInteriorFreeTypesOnScope) + if (FFlag::LuauNonReentrantGeneralization) + { + auto ft = Luau::freshType(arena, builtinTypes, scope.get(), polarity); + interiorFreeTypes.back().types.push_back(ft); + return ft; + } + else if (FFlag::LuauTrackInteriorFreeTypesOnScope) { auto ft = Luau::freshType(arena, builtinTypes, scope.get()); - interiorTypes.back().push_back(ft); + DEPRECATED_interiorTypes.back().push_back(ft); return ft; } else @@ -345,10 +373,13 @@ TypeId ConstraintGenerator::freshType(const ScopePtr& scope) } } -TypePackId ConstraintGenerator::freshTypePack(const ScopePtr& scope) +TypePackId ConstraintGenerator::freshTypePack(const ScopePtr& scope, Polarity polarity) { - FreeTypePack f{scope.get()}; - return arena->addTypePack(TypePackVar{std::move(f)}); + FreeTypePack f{scope.get(), polarity}; + TypePackId result = arena->addTypePack(TypePackVar{std::move(f)}); + if (FFlag::LuauNonReentrantGeneralization) + interiorFreeTypes.back().typePacks.push_back(result); + return result; } TypePackId ConstraintGenerator::addTypePack(std::vector head, std::optional tail) @@ -654,7 +685,7 @@ void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location locat if (std::optional defTy = lookup(scope, location, def)) { TypeId ty = *defTy; - if (partition.shouldAppendNilType) + if (!FFlag::LuauWeakNilRefinementType && partition.shouldAppendNilType) ty = arena->addType(UnionType{{ty, builtinTypes->nilType}}); // Intersect ty with every discriminant type. If either type is not // sufficiently solved, we queue the intersection up via an @@ -702,6 +733,9 @@ void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location locat if (kind != RefinementsOpKind::None) ty = flushConstraints(kind, ty, discriminants); + if (FFlag::LuauWeakNilRefinementType && partition.shouldAppendNilType) + ty = createTypeFunctionInstance(builtinTypeFunctions().weakoptionalFunc, {ty}, {}, scope, location); + scope->rvalueRefinements[def] = ty; } } @@ -1783,7 +1817,10 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeFunctio // Place this function as a child of the non-type function scope scope->children.push_back(NotNull{sig.signatureScope.get()}); - interiorTypes.push_back(std::vector{}); + if (FFlag::LuauNonReentrantGeneralization) + interiorFreeTypes.emplace_back(); + else + DEPRECATED_interiorTypes.push_back(std::vector{}); checkFunctionBody(sig.bodyScope, function->body); Checkpoint endCheckpoint = checkpoint(this); @@ -1792,15 +1829,25 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeFunctio sig.signatureScope, function->location, GeneralizationConstraint{ - generalizedTy, sig.signature, FFlag::LuauTrackInteriorFreeTypesOnScope ? std::vector{} : std::move(interiorTypes.back()) + generalizedTy, + sig.signature, + FFlag::LuauTrackInteriorFreeTypesOnScope ? std::vector{} : std::move(DEPRECATED_interiorTypes.back()) } ); - if (FFlag::LuauTrackInteriorFreeTypesOnScope) - sig.signatureScope->interiorFreeTypes = std::move(interiorTypes.back()); + if (FFlag::LuauNonReentrantGeneralization) + { + sig.signatureScope->interiorFreeTypes = std::move(interiorFreeTypes.back().types); + sig.signatureScope->interiorFreeTypePacks = std::move(interiorFreeTypes.back().typePacks); + } + else if (FFlag::LuauTrackInteriorFreeTypesOnScope) + sig.signatureScope->interiorFreeTypes = std::move(DEPRECATED_interiorTypes.back()); getMutable(generalizedTy)->setOwner(gc); - interiorTypes.pop_back(); + if (FFlag::LuauNonReentrantGeneralization) + interiorFreeTypes.pop_back(); + else + DEPRECATED_interiorTypes.pop_back(); Constraint* previous = nullptr; forEachConstraint( @@ -2023,6 +2070,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareFunc defn.originalNameLocation = global->nameLocation; TypeId fnType = arena->addType(FunctionType{TypeLevel{}, std::move(genericTys), std::move(genericTps), paramPack, retPack, defn}); + inferGenericPolarities(arena, NotNull{scope.get()}, fnType); + FunctionType* ftv = getMutable(fnType); ftv->isCheckedFunction = global->isCheckedFunction(); if (FFlag::LuauDeprecatedAttribute) @@ -2190,7 +2239,7 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* if (selfTy) args.push_back(*selfTy); else - args.push_back(freshType(scope)); + args.push_back(freshType(scope, Polarity::Negative)); } else if (i < exprArgs.size() - 1 || !(arg->is() || arg->is())) { @@ -2456,11 +2505,24 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantStrin if (forceSingleton) return Inference{arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}})}; - FreeType ft = - FFlag::LuauFreeTypesMustHaveBounds ? FreeType{scope.get(), builtinTypes->neverType, builtinTypes->unknownType} : FreeType{scope.get()}; - ft.lowerBound = arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}}); - ft.upperBound = builtinTypes->stringType; - const TypeId freeTy = arena->addType(ft); + TypeId freeTy = nullptr; + if (FFlag::LuauNonReentrantGeneralization) + { + freeTy = freshType(scope, Polarity::Positive); + FreeType* ft = getMutable(freeTy); + LUAU_ASSERT(ft); + ft->lowerBound = arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}}); + ft->upperBound = builtinTypes->stringType; + } + else + { + FreeType ft = + FFlag::LuauFreeTypesMustHaveBounds ? FreeType{scope.get(), builtinTypes->neverType, builtinTypes->unknownType} : FreeType{scope.get()}; + ft.lowerBound = arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}}); + ft.upperBound = builtinTypes->stringType; + freeTy = arena->addType(ft); + } + addConstraint(scope, string->location, PrimitiveTypeConstraint{freeTy, expectedType, builtinTypes->stringType}); return Inference{freeTy}; } @@ -2471,11 +2533,24 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantBool* if (forceSingleton) return Inference{singletonType}; - FreeType ft = - FFlag::LuauFreeTypesMustHaveBounds ? FreeType{scope.get(), builtinTypes->neverType, builtinTypes->unknownType} : FreeType{scope.get()}; - ft.lowerBound = singletonType; - ft.upperBound = builtinTypes->booleanType; - const TypeId freeTy = arena->addType(ft); + TypeId freeTy = nullptr; + if (FFlag::LuauNonReentrantGeneralization) + { + freeTy = freshType(scope, Polarity::Positive); + FreeType* ft = getMutable(freeTy); + LUAU_ASSERT(ft); + ft->lowerBound = singletonType; + ft->upperBound = builtinTypes->booleanType; + } + else + { + FreeType ft = + FFlag::LuauFreeTypesMustHaveBounds ? FreeType{scope.get(), builtinTypes->neverType, builtinTypes->unknownType} : FreeType{scope.get()}; + ft.lowerBound = singletonType; + ft.upperBound = builtinTypes->booleanType; + freeTy = arena->addType(ft); + } + addConstraint(scope, boolExpr->location, PrimitiveTypeConstraint{freeTy, expectedType, builtinTypes->booleanType}); return Inference{freeTy}; } @@ -2621,7 +2696,10 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprFunction* fun Checkpoint startCheckpoint = checkpoint(this); FunctionSignature sig = checkFunctionSignature(scope, func, expectedType); - interiorTypes.push_back(std::vector{}); + if (FFlag::LuauNonReentrantGeneralization) + interiorFreeTypes.emplace_back(); + else + DEPRECATED_interiorTypes.push_back(std::vector{}); checkFunctionBody(sig.bodyScope, func); Checkpoint endCheckpoint = checkpoint(this); @@ -2630,15 +2708,29 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprFunction* fun sig.signatureScope, func->location, GeneralizationConstraint{ - generalizedTy, sig.signature, FFlag::LuauTrackInteriorFreeTypesOnScope ? std::vector{} : std::move(interiorTypes.back()) + generalizedTy, + sig.signature, + (FFlag::LuauNonReentrantGeneralization || FFlag::LuauTrackInteriorFreeTypesOnScope) ? std::vector{} + : std::move(DEPRECATED_interiorTypes.back()) } ); - if (FFlag::LuauTrackInteriorFreeTypesOnScope) - sig.signatureScope->interiorFreeTypes = std::move(interiorTypes.back()); + if (FFlag::LuauNonReentrantGeneralization) + { + sig.signatureScope->interiorFreeTypes = std::move(interiorFreeTypes.back().types); + sig.signatureScope->interiorFreeTypePacks = std::move(interiorFreeTypes.back().typePacks); + interiorFreeTypes.pop_back(); - getMutable(generalizedTy)->setOwner(gc); - interiorTypes.pop_back(); + getMutable(generalizedTy)->setOwner(gc); + } + else + { + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + sig.signatureScope->interiorFreeTypes = std::move(DEPRECATED_interiorTypes.back()); + + getMutable(generalizedTy)->setOwner(gc); + DEPRECATED_interiorTypes.pop_back(); + } Constraint* previous = nullptr; forEachConstraint( @@ -3118,7 +3210,10 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, ttv->definitionLocation = expr->location; ttv->scope = scope.get(); - interiorTypes.back().push_back(ty); + if (FFlag::LuauNonReentrantGeneralization) + interiorFreeTypes.back().types.push_back(ty); + else + DEPRECATED_interiorTypes.back().push_back(ty); TypeIds indexKeyLowerBound; TypeIds indexValueLowerBound; @@ -3220,7 +3315,7 @@ ConstraintGenerator::FunctionSignature ConstraintGenerator::checkFunctionSignatu // We need to assign returnType before creating bodyScope so that the // return type gets propagated to bodyScope. - returnType = freshTypePack(signatureScope); + returnType = freshTypePack(signatureScope, Polarity::Positive); signatureScope->returnType = returnType; bodyScope = childScope(fn->body, signatureScope); @@ -3300,7 +3395,7 @@ ConstraintGenerator::FunctionSignature ConstraintGenerator::checkFunctionSignatu if (i < expectedArgPack.head.size()) argTy = expectedArgPack.head[i]; else - argTy = freshType(signatureScope); + argTy = freshType(signatureScope, Polarity::Negative); } argTypes.push_back(argTy); @@ -3379,6 +3474,8 @@ ConstraintGenerator::FunctionSignature ConstraintGenerator::checkFunctionSignatu LUAU_ASSERT(actualFunctionType); module->astTypes[fn] = actualFunctionType; + inferGenericPolarities(arena, NotNull{signatureScope.get()}, actualFunctionType); + if (expectedType && get(*expectedType)) bindFreeType(*expectedType, actualFunctionType); @@ -3423,7 +3520,7 @@ TypeId ConstraintGenerator::resolveReferenceType( return builtinTypes->errorRecoveryType(); } else - return resolveType(scope, ref->parameters.data[0].type, inTypeArguments); + return resolveType_(scope, ref->parameters.data[0].type, inTypeArguments); } } @@ -3456,11 +3553,11 @@ TypeId ConstraintGenerator::resolveReferenceType( // that is done in the parser. if (p.type) { - parameters.push_back(resolveType(scope, p.type, /* inTypeArguments */ true)); + parameters.push_back(resolveType_(scope, p.type, /* inTypeArguments */ true)); } else if (p.typePack) { - TypePackId tp = resolveTypePack(scope, p.typePack, /*inTypeArguments*/ true); + TypePackId tp = resolveTypePack_(scope, p.typePack, /*inTypeArguments*/ true); // If we need more regular types, we can use single element type packs to fill those in if (parameters.size() < alias->typeParams.size() && size(tp) == 1 && finite(tp) && first(tp)) @@ -3489,7 +3586,7 @@ TypeId ConstraintGenerator::resolveReferenceType( { result = builtinTypes->errorRecoveryType(); if (replaceErrorWithFresh) - result = freshType(scope); + result = freshType(scope, Polarity::Mixed); } return result; @@ -3502,7 +3599,7 @@ TypeId ConstraintGenerator::resolveTableType(const ScopePtr& scope, AstType* ty, for (const AstTableProp& prop : tab->props) { - TypeId propTy = resolveType(scope, prop.type, inTypeArguments); + TypeId propTy = resolveType_(scope, prop.type, inTypeArguments); Property& p = props[prop.name.value]; p.typeLocation = prop.location; @@ -3594,8 +3691,11 @@ TypeId ConstraintGenerator::resolveFunctionType( signatureScope = scope; } - TypePackId argTypes = resolveTypePack(signatureScope, fn->argTypes, inTypeArguments, replaceErrorWithFresh); - TypePackId returnTypes = resolveTypePack(signatureScope, fn->returnTypes, inTypeArguments, replaceErrorWithFresh); + AstTypePackExplicit tempArgTypes{Location{}, fn->argTypes}; + TypePackId argTypes = resolveTypePack_(signatureScope, &tempArgTypes, inTypeArguments, replaceErrorWithFresh); + + AstTypePackExplicit tempRetTypes{Location{}, fn->returnTypes}; + TypePackId returnTypes = resolveTypePack_(signatureScope, &tempRetTypes, inTypeArguments, replaceErrorWithFresh); // TODO: FunctionType needs a pointer to the scope so that we know // how to quantify/instantiate it. @@ -3627,6 +3727,13 @@ TypeId ConstraintGenerator::resolveFunctionType( } TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments, bool replaceErrorWithFresh) +{ + TypeId result = resolveType_(scope, ty, inTypeArguments, replaceErrorWithFresh); + inferGenericPolarities(arena, NotNull{scope.get()}, result); + return result; +} + +TypeId ConstraintGenerator::resolveType_(const ScopePtr& scope, AstType* ty, bool inTypeArguments, bool replaceErrorWithFresh) { TypeId result = nullptr; @@ -3659,13 +3766,13 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool if (FFlag::LuauAlwaysResolveAstTypes) { if (unionAnnotation->types.size == 1) - result = resolveType(scope, unionAnnotation->types.data[0], inTypeArguments); + result = resolveType_(scope, unionAnnotation->types.data[0], inTypeArguments); else { std::vector parts; for (AstType* part : unionAnnotation->types) { - parts.push_back(resolveType(scope, part, inTypeArguments)); + parts.push_back(resolveType_(scope, part, inTypeArguments)); } result = arena->addType(UnionType{parts}); @@ -3676,7 +3783,7 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) { if (unionAnnotation->types.size == 1) - return resolveType(scope, unionAnnotation->types.data[0], inTypeArguments); + return resolveType_(scope, unionAnnotation->types.data[0], inTypeArguments); } std::vector parts; @@ -3693,13 +3800,13 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool if (FFlag::LuauAlwaysResolveAstTypes) { if (intersectionAnnotation->types.size == 1) - result = resolveType(scope, intersectionAnnotation->types.data[0], inTypeArguments); + result = resolveType_(scope, intersectionAnnotation->types.data[0], inTypeArguments); else { std::vector parts; for (AstType* part : intersectionAnnotation->types) { - parts.push_back(resolveType(scope, part, inTypeArguments)); + parts.push_back(resolveType_(scope, part, inTypeArguments)); } result = arena->addType(IntersectionType{parts}); @@ -3710,7 +3817,7 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) { if (intersectionAnnotation->types.size == 1) - return resolveType(scope, intersectionAnnotation->types.data[0], inTypeArguments); + return resolveType_(scope, intersectionAnnotation->types.data[0], inTypeArguments); } std::vector parts; @@ -3724,7 +3831,7 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool } else if (auto typeGroupAnnotation = ty->as()) { - result = resolveType(scope, typeGroupAnnotation->type, inTypeArguments); + result = resolveType_(scope, typeGroupAnnotation->type, inTypeArguments); } else if (auto boolAnnotation = ty->as()) { @@ -3754,6 +3861,13 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool } TypePackId ConstraintGenerator::resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArgument, bool replaceErrorWithFresh) +{ + TypePackId result = resolveTypePack_(scope, tp, inTypeArgument, replaceErrorWithFresh); + inferGenericPolarities(arena, NotNull{scope.get()}, result); + return result; +} + +TypePackId ConstraintGenerator::resolveTypePack_(const ScopePtr& scope, AstTypePack* tp, bool inTypeArgument, bool replaceErrorWithFresh) { TypePackId result; if (auto expl = tp->as()) @@ -3762,7 +3876,7 @@ TypePackId ConstraintGenerator::resolveTypePack(const ScopePtr& scope, AstTypePa } else if (auto var = tp->as()) { - TypeId ty = resolveType(scope, var->variadicType, inTypeArgument, replaceErrorWithFresh); + TypeId ty = resolveType_(scope, var->variadicType, inTypeArgument, replaceErrorWithFresh); result = arena->addTypePack(TypePackVar{VariadicTypePack{ty}}); } else if (auto gen = tp->as()) @@ -3793,16 +3907,18 @@ TypePackId ConstraintGenerator::resolveTypePack(const ScopePtr& scope, const Ast for (AstType* headTy : list.types) { - head.push_back(resolveType(scope, headTy, inTypeArguments, replaceErrorWithFresh)); + head.push_back(resolveType_(scope, headTy, inTypeArguments, replaceErrorWithFresh)); } std::optional tail = std::nullopt; if (list.tailType) { - tail = resolveTypePack(scope, list.tailType, inTypeArguments, replaceErrorWithFresh); + tail = resolveTypePack_(scope, list.tailType, inTypeArguments, replaceErrorWithFresh); } - return addTypePack(std::move(head), tail); + TypePackId result = addTypePack(std::move(head), tail); + inferGenericPolarities(arena, NotNull{scope.get()}, result); + return result; } std::vector> ConstraintGenerator::createGenerics( diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 34cf81b9..c786f882 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -39,6 +39,7 @@ LUAU_FASTFLAGVARIABLE(LuauHasPropProperBlock) LUAU_FASTFLAGVARIABLE(DebugLuauGreedyGeneralization) LUAU_FASTFLAG(LuauSearchForRefineableType) LUAU_FASTFLAG(LuauDeprecatedAttribute) +LUAU_FASTFLAG(LuauNonReentrantGeneralization) LUAU_FASTFLAG(LuauBidirectionalInferenceCollectIndexerTypes) LUAU_FASTFLAG(LuauNewTypeFunReductionChecks2) LUAU_FASTFLAGVARIABLE(LuauTrackInferredFunctionTypeFromCall) @@ -603,12 +604,14 @@ struct TypeSearcher : TypeVisitor explicit TypeSearcher(TypeId needle) : TypeSearcher(needle, Polarity::Positive) - {} + { + } explicit TypeSearcher(TypeId needle, Polarity initialPolarity) : needle(needle) , current(initialPolarity) - {} + { + } bool visit(TypeId ty) override { @@ -625,14 +628,14 @@ struct TypeSearcher : TypeVisitor { switch (current) { - case Polarity::Positive: - current = Polarity::Negative; - break; - case Polarity::Negative: - current = Polarity::Positive; - break; - default: - break; + case Polarity::Positive: + current = Polarity::Negative; + break; + case Polarity::Negative: + current = Polarity::Positive; + break; + default: + break; } } @@ -710,32 +713,32 @@ void ConstraintSolver::generalizeOneType(TypeId ty) switch (ts.result) { - case Polarity::None: + case Polarity::None: + asMutable(ty)->reassign(Type{BoundType{upperBound}}); + break; + + case Polarity::Negative: + case Polarity::Mixed: + if (get(upperBound) && ts.count > 1) + { + asMutable(ty)->reassign(Type{GenericType{tyScope}}); + function->generics.emplace_back(ty); + } + else asMutable(ty)->reassign(Type{BoundType{upperBound}}); - break; + break; - case Polarity::Negative: - case Polarity::Mixed: - if (get(upperBound) && ts.count > 1) - { - asMutable(ty)->reassign(Type{GenericType{tyScope}}); - function->generics.emplace_back(ty); - } - else - asMutable(ty)->reassign(Type{BoundType{upperBound}}); - break; - - case Polarity::Positive: - if (get(lowerBound) && ts.count > 1) - { - asMutable(ty)->reassign(Type{GenericType{tyScope}}); - function->generics.emplace_back(ty); - } - else - asMutable(ty)->reassign(Type{BoundType{lowerBound}}); - break; - default: - LUAU_ASSERT(!"Unreachable"); + case Polarity::Positive: + if (get(lowerBound) && ts.count > 1) + { + asMutable(ty)->reassign(Type{GenericType{tyScope}}); + function->generics.emplace_back(ty); + } + else + asMutable(ty)->reassign(Type{BoundType{lowerBound}}); + break; + default: + LUAU_ASSERT(!"Unreachable"); } } } @@ -747,7 +750,16 @@ void ConstraintSolver::bind(NotNull constraint, TypeId ty, Typ boundTo = follow(boundTo); if (get(ty) && ty == boundTo) - return emplace(constraint, ty, constraint->scope, builtinTypes->neverType, builtinTypes->unknownType); + { + emplace( + constraint, ty, constraint->scope, builtinTypes->neverType, builtinTypes->unknownType, Polarity::Mixed + ); // FIXME? Is this the right polarity? + + if (FFlag::LuauNonReentrantGeneralization) + trackInteriorFreeType(constraint->scope, ty); + + return; + } shiftReferences(ty, boundTo); emplaceType(asMutable(ty), boundTo); @@ -903,8 +915,48 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNullscope->interiorFreeTypes) + { for (TypeId ty : *constraint->scope->interiorFreeTypes) // NOLINT(bugprone-unchecked-optional-access) - generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty); + { + if (FFlag::LuauNonReentrantGeneralization) + { + ty = follow(ty); + if (auto freeTy = get(ty)) + { + GeneralizationParams params; + params.foundOutsideFunctions = true; + params.useCount = 1; + params.polarity = freeTy->polarity; + + generalizeType(arena, builtinTypes, constraint->scope, ty, params); + } + else if (get(ty)) + sealTable(constraint->scope, ty); + } + else + generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty); + } + } + + if (FFlag::LuauNonReentrantGeneralization) + { + if (constraint->scope->interiorFreeTypePacks) + { + for (TypePackId tp : *constraint->scope->interiorFreeTypePacks) // NOLINT(bugprone-unchecked-optional-access) + { + tp = follow(tp); + if (auto freeTp = get(tp)) + { + GeneralizationParams params; + params.foundOutsideFunctions = true; + params.useCount = 1; + params.polarity = freeTp->polarity; + LUAU_ASSERT(isKnown(params.polarity)); + generalizeTypePack(arena, builtinTypes, constraint->scope, tp, params); + } + } + } + } } else { @@ -985,8 +1037,8 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull(nextTy)) { - TypeId keyTy = freshType(arena, builtinTypes, constraint->scope); - TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); + TypeId keyTy = freshType(arena, builtinTypes, constraint->scope, Polarity::Mixed); + TypeId valueTy = freshType(arena, builtinTypes, constraint->scope, Polarity::Mixed); if (FFlag::LuauTrackInteriorFreeTypesOnScope) { trackInteriorFreeType(constraint->scope, keyTy); @@ -1445,7 +1497,8 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulladdTypePack(TypePack{std::move(argsHead), argsTail}); fn = follow(*callMm); - emplace(constraint, c.result, constraint->scope); + emplace(constraint, c.result, constraint->scope, Polarity::Positive); + trackInteriorFreeTypePack(constraint->scope, c.result); } else { @@ -1462,7 +1515,10 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(constraint, c.result, constraint->scope); + { + emplace(constraint, c.result, constraint->scope, Polarity::Positive); + trackInteriorFreeTypePack(constraint->scope, c.result); + } } fillInDiscriminantTypes(constraint, c.discriminantTypes); @@ -1488,6 +1544,14 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope, freeTy); + for (TypePackId freeTp : u2.newFreshTypePacks) + trackInteriorFreeTypePack(constraint->scope, freeTp); + } + if (!u2.genericSubstitutions.empty() || !u2.genericPackSubstitutions.empty()) { std::optional subst = instantiate2(arena, std::move(u2.genericSubstitutions), std::move(u2.genericPackSubstitutions), result); @@ -1855,7 +1919,7 @@ bool ConstraintSolver::tryDispatchHasIndexer( else if (auto mt = get(follow(ft->upperBound))) return tryDispatchHasIndexer(recursionDepth, constraint, mt->table, indexType, resultType, seen); - FreeType freeResult{ft->scope, builtinTypes->neverType, builtinTypes->unknownType}; + FreeType freeResult{ft->scope, builtinTypes->neverType, builtinTypes->unknownType, Polarity::Mixed}; emplace(constraint, resultType, freeResult); TypeId upperBound = @@ -1878,8 +1942,10 @@ bool ConstraintSolver::tryDispatchHasIndexer( { // FIXME this is greedy. - FreeType freeResult{tt->scope, builtinTypes->neverType, builtinTypes->unknownType}; + FreeType freeResult{tt->scope, builtinTypes->neverType, builtinTypes->unknownType, Polarity::Mixed}; emplace(constraint, resultType, freeResult); + if (FFlag::LuauNonReentrantGeneralization) + trackInteriorFreeType(constraint->scope, resultType); tt->indexer = TableIndexer{indexType, resultType}; return true; @@ -2101,11 +2167,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNulladdType(UnionType{{propTy, builtinTypes->nilType}}) : propTy - ); + bind(constraint, c.propType, isIndex ? arena->addType(UnionType{{propTy, builtinTypes->nilType}}) : propTy); unify(constraint, rhsType, propTy); return true; } @@ -2199,11 +2261,7 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNullindexer->indexType); unify(constraint, rhsType, lhsTable->indexer->indexResultType); - bind( - constraint, - c.propType, - arena->addType(UnionType{{lhsTable->indexer->indexResultType, builtinTypes->nilType}}) - ); + bind(constraint, c.propType, arena->addType(UnionType{{lhsTable->indexer->indexResultType, builtinTypes->nilType}})); return true; } @@ -2252,11 +2310,7 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNullindexer->indexType); unify(constraint, rhsType, lhsClass->indexer->indexResultType); - bind( - constraint, - c.propType, - arena->addType(UnionType{{lhsClass->indexer->indexResultType, builtinTypes->nilType}}) - ); + bind(constraint, c.propType, arena->addType(UnionType{{lhsClass->indexer->indexResultType, builtinTypes->nilType}})); return true; } @@ -2350,7 +2404,7 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNullscope); + TypeId f = freshType(arena, builtinTypes, constraint->scope, Polarity::Positive); // FIXME? Is this the right polarity? if (FFlag::LuauTrackInteriorFreeTypesOnScope) trackInteriorFreeType(constraint->scope, f); shiftReferences(resultTy, f); @@ -2493,8 +2547,8 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl if (get(iteratorTy)) { - TypeId keyTy = freshType(arena, builtinTypes, constraint->scope); - TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); + TypeId keyTy = freshType(arena, builtinTypes, constraint->scope, Polarity::Mixed); + TypeId valueTy = freshType(arena, builtinTypes, constraint->scope, Polarity::Mixed); if (FFlag::LuauTrackInteriorFreeTypesOnScope) { trackInteriorFreeType(constraint->scope, keyTy); @@ -2755,7 +2809,7 @@ TablePropLookupResult ConstraintSolver::lookupTableProp( if (ttv->state == TableState::Free) { - TypeId result = freshType(arena, builtinTypes, ttv->scope); + TypeId result = freshType(arena, builtinTypes, ttv->scope, Polarity::Mixed); if (FFlag::LuauTrackInteriorFreeTypesOnScope) trackInteriorFreeType(ttv->scope, result); switch (context) @@ -2869,7 +2923,7 @@ TablePropLookupResult ConstraintSolver::lookupTableProp( TableType* tt = getMutable(newUpperBound); LUAU_ASSERT(tt); - TypeId propType = freshType(arena, builtinTypes, scope); + TypeId propType = freshType(arena, builtinTypes, scope, Polarity::Mixed); if (FFlag::LuauTrackInteriorFreeTypesOnScope) trackInteriorFreeType(scope, propType); diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp index f732c874..0d75dcc2 100644 --- a/Analysis/src/DataFlowGraph.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -15,7 +15,7 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAGVARIABLE(LuauPreprocessTypestatedArgument) LUAU_FASTFLAGVARIABLE(LuauDfgScopeStackTrueReset) - +LUAU_FASTFLAGVARIABLE(LuauDfgScopeStackNotNull) namespace Luau { @@ -196,7 +196,15 @@ DataFlowGraph DataFlowGraphBuilder::build( DataFlowGraphBuilder builder(defArena, keyArena); builder.handle = handle; - DfgScope* moduleScope = builder.makeChildScope(); + + DfgScope* moduleScope; + // We're not explicitly calling makeChildScope here because that function relies on currentScope + // which guarantees that the scope being returned is NotNull + // This means that while the scope stack is empty, we'll have to manually initialize the global scope + if (FFlag::LuauDfgScopeStackNotNull) + moduleScope = builder.scopes.emplace_back(new DfgScope{nullptr, DfgScope::ScopeType::Linear}).get(); + else + moduleScope = builder.makeChildScope(); PushScope ps{builder.scopeStack, moduleScope}; builder.visitBlockWithoutChildScope(block); builder.resolveCaptures(); @@ -228,7 +236,13 @@ void DataFlowGraphBuilder::resolveCaptures() } } -DfgScope* DataFlowGraphBuilder::currentScope() +NotNull DataFlowGraphBuilder::currentScope() +{ + LUAU_ASSERT(!scopeStack.empty()); + return NotNull{scopeStack.back()}; +} + +DfgScope* DataFlowGraphBuilder::currentScope_DEPRECATED() { if (scopeStack.empty()) return nullptr; // nullptr is the root DFG scope. @@ -237,7 +251,10 @@ DfgScope* DataFlowGraphBuilder::currentScope() DfgScope* DataFlowGraphBuilder::makeChildScope(DfgScope::ScopeType scopeType) { - return scopes.emplace_back(new DfgScope{currentScope(), scopeType}).get(); + if (FFlag::LuauDfgScopeStackNotNull) + return scopes.emplace_back(new DfgScope{currentScope(), scopeType}).get(); + else + return scopes.emplace_back(new DfgScope{currentScope_DEPRECATED(), scopeType}).get(); } void DataFlowGraphBuilder::join(DfgScope* p, DfgScope* a, DfgScope* b) @@ -312,9 +329,9 @@ void DataFlowGraphBuilder::joinProps(DfgScope* result, const DfgScope& a, const } } -DefId DataFlowGraphBuilder::lookup(Symbol symbol) +DefId DataFlowGraphBuilder::lookup(Symbol symbol, Location location) { - DfgScope* scope = currentScope(); + DfgScope* scope = FFlag::LuauDfgScopeStackNotNull ? currentScope() : currentScope_DEPRECATED(); // true if any of the considered scopes are a loop. bool outsideLoopScope = false; @@ -339,15 +356,15 @@ DefId DataFlowGraphBuilder::lookup(Symbol symbol) } } - DefId result = defArena->freshCell(); + DefId result = defArena->freshCell(symbol, location); scope->bindings[symbol] = result; captures[symbol].allVersions.push_back(result); return result; } -DefId DataFlowGraphBuilder::lookup(DefId def, const std::string& key) +DefId DataFlowGraphBuilder::lookup(DefId def, const std::string& key, Location location) { - DfgScope* scope = currentScope(); + DfgScope* scope = FFlag::LuauDfgScopeStackNotNull ? currentScope() : currentScope_DEPRECATED(); for (DfgScope* current = scope; current; current = current->parent) { if (auto props = current->props.find(def)) @@ -357,7 +374,7 @@ DefId DataFlowGraphBuilder::lookup(DefId def, const std::string& key) } else if (auto phi = get(def); phi && phi->operands.empty()) // Unresolved phi nodes { - DefId result = defArena->freshCell(); + DefId result = defArena->freshCell(def->name, location); scope->props[def][key] = result; return result; } @@ -367,7 +384,7 @@ DefId DataFlowGraphBuilder::lookup(DefId def, const std::string& key) { std::vector defs; for (DefId operand : phi->operands) - defs.push_back(lookup(operand, key)); + defs.push_back(lookup(operand, key, location)); DefId result = defArena->phi(defs); scope->props[def][key] = result; @@ -375,7 +392,7 @@ DefId DataFlowGraphBuilder::lookup(DefId def, const std::string& key) } else if (get(def)) { - DefId result = defArena->freshCell(); + DefId result = defArena->freshCell(def->name, location); scope->props[def][key] = result; return result; } @@ -393,7 +410,10 @@ ControlFlow DataFlowGraphBuilder::visit(AstStatBlock* b) cf = visitBlockWithoutChildScope(b); } - currentScope()->inherit(child); + if (FFlag::LuauDfgScopeStackNotNull) + currentScope()->inherit(child); + else + currentScope_DEPRECATED()->inherit(child); return cf; } @@ -478,7 +498,7 @@ ControlFlow DataFlowGraphBuilder::visit(AstStatIf* i) elsecf = visit(i->elsebody); } - DfgScope* scope = currentScope(); + DfgScope* scope = FFlag::LuauDfgScopeStackNotNull ? currentScope() : currentScope_DEPRECATED(); if (thencf != ControlFlow::None && elsecf == ControlFlow::None) join(scope, scope, elseScope); else if (thencf == ControlFlow::None && elsecf != ControlFlow::None) @@ -505,7 +525,10 @@ ControlFlow DataFlowGraphBuilder::visit(AstStatWhile* w) visit(w->body); } - currentScope()->inherit(whileScope); + if (FFlag::LuauDfgScopeStackNotNull) + currentScope()->inherit(whileScope); + else + currentScope_DEPRECATED()->inherit(whileScope); return ControlFlow::None; } @@ -521,7 +544,10 @@ ControlFlow DataFlowGraphBuilder::visit(AstStatRepeat* r) visitExpr(r->condition); } - currentScope()->inherit(repeatScope); + if (FFlag::LuauDfgScopeStackNotNull) + currentScope()->inherit(repeatScope); + else + currentScope_DEPRECATED()->inherit(repeatScope); return ControlFlow::None; } @@ -570,7 +596,7 @@ ControlFlow DataFlowGraphBuilder::visit(AstStatLocal* l) // We need to create a new def to intentionally avoid alias tracking, but we'd like to // make sure that the non-aliased defs are also marked as a subscript for refinements. bool subscripted = i < defs.size() && containsSubscriptedDefinition(defs[i]); - DefId def = defArena->freshCell(subscripted); + DefId def = defArena->freshCell(local, local->location, subscripted); if (i < l->values.size) { AstExpr* e = l->values.data[i]; @@ -580,7 +606,10 @@ ControlFlow DataFlowGraphBuilder::visit(AstStatLocal* l) } } graph.localDefs[local] = def; - currentScope()->bindings[local] = def; + if (FFlag::LuauDfgScopeStackNotNull) + currentScope()->bindings[local] = def; + else + currentScope_DEPRECATED()->bindings[local] = def; captures[local].allVersions.push_back(def); } @@ -602,16 +631,22 @@ ControlFlow DataFlowGraphBuilder::visit(AstStatFor* f) if (f->var->annotation) visitType(f->var->annotation); - DefId def = defArena->freshCell(); + DefId def = defArena->freshCell(f->var, f->var->location); graph.localDefs[f->var] = def; - currentScope()->bindings[f->var] = def; + if (FFlag::LuauDfgScopeStackNotNull) + currentScope()->bindings[f->var] = def; + else + currentScope_DEPRECATED()->bindings[f->var] = def; captures[f->var].allVersions.push_back(def); // TODO(controlflow): entry point has a back edge from exit point visit(f->body); } - currentScope()->inherit(forScope); + if (FFlag::LuauDfgScopeStackNotNull) + currentScope()->inherit(forScope); + else + currentScope_DEPRECATED()->inherit(forScope); return ControlFlow::None; } @@ -628,9 +663,12 @@ ControlFlow DataFlowGraphBuilder::visit(AstStatForIn* f) if (local->annotation) visitType(local->annotation); - DefId def = defArena->freshCell(); + DefId def = defArena->freshCell(local, local->location); graph.localDefs[local] = def; - currentScope()->bindings[local] = def; + if (FFlag::LuauDfgScopeStackNotNull) + currentScope()->bindings[local] = def; + else + currentScope_DEPRECATED()->bindings[local] = def; captures[local].allVersions.push_back(def); } @@ -641,8 +679,10 @@ ControlFlow DataFlowGraphBuilder::visit(AstStatForIn* f) visit(f->body); } - - currentScope()->inherit(forScope); + if (FFlag::LuauDfgScopeStackNotNull) + currentScope()->inherit(forScope); + else + currentScope_DEPRECATED()->inherit(forScope); return ControlFlow::None; } @@ -657,7 +697,7 @@ ControlFlow DataFlowGraphBuilder::visit(AstStatAssign* a) for (size_t i = 0; i < a->vars.size; ++i) { AstExpr* v = a->vars.data[i]; - visitLValue(v, i < defs.size() ? defs[i] : defArena->freshCell()); + visitLValue(v, i < defs.size() ? defs[i] : defArena->freshCell(Symbol{}, v->location)); } return ControlFlow::None; @@ -683,7 +723,7 @@ ControlFlow DataFlowGraphBuilder::visit(AstStatFunction* f) // // which is evidence that references to variables must be a phi node of all possible definitions, // but for bug compatibility, we'll assume the same thing here. - visitLValue(f->name, defArena->freshCell()); + visitLValue(f->name, defArena->freshCell(Symbol{}, f->name->location)); visitExpr(f->func); if (auto local = f->name->as()) @@ -703,9 +743,12 @@ ControlFlow DataFlowGraphBuilder::visit(AstStatFunction* f) ControlFlow DataFlowGraphBuilder::visit(AstStatLocalFunction* l) { - DefId def = defArena->freshCell(); + DefId def = defArena->freshCell(l->name, l->location); graph.localDefs[l->name] = def; - currentScope()->bindings[l->name] = def; + if (FFlag::LuauDfgScopeStackNotNull) + currentScope()->bindings[l->name] = def; + else + currentScope_DEPRECATED()->bindings[l->name] = def; captures[l->name].allVersions.push_back(def); visitExpr(l->func); @@ -736,9 +779,12 @@ ControlFlow DataFlowGraphBuilder::visit(AstStatTypeFunction* f) ControlFlow DataFlowGraphBuilder::visit(AstStatDeclareGlobal* d) { - DefId def = defArena->freshCell(); + DefId def = defArena->freshCell(d->name, d->nameLocation); graph.declaredDefs[d] = def; - currentScope()->bindings[d->name] = def; + if (FFlag::LuauDfgScopeStackNotNull) + currentScope()->bindings[d->name] = def; + else + currentScope_DEPRECATED()->bindings[d->name] = def; captures[d->name].allVersions.push_back(def); visitType(d->type); @@ -748,9 +794,12 @@ ControlFlow DataFlowGraphBuilder::visit(AstStatDeclareGlobal* d) ControlFlow DataFlowGraphBuilder::visit(AstStatDeclareFunction* d) { - DefId def = defArena->freshCell(); + DefId def = defArena->freshCell(d->name, d->nameLocation); graph.declaredDefs[d] = def; - currentScope()->bindings[d->name] = def; + if (FFlag::LuauDfgScopeStackNotNull) + currentScope()->bindings[d->name] = def; + else + currentScope_DEPRECATED()->bindings[d->name] = def; captures[d->name].allVersions.push_back(def); DfgScope* unreachable = makeChildScope(); @@ -805,19 +854,19 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExpr* e) if (auto g = e->as()) return visitExpr(g); else if (auto c = e->as()) - return {defArena->freshCell(), nullptr}; // ok + return {defArena->freshCell(Symbol{}, c->location), nullptr}; // ok else if (auto c = e->as()) - return {defArena->freshCell(), nullptr}; // ok + return {defArena->freshCell(Symbol{}, c->location), nullptr}; // ok else if (auto c = e->as()) - return {defArena->freshCell(), nullptr}; // ok + return {defArena->freshCell(Symbol{}, c->location), nullptr}; // ok else if (auto c = e->as()) - return {defArena->freshCell(), nullptr}; // ok + return {defArena->freshCell(Symbol{}, c->location), nullptr}; // ok else if (auto l = e->as()) return visitExpr(l); else if (auto g = e->as()) return visitExpr(g); else if (auto v = e->as()) - return {defArena->freshCell(), nullptr}; // ok + return {defArena->freshCell(Symbol{}, v->location), nullptr}; // ok else if (auto c = e->as()) return visitExpr(c); else if (auto i = e->as()) @@ -858,14 +907,14 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprGroup* group) DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprLocal* l) { - DefId def = lookup(l->local); + DefId def = lookup(l->local, l->local->location); const RefinementKey* key = keyArena->leaf(def); return {def, key}; } DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprGlobal* g) { - DefId def = lookup(g->name); + DefId def = lookup(g->name, g->location); return {def, keyArena->leaf(def)}; } @@ -925,7 +974,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprCall* c) // local v = foo({}) // // We want to consider `v` to be subscripted here. - return {defArena->freshCell(/*subscripted=*/true)}; + return {defArena->freshCell(Symbol{}, c->location, /*subscripted=*/true)}; } DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprIndexName* i) @@ -933,7 +982,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprIndexName* i) auto [parentDef, parentKey] = visitExpr(i->expr); std::string index = i->index.value; - DefId def = lookup(parentDef, index); + DefId def = lookup(parentDef, index, i->location); return {def, keyArena->node(parentKey, def, index)}; } @@ -946,11 +995,11 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprIndexExpr* i) { std::string index{string->value.data, string->value.size}; - DefId def = lookup(parentDef, index); + DefId def = lookup(parentDef, index, i->location); return {def, keyArena->node(parentKey, def, index)}; } - return {defArena->freshCell(/* subscripted= */ true), nullptr}; + return {defArena->freshCell(Symbol{}, i->location, /* subscripted= */ true), nullptr}; } DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprFunction* f) @@ -963,7 +1012,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprFunction* f) // There's no syntax for `self` to have an annotation if using `function t:m()` LUAU_ASSERT(!self->annotation); - DefId def = defArena->freshCell(); + DefId def = defArena->freshCell(f->debugname, f->location); graph.localDefs[self] = def; signatureScope->bindings[self] = def; captures[self].allVersions.push_back(def); @@ -974,7 +1023,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprFunction* f) if (param->annotation) visitType(param->annotation); - DefId def = defArena->freshCell(); + DefId def = defArena->freshCell(param, param->location); graph.localDefs[param] = def; signatureScope->bindings[param] = def; captures[param].allVersions.push_back(def); @@ -996,13 +1045,16 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprFunction* f) // g() --> 5 visit(f->body); - return {defArena->freshCell(), nullptr}; + return {defArena->freshCell(f->debugname, f->location), nullptr}; } DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprTable* t) { - DefId tableCell = defArena->freshCell(); - currentScope()->props[tableCell] = {}; + DefId tableCell = defArena->freshCell(Symbol{}, t->location); + if (FFlag::LuauDfgScopeStackNotNull) + currentScope()->props[tableCell] = {}; + else + currentScope_DEPRECATED()->props[tableCell] = {}; for (AstExprTable::Item item : t->items) { DataFlowResult result = visitExpr(item.value); @@ -1010,7 +1062,12 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprTable* t) { visitExpr(item.key); if (auto string = item.key->as()) - currentScope()->props[tableCell][string->value.data] = result.def; + { + if (FFlag::LuauDfgScopeStackNotNull) + currentScope()->props[tableCell][string->value.data] = result.def; + else + currentScope_DEPRECATED()->props[tableCell][string->value.data] = result.def; + } } } @@ -1021,7 +1078,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprUnary* u) { visitExpr(u->expr); - return {defArena->freshCell(), nullptr}; + return {defArena->freshCell(Symbol{}, u->location), nullptr}; } DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprBinary* b) @@ -1029,7 +1086,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprBinary* b) visitExpr(b->left); visitExpr(b->right); - return {defArena->freshCell(), nullptr}; + return {defArena->freshCell(Symbol{}, b->location), nullptr}; } DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprTypeAssertion* t) @@ -1046,7 +1103,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprIfElse* i) visitExpr(i->trueExpr); visitExpr(i->falseExpr); - return {defArena->freshCell(), nullptr}; + return {defArena->freshCell(Symbol{}, i->location), nullptr}; } DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprInterpString* i) @@ -1054,7 +1111,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprInterpString* i) for (AstExpr* e : i->expressions) visitExpr(e); - return {defArena->freshCell(), nullptr}; + return {defArena->freshCell(Symbol{}, i->location), nullptr}; } DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprError* error) @@ -1065,7 +1122,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprError* error) for (AstExpr* e : error->expressions) visitExpr(e); - return {defArena->freshCell(), nullptr}; + return {defArena->freshCell(Symbol{}, error->location), nullptr}; } void DataFlowGraphBuilder::visitLValue(AstExpr* e, DefId incomingDef) @@ -1091,12 +1148,12 @@ void DataFlowGraphBuilder::visitLValue(AstExpr* e, DefId incomingDef) DefId DataFlowGraphBuilder::visitLValue(AstExprLocal* l, DefId incomingDef) { - DfgScope* scope = currentScope(); + DfgScope* scope = FFlag::LuauDfgScopeStackNotNull ? currentScope() : currentScope_DEPRECATED(); // In order to avoid alias tracking, we need to clip the reference to the parent def. if (scope->canUpdateDefinition(l->local)) { - DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef)); + DefId updated = defArena->freshCell(l->local, l->location, containsSubscriptedDefinition(incomingDef)); scope->bindings[l->local] = updated; captures[l->local].allVersions.push_back(updated); return updated; @@ -1107,12 +1164,12 @@ DefId DataFlowGraphBuilder::visitLValue(AstExprLocal* l, DefId incomingDef) DefId DataFlowGraphBuilder::visitLValue(AstExprGlobal* g, DefId incomingDef) { - DfgScope* scope = currentScope(); + DfgScope* scope = FFlag::LuauDfgScopeStackNotNull ? currentScope() : currentScope_DEPRECATED(); // In order to avoid alias tracking, we need to clip the reference to the parent def. if (scope->canUpdateDefinition(g->name)) { - DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef)); + DefId updated = defArena->freshCell(g->name, g->location, containsSubscriptedDefinition(incomingDef)); scope->bindings[g->name] = updated; captures[g->name].allVersions.push_back(updated); return updated; @@ -1125,10 +1182,10 @@ DefId DataFlowGraphBuilder::visitLValue(AstExprIndexName* i, DefId incomingDef) { DefId parentDef = visitExpr(i->expr).def; - DfgScope* scope = currentScope(); + DfgScope* scope = FFlag::LuauDfgScopeStackNotNull ? currentScope() : currentScope_DEPRECATED(); if (scope->canUpdateDefinition(parentDef, i->index.value)) { - DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef)); + DefId updated = defArena->freshCell(i->index, i->location, containsSubscriptedDefinition(incomingDef)); scope->props[parentDef][i->index.value] = updated; return updated; } @@ -1141,12 +1198,12 @@ DefId DataFlowGraphBuilder::visitLValue(AstExprIndexExpr* i, DefId incomingDef) DefId parentDef = visitExpr(i->expr).def; visitExpr(i->index); - DfgScope* scope = currentScope(); + DfgScope* scope = FFlag::LuauDfgScopeStackNotNull ? currentScope() : currentScope_DEPRECATED(); if (auto string = i->index->as()) { if (scope->canUpdateDefinition(parentDef, string->value.data)) { - DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef)); + DefId updated = defArena->freshCell(Symbol{}, i->location, containsSubscriptedDefinition(incomingDef)); scope->props[parentDef][string->value.data] = updated; return updated; } @@ -1154,7 +1211,7 @@ DefId DataFlowGraphBuilder::visitLValue(AstExprIndexExpr* i, DefId incomingDef) return visitExpr(static_cast(i)).def; } else - return defArena->freshCell(/*subscripted=*/true); + return defArena->freshCell(Symbol{}, i->location, /*subscripted=*/true); } DefId DataFlowGraphBuilder::visitLValue(AstExprError* error, DefId incomingDef) diff --git a/Analysis/src/Def.cpp b/Analysis/src/Def.cpp index 6d58b28f..a3b62af5 100644 --- a/Analysis/src/Def.cpp +++ b/Analysis/src/Def.cpp @@ -36,9 +36,9 @@ void collectOperands(DefId def, std::vector* operands) } } -DefId DefArena::freshCell(bool subscripted) +DefId DefArena::freshCell(Symbol sym, Location location, bool subscripted) { - return NotNull{allocator.allocate(Def{Cell{subscripted}})}; + return NotNull{allocator.allocate(Def{Cell{subscripted}, sym, location})}; } DefId DefArena::phi(DefId a, DefId b) diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 1cd25c94..4b302eff 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -606,7 +606,7 @@ struct ErrorConverter auto tfit = get(e.ty); LUAU_ASSERT(tfit); // Luau analysis has actually done something wrong if this type is not a type function. if (!tfit) - return "Unexpected type " + Luau::toString(e.ty) + " flagged as an uninhabited type function."; + return "Internal error: Unexpected type " + Luau::toString(e.ty) + " flagged as an uninhabited type function."; // unary operators if (auto unaryString = kUnaryOps.find(tfit->function->name); unaryString != kUnaryOps.end()) diff --git a/Analysis/src/FragmentAutocomplete.cpp b/Analysis/src/FragmentAutocomplete.cpp index 3c1395dc..722db46a 100644 --- a/Analysis/src/FragmentAutocomplete.cpp +++ b/Analysis/src/FragmentAutocomplete.cpp @@ -34,7 +34,6 @@ LUAU_FASTFLAGVARIABLE(LuauCloneIncrementalModule) LUAU_FASTFLAGVARIABLE(DebugLogFragmentsFromAutocomplete) LUAU_FASTFLAGVARIABLE(LuauBetterCursorInCommentDetection) LUAU_FASTFLAGVARIABLE(LuauAllFreeTypesHaveScopes) -LUAU_FASTFLAGVARIABLE(LuauFragmentAcSupportsReporter) LUAU_FASTFLAGVARIABLE(LuauPersistConstraintGenerationScopes) LUAU_FASTFLAGVARIABLE(LuauCloneTypeAliasBindings) LUAU_FASTFLAGVARIABLE(LuauCloneReturnTypePack) @@ -43,6 +42,7 @@ LUAU_FASTFLAG(LuauUserTypeFunTypecheck) LUAU_FASTFLAGVARIABLE(LuauFragmentNoTypeFunEval) LUAU_FASTFLAGVARIABLE(LuauBetterScopeSelection) LUAU_FASTFLAGVARIABLE(LuauBlockDiffFragmentSelection) +LUAU_FASTFLAGVARIABLE(LuauFragmentAcMemoryLeak) namespace { @@ -1213,7 +1213,7 @@ void mixedModeCompatibility( static void reportWaypoint(IFragmentAutocompleteReporter* reporter, FragmentAutocompleteWaypoint type) { - if (!FFlag::LuauFragmentAcSupportsReporter || !reporter) + if (!reporter) return; reporter->reportWaypoint(type); @@ -1221,7 +1221,7 @@ static void reportWaypoint(IFragmentAutocompleteReporter* reporter, FragmentAuto static void reportFragmentString(IFragmentAutocompleteReporter* reporter, std::string_view fragment) { - if (!FFlag::LuauFragmentAcSupportsReporter || !reporter) + if (!reporter) return; reporter->reportFragmentString(fragment); @@ -1351,6 +1351,8 @@ FragmentTypeCheckResult typecheckFragmentHelper_DEPRECATED( { if (!sc->interiorFreeTypes.has_value()) sc->interiorFreeTypes.emplace(); + if (!sc->interiorFreeTypePacks.has_value()) + sc->interiorFreeTypePacks.emplace(); } } @@ -1472,6 +1474,7 @@ FragmentTypeCheckResult typecheckFragment_( std::shared_ptr freshChildOfNearestScope = std::make_shared(nullptr); incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope); freshChildOfNearestScope->interiorFreeTypes.emplace(); + freshChildOfNearestScope->interiorFreeTypePacks.emplace(); cg.rootScope = freshChildOfNearestScope.get(); if (FFlag::LuauUserTypeFunTypecheck) @@ -1622,7 +1625,7 @@ FragmentAutocompleteStatusResult tryFragmentAutocomplete( std::move(stringCompletionCB), context.DEPRECATED_fragmentEndPosition, context.freshParse.root, - FFlag::LuauFragmentAcSupportsReporter ? context.reporter : nullptr + context.reporter ); return {FragmentAutocompleteStatus::Success, std::move(fragmentAutocomplete)}; } @@ -1658,11 +1661,13 @@ FragmentAutocompleteResult fragmentAutocomplete( auto globalScope = (opts && opts->forAutocomplete) ? frontend.globalsForAutocomplete.globalScope.get() : frontend.globals.globalScope.get(); if (FFlag::DebugLogFragmentsFromAutocomplete) logLuau("Fragment Autocomplete Source Script", src); - TypeArena arenaForFragmentAutocomplete; + TypeArena arenaForAutocomplete_DEPRECATED; + if (FFlag::LuauFragmentAcMemoryLeak) + unfreeze(tcResult.incrementalModule->internalTypes); auto result = Luau::autocomplete_( tcResult.incrementalModule, frontend.builtinTypes, - &arenaForFragmentAutocomplete, + FFlag::LuauFragmentAcMemoryLeak ? &tcResult.incrementalModule->internalTypes : &arenaForAutocomplete_DEPRECATED, tcResult.ancestry, globalScope, tcResult.freshScope, @@ -1670,9 +1675,10 @@ FragmentAutocompleteResult fragmentAutocomplete( frontend.fileResolver, callback ); - + if (FFlag::LuauFragmentAcMemoryLeak) + freeze(tcResult.incrementalModule->internalTypes); reportWaypoint(reporter, FragmentAutocompleteWaypoint::AutocompleteEnd); - return {std::move(tcResult.incrementalModule), tcResult.freshScope.get(), std::move(arenaForFragmentAutocomplete), std::move(result)}; + return {std::move(tcResult.incrementalModule), tcResult.freshScope.get(), std::move(arenaForAutocomplete_DEPRECATED), std::move(result)}; } } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 6030339f..81df13d8 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -45,7 +45,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJsonFile) LUAU_FASTFLAGVARIABLE(DebugLuauForbidInternalTypes) LUAU_FASTFLAGVARIABLE(DebugLuauForceStrictMode) LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode) -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false) LUAU_FASTFLAGVARIABLE(LuauSelectivelyRetainDFGArena) LUAU_FASTFLAG(LuauTypeFunResultInAutocomplete) @@ -952,7 +951,7 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) item.stats.timeCheck += duration; item.stats.filesStrict += 1; - if (DFFlag::LuauRunCustomModuleChecks && item.options.customModuleCheck) + if (item.options.customModuleCheck) item.options.customModuleCheck(sourceModule, *moduleForAutocomplete); item.module = moduleForAutocomplete; @@ -972,7 +971,7 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) item.stats.filesStrict += mode == Mode::Strict; item.stats.filesNonstrict += mode == Mode::Nonstrict; - if (DFFlag::LuauRunCustomModuleChecks && item.options.customModuleCheck) + if (item.options.customModuleCheck) item.options.customModuleCheck(sourceModule, *module); if (FFlag::LuauSolverV2 && mode == Mode::NoCheck) diff --git a/Analysis/src/Generalization.cpp b/Analysis/src/Generalization.cpp index 5138ad2f..71f82ba2 100644 --- a/Analysis/src/Generalization.cpp +++ b/Analysis/src/Generalization.cpp @@ -4,18 +4,84 @@ #include "Luau/Common.h" #include "Luau/DenseHash.h" +#include "Luau/InsertionOrderedMap.h" +#include "Luau/Polarity.h" #include "Luau/Scope.h" -#include "Luau/Type.h" #include "Luau/ToString.h" +#include "Luau/Type.h" #include "Luau/TypeArena.h" #include "Luau/TypePack.h" #include "Luau/VisitType.h" LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete) +LUAU_FASTFLAGVARIABLE(LuauNonReentrantGeneralization) + namespace Luau { +namespace +{ + +template +struct OrderedSet +{ + using iterator = typename std::vector::iterator; + using const_iterator = typename std::vector::const_iterator; + + bool empty() const + { + return elements.empty(); + } + + size_t size() const + { + return elements.size(); + } + + void insert(T t) + { + if (!elementSet.contains(t)) + { + elementSet.insert(t); + elements.push_back(t); + } + } + + iterator begin() + { + return elements.begin(); + } + + const_iterator begin() const + { + return elements.begin(); + } + + iterator end() + { + return elements.end(); + } + + const_iterator end() const + { + return elements.end(); + } + + /// Move the underlying vector out of the OrderedSet. + std::vector takeVector() + { + elementSet.clear(); + return std::move(elements); + } + +private: + std::vector elements; + DenseHashSet elementSet{nullptr}; +}; + +} // namespace + struct MutatingGeneralizer : TypeOnceVisitor { NotNull arena; @@ -270,6 +336,15 @@ struct MutatingGeneralizer : TypeOnceVisitor return 0; } + template + static size_t getCount(const DenseHashMap& map, TID ty) + { + if (const size_t* count = map.find(ty)) + return *count; + else + return 0; + } + bool visit(TypeId ty, const TableType&) override { if (cachedTypes->contains(ty)) @@ -327,21 +402,12 @@ struct FreeTypeSearcher : TypeVisitor { } + bool isWithinFunction = false; Polarity polarity = Polarity::Positive; void flip() { - switch (polarity) - { - case Polarity::Positive: - polarity = Polarity::Negative; - break; - case Polarity::Negative: - polarity = Polarity::Positive; - break; - default: - break; - } + polarity = invert(polarity); } DenseHashSet seenPositive{nullptr}; @@ -383,12 +449,14 @@ struct FreeTypeSearcher : TypeVisitor return false; } - // The keys in these maps are either TypeIds or TypePackIds. It's safe to - // mix them because we only use these pointers as unique keys. We never - // indirect them. DenseHashMap negativeTypes{0}; DenseHashMap positiveTypes{0}; + InsertionOrderedMap> types; + InsertionOrderedMap> typePacks; + + OrderedSet unsealedTables; + bool visit(TypeId ty) override { if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty)) @@ -400,38 +468,30 @@ struct FreeTypeSearcher : TypeVisitor bool visit(TypeId ty, const FreeType& ft) override { - if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty)) - return false; - - if (!subsumes(scope, ft.scope)) - return true; - - switch (polarity) + if (FFlag::LuauNonReentrantGeneralization) { - case Polarity::Positive: - positiveTypes[ty]++; - break; - case Polarity::Negative: - negativeTypes[ty]++; - break; - case Polarity::Mixed: - positiveTypes[ty]++; - negativeTypes[ty]++; - break; - default: - LUAU_ASSERT(!"Unreachable"); + if (!subsumes(scope, ft.scope)) + return true; + + GeneralizationParams& params = types[ty]; + ++params.useCount; + + if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty)) + return false; + + if (!isWithinFunction) + params.foundOutsideFunctions = true; + + params.polarity |= polarity; } - - return true; - } - - bool visit(TypeId ty, const TableType& tt) override - { - if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty)) - return false; - - if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope)) + else { + if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty)) + return false; + + if (!subsumes(scope, ft.scope)) + return true; + switch (polarity) { case Polarity::Positive: @@ -449,6 +509,38 @@ struct FreeTypeSearcher : TypeVisitor } } + return true; + } + + bool visit(TypeId ty, const TableType& tt) override + { + if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty)) + return false; + + if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope)) + { + if (FFlag::LuauNonReentrantGeneralization) + unsealedTables.insert(ty); + else + { + switch (polarity) + { + case Polarity::Positive: + positiveTypes[ty]++; + break; + case Polarity::Negative: + negativeTypes[ty]++; + break; + case Polarity::Mixed: + positiveTypes[ty]++; + negativeTypes[ty]++; + break; + default: + LUAU_ASSERT(!"Unreachable"); + } + } + } + for (const auto& [_name, prop] : tt.props) { if (prop.isReadOnly()) @@ -466,8 +558,27 @@ struct FreeTypeSearcher : TypeVisitor if (tt.indexer) { - traverse(tt.indexer->indexType); - traverse(tt.indexer->indexResultType); + if (FFlag::LuauNonReentrantGeneralization) + { + // {[K]: V} is equivalent to three functions: get, set, and iterate + // + // (K) -> V + // (K, V) -> () + // () -> {K} + // + // K and V therefore both have mixed polarity. + + const Polarity p = polarity; + polarity = Polarity::Mixed; + traverse(tt.indexer->indexType); + traverse(tt.indexer->indexResultType); + polarity = p; + } + else + { + traverse(tt.indexer->indexType); + traverse(tt.indexer->indexResultType); + } } return false; @@ -478,12 +589,17 @@ struct FreeTypeSearcher : TypeVisitor if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty)) return false; + const bool oldValue = isWithinFunction; + isWithinFunction = true; + flip(); traverse(ft.argTypes); flip(); traverse(ft.retTypes); + isWithinFunction = oldValue; + return false; } @@ -500,20 +616,33 @@ struct FreeTypeSearcher : TypeVisitor if (!subsumes(scope, ftp.scope)) return true; - switch (polarity) + if (FFlag::LuauNonReentrantGeneralization) { - case Polarity::Positive: - positiveTypes[tp]++; - break; - case Polarity::Negative: - negativeTypes[tp]++; - break; - case Polarity::Mixed: - positiveTypes[tp]++; - negativeTypes[tp]++; - break; - default: - LUAU_ASSERT(!"Unreachable"); + GeneralizationParams& params = typePacks[tp]; + ++params.useCount; + + if (!isWithinFunction) + params.foundOutsideFunctions = true; + + params.polarity |= polarity; + } + else + { + switch (polarity) + { + case Polarity::Positive: + positiveTypes[tp]++; + break; + case Polarity::Negative: + negativeTypes[tp]++; + break; + case Polarity::Mixed: + positiveTypes[tp]++; + negativeTypes[tp]++; + break; + default: + LUAU_ASSERT(!"Unreachable"); + } } return true; @@ -963,6 +1092,221 @@ struct TypeCacher : TypeOnceVisitor } }; +/** + * Remove occurrences of `needle` within `haystack`. This is used to cull cyclic bounds from free types. + * + * @param haystack Either the upper or lower bound of a free type. + * @param needle The type to be removed. + */ +[[nodiscard]] +static TypeId removeType(NotNull arena, NotNull builtinTypes, DenseHashSet& seen, TypeId haystack, TypeId needle) +{ + haystack = follow(haystack); + + if (seen.find(haystack)) + return haystack; + seen.insert(haystack); + + if (const UnionType* ut = get(haystack)) + { + OrderedSet newOptions; + + for (TypeId option : ut) + { + if (option == needle) + continue; + + if (get(option)) + continue; + + LUAU_ASSERT(!get(option)); + + if (get(option)) + newOptions.insert(removeType(arena, builtinTypes, seen, option, needle)); + else + newOptions.insert(option); + } + + if (newOptions.empty()) + return builtinTypes->neverType; + else if (newOptions.size() == 1) + { + TypeId onlyType = *newOptions.begin(); + LUAU_ASSERT(onlyType != haystack); + return onlyType; + } + else + return arena->addType(UnionType{newOptions.takeVector()}); + } + + if (const IntersectionType* it = get(haystack)) + { + OrderedSet newParts; + + for (TypeId part : it) + { + part = follow(part); + + if (part == needle) + continue; + + if (get(part)) + continue; + + LUAU_ASSERT(!get(follow(part))); + + if (get(part)) + newParts.insert(removeType(arena, builtinTypes, seen, part, needle)); + else + newParts.insert(part); + } + + if (newParts.empty()) + return builtinTypes->unknownType; + else if (newParts.size() == 1) + { + TypeId onlyType = *newParts.begin(); + LUAU_ASSERT(onlyType != needle); + return onlyType; + } + else + return arena->addType(IntersectionType{newParts.takeVector()}); + } + + return haystack; +} + +std::optional generalizeType( + NotNull arena, + NotNull builtinTypes, + NotNull scope, + TypeId freeTy, + const GeneralizationParams& params +) +{ + freeTy = follow(freeTy); + + FreeType* ft = getMutable(freeTy); + LUAU_ASSERT(ft); + + LUAU_ASSERT(isPositive(params.polarity) || isNegative(params.polarity)); + + const bool hasLowerBound = !get(follow(ft->lowerBound)); + const bool hasUpperBound = !get(follow(ft->upperBound)); + + const bool isWithinFunction = !params.foundOutsideFunctions; + + if (!hasLowerBound && !hasUpperBound) + { + if ((params.polarity != Polarity::Mixed && params.useCount == 1) || !isWithinFunction) + emplaceType(asMutable(freeTy), builtinTypes->unknownType); + else + { + emplaceType(asMutable(freeTy), scope, params.polarity); + return freeTy; + } + } + // It is possible that this free type has other free types in its upper + // or lower bounds. If this is the case, we must replace those + // references with never (for the lower bound) or unknown (for the upper + // bound). + // + // If we do not do this, we get tautological bounds like a <: a <: unknown. + else if (isPositive(params.polarity) && !hasUpperBound) + { + TypeId lb = follow(ft->lowerBound); + if (FreeType* lowerFree = getMutable(lb); lowerFree && lowerFree->upperBound == freeTy) + lowerFree->upperBound = builtinTypes->unknownType; + else + { + DenseHashSet replaceSeen{nullptr}; + lb = removeType(arena, builtinTypes, replaceSeen, lb, freeTy); + ft->lowerBound = lb; + } + + if (follow(lb) != freeTy) + emplaceType(asMutable(freeTy), lb); + else if (!isWithinFunction || params.useCount == 1) + emplaceType(asMutable(freeTy), builtinTypes->unknownType); + else + { + // if the lower bound is the type in question (eg 'a <: 'a), we don't actually have a lower bound. + emplaceType(asMutable(freeTy), scope, params.polarity); + return freeTy; + } + } + else + { + TypeId ub = follow(ft->upperBound); + if (FreeType* upperFree = getMutable(ub); upperFree && upperFree->lowerBound == freeTy) + upperFree->lowerBound = builtinTypes->neverType; + else + { + // If the free type appears within its own upper bound, cull that cycle. + DenseHashSet replaceSeen{nullptr}; + ub = removeType(arena, builtinTypes, replaceSeen, ub, freeTy); + ft->upperBound = ub; + } + + if (follow(ub) != freeTy) + emplaceType(asMutable(freeTy), ub); + else if (!isWithinFunction || params.useCount == 1) + emplaceType(asMutable(freeTy), builtinTypes->unknownType); + else + { + // if the upper bound is the type in question, we don't actually have an upper bound. + emplaceType(asMutable(freeTy), scope, params.polarity); + return freeTy; + } + } + + return std::nullopt; +} + +std::optional generalizeTypePack( + NotNull arena, + NotNull builtinTypes, + NotNull scope, + TypePackId tp, + const GeneralizationParams& params +) +{ + tp = follow(tp); + + if (tp->owningArena != arena) + return std::nullopt; + + const FreeTypePack* ftp = get(tp); + if (!ftp) + return std::nullopt; + + if (!subsumes(scope, ftp->scope)) + return std::nullopt; + + if (1 == params.useCount) + emplaceTypePack(asMutable(tp), builtinTypes->unknownTypePack); + else + { + emplaceTypePack(asMutable(tp), scope, params.polarity); + return tp; + } + + return std::nullopt; +} + +void sealTable(NotNull scope, TypeId ty) +{ + TableType* tableTy = getMutable(follow(ty)); + if (!tableTy) + return; + + if (!subsumes(scope, tableTy->scope)) + return; + + if (tableTy->state == TableState::Unsealed || tableTy->state == TableState::Free) + tableTy->state = TableState::Sealed; +} + std::optional generalize( NotNull arena, NotNull builtinTypes, @@ -979,35 +1323,74 @@ std::optional generalize( FreeTypeSearcher fts{scope, cachedTypes}; fts.traverse(ty); - MutatingGeneralizer gen{arena, builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes)}; - - gen.traverse(ty); - - /* MutatingGeneralizer mutates types in place, so it is possible that ty has - * been transmuted to a BoundType. We must follow it again and verify that - * we are allowed to mutate it before we attach generics to it. - */ - ty = follow(ty); - - if (ty->owningArena != arena || ty->persistent) - return ty; - - TypeCacher cacher{cachedTypes}; - cacher.traverse(ty); - - FunctionType* ftv = getMutable(ty); - if (ftv) + if (FFlag::LuauNonReentrantGeneralization) { - // If we're generalizing a function type, add any of the newly inferred - // generics to the list of existing generic types. - for (const auto g : std::move(gen.generics)) + FunctionType* functionTy = getMutable(ty); + auto pushGeneric = [&](TypeId t) { - ftv->generics.push_back(g); + if (functionTy) + functionTy->generics.push_back(t); + }; + + auto pushGenericPack = [&](TypePackId tp) + { + if (functionTy) + functionTy->genericPacks.push_back(tp); + }; + + for (const auto& [freeTy, params] : fts.types) + { + if (std::optional genericTy = generalizeType(arena, builtinTypes, scope, freeTy, params)) + pushGeneric(*genericTy); } - // Ditto for generic packs. - for (const auto gp : std::move(gen.genericPacks)) + + for (TypeId unsealedTableTy : fts.unsealedTables) + sealTable(scope, unsealedTableTy); + + for (const auto& [freePackId, params] : fts.typePacks) { - ftv->genericPacks.push_back(gp); + TypePackId freePack = follow(freePackId); + std::optional generalizedTp = generalizeTypePack(arena, builtinTypes, scope, freePack, params); + + if (generalizedTp) + pushGenericPack(freePack); + } + + TypeCacher cacher{cachedTypes}; + cacher.traverse(ty); + } + else + { + MutatingGeneralizer gen{arena, builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes)}; + + gen.traverse(ty); + + /* MutatingGeneralizer mutates types in place, so it is possible that ty has + * been transmuted to a BoundType. We must follow it again and verify that + * we are allowed to mutate it before we attach generics to it. + */ + ty = follow(ty); + + if (ty->owningArena != arena || ty->persistent) + return ty; + + TypeCacher cacher{cachedTypes}; + cacher.traverse(ty); + + FunctionType* ftv = getMutable(ty); + if (ftv) + { + // If we're generalizing a function type, add any of the newly inferred + // generics to the list of existing generic types. + for (const auto g : std::move(gen.generics)) + { + ftv->generics.push_back(g); + } + // Ditto for generic packs. + for (const auto gp : std::move(gen.genericPacks)) + { + ftv->genericPacks.push_back(gp); + } } } diff --git a/Analysis/src/InferPolarity.cpp b/Analysis/src/InferPolarity.cpp new file mode 100644 index 00000000..3399abcf --- /dev/null +++ b/Analysis/src/InferPolarity.cpp @@ -0,0 +1,169 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/DenseHash.h" +#include "Luau/Polarity.h" +#include "Luau/Scope.h" +#include "Luau/VisitType.h" + +LUAU_FASTFLAG(LuauNonReentrantGeneralization) + +namespace Luau +{ + +struct InferPolarity : TypeVisitor +{ + NotNull arena; + NotNull scope; + + DenseHashMap types{nullptr}; + DenseHashMap packs{nullptr}; + + Polarity polarity = Polarity::Positive; + + explicit InferPolarity(NotNull arena, NotNull scope) + : arena(arena) + , scope(scope) + { + } + + void flip() + { + polarity = invert(polarity); + } + + bool visit(TypeId ty, const GenericType& gt) override + { + if (ty->owningArena != arena) + return false; + + if (subsumes(scope, gt.scope)) + types[ty] |= polarity; + + return false; + } + + bool visit(TypeId ty, const TableType& tt) override + { + if (ty->owningArena != arena) + return false; + + const Polarity p = polarity; + for (const auto& [name, prop] : tt.props) + { + if (prop.isShared()) + { + polarity = Polarity::Mixed; + traverse(prop.type()); + } + else if (prop.isReadOnly()) + { + polarity = p; + traverse(*prop.readTy); + } + else if (prop.isWriteOnly()) + { + polarity = invert(p); + traverse(*prop.writeTy); + } + else + LUAU_ASSERT(!"Unreachable"); + } + + if (tt.indexer) + { + polarity = Polarity::Mixed; + traverse(tt.indexer->indexType); + traverse(tt.indexer->indexResultType); + } + + polarity = p; + + return false; + } + + bool visit(TypeId ty, const FunctionType& ft) override + { + if (ty->owningArena != arena) + return false; + + const Polarity p = polarity; + + polarity = Polarity::Positive; + + // If these types actually occur within the function signature, their + // polarity will be overwritten. If not, we infer that they are phantom + // types. + for (TypeId generic : ft.generics) + { + const auto gen = get(generic); + LUAU_ASSERT(gen); + if (subsumes(scope, gen->scope)) + types[generic] = Polarity::None; + } + for (TypePackId genericPack : ft.genericPacks) + { + const auto gen = get(genericPack); + LUAU_ASSERT(gen); + if (subsumes(scope, gen->scope)) + packs[genericPack] = Polarity::None; + } + + flip(); + traverse(ft.argTypes); + flip(); + traverse(ft.retTypes); + + polarity = p; + + return false; + } + + bool visit(TypeId, const ClassType&) override + { + return false; + } + + bool visit(TypePackId tp, const GenericTypePack& gtp) override + { + packs[tp] |= polarity; + return false; + } +}; + +template +static void inferGenericPolarities_(NotNull arena, NotNull scope, TID ty) +{ + if (!FFlag::LuauNonReentrantGeneralization) + return; + + InferPolarity infer{arena, scope}; + infer.traverse(ty); + + for (const auto& [ty, polarity] : infer.types) + { + auto gt = getMutable(ty); + LUAU_ASSERT(gt); + gt->polarity = polarity; + } + + for (const auto& [tp, polarity] : infer.packs) + { + if (tp->owningArena != arena) + continue; + auto gp = getMutable(tp); + LUAU_ASSERT(gp); + gp->polarity = polarity; + } +} + +void inferGenericPolarities(NotNull arena, NotNull scope, TypeId ty) +{ + inferGenericPolarities_(arena, scope, ty); +} + +void inferGenericPolarities(NotNull arena, NotNull scope, TypePackId tp) +{ + inferGenericPolarities_(arena, scope, tp); +} + +} // namespace Luau diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index f0833e7a..e7bc6b0f 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -23,6 +23,8 @@ LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity) LUAU_FASTFLAGVARIABLE(LuauSubtypingStopAtNormFail) +LUAU_FASTINTVARIABLE(LuauSubtypingReasoningLimit, 100) +LUAU_FASTFLAGVARIABLE(LuauSubtypingEnableReasoningLimit) namespace Luau { @@ -100,6 +102,9 @@ static SubtypingReasonings mergeReasonings(const SubtypingReasonings& a, const S else result.insert(r); } + + if (FFlag::LuauSubtypingEnableReasoningLimit && result.size() >= size_t(FInt::LuauSubtypingReasoningLimit)) + return result; } for (const SubtypingReasoning& r : b) @@ -116,6 +121,9 @@ static SubtypingReasonings mergeReasonings(const SubtypingReasonings& a, const S else result.insert(r); } + + if (FFlag::LuauSubtypingEnableReasoningLimit && result.size() >= size_t(FInt::LuauSubtypingReasoningLimit)) + return result; } return result; diff --git a/Analysis/src/TableLiteralInference.cpp b/Analysis/src/TableLiteralInference.cpp index ceaf4798..36fcc34c 100644 --- a/Analysis/src/TableLiteralInference.cpp +++ b/Analysis/src/TableLiteralInference.cpp @@ -241,12 +241,10 @@ TypeId matchLiteralType( if (FFlag::LuauBidirectionalInferenceUpcast && expr->is()) { - // TODO: Push argument / return types into the lambda. For now, just do + // TODO: Push argument / return types into the lambda. For now, just do // the non-literal thing: check for a subtype and upcast if valid. auto result = subtyping->isSubtype(/*subTy=*/exprType, /*superTy=*/expectedType, unifier->scope); - return result.isSubtype - ? expectedType - : exprType; + return result.isSubtype ? expectedType : exprType; } if (auto exprTable = expr->as()) @@ -352,7 +350,6 @@ TypeId matchLiteralType( } keysToDelete.insert(item.key->as()); - } // If it's just an extra property and the expected type @@ -375,22 +372,25 @@ TypeId matchLiteralType( // quadratic in a hurry. if (expectedProp.isShared()) { - matchedType = - matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, subtyping, *expectedReadTy, propTy, item.value, toBlock); + matchedType = matchLiteralType( + astTypes, astExpectedTypes, builtinTypes, arena, unifier, subtyping, *expectedReadTy, propTy, item.value, toBlock + ); prop.readTy = matchedType; prop.writeTy = matchedType; } else if (expectedReadTy) { - matchedType = - matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, subtyping, *expectedReadTy, propTy, item.value, toBlock); + matchedType = matchLiteralType( + astTypes, astExpectedTypes, builtinTypes, arena, unifier, subtyping, *expectedReadTy, propTy, item.value, toBlock + ); prop.readTy = matchedType; prop.writeTy.reset(); } else if (expectedWriteTy) { - matchedType = - matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, subtyping, *expectedWriteTy, propTy, item.value, toBlock); + matchedType = matchLiteralType( + astTypes, astExpectedTypes, builtinTypes, arena, unifier, subtyping, *expectedWriteTy, propTy, item.value, toBlock + ); prop.readTy.reset(); prop.writeTy = matchedType; } @@ -448,7 +448,6 @@ TypeId matchLiteralType( if (tableTy->indexer->indexResultType == *propTy) tableTy->indexer->indexResultType = matchedType; } - } } else if (item.kind == AstExprTable::Item::General) @@ -476,7 +475,6 @@ TypeId matchLiteralType( indexerKeyTypes.insert(tKey); indexerValueTypes.insert(tProp); } - } else LUAU_ASSERT(!"Unexpected"); @@ -544,12 +542,12 @@ TypeId matchLiteralType( { TypeId inferredKeyType = builtinTypes->neverType; TypeId inferredValueType = builtinTypes->neverType; - for (auto kt: indexerKeyTypes) + for (auto kt : indexerKeyTypes) { auto simplified = simplifyUnion(builtinTypes, arena, inferredKeyType, kt); inferredKeyType = simplified.result; } - for (auto vt: indexerValueTypes) + for (auto vt : indexerValueTypes) { auto simplified = simplifyUnion(builtinTypes, arena, inferredValueType, vt); inferredValueType = simplified.result; diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 91ec3edc..1e3fedea 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -301,6 +301,28 @@ struct StringifierState emit(std::to_string(i).c_str()); } + void emit(Polarity p) + { + switch (p) + { + case Polarity::None: + emit(" "); + break; + case Polarity::Negative: + emit(" -"); + break; + case Polarity::Positive: + emit("+ "); + break; + case Polarity::Mixed: + emit("+-"); + break; + default: + emit("!!"); + break; + } + } + void indent() { indentation += 4; @@ -482,6 +504,8 @@ struct TypeStringifier { state.emit("'"); state.emit(state.getName(ty)); + if (FInt::DebugLuauVerboseTypeNames >= 1) + state.emit(ftv.polarity); } else { @@ -494,6 +518,9 @@ struct TypeStringifier state.emit("'"); state.emit(state.getName(ty)); + if (FInt::DebugLuauVerboseTypeNames >= 1) + state.emit(ftv.polarity); + if (!get(upperBound)) { state.emit(" <: "); @@ -509,6 +536,9 @@ struct TypeStringifier state.emit(state.getName(ty)); + if (FFlag::LuauSolverV2 && FInt::DebugLuauVerboseTypeNames >= 1) + state.emit(ftv.polarity); + if (FInt::DebugLuauVerboseTypeNames >= 2) { state.emit("-"); @@ -538,6 +568,9 @@ struct TypeStringifier else state.emit(state.getName(ty)); + if (FInt::DebugLuauVerboseTypeNames >= 1) + state.emit(gtv.polarity); + if (FInt::DebugLuauVerboseTypeNames >= 2) { state.emit("-"); @@ -1222,6 +1255,9 @@ struct TypePackStringifier state.emit(state.getName(tp)); } + if (FInt::DebugLuauVerboseTypeNames >= 1) + state.emit(pack.polarity); + if (FInt::DebugLuauVerboseTypeNames >= 2) { state.emit("-"); @@ -1241,6 +1277,9 @@ struct TypePackStringifier state.emit("free-"); state.emit(state.getName(tp)); + if (FInt::DebugLuauVerboseTypeNames >= 1) + state.emit(pack.polarity); + if (FInt::DebugLuauVerboseTypeNames >= 2) { state.emit("-"); diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 3d7ec6c2..5cb7f58a 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -11,7 +11,6 @@ #include LUAU_FASTFLAG(LuauStoreCSTData2) -LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon) LUAU_FASTFLAG(LuauAstTypeGroup3) LUAU_FASTFLAG(LuauFixDoBlockEndLocation) LUAU_FASTFLAG(LuauParseOptionalAsNode2) diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 027d5c7f..23f407d2 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -488,11 +488,12 @@ FreeType::FreeType(TypeLevel level, TypeId lowerBound, TypeId upperBound) { } -FreeType::FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound) +FreeType::FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound, Polarity polarity) : index(Unifiable::freshIndex()) , scope(scope) , lowerBound(lowerBound) , upperBound(upperBound) + , polarity(polarity) { } @@ -543,16 +544,18 @@ GenericType::GenericType(TypeLevel level) { } -GenericType::GenericType(const Name& name) +GenericType::GenericType(const Name& name, Polarity polarity) : index(Unifiable::freshIndex()) , name(name) , explicitName(true) + , polarity(polarity) { } -GenericType::GenericType(Scope* scope) +GenericType::GenericType(Scope* scope, Polarity polarity) : index(Unifiable::freshIndex()) , scope(scope) + , polarity(polarity) { } @@ -1268,9 +1271,9 @@ IntersectionTypeIterator end(const IntersectionType* itv) return IntersectionTypeIterator{}; } -TypeId freshType(NotNull arena, NotNull builtinTypes, Scope* scope) +TypeId freshType(NotNull arena, NotNull builtinTypes, Scope* scope, Polarity polarity) { - return arena->addType(FreeType{scope, builtinTypes->neverType, builtinTypes->unknownType}); + return arena->addType(FreeType{scope, builtinTypes->neverType, builtinTypes->unknownType, polarity}); } std::vector filterMap(TypeId type, TypeIdPredicate predicate) diff --git a/Analysis/src/TypeArena.cpp b/Analysis/src/TypeArena.cpp index e4e9e293..c5ccc0d6 100644 --- a/Analysis/src/TypeArena.cpp +++ b/Analysis/src/TypeArena.cpp @@ -77,9 +77,9 @@ TypeId TypeArena::freshType_DEPRECATED(Scope* scope, TypeLevel level) return allocated; } -TypePackId TypeArena::freshTypePack(Scope* scope) +TypePackId TypeArena::freshTypePack(Scope* scope, Polarity polarity) { - TypePackId allocated = typePacks.allocate(FreeTypePack{scope}); + TypePackId allocated = typePacks.allocate(FreeTypePack{scope, polarity}); asMutable(allocated)->owningArena = this; diff --git a/Analysis/src/TypeFunction.cpp b/Analysis/src/TypeFunction.cpp index ddc52cd7..3b5263f6 100644 --- a/Analysis/src/TypeFunction.cpp +++ b/Analysis/src/TypeFunction.cpp @@ -46,9 +46,11 @@ LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyApplicationCartesianProductLimit, 5'0 // when this value is set to a negative value, guessing will be totally disabled. LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyUseGuesserDepth, -1); -LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies) LUAU_FASTFLAG(DebugLuauEqSatSimplification) LUAU_FASTFLAG(LuauTypeFunResultInAutocomplete) +LUAU_FASTFLAG(LuauNonReentrantGeneralization) + +LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies) LUAU_FASTFLAGVARIABLE(LuauMetatableTypeFunctions) LUAU_FASTFLAGVARIABLE(LuauIndexTypeFunctionImprovements) LUAU_FASTFLAGVARIABLE(LuauIndexTypeFunctionFunctionMetamethods) @@ -520,7 +522,7 @@ struct TypeFunctionReducer return; if (FFlag::DebugLuauLogTypeFamilies) - printf("Trying to reduce %s\n", toString(subject, {true}).c_str()); + printf("Trying to %sreduce %s\n", force ? "force " : "", toString(subject, {true}).c_str()); if (const TypeFunctionInstanceType* tfit = get(subject)) { @@ -1219,6 +1221,9 @@ TypeFunctionReductionResult unmTypeFunction( if (isPending(operandTy, ctx->solver)) return {std::nullopt, Reduction::MaybeOk, {operandTy}, {}}; + if (FFlag::LuauNonReentrantGeneralization) + operandTy = follow(operandTy); + std::shared_ptr normTy = ctx->normalizer->normalize(operandTy); // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. @@ -2112,28 +2117,50 @@ struct FindRefinementBlockers : TypeOnceVisitor struct ContainsRefinableType : TypeOnceVisitor { bool found = false; - ContainsRefinableType() : TypeOnceVisitor(/* skipBoundTypes */ true) {} + ContainsRefinableType() + : TypeOnceVisitor(/* skipBoundTypes */ true) + { + } - bool visit(TypeId ty) override { + bool visit(TypeId ty) override + { // Default case: if we find *some* type that's worth refining against, // then we can claim that this type contains a refineable type. found = true; return false; } - bool visit(TypeId Ty, const NoRefineType&) override { + bool visit(TypeId Ty, const NoRefineType&) override + { // No refine types aren't interesting return false; } - bool visit(TypeId ty, const TableType&) override { return !found; } - bool visit(TypeId ty, const MetatableType&) override { return !found; } - bool visit(TypeId ty, const FunctionType&) override { return !found; } - bool visit(TypeId ty, const UnionType&) override { return !found; } - bool visit(TypeId ty, const IntersectionType&) override { return !found; } - bool visit(TypeId ty, const NegationType&) override { return !found; } - + bool visit(TypeId ty, const TableType&) override + { + return !found; + } + bool visit(TypeId ty, const MetatableType&) override + { + return !found; + } + bool visit(TypeId ty, const FunctionType&) override + { + return !found; + } + bool visit(TypeId ty, const UnionType&) override + { + return !found; + } + bool visit(TypeId ty, const IntersectionType&) override + { + return !found; + } + bool visit(TypeId ty, const NegationType&) override + { + return !found; + } }; TypeFunctionReductionResult refineTypeFunction( @@ -2414,7 +2441,6 @@ TypeFunctionReductionResult unionTypeFunction( } return {resultTy, Reduction::MaybeOk, {}, {}}; - } @@ -2849,7 +2875,14 @@ bool tblIndexInto_DEPRECATED(TypeId indexer, TypeId indexee, DenseHashSet& result, DenseHashSet& seenSet, NotNull ctx, bool isRaw) +bool tblIndexInto( + TypeId indexer, + TypeId indexee, + DenseHashSet& result, + DenseHashSet& seenSet, + NotNull ctx, + bool isRaw +) { indexer = follow(indexer); indexee = follow(indexee); @@ -2860,7 +2893,7 @@ bool tblIndexInto(TypeId indexer, TypeId indexee, DenseHashSet& result, if (FFlag::LuauIndexTypeFunctionFunctionMetamethods) { - if (auto unionTy = get(indexee)) + if (auto unionTy = get(indexee)) { bool res = true; for (auto component : unionTy) @@ -3087,8 +3120,9 @@ TypeFunctionReductionResult indexFunctionImpl( { return follow(ty); } - ); -} + ); + } + // If the type being reduced to is a single type, no need to union if (properties.size() == 1) return {*properties.begin(), Reduction::MaybeOk, {}, {}}; @@ -3332,6 +3366,39 @@ TypeFunctionReductionResult getmetatableTypeFunction( return getmetatableHelper(targetTy, location, ctx); } +TypeFunctionReductionResult weakoptionalTypeFunc( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 1 || !packParams.empty()) + { + ctx->ice->ice("weakoptional type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + TypeId targetTy = follow(typeParams.at(0)); + + if (isPending(targetTy, ctx->solver)) + return {std::nullopt, Reduction::MaybeOk, {targetTy}, {}}; + + if (is(instance)) + return {ctx->builtins->nilType, Reduction::MaybeOk, {}, {}}; + + std::shared_ptr targetNorm = ctx->normalizer->normalize(targetTy); + + if (!targetNorm) + return {std::nullopt, Reduction::MaybeOk, {}, {}}; + + auto result = ctx->normalizer->isInhabited(targetNorm.get()); + if (result == NormalizationResult::False) + return {ctx->builtins->nilType, Reduction::MaybeOk, {}, {}}; + + return {targetTy, Reduction::MaybeOk, {}, {}}; +} + BuiltinTypeFunctions::BuiltinTypeFunctions() : userFunc{"user", userDefinedTypeFunction} @@ -3361,6 +3428,7 @@ BuiltinTypeFunctions::BuiltinTypeFunctions() , rawgetFunc{"rawget", rawgetTypeFunction} , setmetatableFunc{"setmetatable", setmetatableTypeFunction} , getmetatableFunc{"getmetatable", getmetatableTypeFunction} + , weakoptionalFunc{"weakoptional", weakoptionalTypeFunc} { } @@ -3369,7 +3437,7 @@ void BuiltinTypeFunctions::addToScope(NotNull arena, NotNull s // make a type function for a one-argument type function auto mkUnaryTypeFunction = [&](const TypeFunction* tf) { - TypeId t = arena->addType(GenericType{"T"}); + TypeId t = arena->addType(GenericType{"T", Polarity::Negative}); GenericTypeDefinition genericT{t}; return TypeFun{{genericT}, arena->addType(TypeFunctionInstanceType{NotNull{tf}, {t}, {}})}; @@ -3378,8 +3446,8 @@ void BuiltinTypeFunctions::addToScope(NotNull arena, NotNull s // make a type function for a two-argument type function auto mkBinaryTypeFunction = [&](const TypeFunction* tf) { - TypeId t = arena->addType(GenericType{"T"}); - TypeId u = arena->addType(GenericType{"U"}); + TypeId t = arena->addType(GenericType{"T", Polarity::Negative}); + TypeId u = arena->addType(GenericType{"U", Polarity::Negative}); GenericTypeDefinition genericT{t}; GenericTypeDefinition genericU{u, {t}}; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index f80e2be3..ef14abf8 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -35,6 +35,7 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) LUAU_FASTFLAG(LuauRetainDefinitionAliasLocations) +LUAU_FASTFLAGVARIABLE(LuauStatForInFix) namespace Luau { @@ -1317,8 +1318,24 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) // Extract the remaining return values of the call // and check them against the parameter types of the iterator function. auto [types, tail] = flatten(callRetPack); - std::vector argTypes = std::vector(types.begin() + 1, types.end()); - argPack = addTypePack(TypePackVar{TypePack{std::move(argTypes), tail}}); + + if (FFlag::LuauStatForInFix) + { + if (!types.empty()) + { + std::vector argTypes = std::vector(types.begin() + 1, types.end()); + argPack = addTypePack(TypePackVar{TypePack{std::move(argTypes), tail}}); + } + else + { + argPack = addTypePack(TypePack{}); + } + } + else + { + std::vector argTypes = std::vector(types.begin() + 1, types.end()); + argPack = addTypePack(TypePackVar{TypePack{std::move(argTypes), tail}}); + } } else { diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 7e11d462..a1aa6d16 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -6,7 +6,7 @@ #include -LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAGVARIABLE(LuauTypePackDetectCycles) namespace Luau { @@ -18,10 +18,11 @@ FreeTypePack::FreeTypePack(TypeLevel level) { } -FreeTypePack::FreeTypePack(Scope* scope) +FreeTypePack::FreeTypePack(Scope* scope, Polarity polarity) : index(Unifiable::freshIndex()) , level{} , scope(scope) + , polarity(polarity) { } @@ -52,9 +53,10 @@ GenericTypePack::GenericTypePack(const Name& name) { } -GenericTypePack::GenericTypePack(Scope* scope) +GenericTypePack::GenericTypePack(Scope* scope, Polarity polarity) : index(Unifiable::freshIndex()) , scope(scope) + , polarity(polarity) { } @@ -147,6 +149,15 @@ TypePackIterator& TypePackIterator::operator++() currentTypePack = tp->tail ? log->follow(*tp->tail) : nullptr; tp = currentTypePack ? log->getMutable(currentTypePack) : nullptr; + if (FFlag::LuauTypePackDetectCycles && tp) + { + // Step twice on each iteration to detect cycles + tailCycleCheck = tp->tail ? log->follow(*tp->tail) : nullptr; + + if (currentTypePack == tailCycleCheck) + throw InternalCompilerError("TypePackIterator detected a type pack cycle"); + } + currentIndex = 0; } diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index bf8cf533..3240a2d3 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -14,7 +14,9 @@ LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete); LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope); LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) +LUAU_FASTFLAG(LuauNonReentrantGeneralization) LUAU_FASTFLAG(LuauDisableNewSolverAssertsInMixedMode) + namespace Luau { @@ -304,7 +306,11 @@ TypePack extendTypePack( // also have to create a new tail. TypePack newPack; - newPack.tail = arena.freshTypePack(ftp->scope); + newPack.tail = arena.freshTypePack(ftp->scope, ftp->polarity); + + if (FFlag::LuauNonReentrantGeneralization) + trackInteriorFreeTypePack(ftp->scope, *newPack.tail); + if (FFlag::LuauSolverV2) result.tail = newPack.tail; size_t overridesIndex = 0; @@ -319,7 +325,7 @@ TypePack extendTypePack( { if (FFlag::LuauSolverV2) { - FreeType ft{ftp->scope, builtinTypes->neverType, builtinTypes->unknownType}; + FreeType ft{ftp->scope, builtinTypes->neverType, builtinTypes->unknownType, ftp->polarity}; t = arena.addType(ft); if (FFlag::LuauTrackInteriorFreeTypesOnScope) trackInteriorFreeType(ftp->scope, t); @@ -568,4 +574,24 @@ void trackInteriorFreeType(Scope* scope, TypeId ty) LUAU_ASSERT(!"No scopes in parent chain had a present `interiorFreeTypes` member."); } +void trackInteriorFreeTypePack(Scope* scope, TypePackId tp) +{ + LUAU_ASSERT(tp); + if (!FFlag::LuauNonReentrantGeneralization) + return; + + for (; scope; scope = scope->parent.get()) + { + if (scope->interiorFreeTypePacks) + { + scope->interiorFreeTypePacks->push_back(tp); + return; + } + } + // There should at least be *one* generalization constraint per module + // where `interiorFreeTypes` is present, which would be the one made + // by ConstraintGenerator::visitModuleRoot. + LUAU_ASSERT(!"No scopes in parent chain had a present `interiorFreeTypePacks` member."); +} + } // namespace Luau diff --git a/Analysis/src/Unifier2.cpp b/Analysis/src/Unifier2.cpp index 9389df8b..8ad2fae4 100644 --- a/Analysis/src/Unifier2.cpp +++ b/Analysis/src/Unifier2.cpp @@ -20,6 +20,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAGVARIABLE(LuauUnifyMetatableWithAny) LUAU_FASTFLAG(LuauExtraFollows) +LUAU_FASTFLAG(LuauNonReentrantGeneralization) namespace Luau { @@ -321,10 +322,23 @@ bool Unifier2::unify(TypeId subTy, const FunctionType* superFn) if (shouldInstantiate) { for (auto generic : subFn->generics) - genericSubstitutions[generic] = freshType(arena, builtinTypes, scope); + { + const GenericType* gen = get(generic); + LUAU_ASSERT(gen); + genericSubstitutions[generic] = freshType(scope, gen->polarity); + } for (auto genericPack : subFn->genericPacks) - genericPackSubstitutions[genericPack] = arena->freshTypePack(scope); + { + if (FFlag::LuauNonReentrantGeneralization) + { + const GenericTypePack* gen = get(genericPack); + LUAU_ASSERT(gen); + genericPackSubstitutions[genericPack] = freshTypePack(scope, gen->polarity); + } + else + genericPackSubstitutions[genericPack] = arena->freshTypePack(scope); + } } bool argResult = unify(superFn->argTypes, subFn->argTypes); @@ -941,4 +955,23 @@ OccursCheckResult Unifier2::occursCheck(DenseHashSet& seen, TypePack return OccursCheckResult::Pass; } +TypeId Unifier2::freshType(NotNull scope, Polarity polarity) +{ + TypeId result = ::Luau::freshType(arena, builtinTypes, scope.get(), polarity); + newFreshTypes.emplace_back(result); + return result; +} + +TypePackId Unifier2::freshTypePack(NotNull scope, Polarity polarity) +{ + TypePackId result = arena->freshTypePack(scope.get()); + + auto ftp = getMutable(result); + LUAU_ASSERT(ftp); + ftp->polarity = polarity; + + newFreshTypePacks.emplace_back(result); + return result; +} + } // namespace Luau diff --git a/Ast/include/Luau/Cst.h b/Ast/include/Luau/Cst.h index c6489d9f..0f7b5911 100644 --- a/Ast/include/Luau/Cst.h +++ b/Ast/include/Luau/Cst.h @@ -113,11 +113,11 @@ public: CstExprFunction(); Position functionKeywordPosition{0, 0}; - Position openGenericsPosition{0,0}; + Position openGenericsPosition{0, 0}; AstArray genericsCommaPositions; - Position closeGenericsPosition{0,0}; + Position closeGenericsPosition{0, 0}; AstArray argsCommaPositions; - Position returnSpecifierPosition{0,0}; + Position returnSpecifierPosition{0, 0}; }; class CstExprTable : public CstNode diff --git a/Ast/src/Cst.cpp b/Ast/src/Cst.cpp index 3873a106..ca2bd105 100644 --- a/Ast/src/Cst.cpp +++ b/Ast/src/Cst.cpp @@ -38,7 +38,8 @@ CstExprIndexExpr::CstExprIndexExpr(Position openBracketPosition, Position closeB { } -CstExprFunction::CstExprFunction() : CstNode(CstClassIndex()) +CstExprFunction::CstExprFunction() + : CstNode(CstClassIndex()) { } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 283baf52..04f32b3f 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -20,7 +20,6 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauSolverV2) LUAU_FASTFLAGVARIABLE(LuauAllowComplexTypesInGenericParams) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForTableTypes) -LUAU_FASTFLAGVARIABLE(LuauExtendStatEndPosWithSemicolon) LUAU_FASTFLAGVARIABLE(LuauStoreCSTData2) LUAU_FASTFLAGVARIABLE(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) LUAU_FASTFLAGVARIABLE(LuauAstTypeGroup3) @@ -204,7 +203,9 @@ ParseExprResult Parser::parseExpr(const char* buffer, size_t bufferSize, AstName AstExpr* expr = p.parseExpr(); size_t lines = p.lexer.current().location.end.line + (bufferSize > 0 && buffer[bufferSize - 1] != '\n'); - return ParseExprResult{expr, lines, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations), std::move(p.cstNodeMap)}; + return ParseExprResult{ + expr, lines, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations), std::move(p.cstNodeMap) + }; } catch (ParseError& err) { @@ -316,10 +317,7 @@ AstStatBlock* Parser::parseBlockNoScope() { nextLexeme(); stat->hasSemicolon = true; - if (FFlag::LuauExtendStatEndPosWithSemicolon) - { - stat->location.end = lexer.previousLocation().end; - } + stat->location.end = lexer.previousLocation().end; } body.push_back(stat); @@ -745,14 +743,7 @@ AstExpr* Parser::parseFunctionName(bool& hasself, AstName& debugname) // while we could concatenate the name chain, for now let's just write the short name debugname = name.name; - expr = allocator.alloc( - Location(expr->location, name.location), - expr, - name.name, - name.location, - opPosition, - '.' - ); + expr = allocator.alloc(Location(expr->location, name.location), expr, name.name, name.location, opPosition, '.'); // note: while the parser isn't recursive here, we're generating recursive structures of unbounded depth incrementRecursionCounter("function name"); @@ -771,14 +762,7 @@ AstExpr* Parser::parseFunctionName(bool& hasself, AstName& debugname) // while we could concatenate the name chain, for now let's just write the short name debugname = name.name; - expr = allocator.alloc( - Location(expr->location, name.location), - expr, - name.name, - name.location, - opPosition, - ':' - ); + expr = allocator.alloc(Location(expr->location, name.location), expr, name.name, name.location, opPosition, ':'); hasself = true; } @@ -1666,13 +1650,12 @@ std::pair Parser::parseFunctionBody( auto* cstNode = FFlag::LuauStoreCSTData2 && options.storeCstData ? allocator.alloc() : nullptr; - auto [generics, genericPacks] = FFlag::LuauStoreCSTData2 && cstNode ? parseGenericTypeList( - /* withDefaultValues= */ false, - &cstNode->openGenericsPosition, - &cstNode->genericsCommaPositions, - &cstNode->closeGenericsPosition - ) - : parseGenericTypeList(/* withDefaultValues= */ false); + auto [generics, genericPacks] = + FFlag::LuauStoreCSTData2 && cstNode + ? parseGenericTypeList( + /* withDefaultValues= */ false, &cstNode->openGenericsPosition, &cstNode->genericsCommaPositions, &cstNode->closeGenericsPosition + ) + : parseGenericTypeList(/* withDefaultValues= */ false); MatchLexeme matchParen = lexer.current(); expectAndConsume('(', "function"); @@ -1822,7 +1805,12 @@ Parser::Binding Parser::parseBinding() } // bindinglist ::= (binding | `...') [`,' bindinglist] -std::tuple Parser::parseBindingList(TempVector& result, bool allowDot3, AstArray* commaPositions, std::optional initialCommaPosition) +std::tuple Parser::parseBindingList( + TempVector& result, + bool allowDot3, + AstArray* commaPositions, + std::optional initialCommaPosition +) { TempVector localCommaPositions(scratchPosition); diff --git a/CLI/include/Luau/AnalyzeRequirer.h b/CLI/include/Luau/AnalyzeRequirer.h new file mode 100644 index 00000000..a4b395da --- /dev/null +++ b/CLI/include/Luau/AnalyzeRequirer.h @@ -0,0 +1,35 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/RequireNavigator.h" +#include "Luau/RequirerUtils.h" + +struct FileNavigationContext : Luau::Require::NavigationContext +{ + using NavigateResult = Luau::Require::NavigationContext::NavigateResult; + + FileNavigationContext(std::string requirerPath); + + std::string getRequirerIdentifier() const override; + + // Navigation interface + NavigateResult reset(const std::string& requirerChunkname) override; + NavigateResult jumpToAlias(const std::string& path) override; + + NavigateResult toParent() override; + NavigateResult toChild(const std::string& component) override; + + bool isConfigPresent() const override; + std::optional getConfig() const override; + + // Custom capabilities + bool isModulePresent() const; + std::optional getIdentifier() const; + + std::string path; + std::string suffix; + std::string requirerPath; + +private: + NavigateResult storePathResult(PathResult result); +}; diff --git a/CLI/include/Luau/ReplRequirer.h b/CLI/include/Luau/ReplRequirer.h new file mode 100644 index 00000000..2f5f7c0b --- /dev/null +++ b/CLI/include/Luau/ReplRequirer.h @@ -0,0 +1,32 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Require.h" + +#include "Luau/Compiler.h" + +#include "lua.h" + +#include +#include + +void requireConfigInit(luarequire_Configuration* config); + +struct ReplRequirer +{ + ReplRequirer( + std::function copts, + std::function coverageActive, + std::function codegenEnabled, + std::function coverageTrack + ); + + std::function copts; + std::function coverageActive; + std::function codegenEnabled; + std::function coverageTrack; + + std::string absPath; + std::string relPath; + std::string suffix; +}; diff --git a/CLI/include/Luau/Require.h b/CLI/include/Luau/Require.h deleted file mode 100644 index e4fc019a..00000000 --- a/CLI/include/Luau/Require.h +++ /dev/null @@ -1,84 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#pragma once - -#include "Luau/Config.h" - -#include -#include -#include - -class RequireResolver -{ -public: - enum class ModuleStatus - { - Cached, - FileRead, - ErrorReported - }; - - struct ResolvedRequire - { - ModuleStatus status; - std::string identifier; - std::string absolutePath; - std::string sourceCode; - }; - - struct RequireContext - { - virtual ~RequireContext() = default; - virtual std::string getPath() = 0; - virtual bool isRequireAllowed() = 0; - virtual bool isStdin() = 0; - virtual std::string createNewIdentifer(const std::string& path) = 0; - }; - - struct CacheManager - { - virtual ~CacheManager() = default; - virtual bool isCached(const std::string& path) - { - return false; - } - }; - - struct ErrorHandler - { - virtual ~ErrorHandler() = default; - virtual void reportError(const std::string message) {} - }; - - RequireResolver(std::string pathToResolve, RequireContext& requireContext, CacheManager& cacheManager, ErrorHandler& errorHandler); - - [[nodiscard]] ResolvedRequire resolveRequire(std::function completionCallback = nullptr); - -private: - std::string pathToResolve; - - RequireContext& requireContext; - CacheManager& cacheManager; - ErrorHandler& errorHandler; - - ResolvedRequire resolvedRequire; - bool isRequireResolved = false; - - Luau::Config config; - std::string lastSearchedDir; - bool isConfigFullyResolved = false; - - [[nodiscard]] bool initialize(); - - ModuleStatus findModule(); - ModuleStatus findModuleImpl(); - - [[nodiscard]] bool resolveAndStoreDefaultPaths(); - std::optional getRequiringContextAbsolute(); - std::string getRequiringContextRelative(); - - [[nodiscard]] bool substituteAliasIfPresent(std::string& path); - std::optional getAlias(std::string alias); - - [[nodiscard]] bool parseNextConfig(); - [[nodiscard]] bool parseConfigInDirectory(const std::string& directory); -}; \ No newline at end of file diff --git a/CLI/include/Luau/RequirerUtils.h b/CLI/include/Luau/RequirerUtils.h new file mode 100644 index 00000000..a9cda9f9 --- /dev/null +++ b/CLI/include/Luau/RequirerUtils.h @@ -0,0 +1,36 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include +#include + +struct PathResult +{ + enum class Status + { + SUCCESS, + AMBIGUOUS, + NOT_FOUND + }; + + Status status; + std::string absPath; + std::string relPath; + std::string suffix; +}; + +PathResult getStdInResult(); + +PathResult getAbsolutePathResult(const std::string& path); + +// If given an absolute path, this will implicitly call getAbsolutePathResult. +// Aliases prevent us from solely operating on relative paths, so we need to +// be able to fall back to operating on absolute paths if needed. +PathResult tryGetRelativePathResult(const std::string& path); + +PathResult getParent(const std::string& absPath, const std::string& relPath); +PathResult getChild(const std::string& absPath, const std::string& relPath, const std::string& name); + +bool isFilePresent(const std::string& path, const std::string& suffix); +std::optional getFileContents(const std::string& path, const std::string& suffix); diff --git a/CLI/src/Analyze.cpp b/CLI/src/Analyze.cpp index e10a2c2e..fdd8df33 100644 --- a/CLI/src/Analyze.cpp +++ b/CLI/src/Analyze.cpp @@ -7,9 +7,10 @@ #include "Luau/TypeAttach.h" #include "Luau/Transpiler.h" +#include "Luau/AnalyzeRequirer.h" #include "Luau/FileUtils.h" #include "Luau/Flags.h" -#include "Luau/Require.h" +#include "Luau/RequireNavigator.h" #include #include @@ -173,15 +174,18 @@ struct CliFileResolver : Luau::FileResolver { std::string path{expr->value.data, expr->value.size}; - AnalysisRequireContext requireContext{context->name}; - AnalysisCacheManager cacheManager; - AnalysisErrorHandler errorHandler; + FileNavigationContext navigationContext{context->name}; + Luau::Require::ErrorHandler nullErrorHandler{}; - RequireResolver resolver(path, requireContext, cacheManager, errorHandler); - RequireResolver::ResolvedRequire resolvedRequire = resolver.resolveRequire(); + Luau::Require::Navigator navigator(navigationContext, nullErrorHandler); + if (navigator.navigate(path) != Luau::Require::Navigator::Status::Success) + return std::nullopt; - if (resolvedRequire.status == RequireResolver::ModuleStatus::FileRead) - return {{resolvedRequire.identifier}}; + if (!navigationContext.isModulePresent()) + return std::nullopt; + + if (std::optional identifier = navigationContext.getIdentifier()) + return {{*identifier}}; } return std::nullopt; @@ -193,48 +197,6 @@ struct CliFileResolver : Luau::FileResolver return "stdin"; return name; } - -private: - struct AnalysisRequireContext : RequireResolver::RequireContext - { - explicit AnalysisRequireContext(std::string path) - : path(std::move(path)) - { - } - - std::string getPath() override - { - return path; - } - - bool isRequireAllowed() override - { - return true; - } - - bool isStdin() override - { - return path == "-"; - } - - std::string createNewIdentifer(const std::string& path) override - { - return path; - } - - private: - std::string path; - }; - - struct AnalysisCacheManager : public RequireResolver::CacheManager - { - AnalysisCacheManager() = default; - }; - - struct AnalysisErrorHandler : RequireResolver::ErrorHandler - { - AnalysisErrorHandler() = default; - }; }; struct CliConfigResolver : Luau::ConfigResolver diff --git a/CLI/src/AnalyzeRequirer.cpp b/CLI/src/AnalyzeRequirer.cpp new file mode 100644 index 00000000..19d1b431 --- /dev/null +++ b/CLI/src/AnalyzeRequirer.cpp @@ -0,0 +1,99 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/AnalyzeRequirer.h" + +#include "Luau/RequireNavigator.h" +#include "Luau/RequirerUtils.h" + +#include +#include + +Luau::Require::NavigationContext::NavigateResult FileNavigationContext::storePathResult(PathResult result) +{ + if (result.status == PathResult::Status::AMBIGUOUS) + return Luau::Require::NavigationContext::NavigateResult::Ambiguous; + + if (result.status == PathResult::Status::NOT_FOUND) + return Luau::Require::NavigationContext::NavigateResult::NotFound; + + path = result.absPath; + suffix = result.suffix; + + return Luau::Require::NavigationContext::NavigateResult::Success; +} + +FileNavigationContext::FileNavigationContext(std::string requirerPath) +{ + std::string_view path = requirerPath; + if (path.size() >= 10 && path.substr(path.size() - 10) == "/init.luau") + { + path.remove_suffix(10); + } + else if (path.size() >= 9 && path.substr(path.size() - 9) == "/init.lua") + { + path.remove_suffix(9); + } + else if (path.size() >= 5 && path.substr(path.size() - 5) == ".luau") + { + path.remove_suffix(5); + } + else if (path.size() >= 4 && path.substr(path.size() - 4) == ".lua") + { + path.remove_suffix(4); + } + + this->requirerPath = path; +} + +std::string FileNavigationContext::getRequirerIdentifier() const +{ + return requirerPath; +} + +Luau::Require::NavigationContext::NavigateResult FileNavigationContext::reset(const std::string& requirerChunkname) +{ + if (requirerChunkname == "-") + { + return storePathResult(getStdInResult()); + } + + return storePathResult(tryGetRelativePathResult(requirerChunkname)); +} + +Luau::Require::NavigationContext::NavigateResult FileNavigationContext::jumpToAlias(const std::string& path) +{ + Luau::Require::NavigationContext::NavigateResult result = storePathResult(getAbsolutePathResult(path)); + if (result != Luau::Require::NavigationContext::NavigateResult::Success) + return result; + + return Luau::Require::NavigationContext::NavigateResult::Success; +} + +Luau::Require::NavigationContext::NavigateResult FileNavigationContext::toParent() +{ + return storePathResult(getParent(path, path)); +} + +Luau::Require::NavigationContext::NavigateResult FileNavigationContext::toChild(const std::string& component) +{ + return storePathResult(getChild(path, path, component)); +} + +bool FileNavigationContext::isModulePresent() const +{ + return isFilePresent(path, suffix); +} + +std::optional FileNavigationContext::getIdentifier() const +{ + return path + suffix; +} + +bool FileNavigationContext::isConfigPresent() const +{ + return isFilePresent(path, "/.luaurc"); +} + +std::optional FileNavigationContext::getConfig() const +{ + return getFileContents(path, "/.luaurc"); +} diff --git a/CLI/src/Repl.cpp b/CLI/src/Repl.cpp index 3e3ae182..56d9cf8b 100644 --- a/CLI/src/Repl.cpp +++ b/CLI/src/Repl.cpp @@ -14,6 +14,7 @@ #include "Luau/FileUtils.h" #include "Luau/Flags.h" #include "Luau/Profiler.h" +#include "Luau/ReplRequirer.h" #include "Luau/Require.h" #include "isocline.h" @@ -113,172 +114,6 @@ static int lua_loadstring(lua_State* L) return 2; // return nil plus error message } -static int finishrequire(lua_State* L) -{ - if (lua_isstring(L, -1)) - lua_error(L); - - return 1; -} - -struct RuntimeRequireContext : public RequireResolver::RequireContext -{ - // In the context of the REPL, source is the calling context's chunkname. - // - // These chunknames have certain prefixes that indicate context. These - // are used when displaying debug information (see luaO_chunkid). - // - // Generally, the '@' prefix is used for filepaths, and the '=' prefix is - // used for custom chunknames, such as =stdin. - explicit RuntimeRequireContext(std::string source) - : source(std::move(source)) - { - } - - std::string getPath() override - { - return source.substr(1); - } - - bool isRequireAllowed() override - { - return isStdin() || (!source.empty() && source[0] == '@'); - } - - bool isStdin() override - { - return source == "=stdin"; - } - - std::string createNewIdentifer(const std::string& path) override - { - return "@" + path; - } - -private: - std::string source; -}; - -struct RuntimeCacheManager : public RequireResolver::CacheManager -{ - explicit RuntimeCacheManager(lua_State* L) - : L(L) - { - } - - bool isCached(const std::string& path) override - { - luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1); - lua_getfield(L, -1, path.c_str()); - bool cached = !lua_isnil(L, -1); - lua_pop(L, 2); - - if (cached) - cacheKey = path; - - return cached; - } - - std::string cacheKey; - -private: - lua_State* L; -}; - -struct RuntimeErrorHandler : RequireResolver::ErrorHandler -{ - explicit RuntimeErrorHandler(lua_State* L) - : L(L) - { - } - - void reportError(const std::string message) override - { - luaL_errorL(L, "%s", message.c_str()); - } - -private: - lua_State* L; -}; - -static int lua_require(lua_State* L) -{ - std::string name = luaL_checkstring(L, 1); - - RequireResolver::ResolvedRequire resolvedRequire; - { - lua_Debug ar; - lua_getinfo(L, 1, "s", &ar); - - RuntimeRequireContext requireContext{ar.source}; - RuntimeCacheManager cacheManager{L}; - RuntimeErrorHandler errorHandler{L}; - - RequireResolver resolver(std::move(name), requireContext, cacheManager, errorHandler); - - resolvedRequire = resolver.resolveRequire( - [L, &cacheKey = cacheManager.cacheKey](const RequireResolver::ModuleStatus status) - { - lua_getfield(L, LUA_REGISTRYINDEX, "_MODULES"); - if (status == RequireResolver::ModuleStatus::Cached) - lua_getfield(L, -1, cacheKey.c_str()); - } - ); - } - - if (resolvedRequire.status == RequireResolver::ModuleStatus::Cached) - return finishrequire(L); - - // module needs to run in a new thread, isolated from the rest - // note: we create ML on main thread so that it doesn't inherit environment of L - lua_State* GL = lua_mainthread(L); - lua_State* ML = lua_newthread(GL); - lua_xmove(GL, L, 1); - - // new thread needs to have the globals sandboxed - luaL_sandboxthread(ML); - - // now we can compile & run module on the new thread - std::string bytecode = Luau::compile(resolvedRequire.sourceCode, copts()); - if (luau_load(ML, resolvedRequire.identifier.c_str(), bytecode.data(), bytecode.size(), 0) == 0) - { - if (codegen) - { - Luau::CodeGen::CompilationOptions nativeOptions; - Luau::CodeGen::compile(ML, -1, nativeOptions); - } - - if (coverageActive()) - coverageTrack(ML, -1); - - int status = lua_resume(ML, L, 0); - - if (status == 0) - { - if (lua_gettop(ML) == 0) - lua_pushstring(ML, "module must return a value"); - else if (!lua_istable(ML, -1) && !lua_isfunction(ML, -1)) - lua_pushstring(ML, "module must return a table or function"); - } - else if (status == LUA_YIELD) - { - lua_pushstring(ML, "module can not yield"); - } - else if (!lua_isstring(ML, -1)) - { - lua_pushstring(ML, "unknown error while running module"); - } - } - - // there's now a return value on top of ML; L stack: _MODULES ML - lua_xmove(ML, L, 1); - lua_pushvalue(L, -1); - lua_setfield(L, -4, resolvedRequire.absolutePath.c_str()); - - // L stack: _MODULES ML result - return finishrequire(L); -} - static int lua_collectgarbage(lua_State* L) { const char* option = luaL_optstring(L, 1, "collect"); @@ -329,6 +164,39 @@ static int lua_callgrind(lua_State* L) } #endif +static void* createCliRequireContext(lua_State* L) +{ + void* ctx = lua_newuserdatadtor( + L, + sizeof(ReplRequirer), + [](void* ptr) + { + static_cast(ptr)->~ReplRequirer(); + } + ); + + if (!ctx) + luaL_error(L, "unable to allocate ReplRequirer"); + + ctx = new (ctx) ReplRequirer{ + copts, + coverageActive, + []() + { + return codegen; + }, + coverageTrack, + }; + + // Store ReplRequirer in the registry to keep it alive for the lifetime of + // this lua_State. Memory address is used as a key to avoid collisions. + lua_pushlightuserdata(L, ctx); + lua_insert(L, -2); + lua_settable(L, LUA_REGISTRYINDEX); + + return ctx; +} + void setupState(lua_State* L) { if (codegen) @@ -338,7 +206,6 @@ void setupState(lua_State* L) static const luaL_Reg funcs[] = { {"loadstring", lua_loadstring}, - {"require", lua_require}, {"collectgarbage", lua_collectgarbage}, #ifdef CALLGRIND {"callgrind", lua_callgrind}, @@ -350,6 +217,8 @@ void setupState(lua_State* L) luaL_register(L, NULL, funcs); lua_pop(L, 1); + luaopen_require(L, requireConfigInit, createCliRequireContext(L)); + luaL_sandbox(L); } diff --git a/CLI/src/ReplRequirer.cpp b/CLI/src/ReplRequirer.cpp new file mode 100644 index 00000000..22184ddb --- /dev/null +++ b/CLI/src/ReplRequirer.cpp @@ -0,0 +1,221 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/ReplRequirer.h" + +#include "Luau/CodeGen.h" +#include "Luau/CodeGenOptions.h" +#include "Luau/Require.h" + +#include "Luau/RequirerUtils.h" +#include "lua.h" +#include "lualib.h" + +#include +#include +#include + +static luarequire_WriteResult write(std::optional contents, char* buffer, size_t bufferSize, size_t* sizeOut) +{ + if (!contents) + return luarequire_WriteResult::WRITE_FAILURE; + + size_t nullTerminatedSize = contents->size() + 1; + + if (bufferSize < nullTerminatedSize) + { + *sizeOut = nullTerminatedSize; + return luarequire_WriteResult::WRITE_BUFFER_TOO_SMALL; + } + + *sizeOut = nullTerminatedSize; + memcpy(buffer, contents->c_str(), nullTerminatedSize); + return luarequire_WriteResult::WRITE_SUCCESS; +} + +static luarequire_NavigateResult storePathResult(ReplRequirer* req, PathResult result) +{ + if (result.status == PathResult::Status::AMBIGUOUS) + return NAVIGATE_AMBIGUOUS; + + if (result.status == PathResult::Status::NOT_FOUND) + return NAVIGATE_NOT_FOUND; + + req->absPath = result.absPath; + req->relPath = result.relPath; + req->suffix = result.suffix; + + return NAVIGATE_SUCCESS; +} + +static bool is_require_allowed(lua_State* L, void* ctx, const char* requirer_chunkname) +{ + std::string_view chunkname = requirer_chunkname; + return chunkname == "=stdin" || (!chunkname.empty() && chunkname[0] == '@'); +} + +static luarequire_NavigateResult reset(lua_State* L, void* ctx, const char* requirer_chunkname) +{ + ReplRequirer* req = static_cast(ctx); + + std::string chunkname = requirer_chunkname; + if (chunkname == "=stdin") + { + return storePathResult(req, getStdInResult()); + } + else if (!chunkname.empty() && chunkname[0] == '@') + { + return storePathResult(req, tryGetRelativePathResult(chunkname.substr(1))); + } + + return NAVIGATE_NOT_FOUND; +} + +static luarequire_NavigateResult jump_to_alias(lua_State* L, void* ctx, const char* path) +{ + ReplRequirer* req = static_cast(ctx); + + luarequire_NavigateResult result = storePathResult(req, getAbsolutePathResult(path)); + if (result != NAVIGATE_SUCCESS) + return result; + + // Jumping to an absolute path breaks the relative-require chain. The best + // we can do is to store the absolute path itself. + req->relPath = req->absPath; + return NAVIGATE_SUCCESS; +} + +static luarequire_NavigateResult to_parent(lua_State* L, void* ctx) +{ + ReplRequirer* req = static_cast(ctx); + return storePathResult(req, getParent(req->absPath, req->relPath)); +} + +static luarequire_NavigateResult to_child(lua_State* L, void* ctx, const char* name) +{ + ReplRequirer* req = static_cast(ctx); + return storePathResult(req, getChild(req->absPath, req->relPath, name)); +} + +static bool is_module_present(lua_State* L, void* ctx) +{ + ReplRequirer* req = static_cast(ctx); + return isFilePresent(req->absPath, req->suffix); +} + +static luarequire_WriteResult get_contents(lua_State* L, void* ctx, char* buffer, size_t buffer_size, size_t* size_out) +{ + ReplRequirer* req = static_cast(ctx); + return write(getFileContents(req->absPath, req->suffix), buffer, buffer_size, size_out); +} + +static luarequire_WriteResult get_chunkname(lua_State* L, void* ctx, char* buffer, size_t buffer_size, size_t* size_out) +{ + ReplRequirer* req = static_cast(ctx); + return write("@" + req->relPath, buffer, buffer_size, size_out); +} + +static luarequire_WriteResult get_cache_key(lua_State* L, void* ctx, char* buffer, size_t buffer_size, size_t* size_out) +{ + ReplRequirer* req = static_cast(ctx); + return write(req->absPath + req->suffix, buffer, buffer_size, size_out); +} + +static bool is_config_present(lua_State* L, void* ctx) +{ + ReplRequirer* req = static_cast(ctx); + return isFilePresent(req->absPath, "/.luaurc"); +} + +static luarequire_WriteResult get_config(lua_State* L, void* ctx, char* buffer, size_t buffer_size, size_t* size_out) +{ + ReplRequirer* req = static_cast(ctx); + return write(getFileContents(req->absPath, "/.luaurc"), buffer, buffer_size, size_out); +} + +static int load(lua_State* L, void* ctx, const char* chunkname, const char* contents) +{ + ReplRequirer* req = static_cast(ctx); + + // module needs to run in a new thread, isolated from the rest + // note: we create ML on main thread so that it doesn't inherit environment of L + lua_State* GL = lua_mainthread(L); + lua_State* ML = lua_newthread(GL); + lua_xmove(GL, L, 1); + + // new thread needs to have the globals sandboxed + luaL_sandboxthread(ML); + + // now we can compile & run module on the new thread + std::string bytecode = Luau::compile(contents, req->copts()); + if (luau_load(ML, chunkname, bytecode.data(), bytecode.size(), 0) == 0) + { + if (req->codegenEnabled()) + { + Luau::CodeGen::CompilationOptions nativeOptions; + Luau::CodeGen::compile(ML, -1, nativeOptions); + } + + if (req->coverageActive()) + req->coverageTrack(ML, -1); + + int status = lua_resume(ML, L, 0); + + if (status == 0) + { + if (lua_gettop(ML) == 0) + lua_pushstring(ML, "module must return a value"); + else if (!lua_istable(ML, -1) && !lua_isfunction(ML, -1)) + lua_pushstring(ML, "module must return a table or function"); + } + else if (status == LUA_YIELD) + { + lua_pushstring(ML, "module can not yield"); + } + else if (!lua_isstring(ML, -1)) + { + lua_pushstring(ML, "unknown error while running module"); + } + } + + // add ML result to L stack + lua_xmove(ML, L, 1); + if (lua_isstring(L, -1)) + lua_error(L); + + // remove ML thread from L stack + lua_remove(L, -2); + + // added one value to L stack: module result + return 1; +} + +void requireConfigInit(luarequire_Configuration* config) +{ + if (config == nullptr) + return; + + config->is_require_allowed = is_require_allowed; + config->reset = reset; + config->jump_to_alias = jump_to_alias; + config->to_parent = to_parent; + config->to_child = to_child; + config->is_module_present = is_module_present; + config->get_contents = get_contents; + config->is_config_present = is_config_present; + config->get_chunkname = get_chunkname; + config->get_cache_key = get_cache_key; + config->get_config = get_config; + config->load = load; +} + +ReplRequirer::ReplRequirer( + std::function copts, + std::function coverageActive, + std::function codegenEnabled, + std::function coverageTrack +) + : copts(std::move(copts)) + , coverageActive(std::move(coverageActive)) + , codegenEnabled(std::move(codegenEnabled)) + , coverageTrack(std::move(coverageTrack)) +{ +} diff --git a/CLI/src/Require.cpp b/CLI/src/Require.cpp deleted file mode 100644 index 1039f85c..00000000 --- a/CLI/src/Require.cpp +++ /dev/null @@ -1,313 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Require.h" - -#include "Luau/FileUtils.h" -#include "Luau/Common.h" -#include "Luau/Config.h" - -#include -#include -#include - -static constexpr char kRequireErrorGeneric[] = "error requiring module"; - -RequireResolver::RequireResolver(std::string path, RequireContext& requireContext, CacheManager& cacheManager, ErrorHandler& errorHandler) - : pathToResolve(std::move(path)) - , requireContext(requireContext) - , cacheManager(cacheManager) - , errorHandler(errorHandler) -{ -} - -RequireResolver::ResolvedRequire RequireResolver::resolveRequire(std::function completionCallback) -{ - if (isRequireResolved) - { - errorHandler.reportError("require statement has already been resolved"); - return ResolvedRequire{ModuleStatus::ErrorReported}; - } - - if (!initialize()) - return ResolvedRequire{ModuleStatus::ErrorReported}; - - resolvedRequire.status = findModule(); - - if (completionCallback) - completionCallback(resolvedRequire.status); - - isRequireResolved = true; - return resolvedRequire; -} - -static bool hasValidPrefix(std::string_view path) -{ - return path.compare(0, 2, "./") == 0 || path.compare(0, 3, "../") == 0 || path.compare(0, 1, "@") == 0; -} - -static bool isPathAmbiguous(const std::string& path) -{ - bool found = false; - for (const char* suffix : {".luau", ".lua"}) - { - if (isFile(path + suffix)) - { - if (found) - return true; - else - found = true; - } - } - if (isDirectory(path) && found) - return true; - - return false; -} - -bool RequireResolver::initialize() -{ - if (!requireContext.isRequireAllowed()) - { - errorHandler.reportError("require is not supported in this context"); - return false; - } - - if (isAbsolutePath(pathToResolve)) - { - errorHandler.reportError("cannot require an absolute path"); - return false; - } - - std::replace(pathToResolve.begin(), pathToResolve.end(), '\\', '/'); - - if (!hasValidPrefix(pathToResolve)) - { - errorHandler.reportError("require path must start with a valid prefix: ./, ../, or @"); - return false; - } - - return substituteAliasIfPresent(pathToResolve); -} - -RequireResolver::ModuleStatus RequireResolver::findModule() -{ - if (!resolveAndStoreDefaultPaths()) - return ModuleStatus::ErrorReported; - - if (isPathAmbiguous(resolvedRequire.absolutePath)) - { - errorHandler.reportError("require path could not be resolved to a unique file"); - return ModuleStatus::ErrorReported; - } - - static constexpr std::array possibleSuffixes = {".luau", ".lua", "/init.luau", "/init.lua"}; - size_t unsuffixedAbsolutePathSize = resolvedRequire.absolutePath.size(); - - for (const char* possibleSuffix : possibleSuffixes) - { - resolvedRequire.absolutePath += possibleSuffix; - - if (cacheManager.isCached(resolvedRequire.absolutePath)) - return ModuleStatus::Cached; - - // Try to read the matching file - if (std::optional source = readFile(resolvedRequire.absolutePath)) - { - resolvedRequire.identifier = requireContext.createNewIdentifer(resolvedRequire.identifier + possibleSuffix); - resolvedRequire.sourceCode = *source; - return ModuleStatus::FileRead; - } - - resolvedRequire.absolutePath.resize(unsuffixedAbsolutePathSize); // truncate to remove suffix - } - - if (hasFileExtension(resolvedRequire.absolutePath, {".luau", ".lua"}) && isFile(resolvedRequire.absolutePath)) - { - errorHandler.reportError("error requiring module: consider removing the file extension"); - return ModuleStatus::ErrorReported; - } - - errorHandler.reportError(kRequireErrorGeneric); - return ModuleStatus::ErrorReported; -} - -bool RequireResolver::resolveAndStoreDefaultPaths() -{ - if (!isAbsolutePath(pathToResolve)) - { - std::string identifierContext = getRequiringContextRelative(); - std::optional absolutePathContext = getRequiringContextAbsolute(); - - if (!absolutePathContext) - return false; - - // resolvePath automatically sanitizes/normalizes the paths - std::optional identifier = resolvePath(pathToResolve, identifierContext); - std::optional absolutePath = resolvePath(pathToResolve, *absolutePathContext); - - if (!identifier || !absolutePath) - { - errorHandler.reportError("could not resolve require path"); - return false; - } - - resolvedRequire.identifier = std::move(*identifier); - resolvedRequire.absolutePath = std::move(*absolutePath); - } - else - { - // Here we must explicitly sanitize, as the path is taken as is - std::string sanitizedPath = normalizePath(pathToResolve); - resolvedRequire.identifier = sanitizedPath; - resolvedRequire.absolutePath = std::move(sanitizedPath); - } - return true; -} - -std::optional RequireResolver::getRequiringContextAbsolute() -{ - std::string requiringFile; - if (isAbsolutePath(requireContext.getPath())) - { - // We already have an absolute path for the requiring file - requiringFile = requireContext.getPath(); - } - else - { - // Requiring file's stored path is relative to the CWD, must make absolute - std::optional cwd = getCurrentWorkingDirectory(); - if (!cwd) - { - errorHandler.reportError("could not determine current working directory"); - return std::nullopt; - } - - if (requireContext.isStdin()) - { - // Require statement is being executed from REPL input prompt - // The requiring context is the pseudo-file "stdin" in the CWD - requiringFile = joinPaths(*cwd, "stdin"); - } - else - { - // Require statement is being executed in a file, must resolve relative to CWD - requiringFile = normalizePath(joinPaths(*cwd, requireContext.getPath())); - } - } - std::replace(requiringFile.begin(), requiringFile.end(), '\\', '/'); - return requiringFile; -} - -std::string RequireResolver::getRequiringContextRelative() -{ - return requireContext.isStdin() ? "./" : requireContext.getPath(); -} - -bool RequireResolver::substituteAliasIfPresent(std::string& path) -{ - if (path.size() < 1 || path[0] != '@') - return true; - - // To ignore the '@' alias prefix when processing the alias - const size_t aliasStartPos = 1; - - // If a directory separator was found, the length of the alias is the - // distance between the start of the alias and the separator. Otherwise, - // the whole string after the alias symbol is the alias. - size_t aliasLen = path.find_first_of("\\/"); - if (aliasLen != std::string::npos) - aliasLen -= aliasStartPos; - - const std::string potentialAlias = path.substr(aliasStartPos, aliasLen); - - // Not worth searching when potentialAlias cannot be an alias - if (!Luau::isValidAlias(potentialAlias)) - { - errorHandler.reportError("@" + potentialAlias + " is not a valid alias"); - return false; - } - - if (std::optional alias = getAlias(potentialAlias)) - { - path = *alias + path.substr(potentialAlias.size() + 1); - return true; - } - - errorHandler.reportError("@" + potentialAlias + " is not a valid alias"); - return false; -} - -std::optional RequireResolver::getAlias(std::string alias) -{ - std::transform( - alias.begin(), - alias.end(), - alias.begin(), - [](unsigned char c) - { - return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; - } - ); - while (!config.aliases.contains(alias) && !isConfigFullyResolved) - { - if (!parseNextConfig()) - return std::nullopt; // error parsing config - } - if (!config.aliases.contains(alias) && isConfigFullyResolved) - return std::nullopt; // could not find alias - - const Luau::Config::AliasInfo& aliasInfo = config.aliases[alias]; - return resolvePath(aliasInfo.value, aliasInfo.configLocation); -} - -bool RequireResolver::parseNextConfig() -{ - if (isConfigFullyResolved) - return true; // no config files left to parse - - std::optional directory; - if (lastSearchedDir.empty()) - { - std::optional requiringFile = getRequiringContextAbsolute(); - if (!requiringFile) - return false; - - directory = getParentPath(*requiringFile); - } - else - directory = getParentPath(lastSearchedDir); - - if (directory) - { - lastSearchedDir = *directory; - if (!parseConfigInDirectory(*directory)) - return false; - } - else - isConfigFullyResolved = true; - - return true; -} - -bool RequireResolver::parseConfigInDirectory(const std::string& directory) -{ - std::string configPath = joinPaths(directory, Luau::kConfigName); - - Luau::ConfigOptions::AliasOptions aliasOpts; - aliasOpts.configLocation = configPath; - aliasOpts.overwriteAliases = false; - - Luau::ConfigOptions opts; - opts.aliasOptions = std::move(aliasOpts); - - if (std::optional contents = readFile(configPath)) - { - std::optional error = Luau::parseConfig(*contents, config, opts); - if (error) - { - errorHandler.reportError("error parsing " + configPath + "(" + *error + ")"); - return false; - } - } - - return true; -} diff --git a/CLI/src/RequirerUtils.cpp b/CLI/src/RequirerUtils.cpp new file mode 100644 index 00000000..dcbcdb5f --- /dev/null +++ b/CLI/src/RequirerUtils.cpp @@ -0,0 +1,119 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/RequirerUtils.h" + +#include "Luau/FileUtils.h" + +#include +#include +#include + +static std::pair getSuffixWithAmbiguityCheck(const std::string& path) +{ + bool found = false; + std::string suffix; + + for (const char* potentialSuffix : {".luau", ".lua"}) + { + if (isFile(path + potentialSuffix)) + { + if (found) + return {PathResult::Status::AMBIGUOUS, ""}; + + suffix = potentialSuffix; + found = true; + } + } + if (isDirectory(path)) + { + if (found) + return {PathResult::Status::AMBIGUOUS, ""}; + + for (const char* potentialSuffix : {"/init.luau", "/init.lua"}) + { + if (isFile(path + potentialSuffix)) + { + if (found) + return {PathResult::Status::AMBIGUOUS, ""}; + + suffix = potentialSuffix; + found = true; + } + } + + found = true; + } + + if (!found) + return {PathResult::Status::NOT_FOUND, ""}; + + return {PathResult::Status::SUCCESS, suffix}; +} + +static PathResult addSuffix(PathResult partialResult) +{ + if (partialResult.status != PathResult::Status::SUCCESS) + return partialResult; + + auto [status, suffix] = getSuffixWithAmbiguityCheck(partialResult.absPath); + if (status != PathResult::Status::SUCCESS) + return PathResult{status}; + + partialResult.suffix = std::move(suffix); + return partialResult; +} + +PathResult getStdInResult() +{ + std::optional cwd = getCurrentWorkingDirectory(); + if (!cwd) + return PathResult{PathResult::Status::NOT_FOUND}; + + std::replace(cwd->begin(), cwd->end(), '\\', '/'); + + return PathResult{PathResult::Status::SUCCESS, *cwd + "/stdin", "./stdin", ""}; +} + +PathResult getAbsolutePathResult(const std::string& path) +{ + return addSuffix(PathResult{PathResult::Status::SUCCESS, path}); +} + +PathResult tryGetRelativePathResult(const std::string& path) +{ + if (isAbsolutePath(path)) + return getAbsolutePathResult(path); + + std::optional cwd = getCurrentWorkingDirectory(); + if (!cwd) + return PathResult{PathResult::Status::NOT_FOUND}; + + std::optional resolvedAbsPath = resolvePath(path, *cwd + "/stdin"); + if (!resolvedAbsPath) + return PathResult{PathResult::Status::NOT_FOUND}; + + return addSuffix(PathResult{PathResult::Status::SUCCESS, std::move(*resolvedAbsPath), path}); +} + +PathResult getParent(const std::string& absPath, const std::string& relPath) +{ + std::optional parent = getParentPath(absPath); + if (!parent) + return PathResult{PathResult::Status::NOT_FOUND}; + + return addSuffix(PathResult{PathResult::Status::SUCCESS, *parent, normalizePath(relPath + "/..")}); +} + +PathResult getChild(const std::string& absPath, const std::string& relPath, const std::string& name) +{ + return addSuffix(PathResult{PathResult::Status::SUCCESS, joinPaths(absPath, name), joinPaths(relPath, name)}); +} + +bool isFilePresent(const std::string& path, const std::string& suffix) +{ + return isFile(path + suffix); +} + +std::optional getFileContents(const std::string& path, const std::string& suffix) +{ + return readFile(path + suffix); +} diff --git a/CMakeLists.txt b/CMakeLists.txt index 220031e2..03f235ce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,6 +31,8 @@ add_library(Luau.Analysis STATIC) add_library(Luau.EqSat STATIC) add_library(Luau.CodeGen STATIC) add_library(Luau.VM STATIC) +add_library(Luau.Require STATIC) +add_library(Luau.RequireNavigator STATIC) add_library(isocline STATIC) if(LUAU_BUILD_CLI) @@ -101,6 +103,15 @@ target_compile_features(Luau.VM PRIVATE cxx_std_11) target_include_directories(Luau.VM PUBLIC VM/include) target_link_libraries(Luau.VM PUBLIC Luau.Common) +target_compile_features(Luau.Require PUBLIC cxx_std_17) +target_include_directories(Luau.Require PUBLIC Require/Runtime/include) +target_link_libraries(Luau.Require PUBLIC Luau.VM) +target_link_libraries(Luau.Require PRIVATE Luau.RequireNavigator) + +target_compile_features(Luau.RequireNavigator PUBLIC cxx_std_17) +target_include_directories(Luau.RequireNavigator PUBLIC Require/Navigator/include) +target_link_libraries(Luau.RequireNavigator PUBLIC Luau.Config) + target_include_directories(isocline PUBLIC extern/isocline/include) target_include_directories(Luau.VM.Internals INTERFACE VM/src) @@ -215,12 +226,12 @@ if(LUAU_BUILD_CLI) target_include_directories(Luau.Repl.CLI PRIVATE extern extern/isocline/include) - target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.Config Luau.CodeGen Luau.VM Luau.CLI.lib isocline) + target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.Config Luau.CodeGen Luau.VM Luau.Require Luau.CLI.lib isocline) target_link_libraries(Luau.Repl.CLI PRIVATE osthreads) target_link_libraries(Luau.Analyze.CLI PRIVATE osthreads) - target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis Luau.CLI.lib) + target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis Luau.CLI.lib Luau.RequireNavigator) target_link_libraries(Luau.Ast.CLI PRIVATE Luau.Ast Luau.Analysis Luau.CLI.lib) @@ -252,7 +263,7 @@ if(LUAU_BUILD_TESTS) target_compile_options(Luau.CLI.Test PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.CLI.Test PRIVATE extern CLI) - target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.Config Luau.CodeGen Luau.VM Luau.CLI.lib isocline) + target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.Config Luau.CodeGen Luau.VM Luau.Require Luau.CLI.lib isocline) target_link_libraries(Luau.CLI.Test PRIVATE osthreads) endif() diff --git a/CodeGen/src/CodeGenUtils.cpp b/CodeGen/src/CodeGenUtils.cpp index 26451eea..6fb8b5c8 100644 --- a/CodeGen/src/CodeGenUtils.cpp +++ b/CodeGen/src/CodeGenUtils.cpp @@ -18,8 +18,6 @@ #include -LUAU_DYNAMIC_FASTFLAG(LuauPopIncompleteCi) - // All external function calls that can cause stack realloc or Lua calls have to be wrapped in VM_PROTECT // This makes sure that we save the pc (in case the Lua call needs to generate a backtrace) before the call, // and restores the stack pointer after in case stack gets reallocated @@ -193,14 +191,7 @@ Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults) // note: this reallocs stack, but we don't need to VM_PROTECT this // this is because we're going to modify base/savedpc manually anyhow // crucially, we can't use ra/argtop after this line - if (DFFlag::LuauPopIncompleteCi) - { - luaD_checkstackfornewci(L, ccl->stacksize); - } - else - { - luaD_checkstack(L, ccl->stacksize); - } + luaD_checkstackfornewci(L, ccl->stacksize); return ccl; } @@ -270,14 +261,7 @@ Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults) // note: this reallocs stack, but we don't need to VM_PROTECT this // this is because we're going to modify base/savedpc manually anyhow // crucially, we can't use ra/argtop after this line - if (DFFlag::LuauPopIncompleteCi) - { - luaD_checkstackfornewci(L, ccl->stacksize); - } - else - { - luaD_checkstack(L, ccl->stacksize); - } + luaD_checkstackfornewci(L, ccl->stacksize); LUAU_ASSERT(ci->top <= L->stack_last); diff --git a/LICENSE.txt b/LICENSE.txt index 2eac525f..45d93ade 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2019-2024 Roblox Corporation +Copyright (c) 2019-2025 Roblox Corporation Copyright (c) 1994–2019 Lua.org, PUC-Rio. Permission is hereby granted, free of charge, to any person obtaining a copy of diff --git a/Makefile b/Makefile index 2ad0fc00..a3ec7c66 100644 --- a/Makefile +++ b/Makefile @@ -38,19 +38,27 @@ VM_SOURCES=$(wildcard VM/src/*.cpp) VM_OBJECTS=$(VM_SOURCES:%=$(BUILD)/%.o) VM_TARGET=$(BUILD)/libluauvm.a +REQUIRE_SOURCES=$(wildcard Require/Runtime/src/*.cpp) +REQUIRE_OBJECTS=$(REQUIRE_SOURCES:%=$(BUILD)/%.o) +REQUIRE_TARGET=$(BUILD)/libluaurequire.a + +REQUIRENAVIGATOR_SOURCES=$(wildcard Require/Navigator/src/*.cpp) +REQUIRENAVIGATOR_OBJECTS=$(REQUIRENAVIGATOR_SOURCES:%=$(BUILD)/%.o) +REQUIRENAVIGATOR_TARGET=$(BUILD)/libluaurequirenavigator.a + ISOCLINE_SOURCES=extern/isocline/src/isocline.c ISOCLINE_OBJECTS=$(ISOCLINE_SOURCES:%=$(BUILD)/%.o) ISOCLINE_TARGET=$(BUILD)/libisocline.a -TESTS_SOURCES=$(wildcard tests/*.cpp) CLI/src/FileUtils.cpp CLI/src/Flags.cpp CLI/src/Profiler.cpp CLI/src/Coverage.cpp CLI/src/Repl.cpp CLI/src/Require.cpp +TESTS_SOURCES=$(wildcard tests/*.cpp) CLI/src/FileUtils.cpp CLI/src/Flags.cpp CLI/src/Profiler.cpp CLI/src/Coverage.cpp CLI/src/Repl.cpp CLI/src/ReplRequirer.cpp CLI/src/RequirerUtils.cpp TESTS_OBJECTS=$(TESTS_SOURCES:%=$(BUILD)/%.o) TESTS_TARGET=$(BUILD)/luau-tests -REPL_CLI_SOURCES=CLI/src/FileUtils.cpp CLI/src/Flags.cpp CLI/src/Profiler.cpp CLI/src/Coverage.cpp CLI/src/Repl.cpp CLI/src/ReplEntry.cpp CLI/src/Require.cpp +REPL_CLI_SOURCES=CLI/src/FileUtils.cpp CLI/src/Flags.cpp CLI/src/Profiler.cpp CLI/src/Coverage.cpp CLI/src/Repl.cpp CLI/src/ReplEntry.cpp CLI/src/ReplRequirer.cpp CLI/src/RequirerUtils.cpp REPL_CLI_OBJECTS=$(REPL_CLI_SOURCES:%=$(BUILD)/%.o) REPL_CLI_TARGET=$(BUILD)/luau -ANALYZE_CLI_SOURCES=CLI/src/FileUtils.cpp CLI/src/Flags.cpp CLI/src/Require.cpp CLI/src/Analyze.cpp +ANALYZE_CLI_SOURCES=CLI/src/FileUtils.cpp CLI/src/Flags.cpp CLI/src/Analyze.cpp CLI/src/AnalyzeRequirer.cpp CLI/src/RequirerUtils.cpp ANALYZE_CLI_OBJECTS=$(ANALYZE_CLI_SOURCES:%=$(BUILD)/%.o) ANALYZE_CLI_TARGET=$(BUILD)/luau-analyze @@ -73,7 +81,7 @@ ifneq ($(opt),) TESTS_ARGS+=-O$(opt) endif -OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(CONFIG_OBJECTS) $(ANALYSIS_OBJECTS) $(EQSAT_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(REPL_CLI_OBJECTS) $(ANALYZE_CLI_OBJECTS) $(COMPILE_CLI_OBJECTS) $(BYTECODE_CLI_OBJECTS) $(FUZZ_OBJECTS) +OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(CONFIG_OBJECTS) $(ANALYSIS_OBJECTS) $(EQSAT_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(REQUIRE_OBJECTS) $(REQUIRENAVIGATOR_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(REPL_CLI_OBJECTS) $(ANALYZE_CLI_OBJECTS) $(COMPILE_CLI_OBJECTS) $(BYTECODE_CLI_OBJECTS) $(FUZZ_OBJECTS) EXECUTABLE_ALIASES = luau luau-analyze luau-compile luau-bytecode luau-tests # common flags @@ -148,10 +156,12 @@ $(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnaly $(EQSAT_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IEqSat/include $(CODEGEN_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -ICodeGen/include -IVM/include -IVM/src # Code generation needs VM internals $(VM_OBJECTS): CXXFLAGS+=-std=c++11 -ICommon/include -IVM/include +$(REQUIRE_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IVM/include -IAst/include -IConfig/include -IRequire/Navigator/include -IRequire/Runtime/include +$(REQUIRENAVIGATOR_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IConfig/include -IRequire/Navigator/include $(ISOCLINE_OBJECTS): CXXFLAGS+=-Wno-unused-function -Iextern/isocline/include -$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IConfig/include -IAnalysis/include -IEqSat/include -ICodeGen/include -IVM/include -ICLI/include -Iextern -DDOCTEST_CONFIG_DOUBLE_STRINGIFY -$(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -Iextern -Iextern/isocline/include -ICLI/include -$(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include -Iextern -ICLI/include +$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IConfig/include -IAnalysis/include -IEqSat/include -ICodeGen/include -IVM/include -IRequire/Runtime/include -ICLI/include -Iextern -DDOCTEST_CONFIG_DOUBLE_STRINGIFY +$(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -IRequire/Runtime/include -Iextern -Iextern/isocline/include -ICLI/include +$(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include -IRequire/Navigator/include -Iextern -ICLI/include $(COMPILE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -ICLI/include $(BYTECODE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -ICLI/include $(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -IEqSat/include -IVM/include -ICodeGen/include -IConfig/include @@ -227,9 +237,9 @@ luau-tests: $(TESTS_TARGET) ln -fs $^ $@ # executable targets -$(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) -$(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) -$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(COMPILER_TARGET) $(VM_TARGET) +$(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(REQUIRE_TARGET) $(REQUIRENAVIGATOR_TARGET) $(CONFIG_TARGET) $(ISOCLINE_TARGET) +$(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(REQUIRE_TARGET) $(REQUIRENAVIGATOR_TARGET) $(CONFIG_TARGET) $(ISOCLINE_TARGET) +$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(AST_TARGET) $(COMPILER_TARGET) $(VM_TARGET) $(REQUIRENAVIGATOR_TARGET) $(CONFIG_TARGET) $(COMPILE_CLI_TARGET): $(COMPILE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(BYTECODE_CLI_TARGET): $(BYTECODE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) @@ -251,9 +261,11 @@ $(ANALYSIS_TARGET): $(ANALYSIS_OBJECTS) $(EQSAT_TARGET): $(EQSAT_OBJECTS) $(CODEGEN_TARGET): $(CODEGEN_OBJECTS) $(VM_TARGET): $(VM_OBJECTS) +$(REQUIRE_TARGET): $(REQUIRE_OBJECTS) +$(REQUIRENAVIGATOR_TARGET): $(REQUIRENAVIGATOR_OBJECTS) $(ISOCLINE_TARGET): $(ISOCLINE_OBJECTS) -$(AST_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET): +$(AST_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(REQUIRE_TARGET) $(REQUIRENAVIGATOR_TARGET) $(ISOCLINE_TARGET): ar rcs $@ $^ # object file targets diff --git a/Require/Navigator/include/Luau/PathUtilities.h b/Require/Navigator/include/Luau/PathUtilities.h new file mode 100644 index 00000000..dce8cc04 --- /dev/null +++ b/Require/Navigator/include/Luau/PathUtilities.h @@ -0,0 +1,22 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include + +namespace Luau::Require +{ + +enum class PathType +{ + RelativeToCurrent, + RelativeToParent, + Aliased, + Unsupported +}; + +PathType getPathType(std::string_view path); + +std::pair splitPath(std::string_view path); + +} // namespace Luau::Require diff --git a/Require/Navigator/include/Luau/RequireNavigator.h b/Require/Navigator/include/Luau/RequireNavigator.h new file mode 100644 index 00000000..21b01abd --- /dev/null +++ b/Require/Navigator/include/Luau/RequireNavigator.h @@ -0,0 +1,96 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Config.h" + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// +// +// The RequireNavigator library provides a C++ interface for navigating the +// context in which require-by-string operates. This is used internally by the +// require-by-string runtime library to resolve paths based on the rules defined +// by its consumers. +// +// Directly linking against this library allows for inspection of the +// require-by-string path resolution algorithm's behavior without enabling the +// runtime library, which is useful for static tooling as well. +// +//////////////////////////////////////////////////////////////////////////////// + +namespace Luau::Require +{ + +// The ErrorHandler interface is used to report errors during navigation. +// The default implementation does nothing but can be overridden to enable +// custom error handling behavior. +class ErrorHandler +{ +public: + virtual ~ErrorHandler() = default; + virtual void reportError(std::string message) {} +}; + +// NavigationContext is an pure virtual class that is intended to be implemented +// and injected into a Navigator. +// +// When a Navigator traverses a require path, its NavigationContext's methods +// are invoked, with the expectation that the NavigationContext will keep track +// of the current state of the navigation and provide information about the +// current context as needed. +class NavigationContext +{ +public: + virtual ~NavigationContext() = default; + virtual std::string getRequirerIdentifier() const = 0; + + enum class NavigateResult + { + Success, + Ambiguous, + NotFound + }; + + virtual NavigateResult reset(const std::string& identifier) = 0; + virtual NavigateResult jumpToAlias(const std::string& path) = 0; + + virtual NavigateResult toParent() = 0; + virtual NavigateResult toChild(const std::string& component) = 0; + + virtual bool isConfigPresent() const = 0; + virtual std::optional getConfig() const = 0; +}; + +// The Navigator class is responsible for traversing a given require path in the +// context of a given NavigationContext. +// +// The Navigator is not intended to be overridden. Rather, it expects a custom +// injected NavigationContext that provides the desired navigation behavior. +class Navigator +{ +public: + enum class Status + { + Success, + ErrorReported + }; + + Navigator(NavigationContext& navigationContext, ErrorHandler& errorHandler); + [[nodiscard]] Status navigate(std::string path); + +private: + using Error = std::optional; + [[nodiscard]] Error navigateImpl(std::string_view path); + [[nodiscard]] Error navigateThroughPath(std::string_view path); + [[nodiscard]] Error navigateToAlias(const std::string& alias, const std::string& value); + [[nodiscard]] Error navigateToAndPopulateConfig(const std::string& desiredAlias); + + NavigationContext& navigationContext; + ErrorHandler& errorHandler; + Luau::Config config; +}; + +} // namespace Luau::Require diff --git a/Require/Navigator/src/PathUtilities.cpp b/Require/Navigator/src/PathUtilities.cpp new file mode 100644 index 00000000..6de4f7ad --- /dev/null +++ b/Require/Navigator/src/PathUtilities.cpp @@ -0,0 +1,31 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/PathUtilities.h" + +#include + +namespace Luau::Require +{ + +PathType getPathType(std::string_view path) +{ + if (path.size() >= 2 && path.substr(0, 2) == "./") + return PathType::RelativeToCurrent; + if (path.size() >= 3 && path.substr(0, 3) == "../") + return PathType::RelativeToParent; + if (path.size() >= 1 && path[0] == '@') + return PathType::Aliased; + + return PathType::Unsupported; +} + +std::pair splitPath(std::string_view path) +{ + size_t pos = path.find_first_of('/'); + if (pos == std::string_view::npos) + return {path, {}}; + + return {path.substr(0, pos), path.substr(pos + 1)}; +} + +} // namespace Luau::Require diff --git a/Require/Navigator/src/RequireNavigator.cpp b/Require/Navigator/src/RequireNavigator.cpp new file mode 100644 index 00000000..f7e0c0b1 --- /dev/null +++ b/Require/Navigator/src/RequireNavigator.cpp @@ -0,0 +1,208 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/RequireNavigator.h" + +#include "Luau/PathUtilities.h" + +#include "Luau/Config.h" + +#include +#include +#include + +static constexpr char kRequireErrorAmbiguous[] = "require path could not be resolved to a unique file"; +static constexpr char kRequireErrorGeneric[] = "error requiring module"; + +namespace Luau::Require +{ + +using Error = std::optional; + +static Error toError(NavigationContext::NavigateResult result) +{ + if (result == NavigationContext::NavigateResult::Success) + return std::nullopt; + if (result == NavigationContext::NavigateResult::Ambiguous) + return kRequireErrorAmbiguous; + else + return kRequireErrorGeneric; +} + +static std::string extractAlias(std::string_view path) +{ + // To ignore the '@' alias prefix when processing the alias + const size_t aliasStartPos = 1; + + // If a directory separator was found, the length of the alias is the + // distance between the start of the alias and the separator. Otherwise, + // the whole string after the alias symbol is the alias. + size_t aliasLen = path.find_first_of('/'); + if (aliasLen != std::string::npos) + aliasLen -= aliasStartPos; + + return std::string{path.substr(aliasStartPos, aliasLen)}; +} + +Navigator::Navigator(NavigationContext& navigationContext, ErrorHandler& errorHandler) + : navigationContext(navigationContext) + , errorHandler(errorHandler) +{ +} + +Navigator::Status Navigator::navigate(std::string path) +{ + std::replace(path.begin(), path.end(), '\\', '/'); + + if (Error error = toError(navigationContext.reset(navigationContext.getRequirerIdentifier()))) + { + errorHandler.reportError(*error); + return Status::ErrorReported; + } + + if (Error error = navigateImpl(path)) + { + errorHandler.reportError(*error); + return Status::ErrorReported; + } + + return Status::Success; +} + +Error Navigator::navigateImpl(std::string_view path) +{ + PathType pathType = getPathType(path); + + if (pathType == PathType::Unsupported) + return "require path must start with a valid prefix: ./, ../, or @"; + + if (pathType == PathType::Aliased) + { + std::string alias = extractAlias(path); + std::transform( + alias.begin(), + alias.end(), + alias.begin(), + [](unsigned char c) + { + return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; + } + ); + + if (Error error = navigateToAndPopulateConfig(alias)) + return error; + + if (!config.aliases.contains(alias)) + { + if (alias != "self") + return "@" + alias + " is not a valid alias"; + + // If the alias is "@self", we reset to the requirer's context and + // navigate directly from there. + if (Error error = toError(navigationContext.reset(navigationContext.getRequirerIdentifier()))) + return error; + if (Error error = navigateThroughPath(path)) + return error; + + return std::nullopt; + } + + if (Error error = navigateToAlias(alias, config.aliases[alias].value)) + return error; + if (Error error = navigateThroughPath(path)) + return error; + } + + if (pathType == PathType::RelativeToCurrent || pathType == PathType::RelativeToParent) + { + if (Error error = toError(navigationContext.toParent())) + return error; + if (Error error = navigateThroughPath(path)) + return error; + } + + return std::nullopt; +} + +Error Navigator::navigateThroughPath(std::string_view path) +{ + std::pair components = splitPath(path); + if (path.size() >= 1 && path[0] == '@') + { + // If the path is aliased, we ignore the alias: this function assumes + // that navigation to an alias is handled by the caller. + components = splitPath(components.second); + } + + while (!(components.first.empty() && components.second.empty())) + { + if (components.first == "." || components.first.empty()) + { + components = splitPath(components.second); + continue; + } + else if (components.first == "..") + { + if (Error error = toError(navigationContext.toParent())) + return error; + } + else + { + if (Error error = toError(navigationContext.toChild(std::string{components.first}))) + return error; + } + components = splitPath(components.second); + } + + return std::nullopt; +} + +Error Navigator::navigateToAlias(const std::string& alias, const std::string& value) +{ + PathType pathType = getPathType(value); + + if (pathType == PathType::RelativeToCurrent || pathType == PathType::RelativeToParent) + { + if (Error error = navigateThroughPath(value)) + return error; + } + else if (pathType == PathType::Aliased) + { + return "@" + alias + " cannot point to other aliases"; + } + else + { + if (Error error = toError(navigationContext.jumpToAlias(value))) + return error; + } + + return std::nullopt; +} + +Error Navigator::navigateToAndPopulateConfig(const std::string& desiredAlias) +{ + while (!config.aliases.contains(desiredAlias)) + { + if (navigationContext.toParent() != NavigationContext::NavigateResult::Success) + break; + + if (navigationContext.isConfigPresent()) + { + std::optional configContents = navigationContext.getConfig(); + if (!configContents) + return "could not get configuration file contents"; + + Luau::ConfigOptions opts; + Luau::ConfigOptions::AliasOptions aliasOpts; + aliasOpts.configLocation = "unused"; + aliasOpts.overwriteAliases = false; + opts.aliasOptions = std::move(aliasOpts); + + if (Error error = Luau::parseConfig(*configContents, config, opts)) + return error; + } + }; + + return std::nullopt; +} + +} // namespace Luau::Require diff --git a/Require/Runtime/include/Luau/Require.h b/Require/Runtime/include/Luau/Require.h new file mode 100644 index 00000000..4e3ef555 --- /dev/null +++ b/Require/Runtime/include/Luau/Require.h @@ -0,0 +1,117 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "lua.h" + +#include + +//////////////////////////////////////////////////////////////////////////////// +// +// Require-by-string assumes that the context in which it is embedded adheres to +// a particular structure. +// +// Each component in a require path either represents a module or a directory. +// Modules contain Luau code, whereas directories serve solely as organizational +// units. For the purposes of navigation, both modules and directories are +// functionally identical: modules and directories can both have children, which +// could themselves be modules or directories, and both types can have at most +// one parent, which could also be either a module or a directory. +// +// Without more context, it is impossible to tell which components in a given +// path "./foo/bar/baz" are modules and which are directories. To provide this +// context, the require-by-string runtime library must be opened with a +// luarequire_Configuration object, which defines the navigation behavior of the +// context in which Luau is embedded. +// +// Calls to to_parent and to_child signal a move up or down the context's +// hierarchy. The context is expected to maintain an internal state so that +// when is_module_present is called, require-by-string can determine whether it +// is currently pointing at a module or a directory. +// +// In a conventional filesystem context, "modules" map either to *.luau files or +// to directories on disk containing an init.luau file, whereas "directories" +// map to directories on disk not containing an init.luau file. In a more +// abstract context, a module and a directory could be represented by any +// nestable code unit and organizational unit, respectively. +// +// Require-by-string's runtime behavior can be additionally be configured in +// configuration files, such as .luaurc files in a filesystem context. The +// presence of a configuration file in the current context is signaled by the +// is_config_present function. Both modules and directories can contain +// configuration files; however, note that a given configuration file's scope is +// limited to the descendants of the module or directory in which it resides. In +// other words, when searching for a relevant configuration file for a given +// module, the search begins at the module's parent context and proceeds up the +// hierarchy from there, resolving to the first configuration file found. +// +//////////////////////////////////////////////////////////////////////////////// + +enum luarequire_NavigateResult +{ + NAVIGATE_SUCCESS, + NAVIGATE_AMBIGUOUS, + NAVIGATE_NOT_FOUND +}; + +// Functions returning WRITE_SUCCESS are expected to set their size_out argument +// to the number of bytes written to the buffer. If WRITE_BUFFER_TOO_SMALL is +// returned, size_out should be set to the required buffer size. +enum luarequire_WriteResult +{ + WRITE_SUCCESS, + WRITE_BUFFER_TOO_SMALL, + WRITE_FAILURE +}; + +struct luarequire_Configuration +{ + // Returns whether requires are permitted from the given chunkname. + bool (*is_require_allowed)(lua_State* L, void* ctx, const char* requirer_chunkname); + + // Resets the internal state to point at the requirer module. + luarequire_NavigateResult (*reset)(lua_State* L, void* ctx, const char* requirer_chunkname); + + // Resets the internal state to point at an aliased module, given its exact + // path from a configuration file. This function is only called when an + // alias's path cannot be resolved relative to its configuration file. + luarequire_NavigateResult (*jump_to_alias)(lua_State* L, void* ctx, const char* path); + + // Navigates through the context by making mutations to the internal state. + luarequire_NavigateResult (*to_parent)(lua_State* L, void* ctx); + luarequire_NavigateResult (*to_child)(lua_State* L, void* ctx, const char* name); + + // Returns whether the context is currently pointing at a module. + bool (*is_module_present)(lua_State* L, void* ctx); + + // Provides the contents of the current module. This function is only called + // if is_module_present returns true. + luarequire_WriteResult (*get_contents)(lua_State* L, void* ctx, char* buffer, size_t buffer_size, size_t* size_out); + + // Provides a chunkname for the current module. This will be accessible + // through the debug library. This function is only called if + // is_module_present returns true. + luarequire_WriteResult (*get_chunkname)(lua_State* L, void* ctx, char* buffer, size_t buffer_size, size_t* size_out); + + // Provides a cache key representing the current module. This function is + // only called if is_module_present returns true. + luarequire_WriteResult (*get_cache_key)(lua_State* L, void* ctx, char* buffer, size_t buffer_size, size_t* size_out); + + // Returns whether a configuration file is present in the current context. + // If not, require-by-string will call to_parent until either a + // configuration file is present or NAVIGATE_FAILURE is returned (at root). + bool (*is_config_present)(lua_State* L, void* ctx); + + // Provides the contents of the configuration file in the current context. + // This function is only called if is_config_present returns true. + luarequire_WriteResult (*get_config)(lua_State* L, void* ctx, char* buffer, size_t buffer_size, size_t* size_out); + + // Executes the module and places the result on the stack. Returns the + // number of results placed on the stack. + int (*load)(lua_State* L, void* ctx, const char* chunkname, const char* contents); +}; + +// Populates function pointers in the given luarequire_Configuration. +typedef void (*luarequire_Configuration_init)(luarequire_Configuration* config); + +// Initializes the require library with the given configuration and context. +LUALIB_API void luaopen_require(lua_State* L, luarequire_Configuration_init config_init, void* ctx); diff --git a/Require/Runtime/src/Navigation.cpp b/Require/Runtime/src/Navigation.cpp new file mode 100644 index 00000000..79712062 --- /dev/null +++ b/Require/Runtime/src/Navigation.cpp @@ -0,0 +1,124 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Navigation.h" + +#include "Luau/Require.h" +#include "lua.h" +#include "lualib.h" + +static constexpr size_t initalFileBufferSize = 1024; +static constexpr size_t initalIdentifierBufferSize = 64; + +namespace Luau::Require +{ + +static NavigationContext::NavigateResult convertNavigateResult(luarequire_NavigateResult result) +{ + if (result == NAVIGATE_SUCCESS) + return NavigationContext::NavigateResult::Success; + if (result == NAVIGATE_AMBIGUOUS) + return NavigationContext::NavigateResult::Ambiguous; + + return NavigationContext::NavigateResult::NotFound; +} + +RuntimeNavigationContext::RuntimeNavigationContext(luarequire_Configuration* config, lua_State* L, void* ctx, std::string requirerChunkname) + : config(config) + , L(L) + , ctx(ctx) + , requirerChunkname(std::move(requirerChunkname)) +{ +} + +std::string RuntimeNavigationContext::getRequirerIdentifier() const +{ + return requirerChunkname; +} + +NavigationContext::NavigateResult RuntimeNavigationContext::reset(const std::string& requirerChunkname) +{ + return convertNavigateResult(config->reset(L, ctx, requirerChunkname.c_str())); +} + +NavigationContext::NavigateResult RuntimeNavigationContext::jumpToAlias(const std::string& path) +{ + return convertNavigateResult(config->jump_to_alias(L, ctx, path.c_str())); +} + +NavigationContext::NavigateResult RuntimeNavigationContext::toParent() +{ + return convertNavigateResult(config->to_parent(L, ctx)); +} + +NavigationContext::NavigateResult RuntimeNavigationContext::toChild(const std::string& component) +{ + return convertNavigateResult(config->to_child(L, ctx, component.c_str())); +} + +bool RuntimeNavigationContext::isModulePresent() const +{ + return config->is_module_present(L, ctx); +} + +std::optional RuntimeNavigationContext::getContents() const +{ + return getStringFromCWriter(config->get_contents, initalFileBufferSize); +} + +std::optional RuntimeNavigationContext::getChunkname() const +{ + return getStringFromCWriter(config->get_chunkname, initalIdentifierBufferSize); +} + +std::optional RuntimeNavigationContext::getCacheKey() const +{ + return getStringFromCWriter(config->get_cache_key, initalIdentifierBufferSize); +} + +bool RuntimeNavigationContext::isConfigPresent() const +{ + return config->is_config_present(L, ctx); +} + +std::optional RuntimeNavigationContext::getConfig() const +{ + return getStringFromCWriter(config->get_config, initalFileBufferSize); +} + +std::optional RuntimeNavigationContext::getStringFromCWriter( + luarequire_WriteResult (*writer)(lua_State* L, void* ctx, char* buffer, size_t buffer_size, size_t* size_out), + size_t initalBufferSize +) const +{ + std::string buffer; + buffer.resize(initalBufferSize); + + size_t size; + luarequire_WriteResult result = writer(L, ctx, buffer.data(), buffer.size(), &size); + if (result == WRITE_BUFFER_TOO_SMALL) + { + buffer.resize(size); + result = writer(L, ctx, buffer.data(), buffer.size(), &size); + } + + if (result == WRITE_SUCCESS) + { + buffer.resize(size); + return buffer; + } + + return std::nullopt; +} + + +RuntimeErrorHandler::RuntimeErrorHandler(lua_State* L) + : L(L) +{ +} + +void RuntimeErrorHandler::reportError(std::string message) +{ + luaL_errorL(L, "%s", message.c_str()); +} + +} // namespace Luau::Require diff --git a/Require/Runtime/src/Navigation.h b/Require/Runtime/src/Navigation.h new file mode 100644 index 00000000..35a836b3 --- /dev/null +++ b/Require/Runtime/src/Navigation.h @@ -0,0 +1,58 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/RequireNavigator.h" +#include "Luau/Require.h" + +struct lua_State; +struct luarequire_Configuration; + +namespace Luau::Require +{ + +class RuntimeNavigationContext : public NavigationContext +{ +public: + RuntimeNavigationContext(luarequire_Configuration* config, lua_State* L, void* ctx, std::string requirerChunkname); + + std::string getRequirerIdentifier() const override; + + // Navigation interface + NavigateResult reset(const std::string& requirerChunkname) override; + NavigateResult jumpToAlias(const std::string& path) override; + + NavigateResult toParent() override; + NavigateResult toChild(const std::string& component) override; + + bool isConfigPresent() const override; + std::optional getConfig() const override; + + // Custom capabilities + bool isModulePresent() const; + std::optional getContents() const; + std::optional getChunkname() const; + std::optional getCacheKey() const; + +private: + std::optional getStringFromCWriter( + luarequire_WriteResult (*writer)(lua_State* L, void* ctx, char* buffer, size_t buffer_size, size_t* size_out), + size_t initalBufferSize + ) const; + + luarequire_Configuration* config; + lua_State* L; + void* ctx; + std::string requirerChunkname; +}; + +class RuntimeErrorHandler : public ErrorHandler +{ +public: + RuntimeErrorHandler(lua_State* L); + void reportError(std::string message) override; + +private: + lua_State* L; +}; + +} // namespace Luau::Require diff --git a/Require/Runtime/src/Require.cpp b/Require/Runtime/src/Require.cpp new file mode 100644 index 00000000..8939ccc2 --- /dev/null +++ b/Require/Runtime/src/Require.cpp @@ -0,0 +1,52 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Require.h" + +#include "RequireImpl.h" + +#include "lua.h" +#include "lualib.h" + +static void validateConfig(lua_State* L, const luarequire_Configuration& config) +{ + if (!config.is_require_allowed) + luaL_error(L, "require configuration is missing required function pointer: is_require_allowed"); + if (!config.reset) + luaL_error(L, "require configuration is missing required function pointer: reset"); + if (!config.jump_to_alias) + luaL_error(L, "require configuration is missing required function pointer: jump_to_alias"); + if (!config.to_parent) + luaL_error(L, "require configuration is missing required function pointer: to_parent"); + if (!config.to_child) + luaL_error(L, "require configuration is missing required function pointer: to_child"); + if (!config.is_module_present) + luaL_error(L, "require configuration is missing required function pointer: is_module_present"); + if (!config.get_contents) + luaL_error(L, "require configuration is missing required function pointer: get_contents"); + if (!config.get_chunkname) + luaL_error(L, "require configuration is missing required function pointer: get_chunkname"); + if (!config.get_cache_key) + luaL_error(L, "require configuration is missing required function pointer: get_cache_key"); + if (!config.is_config_present) + luaL_error(L, "require configuration is missing required function pointer: is_config_present"); + if (!config.get_config) + luaL_error(L, "require configuration is missing required function pointer: get_config"); + if (!config.load) + luaL_error(L, "require configuration is missing required function pointer: load"); +} + +void luaopen_require(lua_State* L, luarequire_Configuration_init config_init, void* ctx) +{ + luarequire_Configuration* config = static_cast(lua_newuserdata(L, sizeof(luarequire_Configuration))); + if (!config) + luaL_error(L, "failed to allocate memory for require configuration"); + + config_init(config); + validateConfig(L, *config); + + lua_pushlightuserdata(L, ctx); + + // "require" captures config and ctx as upvalues + lua_pushcclosure(L, Luau::Require::lua_require, "require", 2); + lua_setglobal(L, "require"); +} diff --git a/Require/Runtime/src/RequireImpl.cpp b/Require/Runtime/src/RequireImpl.cpp new file mode 100644 index 00000000..1858abcb --- /dev/null +++ b/Require/Runtime/src/RequireImpl.cpp @@ -0,0 +1,146 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "RequireImpl.h" + +#include "Navigation.h" + +#include "Luau/RequireNavigator.h" +#include "Luau/Require.h" + +#include "lua.h" +#include "lualib.h" + +namespace Luau::Require +{ + +static const char* cacheTableKey = "_MODULES"; + +struct ResolvedRequire +{ + enum class Status + { + Cached, + ModuleRead, + ErrorReported + }; + + Status status; + std::string contents; + std::string chunkname; + std::string cacheKey; +}; + +static bool isCached(lua_State* L, const std::string& key) +{ + luaL_findtable(L, LUA_REGISTRYINDEX, cacheTableKey, 1); + lua_getfield(L, -1, key.c_str()); + bool cached = !lua_isnil(L, -1); + lua_pop(L, 2); + + return cached; +} + +static ResolvedRequire resolveRequire(luarequire_Configuration* lrc, lua_State* L, void* ctx, std::string path) +{ + lua_Debug ar; + lua_getinfo(L, 1, "s", &ar); + + if (!lrc->is_require_allowed(L, ctx, ar.source)) + luaL_error(L, "require is not supported in this context"); + + RuntimeNavigationContext navigationContext{lrc, L, ctx, ar.source}; + RuntimeErrorHandler errorHandler{L}; // Errors reported directly to lua_State. + + Navigator navigator(navigationContext, errorHandler); + + // Updates navigationContext while navigating through the given path. + Navigator::Status status = navigator.navigate(std::move(path)); + if (status == Navigator::Status::ErrorReported) + return {ResolvedRequire::Status::ErrorReported}; + + if (!navigationContext.isModulePresent()) + { + luaL_errorL(L, "no module present at resolved path"); + return ResolvedRequire{ResolvedRequire::Status::ErrorReported}; + } + + std::optional cacheKey = navigationContext.getCacheKey(); + if (!cacheKey) + { + errorHandler.reportError("could not get cache key for module"); + return ResolvedRequire{ResolvedRequire::Status::ErrorReported}; + } + + if (isCached(L, *cacheKey)) + { + // Put cached result on top of stack before returning. + lua_getfield(L, LUA_REGISTRYINDEX, cacheTableKey); + lua_getfield(L, -1, cacheKey->c_str()); + lua_remove(L, -2); + + return ResolvedRequire{ResolvedRequire::Status::Cached}; + } + + std::optional chunkname = navigationContext.getChunkname(); + if (!chunkname) + { + errorHandler.reportError("could not get chunkname for module"); + return ResolvedRequire{ResolvedRequire::Status::ErrorReported}; + } + + std::optional contents = navigationContext.getContents(); + if (!contents) + { + errorHandler.reportError("could not get contents for module"); + return ResolvedRequire{ResolvedRequire::Status::ErrorReported}; + } + + return ResolvedRequire{ + ResolvedRequire::Status::ModuleRead, + std::move(*contents), + std::move(*chunkname), + std::move(*cacheKey), + }; +} + +int lua_require(lua_State* L) +{ + luarequire_Configuration* lrc = static_cast(lua_touserdata(L, lua_upvalueindex(1))); + if (!lrc) + luaL_error(L, "unable to find require configuration"); + + void* ctx = lua_tolightuserdata(L, lua_upvalueindex(2)); + + const char* path = luaL_checkstring(L, 1); + + ResolvedRequire resolvedRequire = resolveRequire(lrc, L, ctx, path); + if (resolvedRequire.status == ResolvedRequire::Status::Cached) + return 1; + + int numResults = lrc->load(L, ctx, resolvedRequire.chunkname.c_str(), resolvedRequire.contents.c_str()); + if (numResults > 1) + luaL_error(L, "module must return a single value"); + + // Cache the result + if (numResults == 1) + { + // Initial stack state + // (-1) result + + lua_getfield(L, LUA_REGISTRYINDEX, cacheTableKey); + // (-2) result, (-1) cache table + + lua_pushvalue(L, -2); + // (-3) result, (-2) cache table, (-1) result + + lua_setfield(L, -2, resolvedRequire.cacheKey.c_str()); + // (-2) result, (-1) cache table + + lua_pop(L, 1); + // (-1) result + } + + return numResults; +} + +} // namespace Luau::Require diff --git a/Require/Runtime/src/RequireImpl.h b/Require/Runtime/src/RequireImpl.h new file mode 100644 index 00000000..9889acc7 --- /dev/null +++ b/Require/Runtime/src/RequireImpl.h @@ -0,0 +1,11 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +struct lua_State; + +namespace Luau::Require +{ + +int lua_require(lua_State* L); + +} // namespace Luau::Require diff --git a/Sources.cmake b/Sources.cmake index a40505df..d3453184 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -193,6 +193,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Frontend.h Analysis/include/Luau/Generalization.h Analysis/include/Luau/GlobalTypes.h + Analysis/include/Luau/InferPolarity.h Analysis/include/Luau/InsertionOrderedMap.h Analysis/include/Luau/Instantiation.h Analysis/include/Luau/Instantiation2.h @@ -206,6 +207,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/NonStrictTypeChecker.h Analysis/include/Luau/Normalize.h Analysis/include/Luau/OverloadResolution.h + Analysis/include/Luau/Polarity.h Analysis/include/Luau/Predicate.h Analysis/include/Luau/Quantify.h Analysis/include/Luau/RecursionCounter.h @@ -269,6 +271,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Frontend.cpp Analysis/src/Generalization.cpp Analysis/src/GlobalTypes.cpp + Analysis/src/InferPolarity.cpp Analysis/src/Instantiation.cpp Analysis/src/Instantiation2.cpp Analysis/src/IostreamHelpers.cpp @@ -394,11 +397,9 @@ target_sources(isocline PRIVATE target_sources(Luau.CLI.lib PRIVATE CLI/include/Luau/FileUtils.h CLI/include/Luau/Flags.h - CLI/include/Luau/Require.h CLI/src/FileUtils.cpp CLI/src/Flags.cpp - CLI/src/Require.cpp ) if(TARGET Luau.Repl.CLI) @@ -406,18 +407,27 @@ if(TARGET Luau.Repl.CLI) target_sources(Luau.Repl.CLI PRIVATE CLI/include/Luau/Coverage.h CLI/include/Luau/Profiler.h + CLI/include/Luau/ReplRequirer.h + CLI/include/Luau/RequirerUtils.h CLI/src/Coverage.cpp CLI/src/Profiler.cpp CLI/src/Repl.cpp CLI/src/ReplEntry.cpp + CLI/src/ReplRequirer.cpp + CLI/src/RequirerUtils.cpp ) endif() if(TARGET Luau.Analyze.CLI) # Luau.Analyze.CLI Sources target_sources(Luau.Analyze.CLI PRIVATE + CLI/include/Luau/AnalyzeRequirer.h + CLI/include/Luau/RequirerUtils.h + CLI/src/Analyze.cpp + CLI/src/AnalyzeRequirer.cpp + CLI/src/RequirerUtils.cpp ) endif() @@ -461,9 +471,10 @@ if(TARGET Luau.UnitTest) tests/Error.test.cpp tests/Fixture.cpp tests/Fixture.h - tests/FragmentAutocomplete.test.cpp + tests/FragmentAutocomplete.test.cpp tests/Frontend.test.cpp tests/Generalization.test.cpp + tests/InferPolarity.test.cpp tests/InsertionOrderedMap.test.cpp tests/Instantiation2.test.cpp tests/IostreamOptional.h @@ -550,10 +561,14 @@ if(TARGET Luau.CLI.Test) target_sources(Luau.CLI.Test PRIVATE CLI/include/Luau/Coverage.h CLI/include/Luau/Profiler.h + CLI/include/Luau/ReplRequirer.h + CLI/include/Luau/RequirerUtils.h CLI/src/Coverage.cpp CLI/src/Profiler.cpp CLI/src/Repl.cpp + CLI/src/ReplRequirer.cpp + CLI/src/RequirerUtils.cpp tests/RegisterCallbacks.h tests/RegisterCallbacks.cpp @@ -562,6 +577,29 @@ if(TARGET Luau.CLI.Test) tests/main.cpp) endif() +if(TARGET Luau.Require) + # Luau.Require Sources + target_sources(Luau.Require PRIVATE + Require/Runtime/include/Luau/Require.h + + Require/Runtime/src/Navigation.h + Require/Runtime/src/RequireImpl.h + + Require/Runtime/src/Navigation.cpp + Require/Runtime/src/Require.cpp + Require/Runtime/src/RequireImpl.cpp) +endif() + +if(TARGET Luau.RequireNavigator) + # Luau.Require Sources + target_sources(Luau.RequireNavigator PRIVATE + Require/Navigator/include/Luau/PathUtilities.h + Require/Navigator/include/Luau/RequireNavigator.h + + Require/Navigator/src/PathUtilities.cpp + Require/Navigator/src/RequireNavigator.cpp) +endif() + if(TARGET Luau.Web) # Luau.Web Sources target_sources(Luau.Web PRIVATE diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index 5a372aec..ddd9b731 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -6,8 +6,6 @@ #include "lstate.h" #include "lvm.h" -LUAU_DYNAMIC_FASTFLAG(LuauStackLimit) - #define CO_STATUS_ERROR -1 #define CO_STATUS_BREAK -2 @@ -237,20 +235,12 @@ static int coclose(lua_State* L) { lua_pushboolean(L, false); - if (DFFlag::LuauStackLimit) - { - if (co->status == LUA_ERRMEM) - lua_pushstring(L, LUA_MEMERRMSG); - else if (co->status == LUA_ERRERR) - lua_pushstring(L, LUA_ERRERRMSG); - else if (lua_gettop(co)) - lua_xmove(co, L, 1); // move error message - } - else - { - if (lua_gettop(co)) - lua_xmove(co, L, 1); // move error message - } + if (co->status == LUA_ERRMEM) + lua_pushstring(L, LUA_MEMERRMSG); + else if (co->status == LUA_ERRERR) + lua_pushstring(L, LUA_ERRERRMSG); + else if (lua_gettop(co)) + lua_xmove(co, L, 1); // move error message lua_resetthread(co); return 2; diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index f9fe30d6..950e85d6 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,9 +17,6 @@ #include -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauStackLimit, false) -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauPopIncompleteCi, false) - // keep max stack allocation request under 1GB #define MAX_STACK_SIZE (int(1024 / sizeof(TValue)) * 1024 * 1024) @@ -183,10 +180,10 @@ static void correctstack(lua_State* L, TValue* oldstack) void luaD_reallocstack(lua_State* L, int newsize, int fornewci) { // throw 'out of memory' error because space for a custom error message cannot be guaranteed here - if (DFFlag::LuauStackLimit && newsize > MAX_STACK_SIZE) + if (newsize > MAX_STACK_SIZE) { - // reallocation was performaed to setup a new CallInfo frame, which we have to remove - if (DFFlag::LuauPopIncompleteCi && fornewci) + // reallocation was performed to setup a new CallInfo frame, which we have to remove + if (fornewci) { CallInfo* cip = L->ci - 1; @@ -221,17 +218,7 @@ void luaD_reallocCI(lua_State* L, int newsize) void luaD_growstack(lua_State* L, int n) { - if (DFFlag::LuauPopIncompleteCi) - { - luaD_reallocstack(L, getgrownstacksize(L, n), 0); - } - else - { - if (n <= L->stacksize) // double size is enough? - luaD_reallocstack(L, 2 * L->stacksize, 0); - else - luaD_reallocstack(L, L->stacksize + n, 0); - } + luaD_reallocstack(L, getgrownstacksize(L, n), 0); } CallInfo* luaD_growCI(lua_State* L) diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index ce07d878..a355af34 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,8 +16,6 @@ #include -LUAU_DYNAMIC_FASTFLAG(LuauPopIncompleteCi) - // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -937,14 +935,7 @@ reentry: // note: this reallocs stack, but we don't need to VM_PROTECT this // this is because we're going to modify base/savedpc manually anyhow // crucially, we can't use ra/argtop after this line - if (DFFlag::LuauPopIncompleteCi) - { - luaD_checkstackfornewci(L, ccl->stacksize); - } - else - { - luaD_checkstack(L, ccl->stacksize); - } + luaD_checkstackfornewci(L, ccl->stacksize); LUAU_ASSERT(ci->top <= L->stack_last); @@ -3080,14 +3071,7 @@ int luau_precall(lua_State* L, StkId func, int nresults) L->base = ci->base; // Note: L->top is assigned externally - if (DFFlag::LuauPopIncompleteCi) - { - luaD_checkstackfornewci(L, ccl->stacksize); - } - else - { - luaD_checkstack(L, ccl->stacksize); - } + luaD_checkstackfornewci(L, ccl->stacksize); LUAU_ASSERT(ci->top <= L->stack_last); if (!ccl->isC) diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index e38ba4f9..959ea78d 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -4364,12 +4364,19 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_generic_on_argument_type_pack_vararg") { - check(R"( -local function foo(a: (...: T...) -> number) - return a(4, 5, 6) -end + // Caveat lector! This is actually invalid syntax! + // The correct syntax would be as follows: + // + // local function foo(a: (T...) -> number) + // + // We leave it as-written here because we still expect autocomplete to + // handle this code sensibly. + CheckResult result = check(R"( + local function foo(a: (...: T...) -> number) + return a(4, 5, 6) + end -foo(@1) + foo(@1) )"); const std::optional EXPECTED_INSERT = FFlag::LuauSolverV2 ? "function(...: number): number end" : "function(...): number end"; diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 0b343771..ad9cf7fd 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -34,7 +34,6 @@ void luaC_validate(lua_State* L); LUAU_FASTFLAG(LuauLibWhereErrorAutoreserve) LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) -LUAU_DYNAMIC_FASTFLAG(LuauStackLimit) LUAU_FASTFLAG(LuauVectorLibNativeDot) LUAU_DYNAMIC_FASTFLAG(LuauStringFormatFixC) @@ -756,8 +755,6 @@ TEST_CASE("Closure") TEST_CASE("Calls") { - ScopedFastFlag LuauStackLimit{DFFlag::LuauStackLimit, true}; - runConformance("calls.luau"); } @@ -797,8 +794,6 @@ static int cxxthrow(lua_State* L) TEST_CASE("PCall") { - ScopedFastFlag LuauStackLimit{DFFlag::LuauStackLimit, true}; - runConformance( "pcall.luau", [](lua_State* L) diff --git a/tests/FragmentAutocomplete.test.cpp b/tests/FragmentAutocomplete.test.cpp index 1f823063..35e78af0 100644 --- a/tests/FragmentAutocomplete.test.cpp +++ b/tests/FragmentAutocomplete.test.cpp @@ -45,6 +45,7 @@ LUAU_FASTFLAG(LuauIncrementalAutocompleteDemandBasedCloning) LUAU_FASTFLAG(LuauUserTypeFunTypecheck) LUAU_FASTFLAG(LuauBetterScopeSelection) LUAU_FASTFLAG(LuauBlockDiffFragmentSelection) +LUAU_FASTFLAG(LuauFragmentAcMemoryLeak) static std::optional nullCallback(std::string tag, std::optional ptr, std::optional contents) { @@ -87,6 +88,7 @@ struct FragmentAutocompleteFixtureImpl : BaseType ScopedFastFlag luauBetterScopeSelection{FFlag::LuauBetterScopeSelection, true}; ScopedFastFlag luauBlockDiffFragmentSelection{FFlag::LuauBlockDiffFragmentSelection, true}; ScopedFastFlag luauAutocompleteUsesModuleForTypeCompatibility{FFlag::LuauAutocompleteUsesModuleForTypeCompatibility, true}; + ScopedFastFlag luauFragmentAcMemoryLeak{FFlag::LuauFragmentAcMemoryLeak, true}; FragmentAutocompleteFixtureImpl() : BaseType(true) @@ -1331,6 +1333,7 @@ t FragmentAutocompleteStatusResult frag = Luau::tryFragmentAutocomplete(frontend, "game/A", Position{2, 1}, context, nullCallback); REQUIRE(frag.result); + REQUIRE(frag.result->incrementalModule); CHECK_EQ("game/A", frag.result->incrementalModule->name); CHECK_NE(frontend.moduleResolverForAutocomplete.getModule("game/A"), nullptr); CHECK_EQ(frontend.moduleResolver.getModule("game/A"), nullptr); diff --git a/tests/Generalization.test.cpp b/tests/Generalization.test.cpp index 1614001c..280909d5 100644 --- a/tests/Generalization.test.cpp +++ b/tests/Generalization.test.cpp @@ -15,6 +15,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauNonReentrantGeneralization) LUAU_FASTFLAG(DebugLuauForbidInternalTypes) LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) LUAU_FASTFLAG(LuauTrackInferredFunctionTypeFromCall) @@ -214,6 +215,92 @@ TEST_CASE_FIXTURE(GeneralizationFixture, "intersection_type_traversal_doesnt_cra generalize(intersectionType); } +TEST_CASE_FIXTURE(GeneralizationFixture, "('a) -> 'a") +{ + TypeId freeTy = freshType().first; + TypeId fnTy = arena.addType(FunctionType{arena.addTypePack({freeTy}), arena.addTypePack({freeTy})}); + + generalize(fnTy); + + CHECK("(a) -> a" == toString(fnTy)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "(t1, (t1 <: 'b)) -> () where t1 = ('a <: (t1 <: 'b) & {number} & {number})") +{ + ScopedFastFlag sff{FFlag::LuauNonReentrantGeneralization, true}; + + TableType tt; + tt.indexer = TableIndexer{builtinTypes.numberType, builtinTypes.numberType}; + TypeId numberArray = arena.addType(TableType{tt}); + + auto [aTy, aFree] = freshType(); + auto [bTy, bFree] = freshType(); + + aFree->upperBound = arena.addType(IntersectionType{{bTy, numberArray, numberArray}}); + bFree->lowerBound = aTy; + + TypeId functionTy = arena.addType(FunctionType{arena.addTypePack({aTy, bTy}), builtinTypes.emptyTypePack}); + + generalize(functionTy); + + CHECK("(unknown & {number}, unknown) -> ()" == toString(functionTy)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "(('a <: number | string)) -> string?") +{ + auto [aTy, aFree] = freshType(); + + aFree->upperBound = arena.addType(UnionType{{builtinTypes.numberType, builtinTypes.stringType}}); + + TypeId fnType = arena.addType(FunctionType{arena.addTypePack({aTy}), arena.addTypePack({builtinTypes.optionalStringType})}); + + generalize(fnType); + + CHECK("(number | string) -> string?" == toString(fnType)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "(('a <: {'b})) -> ()") +{ + ScopedFastFlag sff{FFlag::LuauNonReentrantGeneralization, true}; + + auto [aTy, aFree] = freshType(); + auto [bTy, bFree] = freshType(); + + TableType tt; + tt.indexer = TableIndexer{builtinTypes.numberType, bTy}; + + aFree->upperBound = arena.addType(tt); + + TypeId functionTy = arena.addType(FunctionType{arena.addTypePack({aTy}), builtinTypes.emptyTypePack}); + + generalize(functionTy); + + // The free type 'b is not replace with unknown because it appears in an + // invariant context. + CHECK("({a}) -> ()" == toString(functionTy)); +} + +TEST_CASE_FIXTURE(GeneralizationFixture, "(('b <: {t1}), ('a <: t1)) -> t1 where t1 = (('a <: t1) <: 'c)") +{ + auto [aTy, aFree] = freshType(); + auto [bTy, bFree] = freshType(); + auto [cTy, cFree] = freshType(); + + aFree->upperBound = cTy; + cFree->lowerBound = aTy; + + TableType tt; + tt.indexer = TableIndexer{builtinTypes.numberType, cTy}; + + bFree->upperBound = arena.addType(tt); + + TypeId functionTy = arena.addType(FunctionType{arena.addTypePack({bTy, aTy}), arena.addTypePack({cTy})}); + + generalize(functionTy); + + CHECK("({a}, a) -> a" == toString(functionTy)); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "generalization_traversal_should_re_traverse_unions_if_they_change_type") { // This test case should just not assert @@ -233,7 +320,7 @@ function foo() button.LayoutOrder = func(product) * dir end end - + function(mode) if mode == 'Name'then else diff --git a/tests/InferPolarity.test.cpp b/tests/InferPolarity.test.cpp new file mode 100644 index 00000000..aa681a0d --- /dev/null +++ b/tests/InferPolarity.test.cpp @@ -0,0 +1,51 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Fixture.h" + +#include "Luau/InferPolarity.h" +#include "Luau/Polarity.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" + +using namespace Luau; + +LUAU_FASTFLAG(LuauNonReentrantGeneralization); + +TEST_SUITE_BEGIN("InferPolarity"); + +TEST_CASE_FIXTURE(Fixture, "T where T = { m: (a) -> T }") +{ + ScopedFastFlag sff{FFlag::LuauNonReentrantGeneralization, true}; + + TypeArena arena; + ScopePtr globalScope = std::make_shared(builtinTypes->anyTypePack); + + TypeId tType = arena.addType(BlockedType{}); + TypeId aType = arena.addType(GenericType{globalScope.get(), "a"}); + + TypeId mType = arena.addType(FunctionType{ + TypeLevel{}, + /* generics */ {aType}, + /* genericPacks */ {}, + /* argPack */ arena.addTypePack({aType}), + /* retPack */ arena.addTypePack({tType}) + }); + + emplaceType( + asMutable(tType), + TableType{ + TableType::Props{{"m", Property::rw(mType)}}, + /* indexer */ std::nullopt, + TypeLevel{}, + globalScope.get(), + TableState::Sealed + } + ); + + inferGenericPolarities(NotNull{&arena}, NotNull{globalScope.get()}, tType); + + const GenericType* aGeneric = get(aType); + REQUIRE(aGeneric); + CHECK(aGeneric->polarity == Polarity::Negative); +} + +TEST_SUITE_END(); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 122be810..b402d33a 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -19,6 +19,8 @@ LUAU_FASTINT(LuauNormalizeUnionLimit) LUAU_FASTFLAG(LuauNormalizeLimitFunctionSet) LUAU_FASTFLAG(LuauSubtypingStopAtNormFail) LUAU_FASTFLAG(LuauNormalizationCatchMetatableCycles) +LUAU_FASTFLAG(LuauSubtypingEnableReasoningLimit) +LUAU_FASTFLAG(LuauTypePackDetectCycles) using namespace Luau; @@ -1193,6 +1195,7 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_limit_function_intersection_complexity") { + ScopedFastInt luauTypeInferRecursionLimit{FInt::LuauTypeInferRecursionLimit, 80}; ScopedFastInt luauNormalizeIntersectionLimit{FInt::LuauNormalizeIntersectionLimit, 50}; ScopedFastInt luauNormalizeUnionLimit{FInt::LuauNormalizeUnionLimit, 20}; ScopedFastFlag luauNormalizeLimitFunctionSet{FFlag::LuauNormalizeLimitFunctionSet, true}; @@ -1215,6 +1218,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_propagate_normalization_failures") ScopedFastInt luauNormalizeUnionLimit{FInt::LuauNormalizeUnionLimit, 20}; ScopedFastFlag luauNormalizeLimitFunctionSet{FFlag::LuauNormalizeLimitFunctionSet, true}; ScopedFastFlag luauSubtypingStopAtNormFail{FFlag::LuauSubtypingStopAtNormFail, true}; + ScopedFastFlag luauSubtypingEnableReasoningLimit{FFlag::LuauSubtypingEnableReasoningLimit, true}; CheckResult result = check(R"( function _(_,"").readu32(l0) @@ -1227,4 +1231,41 @@ _().readu32 %= _(_(_(_),_)) } #endif +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_flatten_type_pack_cycle") +{ + ScopedFastFlag sff[] = {{FFlag::LuauSolverV2, true}, {FFlag::LuauTypePackDetectCycles, true}}; + + // Note: if this stops throwing an exception, it means we fixed cycle construction and can replace with a regular check + CHECK_THROWS_AS( + check(R"( +function _(_).readu32() +repeat +until function() +end +return if _ then _,_(_) +end +_(_(_(_)),``) +do end + )"), + InternalCompilerError + ); +} +#if 0 +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_union_type_pack_cycle") +{ + ScopedFastFlag sff[] = {{FFlag::LuauSolverV2, true}, {FFlag::LuauTypePackDetectCycles, true}}; + + // Note: if this stops throwing an exception, it means we fixed cycle construction and can replace with a regular check + CHECK_THROWS_AS( + check(R"( +function _(_).n0(l32,...) +return ({n0=_,[_(if _ then _,nil)]=- _,[_(_(_))]=_,})[_],_(_) +end +_[_] ^= _(_(_)) + )"), + InternalCompilerError + ); +} +#endif + TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 6ca13aed..c0da303e 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -18,7 +18,6 @@ LUAU_FASTINT(LuauParseErrorLimit) LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauAllowComplexTypesInGenericParams) LUAU_FASTFLAG(LuauErrorRecoveryForTableTypes) -LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon) LUAU_FASTFLAG(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) LUAU_FASTFLAG(LuauAstTypeGroup3) LUAU_FASTFLAG(LuauFixDoBlockEndLocation) @@ -3940,7 +3939,6 @@ TEST_CASE_FIXTURE(Fixture, "function_name_has_correct_start_location") TEST_CASE_FIXTURE(Fixture, "stat_end_includes_semicolon_position") { - ScopedFastFlag _{FFlag::LuauExtendStatEndPosWithSemicolon, true}; AstStatBlock* block = parse(R"( local x = 1 local y = 2; diff --git a/tests/RequireByString.test.cpp b/tests/RequireByString.test.cpp index 59a1af3b..14447c6a 100644 --- a/tests/RequireByString.test.cpp +++ b/tests/RequireByString.test.cpp @@ -136,7 +136,11 @@ public: return luauDirAbs; } - luauDirRel += "/.."; + if (luauDirRel == ".") + luauDirRel = ".."; + else + luauDirRel += "/.."; + std::optional parentPath = getParentPath(luauDirAbs); REQUIRE_MESSAGE(parentPath, "Error getting Luau path"); luauDirAbs = *parentPath; @@ -355,6 +359,13 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireInitLua") assertOutputContainsAll({"true", "result from init.lua"}); } +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireSubmoduleUsingSelf") +{ + std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/nested_module_requirer"; + runProtectedRequire(path); + assertOutputContainsAll({"true", "result from submodule"}); +} + TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireWithFileAmbiguity") { std::string ambiguousPath = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/ambiguous_file_requirer"; @@ -462,13 +473,10 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "LoadStringRelative") TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireAbsolutePath") { -#ifdef _WIN32 - std::string absolutePath = "C:/an/absolute/path"; -#else std::string absolutePath = "/an/absolute/path"; -#endif + runProtectedRequire(absolutePath); - assertOutputContainsAll({"false", "cannot require an absolute path"}); + assertOutputContainsAll({"false", "require path must start with a valid prefix: ./, ../, or @"}); } TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireUnprefixedPath") @@ -478,13 +486,6 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireUnprefixedPath") assertOutputContainsAll({"false", "require path must start with a valid prefix: ./, ../, or @"}); } -TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithExtension") -{ - std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/dependency.luau"; - runProtectedRequire(path); - assertOutputContainsAll({"false", "error requiring module: consider removing the file extension"}); -} - TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithAlias") { std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/alias_requirer"; diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 58fb65ea..e1fe18f8 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -13,7 +13,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauStoreCSTData2) -LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon) LUAU_FASTFLAG(LuauAstTypeGroup3); LUAU_FASTFLAG(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) LUAU_FASTFLAG(LuauParseOptionalAsNode2) @@ -360,9 +359,9 @@ TEST_CASE("function_with_types_spaces_around_tokens") code = R"( function p (o: string, m: number, ...: any): string end )"; CHECK_EQ(code, transpile(code, {}, true).code); -// TODO(CLI-139347): re-enable test once colon positions are supported -// code = R"( function p(o : string, m: number, ...: any): string end )"; -// CHECK_EQ(code, transpile(code, {}, true).code); + // TODO(CLI-139347): re-enable test once colon positions are supported + // code = R"( function p(o : string, m: number, ...: any): string end )"; + // CHECK_EQ(code, transpile(code, {}, true).code); code = R"( function p(o: string, m: number, ...: any): string end )"; CHECK_EQ(code, transpile(code, {}, true).code); @@ -376,9 +375,9 @@ TEST_CASE("function_with_types_spaces_around_tokens") code = R"( function p(o: string, m: number, ...: any): string end )"; CHECK_EQ(code, transpile(code, {}, true).code); -// TODO(CLI-139347): re-enable test once colon positions are supported -// code = R"( function p(o: string, m: number, ... : any): string end )"; -// CHECK_EQ(code, transpile(code, {}, true).code); + // TODO(CLI-139347): re-enable test once colon positions are supported + // code = R"( function p(o: string, m: number, ... : any): string end )"; + // CHECK_EQ(code, transpile(code, {}, true).code); code = R"( function p(o: string, m: number, ...: any): string end )"; CHECK_EQ(code, transpile(code, {}, true).code); @@ -386,9 +385,9 @@ TEST_CASE("function_with_types_spaces_around_tokens") code = R"( function p(o: string, m: number, ...: any ): string end )"; CHECK_EQ(code, transpile(code, {}, true).code); -// TODO(CLI-139347): re-enable test once return type positions are supported -// code = R"( function p(o: string, m: number, ...: any) :string end )"; -// CHECK_EQ(code, transpile(code, {}, true).code); + // TODO(CLI-139347): re-enable test once return type positions are supported + // code = R"( function p(o: string, m: number, ...: any) :string end )"; + // CHECK_EQ(code, transpile(code, {}, true).code); code = R"( function p(o: string, m: number, ...: any): string end )"; CHECK_EQ(code, transpile(code, {}, true).code); @@ -858,7 +857,6 @@ TEST_CASE_FIXTURE(Fixture, "stmt_semicolon") { ScopedFastFlag flags[] = { {FFlag::LuauStoreCSTData2, true}, - {FFlag::LuauExtendStatEndPosWithSemicolon, true}, }; std::string code = R"( local test = 1; )"; CHECK_EQ(code, transpile(code, {}, true).code); @@ -869,6 +867,7 @@ TEST_CASE_FIXTURE(Fixture, "stmt_semicolon") TEST_CASE_FIXTURE(Fixture, "do_block_ending_with_semicolon") { + ScopedFastFlag sff{FFlag::LuauStoreCSTData2, true}; std::string code = R"( do return; @@ -881,7 +880,6 @@ TEST_CASE_FIXTURE(Fixture, "if_stmt_semicolon") { ScopedFastFlag flags[] = { {FFlag::LuauStoreCSTData2, true}, - {FFlag::LuauExtendStatEndPosWithSemicolon, true}, }; std::string code = R"( if init then @@ -895,7 +893,6 @@ TEST_CASE_FIXTURE(Fixture, "if_stmt_semicolon_2") { ScopedFastFlag flags[] = { {FFlag::LuauStoreCSTData2, true}, - {FFlag::LuauExtendStatEndPosWithSemicolon, true}, }; std::string code = R"( if (t < 1) then return c/2*t*t + b end; @@ -907,7 +904,6 @@ TEST_CASE_FIXTURE(Fixture, "for_loop_stmt_semicolon") { ScopedFastFlag flags[] = { {FFlag::LuauStoreCSTData2, true}, - {FFlag::LuauExtendStatEndPosWithSemicolon, true}, }; std::string code = R"( for i,v in ... do @@ -920,7 +916,6 @@ TEST_CASE_FIXTURE(Fixture, "while_do_semicolon") { ScopedFastFlag flags[] = { {FFlag::LuauStoreCSTData2, true}, - {FFlag::LuauExtendStatEndPosWithSemicolon, true}, }; std::string code = R"( while true do @@ -933,7 +928,6 @@ TEST_CASE_FIXTURE(Fixture, "function_definition_semicolon") { ScopedFastFlag flags[] = { {FFlag::LuauStoreCSTData2, true}, - {FFlag::LuauExtendStatEndPosWithSemicolon, true}, }; std::string code = R"( function foo() @@ -1068,7 +1062,7 @@ TEST_CASE_FIXTURE(Fixture, "type_lists_should_be_emitted_correctly") local c:()->()=function(): () end )" - : R"( + : R"( local a:(string,number,...string)->(string,...number)=function(a:string,b:number,...:string): (string,...number) end diff --git a/tests/TypeFunction.test.cpp b/tests/TypeFunction.test.cpp index b464d3d8..854a2021 100644 --- a/tests/TypeFunction.test.cpp +++ b/tests/TypeFunction.test.cpp @@ -966,6 +966,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "index_should_not_crash_on_cyclic_stuff2") CHECK(toString(requireTypeAlias("Keys")) == "number"); } +#if 0 +// CLI-148701 TEST_CASE_FIXTURE(BuiltinsFixture, "index_should_not_crash_on_cyclic_stuff3") { if (!FFlag::LuauSolverV2) @@ -993,6 +995,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "index_should_not_crash_on_cyclic_stuff3") LUAU_REQUIRE_NO_ERRORS(result); CHECK(toString(requireTypeAlias("Keys")) == "unknown"); } +#endif TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_function_works") { @@ -1123,8 +1126,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_function_works_on_function_metame if (!FFlag::LuauSolverV2) return; - ScopedFastFlag sff[] - { + ScopedFastFlag sff[]{ {FFlag::LuauIndexTypeFunctionFunctionMetamethods, true}, {FFlag::LuauIndexTypeFunctionImprovements, true}, }; @@ -1152,8 +1154,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_function_works_on_function_metame if (!FFlag::LuauSolverV2) return; - ScopedFastFlag sff[] - { + ScopedFastFlag sff[]{ {FFlag::LuauIndexTypeFunctionFunctionMetamethods, true}, {FFlag::LuauIndexTypeFunctionImprovements, true}, }; @@ -1236,7 +1237,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "index_type_function_rfc_alternative_section" type MyObject = {a: string} type MyObject2 = {a: string, b: number} - local function edgeCase(param: MyObject) + local function edgeCase(param: MyObject) type unknownType = index end )"); diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index a13de0f0..b5b494a9 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -556,10 +556,7 @@ TEST_CASE_FIXTURE(Fixture, "recursive_redefinition_reduces_rightfully") TEST_CASE_FIXTURE(BuiltinsFixture, "cli_142285_reduce_minted_union_func") { - ScopedFastFlag sffs[] = { - {FFlag::LuauSolverV2, true}, - {FFlag::LuauDontForgetToReduceUnionFunc, true} - }; + ScopedFastFlag sffs[] = {{FFlag::LuauSolverV2, true}, {FFlag::LuauDontForgetToReduceUnionFunc, true}}; CheckResult result = check(R"( local function middle(a: number, b: number): number @@ -582,7 +579,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cli_142285_reduce_minted_union_func") LUAU_REQUIRE_ERROR_COUNT(3, result); // There are three errors in the above snippet, but they should all be where // clause needed errors. - for (const auto& e: result.errors) + for (const auto& e : result.errors) CHECK(get(e)); } diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 5c2b752f..eafc78c1 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -3194,10 +3194,7 @@ TEST_CASE_FIXTURE(Fixture, "recursive_function_calls_should_not_use_the_generali TEST_CASE_FIXTURE(Fixture, "fuzz_unwind_mutually_recursive_union_type_func") { - ScopedFastFlag sffs[] = { - {FFlag::LuauSolverV2, true}, - {FFlag::LuauReduceUnionFollowUnionType, true} - }; + ScopedFastFlag sffs[] = {{FFlag::LuauSolverV2, true}, {FFlag::LuauReduceUnionFollowUnionType, true}}; // This block ends up minting a type like: // diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 41753b66..85b65b7c 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -15,6 +15,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauStatForInFix) TEST_SUITE_BEGIN("TypeInferLoops"); @@ -1249,10 +1250,25 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tryDispatchIterableFunction_under_constraine { CheckResult result = check(R"( local function foo(Instance) - for _, Child in next, Instance:GetChildren() do - end + for _, Child in next, Instance:GetChildren() do + end end )"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_surprising_iterator") +{ + ScopedFastFlag luauStatForInFix{FFlag::LuauStatForInFix, true}; + + CheckResult result = check(R"( +function broken(): (...() -> ()) + return function() end, function() end +end + +for p in broken() do print(p) end + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 38d349ad..037e87e4 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -58,7 +58,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end end )" - : R"( + : R"( function f(a:{fn:()->(a,b...)}): () if type(a) == 'boolean'then local a1:boolean=a @@ -77,7 +77,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end end )" - : R"( + : R"( function f(a:{fn:()->(unknown,...unknown)}): () if type(a) == 'boolean'then local a1:{fn:()->(unknown,...unknown)}&boolean=a @@ -96,7 +96,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end end )" - : R"( + : R"( function f(a:{fn:()->(unknown,...unknown)}): () if type(a) == 'boolean'then local a1:{fn:()->(unknown,...unknown)}&boolean=a diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index a3377146..98238bb6 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAG(LuauSkipNoRefineDuringRefinement) LUAU_FASTFLAG(LuauFunctionCallsAreNotNilable) LUAU_FASTFLAG(LuauDoNotLeakNilInRefinement) LUAU_FASTFLAG(LuauSimplyRefineNotNil) +LUAU_FASTFLAG(LuauWeakNilRefinementType) using namespace Luau; @@ -670,6 +671,8 @@ TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_e TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") { + ScopedFastFlag _{FFlag::LuauWeakNilRefinementType, true}; + CheckResult result = check(R"( local t: {string} = {"hello"} @@ -687,16 +690,8 @@ TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string"); // a ~= b CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "string?"); // a ~= b - if (FFlag::LuauSolverV2) - { - CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string?"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b - } + CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b } TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable") @@ -734,6 +729,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_narrow_to_vector") TEST_CASE_FIXTURE(BuiltinsFixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true") { + ScopedFastFlag sffs[] = { + {FFlag::LuauSimplyRefineNotNil, true}, + {FFlag::LuauWeakNilRefinementType, true}, + }; + CheckResult result = check(R"( local t = {"hello"} local v = t[2] @@ -752,37 +752,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "nonoptional_type_can_narrow_to_nil_if_sense_ LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauSolverV2) - { - // CLI-115281 Types produced by refinements do not consistently get simplified - CHECK_EQ("(nil & string)?", toString(requireTypeAtPosition({4, 24}))); // type(v) == "nil" + CHECK_EQ("nil", toString(requireTypeAtPosition({4, 24}))); // type(v) == "nil" + CHECK_EQ("string", toString(requireTypeAtPosition({6, 24}))); // type(v) ~= "nil" - if (FFlag::LuauSimplyRefineNotNil) - CHECK_EQ( - "string & ~nil", toString(requireTypeAtPosition({6, 24})) - ); // type(v) ~= "nil" - else - CHECK_EQ( - "(boolean | buffer | class | function | number | string | table | thread) & string", toString(requireTypeAtPosition({6, 24})) - ); // type(v) ~= "nil" - - CHECK_EQ("(nil & string)?", toString(requireTypeAtPosition({10, 24}))); // equivalent to type(v) == "nil" - - if (FFlag::LuauSimplyRefineNotNil) - CHECK_EQ("string & ~nil", toString(requireTypeAtPosition({12, 24}))); // equivalent to type(v) ~= "nil" - else - CHECK_EQ( - "(boolean | buffer | class | function | number | string | table | thread) & string", toString(requireTypeAtPosition({12, 24})) - ); // equivalent to type(v) ~= "nil" - } - else - { - CHECK_EQ("nil", toString(requireTypeAtPosition({4, 24}))); // type(v) == "nil" - CHECK_EQ("string", toString(requireTypeAtPosition({6, 24}))); // type(v) ~= "nil" - - CHECK_EQ("nil", toString(requireTypeAtPosition({10, 24}))); // equivalent to type(v) == "nil" - CHECK_EQ("string", toString(requireTypeAtPosition({12, 24}))); // equivalent to type(v) ~= "nil" - } + CHECK_EQ("nil", toString(requireTypeAtPosition({10, 24}))); // equivalent to type(v) == "nil" + CHECK_EQ("string", toString(requireTypeAtPosition({12, 24}))); // equivalent to type(v) ~= "nil" } TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_not_to_be_string") @@ -2454,6 +2428,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "remove_recursive_upper_bound_when_generalizi { ScopedFastFlag sffs[] = { {FFlag::LuauSolverV2, true}, + {FFlag::LuauWeakNilRefinementType, true}, {FFlag::DebugLuauEqSatSimplification, true}, }; @@ -2465,7 +2440,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "remove_recursive_upper_bound_when_generalizi end )")); - CHECK_EQ("(nil & string)?", toString(requireTypeAtPosition({4, 24}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({4, 24}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "nonnil_refinement_on_generic") @@ -2547,7 +2522,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "function_calls_are_not_nillable") return nil end )")); - } TEST_CASE_FIXTURE(BuiltinsFixture, "oss_1528_method_calls_are_not_nillable") @@ -2571,4 +2545,26 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "oss_1528_method_calls_are_not_nillable") )")); } +TEST_CASE_FIXTURE(Fixture, "oss_1687_equality_shouldnt_leak_nil") +{ + ScopedFastFlag _{FFlag::LuauWeakNilRefinementType, true}; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + --!strict + function returns_two(): number + return 2 + end + + function is_two(num: number): boolean + return num==2 + end + + local my_number = returns_two() + + if my_number == 2 then + is_two(my_number) --type error, my_number: number? + end + )")); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 83cd581a..8985cd9f 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -24,6 +24,7 @@ LUAU_FASTFLAG(LuauFixIndexerSubtypingOrdering) LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) LUAU_FASTFLAG(LuauTrackInteriorFreeTablesOnScope) LUAU_FASTFLAG(LuauFollowTableFreeze) +LUAU_FASTFLAG(LuauNonReentrantGeneralization) LUAU_FASTFLAG(LuauBidirectionalInferenceUpcast) LUAU_FASTFLAG(DebugLuauAssertOnForcedConstraint) LUAU_FASTFLAG(LuauSearchForRefineableType) @@ -701,7 +702,9 @@ TEST_CASE_FIXTURE(Fixture, "indexers_get_quantified_too") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauSolverV2) + if (FFlag::LuauSolverV2 && FFlag::LuauNonReentrantGeneralization) + CHECK("({a}) -> ()" == toString(requireType("swap"))); + else if (FFlag::LuauSolverV2) CHECK("({unknown}) -> ()" == toString(requireType("swap"))); else { @@ -1966,7 +1969,7 @@ TEST_CASE_FIXTURE(Fixture, "ok_to_set_nil_even_on_non_lvalue_base_expr") LUAU_REQUIRE_NO_ERRORS(check(R"( local function f( - t: {known_prop: boolean, [string]: number}, + t: {known_prop: boolean, [string]: number}, key: string ) t[key] = nil @@ -4675,7 +4678,7 @@ TEST_CASE_FIXTURE(Fixture, "table_writes_introduce_write_properties") if (!FFlag::LuauSolverV2) return; - ScopedFastFlag sff[] = {{FFlag::LuauSolverV2, true}}; + ScopedFastFlag sff[] = {{FFlag::LuauNonReentrantGeneralization, true}}; CheckResult result = check(R"( function oc(player, speaker) @@ -4687,9 +4690,9 @@ TEST_CASE_FIXTURE(Fixture, "table_writes_introduce_write_properties") LUAU_REQUIRE_NO_ERRORS(result); CHECK( - "({{ read Character: t1 }}, { Character: t1 }) -> () " + "({{ read Character: t1 }}, { Character: t1 }) -> () " "where " - "t1 = { read FindFirstChild: (t1, string) -> (a, b...) }" == toString(requireType("oc")) + "t1 = { read FindFirstChild: (t1, string) -> (a, ...unknown) }" == toString(requireType("oc")) ); } @@ -5515,7 +5518,6 @@ TEST_CASE_FIXTURE(Fixture, "missing_fields_bidirectional_inference") CHECK_EQ(toString(err->givenType), "{{ author: string }}"); CHECK_EQ(toString(err->wantedType), "{Book}"); CHECK_EQ(result.errors[1].location, Location{{3, 28}, {7, 9}}); - } TEST_CASE_FIXTURE(Fixture, "generic_index_syntax_bidirectional_infer_with_tables") diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index b81f1806..6a5292d1 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -33,6 +33,7 @@ LUAU_FASTFLAG(LuauImproveTypePathsInErrors) LUAU_FASTFLAG(LuauTypeCheckerAcceptNumberConcats) LUAU_FASTFLAG(LuauPreprocessTypestatedArgument) LUAU_FASTFLAG(LuauCacheInferencePerAstExpr) +LUAU_FASTFLAG(LuauMagicFreezeCheckBlocked) using namespace Luau; @@ -1972,10 +1973,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_local_before_declaration_ice") TEST_CASE_FIXTURE(Fixture, "fuzz_dont_double_solve_compound_assignment" * doctest::timeout(1.0)) { - ScopedFastFlag sffs[] = { - {FFlag::LuauSolverV2, true}, - {FFlag::LuauCacheInferencePerAstExpr, true} - }; + ScopedFastFlag sffs[] = {{FFlag::LuauSolverV2, true}, {FFlag::LuauCacheInferencePerAstExpr, true}}; CheckResult result = check(R"( local _ = {} @@ -2001,5 +1999,39 @@ TEST_CASE_FIXTURE(Fixture, "assert_allows_singleton_union_or_intersection") )")); } +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_table_freeze_constraint_solving") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauMagicFreezeCheckBlocked, true} + }; + LUAU_REQUIRE_NO_ERRORS(check(R"( + local f = table.freeze + f(table) + )")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_assert_table_freeze_constraint_solving") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauMagicFreezeCheckBlocked, true} + }; + // This is the original fuzzer version of the above issue. + CheckResult results = check(R"( + local function l0() + end + for l0 in false do + _ = (if _ then table) + repeat + do end + _:freeze(table) + until if _ then {{n0=_,},(_:freeze()._[_]),} + end + )"); + LUAU_REQUIRE_ERRORS(results); + LUAU_REQUIRE_NO_ERROR(results, ConstraintSolvingIncompleteError); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typestates.test.cpp b/tests/TypeInfer.typestates.test.cpp index 0bce7546..9729628f 100644 --- a/tests/TypeInfer.typestates.test.cpp +++ b/tests/TypeInfer.typestates.test.cpp @@ -83,7 +83,7 @@ TEST_CASE_FIXTURE(TypeStateFixture, "parameter_x_was_constrained_by_two_types") LUAU_REQUIRE_ERRORS(result); TypePackMismatch* tpm = get(result.errors[0]); - REQUIRE(tpm); + REQUIRE_MESSAGE(tpm, "Expected TypePackMismatch but got " << result.errors[0]); CHECK("string?" == toString(tpm->wantedTp)); CHECK("number | string" == toString(tpm->givenTp)); diff --git a/tests/require/with_config/.luaurc b/tests/require/with_config/.luaurc index 7e7abf18..2b64ad06 100644 --- a/tests/require/with_config/.luaurc +++ b/tests/require/with_config/.luaurc @@ -1,6 +1,6 @@ { "aliases": { - "dep": "this_should_be_overwritten_by_child_luaurc", - "otherdep": "src/other_dependency" + "dep": "./this_should_be_overwritten_by_child_luaurc", + "otherdep": "./src/other_dependency" } } diff --git a/tests/require/with_config/src/.luaurc b/tests/require/with_config/src/.luaurc index 90c6b646..27263339 100644 --- a/tests/require/with_config/src/.luaurc +++ b/tests/require/with_config/src/.luaurc @@ -1,6 +1,6 @@ { "aliases": { - "dep": "dependency", - "subdir": "subdirectory" + "dep": "./dependency", + "subdir": "./subdirectory" } } diff --git a/tests/require/without_config/luau/init.lua b/tests/require/without_config/luau/init.lua deleted file mode 100644 index 7e3680b2..00000000 --- a/tests/require/without_config/luau/init.lua +++ /dev/null @@ -1 +0,0 @@ -return {"wrong file"} diff --git a/tests/require/without_config/nested/init.luau b/tests/require/without_config/nested/init.luau new file mode 100644 index 00000000..75b9617d --- /dev/null +++ b/tests/require/without_config/nested/init.luau @@ -0,0 +1,2 @@ +local result = require("@self/submodule") +return result diff --git a/tests/require/without_config/nested/submodule.luau b/tests/require/without_config/nested/submodule.luau new file mode 100644 index 00000000..9221587e --- /dev/null +++ b/tests/require/without_config/nested/submodule.luau @@ -0,0 +1 @@ +return {"result from submodule"} diff --git a/tests/require/without_config/nested_module_requirer.luau b/tests/require/without_config/nested_module_requirer.luau new file mode 100644 index 00000000..fc8d5e79 --- /dev/null +++ b/tests/require/without_config/nested_module_requirer.luau @@ -0,0 +1,3 @@ +local result = require("./nested") +result[#result+1] = "required into module" +return result