diff --git a/Analysis/include/Luau/Connective.h b/Analysis/include/Luau/Connective.h new file mode 100644 index 00000000..c9daa0f9 --- /dev/null +++ b/Analysis/include/Luau/Connective.h @@ -0,0 +1,68 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Def.h" +#include "Luau/TypedAllocator.h" +#include "Luau/TypeVar.h" +#include "Luau/Variant.h" + +#include + +namespace Luau +{ + +struct Negation; +struct Conjunction; +struct Disjunction; +struct Equivalence; +struct Proposition; +using Connective = Variant; +using ConnectiveId = Connective*; // Can and most likely is nullptr. + +struct Negation +{ + ConnectiveId connective; +}; + +struct Conjunction +{ + ConnectiveId lhs; + ConnectiveId rhs; +}; + +struct Disjunction +{ + ConnectiveId lhs; + ConnectiveId rhs; +}; + +struct Equivalence +{ + ConnectiveId lhs; + ConnectiveId rhs; +}; + +struct Proposition +{ + DefId def; + TypeId discriminantTy; +}; + +template +const T* get(ConnectiveId connective) +{ + return get_if(connective); +} + +struct ConnectiveArena +{ + TypedAllocator allocator; + + ConnectiveId negation(ConnectiveId connective); + ConnectiveId conjunction(ConnectiveId lhs, ConnectiveId rhs); + ConnectiveId disjunction(ConnectiveId lhs, ConnectiveId rhs); + ConnectiveId equivalence(ConnectiveId lhs, ConnectiveId rhs); + ConnectiveId proposition(DefId def, TypeId discriminantTy); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 7f092f5b..4370d0cf 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -132,15 +132,16 @@ struct HasPropConstraint std::string prop; }; -struct RefinementConstraint +// result ~ if isSingleton D then ~D else unknown where D = discriminantType +struct SingletonOrTopTypeConstraint { - DefId def; + TypeId resultType; TypeId discriminantType; }; using ConstraintV = Variant; + HasPropConstraint, SingletonOrTopTypeConstraint>; struct Constraint { diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 6106717c..2bfc62f1 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Ast.h" +#include "Luau/Connective.h" #include "Luau/Constraint.h" #include "Luau/DataFlowGraphBuilder.h" #include "Luau/Module.h" @@ -26,11 +27,13 @@ struct DcrLogger; struct Inference { TypeId ty = nullptr; + ConnectiveId connective = nullptr; Inference() = default; - explicit Inference(TypeId ty) + explicit Inference(TypeId ty, ConnectiveId connective = nullptr) : ty(ty) + , connective(connective) { } }; @@ -38,11 +41,13 @@ struct Inference struct InferencePack { TypePackId tp = nullptr; + std::vector connectives; InferencePack() = default; - explicit InferencePack(TypePackId tp) + explicit InferencePack(TypePackId tp, const std::vector& connectives = {}) : tp(tp) + , connectives(connectives) { } }; @@ -61,6 +66,14 @@ struct ConstraintGraphBuilder // The root scope of the module we're generating constraints for. // This is null when the CGB is initially constructed. Scope* rootScope; + + // Constraints that go straight to the solver. + std::vector constraints; + + // Constraints that do not go to the solver right away. Other constraints + // will enqueue them during solving. + std::vector unqueuedConstraints; + // A mapping of AST node to TypeId. DenseHashMap astTypes{nullptr}; // A mapping of AST node to TypePackId. @@ -73,6 +86,7 @@ struct ConstraintGraphBuilder // Defining scopes for AST nodes. DenseHashMap astTypeAliasDefiningScopes{nullptr}; NotNull dfg; + ConnectiveArena connectiveArena; int recursionCount = 0; @@ -126,6 +140,8 @@ struct ConstraintGraphBuilder */ NotNull addConstraint(const ScopePtr& scope, std::unique_ptr c); + void applyRefinements(const ScopePtr& scope, Location location, ConnectiveId connective); + /** * The entry point to the ConstraintGraphBuilder. This will construct a set * of scopes, constraints, and free types that can be solved later. @@ -167,10 +183,10 @@ struct ConstraintGraphBuilder * surrounding context. Used to implement bidirectional type checking. * @return the type of the expression. */ - Inference check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}); + Inference check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}, bool forceSingleton = false); - Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType); - Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional expectedType); + Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton); + Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional expectedType, bool forceSingleton); Inference check(const ScopePtr& scope, AstExprLocal* local); Inference check(const ScopePtr& scope, AstExprGlobal* global); Inference check(const ScopePtr& scope, AstExprIndexName* indexName); @@ -180,6 +196,7 @@ struct ConstraintGraphBuilder Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); + std::tuple checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); TypePackId checkLValues(const ScopePtr& scope, AstArray exprs); @@ -243,16 +260,8 @@ struct ConstraintGraphBuilder void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program); }; -/** - * Collects a vector of borrowed constraints from the scope and all its child - * scopes. It is important to only call this function when you're done adding - * constraints to the scope or its descendants, lest the borrowed pointers - * become invalid due to a container reallocation. - * @param rootScope the root scope of the scope graph to collect constraints - * from. - * @return a list of pointers to constraints contained within the scope graph. - * None of these pointers should be null. +/** Borrow a vector of pointers from a vector of owning pointers to constraints. */ -std::vector> collectConstraints(NotNull rootScope); +std::vector> borrowConstraints(const std::vector& constraints); } // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 5cc63e65..7b89a278 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -76,8 +76,8 @@ struct ConstraintSolver DcrLogger* logger; - explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, ModuleName moduleName, NotNull moduleResolver, - std::vector requireCycles, DcrLogger* logger); + explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, + ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger); // Randomize the order in which to dispatch constraints void randomize(unsigned seed); @@ -110,7 +110,7 @@ struct ConstraintSolver bool tryDispatch(const FunctionCallConstraint& c, NotNull constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); - bool tryDispatch(const RefinementConstraint& c, NotNull constraint); + bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint); // for a, ... in some_table do // also handles __iter metamethod diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index f98442dd..b28c06a5 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -17,10 +17,8 @@ struct SingletonTypes; using ModulePtr = std::shared_ptr; -bool isSubtype( - TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop = true); -bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, - bool anyIsTop = true); +bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice); +bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice); class TypeIds { @@ -169,12 +167,26 @@ struct NormalizedStringType bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr); -// A normalized function type is either `never` (represented by `nullopt`) +// A normalized function type can be `never`, the top function type `function`, // or an intersection of function types. -// NOTE: type normalization can fail on function types with generics -// (e.g. because we do not support unions and intersections of generic type packs), -// so this type may contain `error`. -using NormalizedFunctionType = std::optional; +// +// NOTE: type normalization can fail on function types with generics (e.g. +// because we do not support unions and intersections of generic type packs), so +// this type may contain `error`. +struct NormalizedFunctionType +{ + NormalizedFunctionType(); + + bool isTop = false; + // TODO: Remove this wrapping optional when clipping + // FFlagLuauNegatedFunctionTypes. + std::optional parts; + + void resetToNever(); + void resetToTop(); + + bool isNever() const; +}; // A normalized generic/free type is a union, where each option is of the form (X & T) where // * X is either a free type or a generic @@ -234,12 +246,14 @@ struct NormalizedType NormalizedType(NotNull singletonTypes); - NormalizedType(const NormalizedType&) = delete; - NormalizedType(NormalizedType&&) = default; NormalizedType() = delete; ~NormalizedType() = default; + + NormalizedType(const NormalizedType&) = delete; + NormalizedType& operator=(const NormalizedType&) = delete; + + NormalizedType(NormalizedType&&) = default; NormalizedType& operator=(NormalizedType&&) = default; - NormalizedType& operator=(NormalizedType&) = delete; }; class Normalizer @@ -291,7 +305,7 @@ public: bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1); // ------- Negations - NormalizedType negateNormal(const NormalizedType& here); + std::optional negateNormal(const NormalizedType& here); TypeIds negateAll(const TypeIds& theres); TypeId negate(TypeId there); void subtractPrimitive(NormalizedType& here, TypeId ty); diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index ccf2964c..a26f506d 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -38,11 +38,6 @@ struct Scope std::unordered_map bindings; TypePackId returnType; std::optional varargPack; - // All constraints belonging to this scope. - std::vector constraints; - // Constraints belonging to this scope that are queued manually by other - // constraints. - std::vector unqueuedConstraints; TypeLevel level; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index ff2561e6..0200a719 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -34,7 +34,6 @@ struct ToStringOptions size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); ToStringNameMap nameMap; - std::optional DEPRECATED_nameMap; std::shared_ptr scope; // If present, module names will be added and types that are not available in scope will be marked as 'invalid' std::vector namedFunctionOverrideArgNames; // If present, named function argument names will be overridden }; @@ -42,7 +41,6 @@ struct ToStringOptions struct ToStringResult { std::string name; - ToStringNameMap DEPRECATED_nameMap; bool invalid = false; bool error = false; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 4f61eaa5..2eea191a 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -281,14 +281,14 @@ private: TypeId singletonType(bool value); TypeId singletonType(std::string value); - TypeIdPredicate mkTruthyPredicate(bool sense); + TypeIdPredicate mkTruthyPredicate(bool sense, TypeId emptySetTy); // TODO: Return TypeId only. std::optional filterMapImpl(TypeId type, TypeIdPredicate predicate); std::pair, bool> filterMap(TypeId type, TypeIdPredicate predicate); public: - std::pair, bool> pickTypesFromSense(TypeId type, bool sense); + std::pair, bool> pickTypesFromSense(TypeId type, bool sense, TypeId emptySetTy); private: TypeId unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes = true); diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 7409dbe7..085ee21b 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -35,7 +35,7 @@ std::vector flatten(TypeArena& arena, NotNull singletonT * identity) types. * @param types the input type list to reduce. * @returns the reduced type list. -*/ + */ std::vector reduceUnion(const std::vector& types); /** @@ -45,7 +45,7 @@ std::vector reduceUnion(const std::vector& types); * @param arena the type arena to allocate the new type in, if necessary * @param ty the type to remove nil from * @returns a type with nil removed, or nil itself if that were the only option. -*/ + */ TypeId stripNil(NotNull singletonTypes, TypeArena& arena, TypeId ty); } // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 70c12cb9..0ab4d474 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -115,6 +115,7 @@ struct PrimitiveTypeVar Number, String, Thread, + Function, }; Type type; @@ -504,14 +505,6 @@ struct NeverTypeVar { }; -// Invariant 1: there should never be a reason why such UseTypeVar exists without it mapping to another type. -// Invariant 2: UseTypeVar should always disappear across modules. -struct UseTypeVar -{ - DefId def; - NotNull scope; -}; - // ~T // TODO: Some simplification step that overwrites the type graph to make sure negation // types disappear from the user's view, and (?) a debug flag to disable that @@ -522,9 +515,9 @@ struct NegationTypeVar using ErrorTypeVar = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = + Unifiable::Variant; struct TypeVar final { @@ -644,13 +637,14 @@ public: const TypeId stringType; const TypeId booleanType; const TypeId threadType; + const TypeId functionType; const TypeId trueType; const TypeId falseType; const TypeId anyType; const TypeId unknownType; const TypeId neverType; const TypeId errorType; - const TypeId falsyType; // No type binding! + const TypeId falsyType; // No type binding! const TypeId truthyType; // No type binding! const TypePackId anyTypePack; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 7bf4d50b..b5f58d3c 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -61,7 +61,6 @@ struct Unifier ErrorVec errors; Location location; Variance variance = Covariant; - bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once. bool normalize; // Normalize unions and intersections if necessary bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels CountMismatch::Context ctx = CountMismatch::Arg; @@ -131,6 +130,7 @@ public: Unifier makeChildUnifier(); void reportError(TypeError err); + LUAU_NOINLINE void reportError(Location location, TypeErrorData data); private: bool isNonstrictMode() const; diff --git a/Analysis/include/Luau/Variant.h b/Analysis/include/Luau/Variant.h index 76812c9b..016c51f6 100644 --- a/Analysis/include/Luau/Variant.h +++ b/Analysis/include/Luau/Variant.h @@ -58,13 +58,15 @@ public: constexpr int tid = getTypeId(); typeId = tid; - new (&storage) TT(value); + new (&storage) TT(std::forward(value)); } Variant(const Variant& other) { + static constexpr FnCopy table[sizeof...(Ts)] = {&fnCopy...}; + typeId = other.typeId; - tableCopy[typeId](&storage, &other.storage); + table[typeId](&storage, &other.storage); } Variant(Variant&& other) @@ -192,7 +194,6 @@ private: return *static_cast(lhs) == *static_cast(rhs); } - static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy...}; static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove...}; static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor...}; diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index d4f8528f..3dcddba1 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -155,10 +155,6 @@ struct GenericTypeVarVisitor { return visit(ty); } - virtual bool visit(TypeId ty, const UseTypeVar& utv) - { - return visit(ty); - } virtual bool visit(TypeId ty, const NegationTypeVar& ntv) { return visit(ty); @@ -321,8 +317,6 @@ struct GenericTypeVarVisitor traverse(a); } } - else if (auto utv = get(ty)) - visit(ty, *utv); else if (auto ntv = get(ty)) visit(ty, *ntv); else if (!FFlag::LuauCompleteVisitor) diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 6051e117..67e3979a 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -17,7 +17,9 @@ LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauBuiltInMetatableNoBadSynthetic, false) +LUAU_FASTFLAG(LuauOptionalNextKey) LUAU_FASTFLAG(LuauReportShadowedTypeAlias) +LUAU_FASTFLAG(LuauNewLibraryTypeNames) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -276,18 +278,38 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); - // next(t: Table, i: K?) -> (K, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); - addGlobalBinding(typeChecker, "next", - arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); + if (FFlag::LuauOptionalNextKey) + { + // next(t: Table, i: K?) -> (K?, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); + TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(typeChecker, arena, genericK), genericV}}); + addGlobalBinding(typeChecker, "next", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); - TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, nextRetsTypePack}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); - // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) - addGlobalBinding(typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) + addGlobalBinding( + typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + } + else + { + // next(t: Table, i: K?) -> (K, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); + addGlobalBinding(typeChecker, "next", + arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); + + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); + + // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) + addGlobalBinding( + typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + } TypeId genericMT = arena.addType(GenericTypeVar{"MT"}); @@ -319,7 +341,12 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) if (TableTypeVar* ttv = getMutable(pair.second.typeId)) { if (!ttv->name) - ttv->name = toString(pair.first); + { + if (FFlag::LuauNewLibraryTypeNames) + ttv->name = "typeof(" + toString(pair.first) + ")"; + else + ttv->name = toString(pair.first); + } } } @@ -370,18 +397,38 @@ void registerBuiltinGlobals(Frontend& frontend) addGlobalBinding(frontend, "string", it->second.type, "@luau"); - // next(t: Table, i: K?) -> (K, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); - addGlobalBinding(frontend, "next", - arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); + if (FFlag::LuauOptionalNextKey) + { + // next(t: Table, i: K?) -> (K?, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); + TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(frontend, arena, genericK), genericV}}); + addGlobalBinding(frontend, "next", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); - TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.singletonTypes->nilType}}); + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, nextRetsTypePack}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.singletonTypes->nilType}}); - // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) - addGlobalBinding(frontend, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + // pairs(t: Table) -> ((Table, K?) -> (K?, V), Table, nil) + addGlobalBinding( + frontend, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + } + else + { + // next(t: Table, i: K?) -> (K, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); + addGlobalBinding(frontend, "next", + arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); + + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.singletonTypes->nilType}}); + + // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) + addGlobalBinding( + frontend, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + } TypeId genericMT = arena.addType(GenericTypeVar{"MT"}); @@ -413,7 +460,12 @@ void registerBuiltinGlobals(Frontend& frontend) if (TableTypeVar* ttv = getMutable(pair.second.typeId)) { if (!ttv->name) - ttv->name = toString(pair.first); + { + if (FFlag::LuauNewLibraryTypeNames) + ttv->name = "typeof(" + toString(pair.first) + ")"; + else + ttv->name = toString(pair.first); + } } } @@ -623,7 +675,7 @@ static std::optional> magicFunctionAssert( if (head.size() > 0) { - auto [ty, ok] = typechecker.pickTypesFromSense(head[0], true); + auto [ty, ok] = typechecker.pickTypesFromSense(head[0], true, typechecker.singletonTypes->nilType); if (FFlag::LuauUnknownAndNeverType) { if (get(*ty)) @@ -714,8 +766,7 @@ static bool dcrMagicFunctionPack(MagicFunctionCallContext context) result = arena->addType(UnionTypeVar{std::move(options)}); TypeId numberType = context.solver->singletonTypes->numberType; - TypeId packedTable = arena->addType( - TableTypeVar{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed}); + TypeId packedTable = arena->addType(TableTypeVar{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed}); TypePackId tableTypePack = arena->addTypePack({packedTable}); asMutable(context.result)->ty.emplace(tableTypePack); diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 85408919..86e1c7fc 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -62,7 +62,6 @@ struct TypeCloner void operator()(const LazyTypeVar& t); void operator()(const UnknownTypeVar& t); void operator()(const NeverTypeVar& t); - void operator()(const UseTypeVar& t); void operator()(const NegationTypeVar& t); }; @@ -338,12 +337,6 @@ void TypeCloner::operator()(const NeverTypeVar& t) defaultClone(t); } -void TypeCloner::operator()(const UseTypeVar& t) -{ - TypeId result = dest.addType(BoundTypeVar{follow(typeId)}); - seenTypes[typeId] = result; -} - void TypeCloner::operator()(const NegationTypeVar& t) { TypeId result = dest.addType(AnyTypeVar{}); diff --git a/Analysis/src/Connective.cpp b/Analysis/src/Connective.cpp new file mode 100644 index 00000000..114b5f2f --- /dev/null +++ b/Analysis/src/Connective.cpp @@ -0,0 +1,32 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Connective.h" + +namespace Luau +{ + +ConnectiveId ConnectiveArena::negation(ConnectiveId connective) +{ + return NotNull{allocator.allocate(Negation{connective})}; +} + +ConnectiveId ConnectiveArena::conjunction(ConnectiveId lhs, ConnectiveId rhs) +{ + return NotNull{allocator.allocate(Conjunction{lhs, rhs})}; +} + +ConnectiveId ConnectiveArena::disjunction(ConnectiveId lhs, ConnectiveId rhs) +{ + return NotNull{allocator.allocate(Disjunction{lhs, rhs})}; +} + +ConnectiveId ConnectiveArena::equivalence(ConnectiveId lhs, ConnectiveId rhs) +{ + return NotNull{allocator.allocate(Equivalence{lhs, rhs})}; +} + +ConnectiveId ConnectiveArena::proposition(DefId def, TypeId discriminantTy) +{ + return NotNull{allocator.allocate(Proposition{def, discriminantTy})}; +} + +} // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 455fc221..42dc07f6 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -52,6 +52,70 @@ static bool matchSetmetatable(const AstExprCall& call) return true; } +struct TypeGuard +{ + bool isTypeof; + AstExpr* target; + std::string type; +}; + +static std::optional matchTypeGuard(const AstExprBinary* binary) +{ + if (binary->op != AstExprBinary::CompareEq && binary->op != AstExprBinary::CompareNe) + return std::nullopt; + + AstExpr* left = binary->left; + AstExpr* right = binary->right; + if (right->is()) + std::swap(left, right); + + if (!right->is()) + return std::nullopt; + + AstExprCall* call = left->as(); + AstExprConstantString* string = right->as(); + if (!call || !string) + return std::nullopt; + + AstExprGlobal* callee = call->func->as(); + if (!callee) + return std::nullopt; + + if (callee->name != "type" && callee->name != "typeof") + return std::nullopt; + + if (call->args.size != 1) + return std::nullopt; + + return TypeGuard{ + /*isTypeof*/ callee->name == "typeof", + /*target*/ call->args.data[0], + /*type*/ std::string(string->value.data, string->value.size), + }; +} + +namespace +{ + +struct Checkpoint +{ + size_t offset; +}; + +Checkpoint checkpoint(const ConstraintGraphBuilder* cgb) +{ + return Checkpoint{cgb->constraints.size()}; +} + +template +void forEachConstraint(const Checkpoint& start, const Checkpoint& end, const ConstraintGraphBuilder* cgb, F f) +{ + for (size_t i = start.offset; i < end.offset; ++i) + f(cgb->constraints[i]); +} + +} // namespace + ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull moduleResolver, NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger, NotNull dfg) @@ -99,12 +163,107 @@ ScopePtr ConstraintGraphBuilder::childScope(AstNode* node, const ScopePtr& paren NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv) { - return NotNull{scope->constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}).get()}; + return NotNull{constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}).get()}; } NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr c) { - return NotNull{scope->constraints.emplace_back(std::move(c)).get()}; + return NotNull{constraints.emplace_back(std::move(c)).get()}; +} + +static void unionRefinements(const std::unordered_map& lhs, const std::unordered_map& rhs, + std::unordered_map& dest, NotNull arena) +{ + for (auto [def, ty] : lhs) + { + auto rhsIt = rhs.find(def); + if (rhsIt == rhs.end()) + continue; + + std::vector discriminants{{ty, rhsIt->second}}; + + if (auto destIt = dest.find(def); destIt != dest.end()) + discriminants.push_back(destIt->second); + + dest[def] = arena->addType(UnionTypeVar{std::move(discriminants)}); + } +} + +static void computeRefinement(const ScopePtr& scope, ConnectiveId connective, std::unordered_map* refis, bool sense, + NotNull arena, bool eq, std::vector* constraints) +{ + using RefinementMap = std::unordered_map; + + if (!connective) + return; + else if (auto negation = get(connective)) + return computeRefinement(scope, negation->connective, refis, !sense, arena, eq, constraints); + else if (auto conjunction = get(connective)) + { + RefinementMap lhsRefis; + RefinementMap rhsRefis; + + computeRefinement(scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, arena, eq, constraints); + computeRefinement(scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, arena, eq, constraints); + + if (!sense) + unionRefinements(lhsRefis, rhsRefis, *refis, arena); + } + else if (auto disjunction = get(connective)) + { + RefinementMap lhsRefis; + RefinementMap rhsRefis; + + computeRefinement(scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, arena, eq, constraints); + computeRefinement(scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, arena, eq, constraints); + + if (sense) + unionRefinements(lhsRefis, rhsRefis, *refis, arena); + } + else if (auto equivalence = get(connective)) + { + computeRefinement(scope, equivalence->lhs, refis, sense, arena, true, constraints); + computeRefinement(scope, equivalence->rhs, refis, sense, arena, true, constraints); + } + else if (auto proposition = get(connective)) + { + TypeId discriminantTy = proposition->discriminantTy; + if (!sense && !eq) + discriminantTy = arena->addType(NegationTypeVar{proposition->discriminantTy}); + else if (!sense && eq) + { + discriminantTy = arena->addType(BlockedTypeVar{}); + constraints->push_back(SingletonOrTopTypeConstraint{discriminantTy, proposition->discriminantTy}); + } + + if (auto it = refis->find(proposition->def); it != refis->end()) + (*refis)[proposition->def] = arena->addType(IntersectionTypeVar{{discriminantTy, it->second}}); + else + (*refis)[proposition->def] = discriminantTy; + } +} + +void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location location, ConnectiveId connective) +{ + if (!connective) + return; + + std::unordered_map refinements; + std::vector constraints; + computeRefinement(scope, connective, &refinements, /*sense*/ true, arena, /*eq*/ false, &constraints); + + for (auto [def, discriminantTy] : refinements) + { + std::optional defTy = scope->lookup(def); + if (!defTy) + ice->ice("Every DefId must map to a type!"); + + TypeId resultTy = arena->addType(IntersectionTypeVar{{*defTy, discriminantTy}}); + scope->dcrRefinements[def] = resultTy; + } + + for (auto& c : constraints) + addConstraint(scope, location, c); } void ConstraintGraphBuilder::visit(AstStatBlock* block) @@ -250,14 +409,33 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (value->is()) { - // HACK: we leave nil-initialized things floating under the assumption that they will later be populated. - // See the test TypeInfer/infer_locals_with_nil_value. - // Better flow awareness should make this obsolete. + // HACK: we leave nil-initialized things floating under the + // assumption that they will later be populated. + // + // See the test TypeInfer/infer_locals_with_nil_value. Better flow + // awareness should make this obsolete. if (!varTypes[i]) varTypes[i] = freshType(scope); } - else if (i == local->values.size - 1) + // Only function calls and vararg expressions can produce packs. All + // other expressions produce exactly one value. + else if (i != local->values.size - 1 || (!value->is() && !value->is())) + { + std::optional expectedType; + if (hasAnnotation) + expectedType = varTypes.at(i); + + TypeId exprType = check(scope, value, expectedType).ty; + if (i < varTypes.size()) + { + if (varTypes[i]) + addConstraint(scope, local->location, SubtypeConstraint{exprType, varTypes[i]}); + else + varTypes[i] = exprType; + } + } + else { std::vector expectedTypes; if (hasAnnotation) @@ -286,21 +464,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack}); } } - else - { - std::optional expectedType; - if (hasAnnotation) - expectedType = varTypes.at(i); - - TypeId exprType = check(scope, value, expectedType).ty; - if (i < varTypes.size()) - { - if (varTypes[i]) - addConstraint(scope, local->location, SubtypeConstraint{varTypes[i], exprType}); - else - varTypes[i] = exprType; - } - } } for (size_t i = 0; i < local->vars.size; ++i) @@ -377,6 +540,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn) TypeId ty = freshType(loopScope); loopScope->bindings[var] = Binding{ty, var->location}; variableTypes.push_back(ty); + + if (auto def = dfg->getDef(var)) + loopScope->dcrRefinements[*def] = ty; } // It is always ok to provide too few variables, so we give this pack a free tail. @@ -407,20 +573,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatRepeat* repeat) check(repeatScope, repeat->condition); } -void addConstraints(Constraint* constraint, NotNull scope) -{ - scope->constraints.reserve(scope->constraints.size() + scope->constraints.size()); - - for (const auto& c : scope->constraints) - constraint->dependencies.push_back(NotNull{c.get()}); - - for (const auto& c : scope->unqueuedConstraints) - constraint->dependencies.push_back(NotNull{c.get()}); - - for (NotNull childScope : scope->children) - addConstraints(constraint, childScope); -} - void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* function) { // Local @@ -438,12 +590,17 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* FunctionSignature sig = checkFunctionSignature(scope, function->func); sig.bodyScope->bindings[function->name] = Binding{sig.signature, function->func->location}; + auto start = checkpoint(this); checkFunctionBody(sig.bodyScope, function->func); + auto end = checkpoint(this); NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; std::unique_ptr c = std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature}); - addConstraints(c.get(), NotNull(sig.bodyScope.get())); + + forEachConstraint(start, end, this, [&c](const ConstraintPtr& constraint) { + c->dependencies.push_back(NotNull{constraint.get()}); + }); addConstraint(scope, std::move(c)); } @@ -511,12 +668,17 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct LUAU_ASSERT(functionType != nullptr); + auto start = checkpoint(this); checkFunctionBody(sig.bodyScope, function->func); + auto end = checkpoint(this); NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; std::unique_ptr c = std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature}); - addConstraints(c.get(), NotNull(sig.bodyScope.get())); + + forEachConstraint(start, end, this, [&c](const ConstraintPtr& constraint) { + c->dependencies.push_back(NotNull{constraint.get()}); + }); addConstraint(scope, std::move(c)); } @@ -569,14 +731,16 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement // TODO: Optimization opportunity, the interior scope of the condition could be // reused for the then body, so we don't need to refine twice. ScopePtr condScope = childScope(ifStatement->condition, scope); - check(condScope, ifStatement->condition, std::nullopt); + auto [_, connective] = check(condScope, ifStatement->condition, std::nullopt); ScopePtr thenScope = childScope(ifStatement->thenbody, scope); + applyRefinements(thenScope, Location{}, connective); visit(thenScope, ifStatement->thenbody); if (ifStatement->elsebody) { ScopePtr elseScope = childScope(ifStatement->elsebody, scope); + applyRefinements(elseScope, Location{}, connectiveArena.negation(connective)); visit(elseScope, ifStatement->elsebody); } } @@ -846,8 +1010,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector& expectedTypes) { TypeId fnType = check(scope, call->func).ty; - const size_t constraintIndex = scope->constraints.size(); - const size_t scopeIndex = scopes.size(); + auto startCheckpoint = checkpoint(this); std::vector args; @@ -876,8 +1039,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa } else { - const size_t constraintEndIndex = scope->constraints.size(); - const size_t scopeEndIndex = scopes.size(); + auto endCheckpoint = checkpoint(this); astOriginalCallTypes[call->func] = fnType; @@ -888,29 +1050,22 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa FunctionTypeVar ftv(TypeLevel{}, scope.get(), argPack, rets); TypeId inferredFnType = arena->addType(ftv); - scope->unqueuedConstraints.push_back( + unqueuedConstraints.push_back( std::make_unique(NotNull{scope.get()}, call->func->location, InstantiationConstraint{instantiatedType, fnType})); - NotNull ic(scope->unqueuedConstraints.back().get()); + NotNull ic(unqueuedConstraints.back().get()); - scope->unqueuedConstraints.push_back( + unqueuedConstraints.push_back( std::make_unique(NotNull{scope.get()}, call->func->location, SubtypeConstraint{inferredFnType, instantiatedType})); - NotNull sc(scope->unqueuedConstraints.back().get()); + NotNull sc(unqueuedConstraints.back().get()); // We force constraints produced by checking function arguments to wait // until after we have resolved the constraint on the function itself. // This ensures, for instance, that we start inferring the contents of // lambdas under the assumption that their arguments and return types // will be compatible with the enclosing function call. - for (size_t ci = constraintIndex; ci < constraintEndIndex; ++ci) - scope->constraints[ci]->dependencies.push_back(sc); - - for (size_t si = scopeIndex; si < scopeEndIndex; ++si) - { - for (auto& c : scopes[si].second->constraints) - { - c->dependencies.push_back(sc); - } - } + forEachConstraint(startCheckpoint, endCheckpoint, this, [sc](const ConstraintPtr& constraint) { + constraint->dependencies.push_back(sc); + }); addConstraint(scope, call->func->location, FunctionCallConstraint{ @@ -925,7 +1080,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa } } -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType, bool forceSingleton) { RecursionCounter counter{&recursionCount}; @@ -938,13 +1093,13 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st Inference result; if (auto group = expr->as()) - result = check(scope, group->expr, expectedType); + result = check(scope, group->expr, expectedType, forceSingleton); else if (auto stringExpr = expr->as()) - result = check(scope, stringExpr, expectedType); + result = check(scope, stringExpr, expectedType, forceSingleton); else if (expr->is()) result = Inference{singletonTypes->numberType}; else if (auto boolExpr = expr->as()) - result = check(scope, boolExpr, expectedType); + result = check(scope, boolExpr, expectedType, forceSingleton); else if (expr->is()) result = Inference{singletonTypes->nilType}; else if (auto local = expr->as()) @@ -999,8 +1154,11 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st return result; } -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton) { + if (forceSingleton) + return Inference{arena->addType(SingletonTypeVar{StringSingleton{std::string{string->value.data, string->value.size}}})}; + if (expectedType) { const TypeId expectedTy = follow(*expectedType); @@ -1020,12 +1178,15 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantSt return Inference{singletonTypes->stringType}; } -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional expectedType) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional expectedType, bool forceSingleton) { + const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType; + if (forceSingleton) + return Inference{singletonType}; + if (expectedType) { const TypeId expectedTy = follow(*expectedType); - const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType; if (get(expectedTy) || get(expectedTy)) { @@ -1045,8 +1206,8 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBo Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) { std::optional resultTy; - - if (auto def = dfg->getDef(local)) + auto def = dfg->getDef(local); + if (def) resultTy = scope->lookup(*def); if (!resultTy) @@ -1058,7 +1219,10 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* loc if (!resultTy) return Inference{singletonTypes->errorRecoveryType()}; // TODO: replace with ice, locals should never exist before its definition. - return Inference{*resultTy}; + if (def) + return Inference{*resultTy, connectiveArena.proposition(*def, singletonTypes->truthyType)}; + else + return Inference{*resultTy}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global) @@ -1107,20 +1271,23 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) { - TypeId operandType = check(scope, unary->expr).ty; + auto [operandType, connective] = check(scope, unary->expr); TypeId resultType = arena->addType(BlockedTypeVar{}); addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType}); - return Inference{resultType}; + + if (unary->op == AstExprUnary::Not) + return Inference{resultType, connectiveArena.negation(connective)}; + else + return Inference{resultType}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) { - TypeId leftType = check(scope, binary->left, expectedType).ty; - TypeId rightType = check(scope, binary->right, expectedType).ty; + auto [leftType, rightType, connective] = checkBinary(scope, binary, expectedType); TypeId resultType = arena->addType(BlockedTypeVar{}); addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType}); - return Inference{resultType}; + return Inference{resultType, std::move(connective)}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) @@ -1147,6 +1314,106 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssert return Inference{resolveType(scope, typeAssert->annotation)}; } +std::tuple ConstraintGraphBuilder::checkBinary( + const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) +{ + if (binary->op == AstExprBinary::And) + { + auto [leftType, leftConnective] = check(scope, binary->left, expectedType); + + ScopePtr rightScope = childScope(binary->right, scope); + applyRefinements(rightScope, binary->right->location, leftConnective); + auto [rightType, rightConnective] = check(rightScope, binary->right, expectedType); + + return {leftType, rightType, connectiveArena.conjunction(leftConnective, rightConnective)}; + } + else if (binary->op == AstExprBinary::Or) + { + auto [leftType, leftConnective] = check(scope, binary->left, expectedType); + + ScopePtr rightScope = childScope(binary->right, scope); + applyRefinements(rightScope, binary->right->location, connectiveArena.negation(leftConnective)); + auto [rightType, rightConnective] = check(rightScope, binary->right, expectedType); + + return {leftType, rightType, connectiveArena.disjunction(leftConnective, rightConnective)}; + } + else if (auto typeguard = matchTypeGuard(binary)) + { + TypeId leftType = check(scope, binary->left).ty; + TypeId rightType = check(scope, binary->right).ty; + + std::optional def = dfg->getDef(typeguard->target); + if (!def) + return {leftType, rightType, nullptr}; + + TypeId discriminantTy = singletonTypes->neverType; + if (typeguard->type == "nil") + discriminantTy = singletonTypes->nilType; + else if (typeguard->type == "string") + discriminantTy = singletonTypes->stringType; + else if (typeguard->type == "number") + discriminantTy = singletonTypes->numberType; + else if (typeguard->type == "boolean") + discriminantTy = singletonTypes->threadType; + else if (typeguard->type == "table") + discriminantTy = singletonTypes->neverType; // TODO: replace with top table type + else if (typeguard->type == "function") + discriminantTy = singletonTypes->functionType; + else if (typeguard->type == "userdata") + { + // For now, we don't really care about being accurate with userdata if the typeguard was using typeof + discriminantTy = singletonTypes->neverType; // TODO: replace with top class type + } + else if (!typeguard->isTypeof && typeguard->type == "vector") + discriminantTy = singletonTypes->neverType; // TODO: figure out a way to deal with this quirky type + else if (!typeguard->isTypeof) + discriminantTy = singletonTypes->neverType; + else if (auto typeFun = globalScope->lookupType(typeguard->type); typeFun && typeFun->typeParams.empty() && typeFun->typePackParams.empty()) + { + TypeId ty = follow(typeFun->type); + + // We're only interested in the root class of any classes. + if (auto ctv = get(ty); !ctv || !ctv->parent) + discriminantTy = ty; + } + + ConnectiveId proposition = connectiveArena.proposition(*def, discriminantTy); + if (binary->op == AstExprBinary::CompareEq) + return {leftType, rightType, proposition}; + else if (binary->op == AstExprBinary::CompareNe) + return {leftType, rightType, connectiveArena.negation(proposition)}; + else + ice->ice("matchTypeGuard should only return a Some under `==` or `~=`!"); + } + else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe) + { + TypeId leftType = check(scope, binary->left, expectedType, true).ty; + TypeId rightType = check(scope, binary->right, expectedType, true).ty; + + ConnectiveId leftConnective = nullptr; + if (auto def = dfg->getDef(binary->left)) + leftConnective = connectiveArena.proposition(*def, rightType); + + ConnectiveId rightConnective = nullptr; + if (auto def = dfg->getDef(binary->right)) + rightConnective = connectiveArena.proposition(*def, leftType); + + if (binary->op == AstExprBinary::CompareNe) + { + leftConnective = connectiveArena.negation(leftConnective); + rightConnective = connectiveArena.negation(rightConnective); + } + + return {leftType, rightType, connectiveArena.equivalence(leftConnective, rightConnective)}; + } + else + { + TypeId leftType = check(scope, binary->left, expectedType).ty; + TypeId rightType = check(scope, binary->right, expectedType).ty; + return {leftType, rightType, nullptr}; + } +} + TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray exprs) { std::vector types; @@ -1841,9 +2108,13 @@ std::vector> ConstraintGraphBuilder:: Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, InferencePack pack) { - auto [tp] = pack; + const auto& [tp, connectives] = pack; + ConnectiveId connective = nullptr; + if (!connectives.empty()) + connective = connectives[0]; + if (auto f = first(tp)) - return Inference{*f}; + return Inference{*f, connective}; TypeId typeResult = freshType(scope); TypePack onePack{{typeResult}, freshTypePack(scope)}; @@ -1851,7 +2122,7 @@ Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location lo addConstraint(scope, location, PackSubtypeConstraint{tp, oneTypePack}); - return Inference{typeResult}; + return Inference{typeResult, connective}; } void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) @@ -1897,19 +2168,14 @@ void ConstraintGraphBuilder::prepopulateGlobalScope(const ScopePtr& globalScope, program->visit(&gp); } -void collectConstraints(std::vector>& result, NotNull scope) -{ - for (const auto& c : scope->constraints) - result.push_back(NotNull{c.get()}); - - for (NotNull child : scope->children) - collectConstraints(result, child); -} - -std::vector> collectConstraints(NotNull rootScope) +std::vector> borrowConstraints(const std::vector& constraints) { std::vector> result; - collectConstraints(result, rootScope); + result.reserve(constraints.size()); + + for (const auto& c : constraints) + result.emplace_back(c.get()); + return result; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 5e43be0f..533652e2 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); -LUAU_FASTFLAG(LuauFixNameMaps) namespace Luau { @@ -27,34 +26,14 @@ namespace Luau { for (const auto& [k, v] : scope->bindings) { - if (FFlag::LuauFixNameMaps) - { - auto d = toString(v.typeId, opts); - printf("\t%s : %s\n", k.c_str(), d.c_str()); - } - else - { - auto d = toStringDetailed(v.typeId, opts); - opts.DEPRECATED_nameMap = d.DEPRECATED_nameMap; - printf("\t%s : %s\n", k.c_str(), d.name.c_str()); - } + auto d = toString(v.typeId, opts); + printf("\t%s : %s\n", k.c_str(), d.c_str()); } for (NotNull child : scope->children) dumpBindings(child, opts); } -static void dumpConstraints(NotNull scope, ToStringOptions& opts) -{ - for (const ConstraintPtr& c : scope->constraints) - { - printf("\t%s\n", toString(*c, opts).c_str()); - } - - for (NotNull child : scope->children) - dumpConstraints(child, opts); -} - static std::pair, std::vector> saturateArguments(TypeArena* arena, NotNull singletonTypes, const TypeFun& fn, const std::vector& rawTypeArguments, const std::vector& rawPackArguments) { @@ -219,12 +198,6 @@ size_t HashInstantiationSignature::operator()(const InstantiationSignature& sign return hash; } -void dump(NotNull rootScope, ToStringOptions& opts) -{ - printf("constraints:\n"); - dumpConstraints(rootScope, opts); -} - void dump(ConstraintSolver* cs, ToStringOptions& opts) { printf("constraints:\n"); @@ -248,17 +221,17 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) if (auto fcc = get(*c)) { for (NotNull inner : fcc->innerConstraints) - printf("\t\t\t%s\n", toString(*inner, opts).c_str()); + printf("\t ->\t\t%s\n", toString(*inner, opts).c_str()); } } } -ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, ModuleName moduleName, - NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) +ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, + ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) : arena(normalizer->arena) , singletonTypes(normalizer->singletonTypes) , normalizer(normalizer) - , constraints(collectConstraints(rootScope)) + , constraints(std::move(constraints)) , rootScope(rootScope) , currentModuleName(std::move(moduleName)) , moduleResolver(moduleResolver) @@ -267,7 +240,7 @@ ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull c : constraints) + for (NotNull c : this->constraints) { unsolvedConstraints.push_back(c); @@ -310,6 +283,8 @@ void ConstraintSolver::run() { printf("Starting solver\n"); dump(this, opts); + printf("Bindings:\n"); + dumpBindings(rootScope, opts); } if (FFlag::DebugLuauLogSolverToJson) @@ -440,8 +415,8 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*fcc, constraint); else if (auto hpc = get(*constraint)) success = tryDispatch(*hpc, constraint); - else if (auto rc = get(*constraint)) - success = tryDispatch(*rc, constraint); + else if (auto sottc = get(*constraint)) + success = tryDispatch(*sottc, constraint); else LUAU_ASSERT(false); @@ -633,13 +608,6 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullty.emplace(unionOfTypes(rightType, singletonTypes->booleanType, constraint->scope, false)); + { + TypeId leftFilteredTy = arena->addType(IntersectionTypeVar{{singletonTypes->falsyType, leftType}}); + + // TODO: normaliztion here should be replaced by a more limited 'simplification' + const NormalizedType* normalized = normalizer->normalize(arena->addType(UnionTypeVar{{leftFilteredTy, rightType}})); + + if (!normalized) + { + reportError(CodeTooComplex{}, constraint->location); + asMutable(resultType)->ty.emplace(errorRecoveryType()); + } + else + { + asMutable(resultType)->ty.emplace(normalizer->typeFromNormal(*normalized)); + } + unblock(resultType); return true; + } // Or evaluates to the LHS type if the LHS is truthy, and the RHS type if // LHS is falsey. case AstExprBinary::Op::Or: - asMutable(resultType)->ty.emplace(unionOfTypes(rightType, leftType, constraint->scope, true)); + { + TypeId rightFilteredTy = arena->addType(IntersectionTypeVar{{singletonTypes->truthyType, leftType}}); + + // TODO: normaliztion here should be replaced by a more limited 'simplification' + const NormalizedType* normalized = normalizer->normalize(arena->addType(UnionTypeVar{{rightFilteredTy, rightType}})); + + if (!normalized) + { + reportError(CodeTooComplex{}, constraint->location); + asMutable(resultType)->ty.emplace(errorRecoveryType()); + } + else + { + asMutable(resultType)->ty.emplace(normalizer->typeFromNormal(*normalized)); + } + unblock(resultType); return true; + } default: iceReporter.ice("Unhandled AstExprBinary::Op for binary operation", constraint->location); break; @@ -1148,6 +1148,17 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullsecond) + block(ic, blockedConstraint); + } + } + asMutable(*c.innerConstraints.at(0)).c = InstantiationConstraint{instantiatedType, *callMm}; asMutable(*c.innerConstraints.at(1)).c = SubtypeConstraint{inferredFnType, instantiatedType}; @@ -1180,6 +1191,17 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullsecond) + block(ic, blockedConstraint); + } + } + unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints)); asMutable(c.result)->ty.emplace(constraint->scope); } @@ -1274,25 +1296,18 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull constraint) +bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint) { - // TODO: Figure out exact details on when refinements need to be blocked. - // It's possible that it never needs to be, since we can just use intersection types with the discriminant type? + if (isBlocked(c.discriminantType)) + return false; - if (!constraint->scope->parent) - iceReporter.ice("No parent scope"); + TypeId followed = follow(c.discriminantType); - std::optional previousTy = constraint->scope->parent->lookup(c.def); - if (!previousTy) - iceReporter.ice("No previous type"); - - std::optional useTy = constraint->scope->lookup(c.def); - if (!useTy) - iceReporter.ice("The def is not bound to a type"); - - TypeId resultTy = follow(*useTy); - std::vector parts{*previousTy, c.discriminantType}; - asMutable(resultTy)->ty.emplace(std::move(parts)); + // `nil` is a singleton type too! There's only one value of type `nil`. + if (get(followed) || isNil(followed)) + *asMutable(c.resultType) = NegationTypeVar{c.discriminantType}; + else + *asMutable(c.resultType) = BoundTypeVar{singletonTypes->unknownType}; return true; } diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 67abbff1..b0f21737 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -2,6 +2,7 @@ #include "Luau/BuiltinDefinitions.h" LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAG(LuauOptionalNextKey) namespace Luau { @@ -13,16 +14,16 @@ declare bit32: { bor: (...number) -> number, bxor: (...number) -> number, btest: (number, ...number) -> boolean, - rrotate: (number, number) -> number, - lrotate: (number, number) -> number, - lshift: (number, number) -> number, - arshift: (number, number) -> number, - rshift: (number, number) -> number, - bnot: (number) -> number, - extract: (number, number, number?) -> number, - replace: (number, number, number, number?) -> number, - countlz: (number) -> number, - countrz: (number) -> number, + rrotate: (x: number, disp: number) -> number, + lrotate: (x: number, disp: number) -> number, + lshift: (x: number, disp: number) -> number, + arshift: (x: number, disp: number) -> number, + rshift: (x: number, disp: number) -> number, + bnot: (x: number) -> number, + extract: (n: number, field: number, width: number?) -> number, + replace: (n: number, v: number, field: number, width: number?) -> number, + countlz: (n: number) -> number, + countrz: (n: number) -> number, } declare math: { @@ -93,9 +94,9 @@ type DateTypeResult = { } declare os: { - time: (DateTypeArg?) -> number, - date: (string?, number?) -> DateTypeResult | string, - difftime: (DateTypeResult | number, DateTypeResult | number) -> number, + time: (time: DateTypeArg?) -> number, + date: (formatString: string?, time: number?) -> DateTypeResult | string, + difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number, clock: () -> number, } @@ -126,7 +127,7 @@ declare function rawlen(obj: {[K]: V} | string): number declare function setfenv(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)? -declare function ipairs(tab: {V}): (({V}, number) -> (number, V), {V}, number) +-- TODO: place ipairs definition here with removal of FFlagLuauOptionalNextKey declare function pcall(f: (A...) -> R..., ...: A...): (boolean, R...) @@ -145,51 +146,51 @@ declare function loadstring(src: string, chunkname: string?): (((A...) -> declare function newproxy(mt: boolean?): any declare coroutine: { - create: ((A...) -> R...) -> thread, - resume: (thread, A...) -> (boolean, R...), + create: (f: (A...) -> R...) -> thread, + resume: (co: thread, A...) -> (boolean, R...), running: () -> thread, - status: (thread) -> "dead" | "running" | "normal" | "suspended", + status: (co: thread) -> "dead" | "running" | "normal" | "suspended", -- FIXME: This technically returns a function, but we can't represent this yet. - wrap: ((A...) -> R...) -> any, + wrap: (f: (A...) -> R...) -> any, yield: (A...) -> R..., isyieldable: () -> boolean, - close: (thread) -> (boolean, any) + close: (co: thread) -> (boolean, any) } declare table: { - concat: ({V}, string?, number?, number?) -> string, - insert: (({V}, V) -> ()) & (({V}, number, V) -> ()), - maxn: ({V}) -> number, - remove: ({V}, number?) -> V?, - sort: ({V}, ((V, V) -> boolean)?) -> (), - create: (number, V?) -> {V}, - find: ({V}, V, number?) -> number?, + concat: (t: {V}, sep: string?, i: number?, j: number?) -> string, + insert: ((t: {V}, value: V) -> ()) & ((t: {V}, pos: number, value: V) -> ()), + maxn: (t: {V}) -> number, + remove: (t: {V}, number?) -> V?, + sort: (t: {V}, comp: ((V, V) -> boolean)?) -> (), + create: (count: number, value: V?) -> {V}, + find: (haystack: {V}, needle: V, init: number?) -> number?, - unpack: ({V}, number?, number?) -> ...V, + unpack: (list: {V}, i: number?, j: number?) -> ...V, pack: (...V) -> { n: number, [number]: V }, - getn: ({V}) -> number, - foreach: ({[K]: V}, (K, V) -> ()) -> (), + getn: (t: {V}) -> number, + foreach: (t: {[K]: V}, f: (K, V) -> ()) -> (), foreachi: ({V}, (number, V) -> ()) -> (), - move: ({V}, number, number, number, {V}?) -> {V}, - clear: ({[K]: V}) -> (), + move: (src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V}, + clear: (table: {[K]: V}) -> (), - isfrozen: ({[K]: V}) -> boolean, + isfrozen: (t: {[K]: V}) -> boolean, } declare debug: { - info: ((thread, number, string) -> R...) & ((number, string) -> R...) & (((A...) -> R1..., string) -> R2...), - traceback: ((string?, number?) -> string) & ((thread, string?, number?) -> string), + info: ((thread: thread, level: number, options: string) -> R...) & ((level: number, options: string) -> R...) & ((func: (A...) -> R1..., options: string) -> R2...), + traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string), } declare utf8: { char: (...number) -> string, charpattern: string, - codes: (string) -> ((string, number) -> (number, number), string, number), - codepoint: (string, number?, number?) -> ...number, - len: (string, number?, number?) -> (number?, number?), - offset: (string, number?, number?) -> number, + codes: (str: string) -> ((string, number) -> (number, number), string, number), + codepoint: (str: string, i: number?, j: number?) -> ...number, + len: (s: string, i: number?, j: number?) -> (number?, number?), + offset: (s: string, n: number?, i: number?) -> number, } -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. @@ -207,6 +208,11 @@ std::string getBuiltinDefinitionSource() else result += "declare function error(message: T, level: number?)\n"; + if (FFlag::LuauOptionalNextKey) + result += "declare function ipairs(tab: {V}): (({V}, number) -> (number?, V), {V}, number)\n"; + else + result += "declare function ipairs(tab: {V}): (({V}, number) -> (number, V), {V}, number)\n"; + return result; } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 39e6428d..22a9ecfa 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -955,7 +955,8 @@ ModulePtr Frontend::check( cgb.visit(sourceModule.root); result->errors = std::move(cgb.errors); - ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), sourceModule.name, NotNull(&moduleResolver), requireCycles, logger.get()}; + ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, NotNull(&moduleResolver), + requireCycles, logger.get()}; if (options.randomizeConstraintResolutionSeed) cs.randomize(*options.randomizeConstraintResolutionSeed); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 0412f007..62674aa8 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -14,9 +14,7 @@ #include -LUAU_FASTFLAG(LuauAnyifyModuleReturnGenerics) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAGVARIABLE(LuauForceExportSurfacesToBeNormal, false); LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess, false); LUAU_FASTFLAG(LuauSubstitutionReentrant); LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution); @@ -222,18 +220,6 @@ void Module::clonePublicInterface(NotNull singletonTypes, Intern } } - if (!FFlag::LuauAnyifyModuleReturnGenerics) - { - for (TypeId ty : returnType) - { - if (get(follow(ty))) - { - auto t = asMutable(ty); - t->ty = AnyTypeVar{}; - } - } - } - for (auto& [name, ty] : declaredGlobals) { if (FFlag::LuauClonePublicInterfaceLess) diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 5ef4b7e7..21e9f787 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -7,9 +7,9 @@ #include "Luau/Clone.h" #include "Luau/Common.h" +#include "Luau/RecursionCounter.h" #include "Luau/TypeVar.h" #include "Luau/Unifier.h" -#include "Luau/VisitTypeVar.h" LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) @@ -20,6 +20,7 @@ LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false); LUAU_FASTFLAGVARIABLE(LuauNegatedStringSingletons, false); +LUAU_FASTFLAGVARIABLE(LuauNegatedFunctionTypes, false); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauOverloadedFunctionSubtypingPerf); @@ -206,6 +207,28 @@ bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& s return true; } +NormalizedFunctionType::NormalizedFunctionType() + : parts(FFlag::LuauNegatedFunctionTypes ? std::optional{TypeIds{}} : std::nullopt) +{ +} + +void NormalizedFunctionType::resetToTop() +{ + isTop = true; + parts.emplace(); +} + +void NormalizedFunctionType::resetToNever() +{ + isTop = false; + parts.emplace(); +} + +bool NormalizedFunctionType::isNever() const +{ + return !isTop && (!parts || parts->empty()); +} + NormalizedType::NormalizedType(NotNull singletonTypes) : tops(singletonTypes->neverType) , booleans(singletonTypes->neverType) @@ -220,8 +243,8 @@ NormalizedType::NormalizedType(NotNull singletonTypes) static bool isInhabited(const NormalizedType& norm) { return !get(norm.tops) || !get(norm.booleans) || !norm.classes.empty() || !get(norm.errors) || - !get(norm.nils) || !get(norm.numbers) || !norm.strings.isNever() || - !get(norm.threads) || norm.functions || !norm.tables.empty() || !norm.tyvars.empty(); + !get(norm.nils) || !get(norm.numbers) || !norm.strings.isNever() || !get(norm.threads) || + !norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty(); } static int tyvarIndex(TypeId ty) @@ -317,10 +340,14 @@ static bool isNormalizedThread(TypeId ty) static bool areNormalizedFunctions(const NormalizedFunctionType& tys) { - if (tys) - for (TypeId ty : *tys) + if (tys.parts) + { + for (TypeId ty : *tys.parts) + { if (!get(ty) && !get(ty)) return false; + } + } return true; } @@ -420,7 +447,7 @@ void Normalizer::clearNormal(NormalizedType& norm) norm.strings.resetToNever(); norm.threads = singletonTypes->neverType; norm.tables.clear(); - norm.functions = std::nullopt; + norm.functions.resetToNever(); norm.tyvars.clear(); } @@ -809,20 +836,28 @@ std::optional Normalizer::unionOfFunctions(TypeId here, TypeId there) void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) { - if (!theres) + if (FFlag::LuauNegatedFunctionTypes) + { + if (heres.isTop) + return; + if (theres.isTop) + heres.resetToTop(); + } + + if (theres.isNever()) return; TypeIds tmps; - if (!heres) + if (heres.isNever()) { - tmps.insert(theres->begin(), theres->end()); - heres = std::move(tmps); + tmps.insert(theres.parts->begin(), theres.parts->end()); + heres.parts = std::move(tmps); return; } - for (TypeId here : *heres) - for (TypeId there : *theres) + for (TypeId here : *heres.parts) + for (TypeId there : *theres.parts) { if (std::optional fun = unionOfFunctions(here, there)) tmps.insert(*fun); @@ -830,28 +865,28 @@ void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedF tmps.insert(singletonTypes->errorRecoveryType(there)); } - heres = std::move(tmps); + heres.parts = std::move(tmps); } void Normalizer::unionFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there) { - if (!heres) + if (heres.isNever()) { TypeIds tmps; tmps.insert(there); - heres = std::move(tmps); + heres.parts = std::move(tmps); return; } TypeIds tmps; - for (TypeId here : *heres) + for (TypeId here : *heres.parts) { if (std::optional fun = unionOfFunctions(here, there)) tmps.insert(*fun); else tmps.insert(singletonTypes->errorRecoveryType(there)); } - heres = std::move(tmps); + heres.parts = std::move(tmps); } void Normalizer::unionTablesWithTable(TypeIds& heres, TypeId there) @@ -1004,6 +1039,11 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor here.strings.resetToString(); else if (ptv->type == PrimitiveTypeVar::Thread) here.threads = there; + else if (ptv->type == PrimitiveTypeVar::Function) + { + LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); + here.functions.resetToTop(); + } else LUAU_ASSERT(!"Unreachable"); } @@ -1036,8 +1076,11 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor else if (const NegationTypeVar* ntv = get(there)) { const NormalizedType* thereNormal = normalize(ntv->ty); - NormalizedType tn = negateNormal(*thereNormal); - if (!unionNormals(here, tn)) + std::optional tn = negateNormal(*thereNormal); + if (!tn) + return false; + + if (!unionNormals(here, *tn)) return false; } else @@ -1053,7 +1096,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor // ------- Negations -NormalizedType Normalizer::negateNormal(const NormalizedType& here) +std::optional Normalizer::negateNormal(const NormalizedType& here) { NormalizedType result{singletonTypes}; if (!get(here.tops)) @@ -1092,10 +1135,24 @@ NormalizedType Normalizer::negateNormal(const NormalizedType& here) result.threads = get(here.threads) ? singletonTypes->threadType : singletonTypes->neverType; + /* + * Things get weird and so, so complicated if we allow negations of + * arbitrary function types. Ordinary code can never form these kinds of + * types, so we decline to negate them. + */ + if (FFlag::LuauNegatedFunctionTypes) + { + if (here.functions.isNever()) + result.functions.resetToTop(); + else if (here.functions.isTop) + result.functions.resetToNever(); + else + return std::nullopt; + } + // TODO: negating tables - // TODO: negating functions // TODO: negating tyvars? - + return result; } @@ -1142,21 +1199,25 @@ void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty) LUAU_ASSERT(ptv); switch (ptv->type) { - case PrimitiveTypeVar::NilType: - here.nils = singletonTypes->neverType; - break; - case PrimitiveTypeVar::Boolean: - here.booleans = singletonTypes->neverType; - break; - case PrimitiveTypeVar::Number: - here.numbers = singletonTypes->neverType; - break; - case PrimitiveTypeVar::String: - here.strings.resetToNever(); - break; - case PrimitiveTypeVar::Thread: - here.threads = singletonTypes->neverType; - break; + case PrimitiveTypeVar::NilType: + here.nils = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Boolean: + here.booleans = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Number: + here.numbers = singletonTypes->neverType; + break; + case PrimitiveTypeVar::String: + here.strings.resetToNever(); + break; + case PrimitiveTypeVar::Thread: + here.threads = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Function: + LUAU_ASSERT(FFlag::LuauNegatedStringSingletons); + here.functions.resetToNever(); + break; } } @@ -1589,7 +1650,7 @@ std::optional Normalizer::intersectionOfFunctions(TypeId here, TypeId th TypePackId argTypes; TypePackId retTypes; - + if (hftv->retTypes == tftv->retTypes) { std::optional argTypesOpt = unionOfTypePacks(hftv->argTypes, tftv->argTypes); @@ -1598,7 +1659,7 @@ std::optional Normalizer::intersectionOfFunctions(TypeId here, TypeId th argTypes = *argTypesOpt; retTypes = hftv->retTypes; } - else if (FFlag::LuauOverloadedFunctionSubtypingPerf && hftv->argTypes == tftv->argTypes) + else if (FFlag::LuauOverloadedFunctionSubtypingPerf && hftv->argTypes == tftv->argTypes) { std::optional retTypesOpt = intersectionOfTypePacks(hftv->argTypes, tftv->argTypes); if (!retTypesOpt) @@ -1738,18 +1799,20 @@ std::optional Normalizer::unionSaturatedFunctions(TypeId here, TypeId th void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there) { - if (!heres) + if (heres.isNever()) return; - for (auto it = heres->begin(); it != heres->end();) + heres.isTop = false; + + for (auto it = heres.parts->begin(); it != heres.parts->end();) { TypeId here = *it; if (get(here)) it++; else if (std::optional tmp = intersectionOfFunctions(here, there)) { - heres->erase(it); - heres->insert(*tmp); + heres.parts->erase(it); + heres.parts->insert(*tmp); return; } else @@ -1757,27 +1820,27 @@ void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, T } TypeIds tmps; - for (TypeId here : *heres) + for (TypeId here : *heres.parts) { if (std::optional tmp = unionSaturatedFunctions(here, there)) tmps.insert(*tmp); } - heres->insert(there); - heres->insert(tmps.begin(), tmps.end()); + heres.parts->insert(there); + heres.parts->insert(tmps.begin(), tmps.end()); } void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) { - if (!heres) + if (heres.isNever()) return; - else if (!theres) + else if (theres.isNever()) { - heres = std::nullopt; + heres.resetToNever(); return; } else { - for (TypeId there : *theres) + for (TypeId there : *theres.parts) intersectFunctionsWithFunction(heres, there); } } @@ -1935,6 +1998,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) TypeId nils = here.nils; TypeId numbers = here.numbers; NormalizedStringType strings = std::move(here.strings); + NormalizedFunctionType functions = std::move(here.functions); TypeId threads = here.threads; clearNormal(here); @@ -1949,6 +2013,11 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) here.strings = std::move(strings); else if (ptv->type == PrimitiveTypeVar::Thread) here.threads = threads; + else if (ptv->type == PrimitiveTypeVar::Function) + { + LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); + here.functions = std::move(functions); + } else LUAU_ASSERT(!"Unreachable"); } @@ -1981,8 +2050,10 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) for (TypeId part : itv->options) { const NormalizedType* normalPart = normalize(part); - NormalizedType negated = negateNormal(*normalPart); - intersectNormals(here, negated); + std::optional negated = negateNormal(*normalPart); + if (!negated) + return false; + intersectNormals(here, *negated); } } else @@ -2016,14 +2087,16 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) result.insert(result.end(), norm.classes.begin(), norm.classes.end()); if (!get(norm.errors)) result.push_back(norm.errors); - if (norm.functions) + if (FFlag::LuauNegatedFunctionTypes && norm.functions.isTop) + result.push_back(singletonTypes->functionType); + else if (!norm.functions.isNever()) { - if (norm.functions->size() == 1) - result.push_back(*norm.functions->begin()); + if (norm.functions.parts->size() == 1) + result.push_back(*norm.functions.parts->begin()); else { std::vector parts; - parts.insert(parts.end(), norm.functions->begin(), norm.functions->end()); + parts.insert(parts.end(), norm.functions.parts->begin(), norm.functions.parts->end()); result.push_back(arena->addType(IntersectionTypeVar{std::move(parts)})); } } @@ -2070,62 +2143,24 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) return arena->addType(UnionTypeVar{std::move(result)}); } -namespace -{ - -struct Replacer -{ - TypeArena* arena; - TypeId sourceType; - TypeId replacedType; - DenseHashMap newTypes; - - Replacer(TypeArena* arena, TypeId sourceType, TypeId replacedType) - : arena(arena) - , sourceType(sourceType) - , replacedType(replacedType) - , newTypes(nullptr) - { - } - - TypeId smartClone(TypeId t) - { - t = follow(t); - TypeId* res = newTypes.find(t); - if (res) - return *res; - - TypeId result = shallowClone(t, *arena, TxnLog::empty()); - newTypes[t] = result; - newTypes[result] = result; - - return result; - } -}; - -} // anonymous namespace - -bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop) +bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; - u.anyIsTop = anyIsTop; u.tryUnify(subTy, superTy); const bool ok = u.errors.empty() && u.log.empty(); return ok; } -bool isSubtype( - TypePackId subPack, TypePackId superPack, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop) +bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; - u.anyIsTop = anyIsTop; u.tryUnify(subPack, superPack); const bool ok = u.errors.empty() && u.log.empty(); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 7d0fb22d..5fd15cf7 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -10,12 +10,11 @@ #include #include -LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauLvaluelessPath) -LUAU_FASTFLAGVARIABLE(LuauSpecialTypesAsterisked, false) -LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false) -LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) +LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) +LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) /* * Prefix generic typenames with gen- @@ -130,28 +129,15 @@ struct StringifierState bool exhaustive; - StringifierState(ToStringOptions& opts, ToStringResult& result, const std::optional& DEPRECATED_nameMap) + StringifierState(ToStringOptions& opts, ToStringResult& result) : opts(opts) , result(result) , exhaustive(opts.exhaustive) { - if (!FFlag::LuauFixNameMaps && DEPRECATED_nameMap) - result.DEPRECATED_nameMap = *DEPRECATED_nameMap; - - if (!FFlag::LuauFixNameMaps) - { - for (const auto& [_, v] : result.DEPRECATED_nameMap.typeVars) - usedNames.insert(v); - for (const auto& [_, v] : result.DEPRECATED_nameMap.typePacks) - usedNames.insert(v); - } - else - { - for (const auto& [_, v] : opts.nameMap.typeVars) - usedNames.insert(v); - for (const auto& [_, v] : opts.nameMap.typePacks) - usedNames.insert(v); - } + for (const auto& [_, v] : opts.nameMap.typeVars) + usedNames.insert(v); + for (const auto& [_, v] : opts.nameMap.typePacks) + usedNames.insert(v); } bool hasSeen(const void* tv) @@ -174,8 +160,8 @@ struct StringifierState std::string getName(TypeId ty) { - const size_t s = FFlag::LuauFixNameMaps ? opts.nameMap.typeVars.size() : result.DEPRECATED_nameMap.typeVars.size(); - std::string& n = FFlag::LuauFixNameMaps ? opts.nameMap.typeVars[ty] : result.DEPRECATED_nameMap.typeVars[ty]; + const size_t s = opts.nameMap.typeVars.size(); + std::string& n = opts.nameMap.typeVars[ty]; if (!n.empty()) return n; @@ -197,8 +183,8 @@ struct StringifierState std::string getName(TypePackId ty) { - const size_t s = FFlag::LuauFixNameMaps ? opts.nameMap.typePacks.size() : result.DEPRECATED_nameMap.typePacks.size(); - std::string& n = FFlag::LuauFixNameMaps ? opts.nameMap.typePacks[ty] : result.DEPRECATED_nameMap.typePacks[ty]; + const size_t s = opts.nameMap.typePacks.size(); + std::string& n = opts.nameMap.typePacks[ty]; if (!n.empty()) return n; @@ -225,6 +211,20 @@ struct StringifierState result.name += s; } + void emitLevel(Scope* scope) + { + size_t count = 0; + for (Scope* s = scope; s; s = s->parent.get()) + ++count; + + emit(count); + emit("-"); + char buffer[16]; + uint32_t s = uint32_t(intptr_t(scope) & 0xFFFFFF); + snprintf(buffer, sizeof(buffer), "0x%x", s); + emit(buffer); + } + void emit(TypeLevel level) { emit(std::to_string(level.level)); @@ -296,10 +296,7 @@ struct TypeVarStringifier if (tv->ty.valueless_by_exception()) { state.result.error = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("* VALUELESS BY EXCEPTION *"); - else - state.emit("< VALUELESS BY EXCEPTION >"); + state.emit("* VALUELESS BY EXCEPTION *"); return; } @@ -377,7 +374,10 @@ struct TypeVarStringifier if (FFlag::DebugLuauVerboseTypeNames) { state.emit("-"); - state.emit(ftv.level); + if (FFlag::DebugLuauDeferredConstraintResolution) + state.emitLevel(ftv.scope); + else + state.emit(ftv.level); } } @@ -391,14 +391,20 @@ struct TypeVarStringifier if (gtv.explicitName) { state.usedNames.insert(gtv.name); - if (FFlag::LuauFixNameMaps) - state.opts.nameMap.typeVars[ty] = gtv.name; - else - state.result.DEPRECATED_nameMap.typeVars[ty] = gtv.name; + state.opts.nameMap.typeVars[ty] = gtv.name; state.emit(gtv.name); } else state.emit(state.getName(ty)); + + if (FFlag::DebugLuauVerboseTypeNames) + { + state.emit("-"); + if (FFlag::DebugLuauDeferredConstraintResolution) + state.emitLevel(gtv.scope); + else + state.emit(gtv.level); + } } void operator()(TypeId, const BlockedTypeVar& btv) @@ -434,6 +440,9 @@ struct TypeVarStringifier case PrimitiveTypeVar::Thread: state.emit("thread"); return; + case PrimitiveTypeVar::Function: + state.emit("function"); + return; default: LUAU_ASSERT(!"Unknown primitive type"); throwRuntimeError("Unknown primitive type " + std::to_string(ptv.type)); @@ -462,10 +471,7 @@ struct TypeVarStringifier if (state.hasSeen(&ftv)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLE*"); - else - state.emit(""); + state.emit("*CYCLE*"); return; } @@ -573,10 +579,7 @@ struct TypeVarStringifier if (state.hasSeen(&ttv)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLE*"); - else - state.emit(""); + state.emit("*CYCLE*"); return; } @@ -710,10 +713,7 @@ struct TypeVarStringifier if (state.hasSeen(&uv)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLE*"); - else - state.emit(""); + state.emit("*CYCLE*"); return; } @@ -780,10 +780,7 @@ struct TypeVarStringifier if (state.hasSeen(&uv)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLE*"); - else - state.emit(""); + state.emit("*CYCLE*"); return; } @@ -828,10 +825,7 @@ struct TypeVarStringifier void operator()(TypeId, const ErrorTypeVar& tv) { state.result.error = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); - else - state.emit(FFlag::LuauUnknownAndNeverType ? "" : "*unknown*"); + state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); } void operator()(TypeId, const LazyTypeVar& ltv) @@ -850,11 +844,6 @@ struct TypeVarStringifier state.emit("never"); } - void operator()(TypeId ty, const UseTypeVar&) - { - stringify(follow(ty)); - } - void operator()(TypeId, const NegationTypeVar& ntv) { state.emit("~"); @@ -907,10 +896,7 @@ struct TypePackStringifier if (tp->ty.valueless_by_exception()) { state.result.error = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("* VALUELESS TP BY EXCEPTION *"); - else - state.emit("< VALUELESS TP BY EXCEPTION >"); + state.emit("* VALUELESS TP BY EXCEPTION *"); return; } @@ -934,10 +920,7 @@ struct TypePackStringifier if (state.hasSeen(&tp)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLETP*"); - else - state.emit(""); + state.emit("*CYCLETP*"); return; } @@ -982,10 +965,7 @@ struct TypePackStringifier void operator()(TypePackId, const Unifiable::Error& error) { state.result.error = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); - else - state.emit(FFlag::LuauUnknownAndNeverType ? "" : "*unknown*"); + state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); } void operator()(TypePackId, const VariadicTypePack& pack) @@ -993,10 +973,7 @@ struct TypePackStringifier state.emit("..."); if (FFlag::DebugLuauVerboseTypeNames && pack.hidden) { - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*hidden*"); - else - state.emit(""); + state.emit("*hidden*"); } stringify(pack.ty); } @@ -1008,10 +985,7 @@ struct TypePackStringifier if (pack.explicitName) { state.usedNames.insert(pack.name); - if (FFlag::LuauFixNameMaps) - state.opts.nameMap.typePacks[tp] = pack.name; - else - state.result.DEPRECATED_nameMap.typePacks[tp] = pack.name; + state.opts.nameMap.typePacks[tp] = pack.name; state.emit(pack.name); } else @@ -1031,7 +1005,10 @@ struct TypePackStringifier if (FFlag::DebugLuauVerboseTypeNames) { state.emit("-"); - state.emit(pack.level); + if (FFlag::DebugLuauDeferredConstraintResolution) + state.emitLevel(pack.scope); + else + state.emit(pack.level); } state.emit("..."); @@ -1111,8 +1088,7 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) ToStringResult result; - StringifierState state = - FFlag::LuauFixNameMaps ? StringifierState{opts, result, opts.nameMap} : StringifierState{opts, result, opts.DEPRECATED_nameMap}; + StringifierState state{opts, result}; std::set cycles; std::set cycleTPs; @@ -1204,10 +1180,7 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) { result.truncated = true; - if (FFlag::LuauSpecialTypesAsterisked) - result.name += "... *TRUNCATED*"; - else - result.name += "... "; + result.name += "... *TRUNCATED*"; } return result; @@ -1222,8 +1195,7 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) * 4. Print out the root of the type using the same algorithm as step 3. */ ToStringResult result; - StringifierState state = - FFlag::LuauFixNameMaps ? StringifierState{opts, result, opts.nameMap} : StringifierState{opts, result, opts.DEPRECATED_nameMap}; + StringifierState state{opts, result}; std::set cycles; std::set cycleTPs; @@ -1280,10 +1252,7 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) { - if (FFlag::LuauSpecialTypesAsterisked) - result.name += "... *TRUNCATED*"; - else - result.name += "... "; + result.name += "... *TRUNCATED*"; } return result; @@ -1312,8 +1281,7 @@ std::string toString(const TypePackVar& tp, ToStringOptions& opts) std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv, ToStringOptions& opts) { ToStringResult result; - StringifierState state = - FFlag::LuauFixNameMaps ? StringifierState{opts, result, opts.nameMap} : StringifierState{opts, result, opts.DEPRECATED_nameMap}; + StringifierState state{opts, result}; TypeVarStringifier tvs{state}; state.emit(funcName); @@ -1445,90 +1413,86 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) auto go = [&opts](auto&& c) -> std::string { using T = std::decay_t; - // TODO: Inline and delete this function when clipping FFlag::LuauFixNameMaps - auto tos = [](auto&& a, ToStringOptions& opts) { - if (FFlag::LuauFixNameMaps) - return toString(a, opts); - else - { - ToStringResult tsr = toStringDetailed(a, opts); - opts.DEPRECATED_nameMap = std::move(tsr.DEPRECATED_nameMap); - return tsr.name; - } + auto tos = [&opts](auto&& a) + { + return toString(a, opts); }; if constexpr (std::is_same_v) { - std::string subStr = tos(c.subType, opts); - std::string superStr = tos(c.superType, opts); + std::string subStr = tos(c.subType); + std::string superStr = tos(c.superType); return subStr + " <: " + superStr; } else if constexpr (std::is_same_v) { - std::string subStr = tos(c.subPack, opts); - std::string superStr = tos(c.superPack, opts); + std::string subStr = tos(c.subPack); + std::string superStr = tos(c.superPack); return subStr + " <: " + superStr; } else if constexpr (std::is_same_v) { - std::string subStr = tos(c.generalizedType, opts); - std::string superStr = tos(c.sourceType, opts); + std::string subStr = tos(c.generalizedType); + std::string superStr = tos(c.sourceType); return subStr + " ~ gen " + superStr; } else if constexpr (std::is_same_v) { - std::string subStr = tos(c.subType, opts); - std::string superStr = tos(c.superType, opts); + std::string subStr = tos(c.subType); + std::string superStr = tos(c.superType); return subStr + " ~ inst " + superStr; } else if constexpr (std::is_same_v) { - std::string resultStr = tos(c.resultType, opts); - std::string operandStr = tos(c.operandType, opts); + std::string resultStr = tos(c.resultType); + std::string operandStr = tos(c.operandType); return resultStr + " ~ Unary<" + toString(c.op) + ", " + operandStr + ">"; } else if constexpr (std::is_same_v) { - std::string resultStr = tos(c.resultType, opts); - std::string leftStr = tos(c.leftType, opts); - std::string rightStr = tos(c.rightType, opts); + std::string resultStr = tos(c.resultType); + std::string leftStr = tos(c.leftType); + std::string rightStr = tos(c.rightType); return resultStr + " ~ Binary<" + toString(c.op) + ", " + leftStr + ", " + rightStr + ">"; } else if constexpr (std::is_same_v) { - std::string iteratorStr = tos(c.iterator, opts); - std::string variableStr = tos(c.variables, opts); + std::string iteratorStr = tos(c.iterator); + std::string variableStr = tos(c.variables); return variableStr + " ~ Iterate<" + iteratorStr + ">"; } else if constexpr (std::is_same_v) { - std::string namedStr = tos(c.namedType, opts); + std::string namedStr = tos(c.namedType); return "@name(" + namedStr + ") = " + c.name; } else if constexpr (std::is_same_v) { - std::string targetStr = tos(c.target, opts); + std::string targetStr = tos(c.target); return "expand " + targetStr; } else if constexpr (std::is_same_v) { - return "call " + tos(c.fn, opts) + " with { result = " + tos(c.result, opts) + " }"; + return "call " + tos(c.fn) + " with { result = " + tos(c.result) + " }"; } else if constexpr (std::is_same_v) { - return tos(c.resultType, opts) + " ~ prim " + tos(c.expectedType, opts) + ", " + tos(c.singletonType, opts) + ", " + - tos(c.multitonType, opts); + return tos(c.resultType) + " ~ prim " + tos(c.expectedType) + ", " + tos(c.singletonType) + ", " + + tos(c.multitonType); } else if constexpr (std::is_same_v) { - return tos(c.resultType, opts) + " ~ hasProp " + tos(c.subjectType, opts) + ", \"" + c.prop + "\""; + return tos(c.resultType) + " ~ hasProp " + tos(c.subjectType) + ", \"" + c.prop + "\""; } - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v) { - return "TODO"; + std::string result = tos(c.resultType); + std::string discriminant = tos(c.discriminantType); + + return result + " ~ if isSingleton D then ~D else unknown where D = " + discriminant; } else static_assert(always_false_v, "Non-exhaustive constraint switch"); diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 179846d7..c97ed05d 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -338,12 +338,6 @@ public: { return allocator->alloc(Location(), std::nullopt, AstName{"never"}); } - AstType* operator()(const UseTypeVar& utv) - { - std::optional ty = utv.scope->lookup(utv.def); - LUAU_ASSERT(ty); - return Luau::visit(*this, (*ty)->ty); - } AstType* operator()(const NegationTypeVar& ntv) { // FIXME: do the same thing we do with ErrorTypeVar diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index a2673158..03575c40 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -301,7 +301,6 @@ struct TypeChecker2 UnifierSharedState sharedState{&ice}; Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant}; - u.anyIsTop = true; u.tryUnify(actualRetType, expectedRetType); const bool ok = u.errors.empty() && u.log.empty(); @@ -331,16 +330,21 @@ struct TypeChecker2 if (value) visit(value); - if (i != local->values.size - 1) + TypeId* maybeValueType = value ? module->astTypes.find(value) : nullptr; + if (i != local->values.size - 1 || maybeValueType) { AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr; if (var && var->annotation) { - TypeId varType = lookupAnnotation(var->annotation); + TypeId annotationType = lookupAnnotation(var->annotation); TypeId valueType = value ? lookupType(value) : nullptr; - if (valueType && !isSubtype(varType, valueType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) - reportError(TypeMismatch{varType, valueType}, value->location); + if (valueType) + { + ErrorVec errors = tryUnify(stack.back(), value->location, valueType, annotationType); + if (!errors.empty()) + reportErrors(std::move(errors)); + } } } else @@ -606,7 +610,7 @@ struct TypeChecker2 visit(rhs); TypeId rhsType = lookupType(rhs); - if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(rhsType, lhsType, stack.back())) { reportError(TypeMismatch{lhsType, rhsType}, rhs->location); } @@ -757,7 +761,7 @@ struct TypeChecker2 TypeId actualType = lookupType(number); TypeId numberType = singletonTypes->numberType; - if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(numberType, actualType, stack.back())) { reportError(TypeMismatch{actualType, numberType}, number->location); } @@ -768,7 +772,7 @@ struct TypeChecker2 TypeId actualType = lookupType(string); TypeId stringType = singletonTypes->stringType; - if (!isSubtype(actualType, stringType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(actualType, stringType, stack.back())) { reportError(TypeMismatch{actualType, stringType}, string->location); } @@ -857,7 +861,7 @@ struct TypeChecker2 FunctionTypeVar ftv{argsTp, expectedRetType}; TypeId expectedType = arena.addType(ftv); - if (!isSubtype(testFunctionType, expectedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(testFunctionType, expectedType, stack.back())) { CloneState cloneState; expectedType = clone(expectedType, module->internalTypes, cloneState); @@ -876,7 +880,7 @@ struct TypeChecker2 getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true); if (ty) { - if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(resultType, *ty, stack.back())) { reportError(TypeMismatch{resultType, *ty}, indexName->location); } @@ -909,7 +913,7 @@ struct TypeChecker2 TypeId inferredArgTy = *argIt; TypeId annotatedArgTy = lookupAnnotation(arg->annotation); - if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back())) { reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); } @@ -950,7 +954,7 @@ struct TypeChecker2 if (const FunctionTypeVar* ftv = get(follow(*mm))) { TypePackId expectedArgs = module->internalTypes.addTypePack({operandType}); - reportErrors(tryUnify(scope, expr->location, ftv->argTypes, expectedArgs)); + reportErrors(tryUnify(scope, expr->location, expectedArgs, ftv->argTypes)); if (std::optional ret = first(ftv->retTypes)) { @@ -1092,7 +1096,7 @@ struct TypeChecker2 expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::Op::CompareLe || expr->op == AstExprBinary::Op::CompareLt) { TypePackId expectedRets = module->internalTypes.addTypePack({singletonTypes->booleanType}); - if (!isSubtype(ftv->retTypes, expectedRets, scope, singletonTypes, ice)) + if (!isSubtype(ftv->retTypes, expectedRets, scope)) { reportError(GenericError{format("Metamethod '%s' must return type 'boolean'", it->second)}, expr->location); } @@ -1203,10 +1207,10 @@ struct TypeChecker2 TypeId computedType = lookupType(expr->expr); // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (isSubtype(annotationType, computedType, stack.back())) return; - if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (isSubtype(computedType, annotationType, stack.back())) return; reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); @@ -1501,13 +1505,27 @@ struct TypeChecker2 } } + template + bool isSubtype(TID subTy, TID superTy, NotNull scope) + { + UnifierSharedState sharedState{&ice}; + TypeArena arena; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; + u.useScopes = true; + + u.tryUnify(subTy, superTy); + const bool ok = u.errors.empty() && u.log.empty(); + return ok; + } + template ErrorVec tryUnify(NotNull scope, const Location& location, TID subTy, TID superTy) { UnifierSharedState sharedState{&ice}; Normalizer normalizer{&module->internalTypes, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant}; - u.anyIsTop = true; + u.useScopes = true; u.tryUnify(subTy, superTy); return std::move(u.errors); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index fef33d1d..9c0a38a1 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -35,18 +35,20 @@ LUAU_FASTFLAG(LuauTypeNormalization2) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) -LUAU_FASTFLAGVARIABLE(LuauAnyifyModuleReturnGenerics, false) LUAU_FASTFLAGVARIABLE(LuauLvaluelessPath, false) +LUAU_FASTFLAGVARIABLE(LuauNilIterator, false) LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) -LUAU_FASTFLAGVARIABLE(LuauFixVarargExprHeadType, false) LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false) +LUAU_FASTFLAGVARIABLE(LuauTryhardAnd, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false) +LUAU_FASTFLAGVARIABLE(LuauOptionalNextKey, false) LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false) LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false) LUAU_FASTFLAGVARIABLE(LuauArgMismatchReportFunctionLocation, false) +LUAU_FASTFLAGVARIABLE(LuauImplicitElseRefinement, false) LUAU_FASTFLAGVARIABLE(LuauDeclareClassPrototype, false) namespace Luau @@ -332,8 +334,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo else moduleScope->returnType = anyify(moduleScope, moduleScope->returnType, Location{}); - if (FFlag::LuauAnyifyModuleReturnGenerics) - moduleScope->returnType = anyifyModuleReturnTypePackGenerics(moduleScope->returnType); + moduleScope->returnType = anyifyModuleReturnTypePackGenerics(moduleScope->returnType); for (auto& [_, typeFun] : moduleScope->exportedTypeBindings) typeFun.type = anyify(moduleScope, typeFun.type, Location{}); @@ -1215,10 +1216,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) AstExpr* firstValue = forin.values.data[0]; // next is a function that takes Table and an optional index of type K - // next(t: Table, index: K | nil) -> (K, V) + // next(t: Table, index: K | nil) -> (K?, V) // however, pairs and ipairs are quite messy, but they both share the same types // pairs returns 'next, t, nil', thus the type would be - // pairs(t: Table) -> ((Table, K | nil) -> (K, V), Table, K | nil) + // pairs(t: Table) -> ((Table, K | nil) -> (K?, V), Table, K | nil) // ipairs returns 'next, t, 0', thus ipairs will also share the same type as pairs, except K = number // // we can also define our own custom iterators by by returning a wrapped coroutine that calls coroutine.yield @@ -1261,6 +1262,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) iterTy = instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location); } + if (FFlag::LuauNilIterator) + iterTy = stripFromNilAndReport(iterTy, firstValue->location); + if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location, /* addErrors= */ true)) { // if __iter metamethod is present, it will be called and the results are going to be called as if they are functions @@ -1344,21 +1348,61 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) reportErrors(state.errors); } - TypePackId varPack = addTypePack(TypePackVar{TypePack{varTypes, freshTypePack(scope)}}); - - if (forin.values.size >= 2) + if (FFlag::LuauOptionalNextKey) { - AstArray arguments{forin.values.data + 1, forin.values.size - 1}; + TypePackId retPack = iterFunc->retTypes; - Position start = firstValue->location.begin; - Position end = values[forin.values.size - 1]->location.end; - AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; + if (forin.values.size >= 2) + { + AstArray arguments{forin.values.data + 1, forin.values.size - 1}; + + Position start = firstValue->location.begin; + Position end = values[forin.values.size - 1]->location.end; + AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; + + retPack = checkExprPack(scope, exprCall).type; + } + + // We need to remove 'nil' from the set of options of the first return value + // Because for loop stops when it gets 'nil', this result is never actually assigned to the first variable + if (std::optional fty = first(retPack); fty && !varTypes.empty()) + { + TypeId keyTy = follow(*fty); + + if (get(keyTy)) + { + if (std::optional ty = tryStripUnionFromNil(keyTy)) + keyTy = *ty; + } + + unify(keyTy, varTypes.front(), scope, forin.location); + + // We have already handled the first variable type, make it match in the pack check + varTypes.front() = *fty; + } + + TypePackId varPack = addTypePack(TypePackVar{TypePack{varTypes, freshTypePack(scope)}}); - TypePackId retPack = checkExprPack(scope, exprCall).type; unify(retPack, varPack, scope, forin.location); } else - unify(iterFunc->retTypes, varPack, scope, forin.location); + { + TypePackId varPack = addTypePack(TypePackVar{TypePack{varTypes, freshTypePack(scope)}}); + + if (forin.values.size >= 2) + { + AstArray arguments{forin.values.data + 1, forin.values.size - 1}; + + Position start = firstValue->location.begin; + Position end = values[forin.values.size - 1]->location.end; + AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; + + TypePackId retPack = checkExprPack(scope, exprCall).type; + unify(retPack, varPack, scope, forin.location); + } + else + unify(iterFunc->retTypes, varPack, scope, forin.location); + } check(loopScope, *forin.body); } @@ -1977,18 +2021,10 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (get(varargPack)) { - if (FFlag::LuauFixVarargExprHeadType) - { - if (std::optional ty = first(varargPack)) - return {*ty}; + if (std::optional ty = first(varargPack)) + return {*ty}; - return {nilType}; - } - else - { - std::vector types = flatten(varargPack).first; - return {!types.empty() ? types[0] : nilType}; - } + return {nilType}; } else if (get(varargPack)) { @@ -2839,12 +2875,54 @@ TypeId TypeChecker::checkRelationalOperation( case AstExprBinary::And: if (lhsIsAny) + { return lhsType; - return unionOfTypes(rhsType, booleanType, scope, expr.location, false); + } + else if (FFlag::LuauTryhardAnd) + { + // If lhs is free, we can't tell which 'falsy' components it has, if any + if (get(lhsType)) + return unionOfTypes(addType(UnionTypeVar{{nilType, singletonType(false)}}), rhsType, scope, expr.location, false); + + auto [oty, notNever] = pickTypesFromSense(lhsType, false, neverType); // Filter out falsy types + + if (notNever) + { + LUAU_ASSERT(oty); + return unionOfTypes(*oty, rhsType, scope, expr.location, false); + } + else + { + return rhsType; + } + } + else + { + return unionOfTypes(rhsType, booleanType, scope, expr.location, false); + } case AstExprBinary::Or: if (lhsIsAny) + { return lhsType; - return unionOfTypes(lhsType, rhsType, scope, expr.location); + } + else if (FFlag::LuauTryhardAnd) + { + auto [oty, notNever] = pickTypesFromSense(lhsType, true, neverType); // Filter out truthy types + + if (notNever) + { + LUAU_ASSERT(oty); + return unionOfTypes(*oty, rhsType, scope, expr.location); + } + else + { + return rhsType; + } + } + else + { + return unionOfTypes(lhsType, rhsType, scope, expr.location); + } default: LUAU_ASSERT(0); ice(format("checkRelationalOperation called with incorrect binary expression '%s'", toString(expr.op).c_str()), expr.location); @@ -4962,9 +5040,9 @@ TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) return singletonTypes->errorRecoveryTypePack(guess); } -TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) +TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense, TypeId emptySetTy) { - return [this, sense](TypeId ty) -> std::optional { + return [this, sense, emptySetTy](TypeId ty) -> std::optional { // any/error/free gets a special pass unconditionally because they can't be decided. if (get(ty) || get(ty) || get(ty)) return ty; @@ -4982,7 +5060,7 @@ TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) return sense ? std::nullopt : std::optional(ty); // at this point, anything else is kept if sense is true, or replaced by nil - return sense ? ty : nilType; + return sense ? ty : emptySetTy; }; } @@ -5008,9 +5086,9 @@ std::pair, bool> TypeChecker::filterMap(TypeId type, TypeI } } -std::pair, bool> TypeChecker::pickTypesFromSense(TypeId type, bool sense) +std::pair, bool> TypeChecker::pickTypesFromSense(TypeId type, bool sense, TypeId emptySetTy) { - return filterMap(type, mkTruthyPredicate(sense)); + return filterMap(type, mkTruthyPredicate(sense, emptySetTy)); } TypeId TypeChecker::addTV(TypeVar&& tv) @@ -5779,7 +5857,7 @@ void TypeChecker::resolve(const TruthyPredicate& truthyP, RefinementMap& refis, if (ty && fromOr) return addRefinement(refis, truthyP.lvalue, *ty); - refineLValue(truthyP.lvalue, refis, scope, mkTruthyPredicate(sense)); + refineLValue(truthyP.lvalue, refis, scope, mkTruthyPredicate(sense, nilType)); } void TypeChecker::resolve(const AndPredicate& andP, RefinementMap& refis, const ScopePtr& scope, bool sense) @@ -5972,13 +6050,57 @@ void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const Sc if (maybeSingleton(eqP.type)) { - // Normally we'd write option <: eqP.type, but singletons are always the subtype, so we flip this. - if (!sense || canUnify(eqP.type, option, scope, eqP.location).empty()) - return sense ? eqP.type : option; + if (FFlag::LuauImplicitElseRefinement) + { + bool optionIsSubtype = canUnify(option, eqP.type, scope, eqP.location).empty(); + bool targetIsSubtype = canUnify(eqP.type, option, scope, eqP.location).empty(); - // local variable works around an odd gcc 9.3 warning: may be used uninitialized - std::optional res = std::nullopt; - return res; + // terminology refresher: + // - option is the type of the expression `x`, and + // - eqP.type is the type of the expression `"hello"` + // + // "hello" == x where + // x : "hello" | "world" -> x : "hello" + // x : number | string -> x : "hello" + // x : number -> x : never + // + // "hello" ~= x where + // x : "hello" | "world" -> x : "world" + // x : number | string -> x : number | string + // x : number -> x : number + + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional nope = std::nullopt; + + if (sense) + { + if (optionIsSubtype && !targetIsSubtype) + return option; + else if (!optionIsSubtype && targetIsSubtype) + return eqP.type; + else if (!optionIsSubtype && !targetIsSubtype) + return nope; + else if (optionIsSubtype && targetIsSubtype) + return eqP.type; + } + else + { + bool isOptionSingleton = get(option); + if (!isOptionSingleton) + return option; + else if (optionIsSubtype && targetIsSubtype) + return nope; + } + } + else + { + if (!sense || canUnify(eqP.type, option, scope, eqP.location).empty()) + return sense ? eqP.type : option; + + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional res = std::nullopt; + return res; + } } return option; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 94d633c7..814eca0d 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -3,6 +3,7 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Common.h" +#include "Luau/ConstraintSolver.h" #include "Luau/DenseHash.h" #include "Luau/Error.h" #include "Luau/RecursionCounter.h" @@ -26,6 +27,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) LUAU_FASTFLAGVARIABLE(LuauNoMoreGlobalSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauNewLibraryTypeNames, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) namespace Luau @@ -33,15 +35,19 @@ namespace Luau std::optional> magicFunctionFormat( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static bool dcrMagicFunctionFormat(MagicFunctionCallContext context); static std::optional> magicFunctionGmatch( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context); static std::optional> magicFunctionMatch( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static bool dcrMagicFunctionMatch(MagicFunctionCallContext context); static std::optional> magicFunctionFind( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static bool dcrMagicFunctionFind(MagicFunctionCallContext context); TypeId follow(TypeId t) { @@ -57,13 +63,6 @@ TypeId follow(TypeId t, std::function mapper) return btv->boundTo; else if (auto ttv = get(mapper(ty))) return ttv->boundTo; - else if (auto utv = get(mapper(ty))) - { - std::optional ty = utv->scope->lookup(utv->def); - if (!ty) - throwRuntimeError("UseTypeVar must map to another TypeId"); - return *ty; - } else return std::nullopt; }; @@ -761,6 +760,7 @@ SingletonTypes::SingletonTypes() , stringType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true})) , booleanType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true})) , threadType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true})) + , functionType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Function}, /*persistent*/ true})) , trueType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true})) , falseType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true})) , anyType(arena->addType(TypeVar{AnyTypeVar{}, /*persistent*/ true})) @@ -806,6 +806,7 @@ TypeId SingletonTypes::makeStringMetatable() FunctionTypeVar formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}; formatFTV.magicFunction = &magicFunctionFormat; const TypeId formatFn = arena->addType(formatFTV); + attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); const TypePackId emptyPack = arena->addTypePack({}); const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}); @@ -820,14 +821,17 @@ TypeId SingletonTypes::makeStringMetatable() const TypeId gmatchFunc = makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionTypeVar{emptyPack, stringVariadicList})}); attachMagicFunction(gmatchFunc, magicFunctionGmatch); + attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); const TypeId matchFunc = arena->addType( FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}); attachMagicFunction(matchFunc, magicFunctionMatch); + attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); const TypeId findFunc = arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}); attachMagicFunction(findFunc, magicFunctionFind); + attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); TableTypeVar::Props stringLib = { {"byte", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, @@ -861,7 +865,7 @@ TypeId SingletonTypes::makeStringMetatable() TypeId tableType = arena->addType(TableTypeVar{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); if (TableTypeVar* ttv = getMutable(tableType)) - ttv->name = "string"; + ttv->name = FFlag::LuauNewLibraryTypeNames ? "typeof(string)" : "string"; return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } @@ -946,7 +950,8 @@ void persist(TypeId ty) queue.push_back(mtv->table); queue.push_back(mtv->metatable); } - else if (get(t) || get(t) || get(t) || get(t) || get(t) || get(t)) + else if (get(t) || get(t) || get(t) || get(t) || get(t) || + get(t)) { } else @@ -1077,7 +1082,7 @@ IntersectionTypeVarIterator end(const IntersectionTypeVar* itv) return IntersectionTypeVarIterator{}; } -static std::vector parseFormatString(TypeChecker& typechecker, const char* data, size_t size) +static std::vector parseFormatString(NotNull singletonTypes, const char* data, size_t size) { const char* options = "cdiouxXeEfgGqs*"; @@ -1100,13 +1105,13 @@ static std::vector parseFormatString(TypeChecker& typechecker, const cha break; if (data[i] == 'q' || data[i] == 's') - result.push_back(typechecker.stringType); + result.push_back(singletonTypes->stringType); else if (data[i] == '*') - result.push_back(typechecker.unknownType); + result.push_back(singletonTypes->unknownType); else if (strchr(options, data[i])) - result.push_back(typechecker.numberType); + result.push_back(singletonTypes->numberType); else - result.push_back(typechecker.errorRecoveryType(typechecker.anyType)); + result.push_back(singletonTypes->errorRecoveryType(singletonTypes->anyType)); } } @@ -1135,7 +1140,7 @@ std::optional> magicFunctionFormat( if (!fmt) return std::nullopt; - std::vector expected = parseFormatString(typechecker, fmt->value.data, fmt->value.size); + std::vector expected = parseFormatString(typechecker.singletonTypes, fmt->value.data, fmt->value.size); const auto& [params, tail] = flatten(paramPack); size_t paramOffset = 1; @@ -1159,7 +1164,50 @@ std::optional> magicFunctionFormat( return WithPredicate{arena.addTypePack({typechecker.stringType})}; } -static std::vector parsePatternString(TypeChecker& typechecker, const char* data, size_t size) +static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) +{ + TypeArena* arena = context.solver->arena; + + AstExprConstantString* fmt = nullptr; + if (auto index = context.callSite->func->as(); index && context.callSite->self) + { + if (auto group = index->expr->as()) + fmt = group->expr->as(); + else + fmt = index->expr->as(); + } + + if (!context.callSite->self && context.callSite->args.size > 0) + fmt = context.callSite->args.data[0]->as(); + + if (!fmt) + return false; + + std::vector expected = parseFormatString(context.solver->singletonTypes, fmt->value.data, fmt->value.size); + const auto& [params, tail] = flatten(context.arguments); + + size_t paramOffset = 1; + + // unify the prefix one argument at a time + for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) + { + context.solver->unify(params[i + paramOffset], expected[i], context.solver->rootScope); + } + + // if we know the argument count or if we have too many arguments for sure, we can issue an error + size_t numActualParams = params.size(); + size_t numExpectedParams = expected.size() + 1; // + 1 for the format string + + if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) + context.solver->reportError(TypeError{context.callSite->location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); + + TypePackId resultPack = arena->addTypePack({context.solver->singletonTypes->stringType}); + asMutable(context.result)->ty.emplace(resultPack); + + return true; +} + +static std::vector parsePatternString(NotNull singletonTypes, const char* data, size_t size) { std::vector result; int depth = 0; @@ -1191,12 +1239,12 @@ static std::vector parsePatternString(TypeChecker& typechecker, const ch if (i + 1 < size && data[i + 1] == ')') { i++; - result.push_back(typechecker.numberType); + result.push_back(singletonTypes->numberType); continue; } ++depth; - result.push_back(typechecker.stringType); + result.push_back(singletonTypes->stringType); } else if (data[i] == ')') { @@ -1214,7 +1262,7 @@ static std::vector parsePatternString(TypeChecker& typechecker, const ch return std::vector(); if (result.empty()) - result.push_back(typechecker.stringType); + result.push_back(singletonTypes->stringType); return result; } @@ -1238,7 +1286,7 @@ static std::optional> magicFunctionGmatch( if (!pattern) return std::nullopt; - std::vector returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size); + std::vector returnTypes = parsePatternString(typechecker.singletonTypes, pattern->value.data, pattern->value.size); if (returnTypes.empty()) return std::nullopt; @@ -1251,6 +1299,39 @@ static std::optional> magicFunctionGmatch( return WithPredicate{arena.addTypePack({iteratorType})}; } +static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() != 2) + return false; + + TypeArena* arena = context.solver->arena; + + AstExprConstantString* pattern = nullptr; + size_t index = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > index) + pattern = context.callSite->args.data[index]->as(); + + if (!pattern) + return false; + + std::vector returnTypes = parsePatternString(context.solver->singletonTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + + context.solver->unify(params[0], context.solver->singletonTypes->stringType, context.solver->rootScope); + + const TypePackId emptyPack = arena->addTypePack({}); + const TypePackId returnList = arena->addTypePack(returnTypes); + const TypeId iteratorType = arena->addType(FunctionTypeVar{emptyPack, returnList}); + const TypePackId resTypePack = arena->addTypePack({iteratorType}); + asMutable(context.result)->ty.emplace(resTypePack); + + return true; +} + static std::optional> magicFunctionMatch( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { @@ -1270,7 +1351,7 @@ static std::optional> magicFunctionMatch( if (!pattern) return std::nullopt; - std::vector returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size); + std::vector returnTypes = parsePatternString(typechecker.singletonTypes, pattern->value.data, pattern->value.size); if (returnTypes.empty()) return std::nullopt; @@ -1287,6 +1368,42 @@ static std::optional> magicFunctionMatch( return WithPredicate{returnList}; } +static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() < 2 || params.size() > 3) + return false; + + TypeArena* arena = context.solver->arena; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > patternIndex) + pattern = context.callSite->args.data[patternIndex]->as(); + + if (!pattern) + return false; + + std::vector returnTypes = parsePatternString(context.solver->singletonTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + + context.solver->unify(params[0], context.solver->singletonTypes->stringType, context.solver->rootScope); + + const TypeId optionalNumber = arena->addType(UnionTypeVar{{context.solver->singletonTypes->nilType, context.solver->singletonTypes->numberType}}); + + size_t initIndex = context.callSite->self ? 1 : 2; + if (params.size() == 3 && context.callSite->args.size > initIndex) + context.solver->unify(params[2], optionalNumber, context.solver->rootScope); + + const TypePackId returnList = arena->addTypePack(returnTypes); + asMutable(context.result)->ty.emplace(returnList); + + return true; +} + static std::optional> magicFunctionFind( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { @@ -1317,7 +1434,7 @@ static std::optional> magicFunctionFind( std::vector returnTypes; if (!plain) { - returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size); + returnTypes = parsePatternString(typechecker.singletonTypes, pattern->value.data, pattern->value.size); if (returnTypes.empty()) return std::nullopt; @@ -1341,6 +1458,60 @@ static std::optional> magicFunctionFind( return WithPredicate{returnList}; } +static bool dcrMagicFunctionFind(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() < 2 || params.size() > 4) + return false; + + TypeArena* arena = context.solver->arena; + NotNull singletonTypes = context.solver->singletonTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > patternIndex) + pattern = context.callSite->args.data[patternIndex]->as(); + + if (!pattern) + return false; + + bool plain = false; + size_t plainIndex = context.callSite->self ? 2 : 3; + if (context.callSite->args.size > plainIndex) + { + AstExprConstantBool* p = context.callSite->args.data[plainIndex]->as(); + plain = p && p->value; + } + + std::vector returnTypes; + if (!plain) + { + returnTypes = parsePatternString(singletonTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + } + + context.solver->unify(params[0], singletonTypes->stringType, context.solver->rootScope); + + const TypeId optionalNumber = arena->addType(UnionTypeVar{{singletonTypes->nilType, singletonTypes->numberType}}); + const TypeId optionalBoolean = arena->addType(UnionTypeVar{{singletonTypes->nilType, singletonTypes->booleanType}}); + + size_t initIndex = context.callSite->self ? 1 : 2; + if (params.size() >= 3 && context.callSite->args.size > initIndex) + context.solver->unify(params[2], optionalNumber, context.solver->rootScope); + + if (params.size() == 4 && context.callSite->args.size > plainIndex) + context.solver->unify(params[3], optionalBoolean, context.solver->rootScope); + + returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); + + const TypePackId returnList = arena->addTypePack(returnTypes); + asMutable(context.result)->ty.emplace(returnList); + return true; +} + std::vector filterMap(TypeId type, TypeIdPredicate predicate) { type = follow(type); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index b5eba980..4dc90983 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -8,6 +8,7 @@ #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/TimeTrace.h" +#include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" #include "Luau/ToString.h" @@ -21,8 +22,10 @@ LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false); LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauOverloadedFunctionSubtypingPerf, false); +LUAU_FASTFLAGVARIABLE(LuauScalarShapeUnifyToMtOwner, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauNegatedFunctionTypes) namespace Luau { @@ -363,7 +366,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) { - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); return; } @@ -404,7 +407,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (subGeneric && !subsumes(useScopes, subGeneric, superFree)) { // TODO: a more informative error message? CLI-39912 - reportError(TypeError{location, GenericError{"Generic subtype escaping scope"}}); + reportError(location, GenericError{"Generic subtype escaping scope"}); return; } @@ -433,7 +436,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superGeneric && !subsumes(useScopes, superGeneric, subFree)) { // TODO: a more informative error message? CLI-39912 - reportError(TypeError{location, GenericError{"Generic supertype escaping scope"}}); + reportError(location, GenericError{"Generic supertype escaping scope"}); return; } @@ -450,15 +453,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool return tryUnifyWithAny(subTy, superTy); if (get(subTy)) - { - if (anyIsTop) - { - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - return; - } - else - return tryUnifyWithAny(superTy, subTy); - } + return tryUnifyWithAny(superTy, subTy); if (log.get(subTy)) return tryUnifyWithAny(superTy, subTy); @@ -478,7 +473,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (auto error = sharedState.cachedUnifyError.find({subTy, superTy})) { - reportError(TypeError{location, *error}); + reportError(location, *error); return; } } @@ -520,6 +515,12 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if ((log.getMutable(superTy) || log.getMutable(superTy)) && log.getMutable(subTy)) tryUnifySingletons(subTy, superTy); + else if (auto ptv = get(superTy); + FFlag::LuauNegatedFunctionTypes && ptv && ptv->type == PrimitiveTypeVar::Function && get(subTy)) + { + // Ok. Do nothing. forall functions F, F <: function + } + else if (log.getMutable(superTy) && log.getMutable(subTy)) tryUnifyFunctions(subTy, superTy, isFunctionCall); @@ -559,7 +560,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool tryUnifyNegationWithType(subTy, superTy); else - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); if (cacheEnabled) cacheResult(subTy, superTy, errorCount); @@ -633,9 +634,9 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* subUnion, else if (failed) { if (firstFailedOption) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); + reportError(location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}); else - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); } } @@ -734,7 +735,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp const NormalizedType* subNorm = normalizer->normalize(subTy); const NormalizedType* superNorm = normalizer->normalize(superTy); if (!subNorm || !superNorm) - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); else @@ -743,9 +744,9 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp else if (!found) { if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); + reportError(location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}); else - reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); + reportError(location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}); } } @@ -774,7 +775,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I if (unificationTooComplex) reportError(*unificationTooComplex); else if (firstFailedOption) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); + reportError(location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}); } void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall) @@ -832,11 +833,11 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV if (subNorm && superNorm) tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); else - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); } else if (!found) { - reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); + reportError(location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}); } } @@ -848,37 +849,37 @@ void Unifier::tryUnifyNormalizedTypes( if (get(superNorm.tops) || get(superNorm.tops) || get(subNorm.tops)) return; else if (get(subNorm.tops)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (get(subNorm.errors)) if (!get(superNorm.errors)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (get(subNorm.booleans)) { if (!get(superNorm.booleans)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } else if (const SingletonTypeVar* stv = get(subNorm.booleans)) { if (!get(superNorm.booleans) && stv != get(superNorm.booleans)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } if (get(subNorm.nils)) if (!get(superNorm.nils)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (get(subNorm.numbers)) if (!get(superNorm.numbers)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (!isSubtype(subNorm.strings, superNorm.strings)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (get(subNorm.threads)) if (!get(superNorm.errors)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); for (TypeId subClass : subNorm.classes) { @@ -894,7 +895,7 @@ void Unifier::tryUnifyNormalizedTypes( } } if (!found) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } for (TypeId subTable : subNorm.tables) @@ -919,21 +920,19 @@ void Unifier::tryUnifyNormalizedTypes( return reportError(*e); } if (!found) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } - if (subNorm.functions) + if (!subNorm.functions.isNever()) { - if (!superNorm.functions) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); - if (superNorm.functions->empty()) - return; - for (TypeId superFun : *superNorm.functions) + if (superNorm.functions.isNever()) + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); + for (TypeId superFun : *superNorm.functions.parts) { Unifier innerState = makeChildUnifier(); const FunctionTypeVar* superFtv = get(superFun); if (!superFtv) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes); innerState.tryUnify_(tgt, superFtv->retTypes); if (innerState.errors.empty()) @@ -941,7 +940,7 @@ void Unifier::tryUnifyNormalizedTypes( else if (auto e = hasUnificationTooComplex(innerState.errors)) return reportError(*e); else - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } } @@ -959,15 +958,15 @@ void Unifier::tryUnifyNormalizedTypes( TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args) { - if (!overloads || overloads->empty()) + if (overloads.isNever()) { - reportError(TypeError{location, CannotCallNonFunction{function}}); + reportError(location, CannotCallNonFunction{function}); return singletonTypes->errorRecoveryTypePack(); } std::optional result; const FunctionTypeVar* firstFun = nullptr; - for (TypeId overload : *overloads) + for (TypeId overload : *overloads.parts) { if (const FunctionTypeVar* ftv = get(overload)) { @@ -1015,12 +1014,12 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized // TODO: better error reporting? // The logic for error reporting overload resolution // is currently over in TypeInfer.cpp, should we move it? - reportError(TypeError{location, GenericError{"No matching overload."}}); + reportError(location, GenericError{"No matching overload."}); return singletonTypes->errorRecoveryTypePack(firstFun->retTypes); } else { - reportError(TypeError{location, CannotCallNonFunction{function}}); + reportError(location, CannotCallNonFunction{function}); return singletonTypes->errorRecoveryTypePack(); } } @@ -1199,7 +1198,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) { - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); return; } @@ -1372,7 +1371,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal size_t actualSize = size(subTp); if (ctx == CountMismatch::FunctionResult || ctx == CountMismatch::ExprListResult) std::swap(expectedSize, actualSize); - reportError(TypeError{location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx}}); + reportError(location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx}); while (superIter.good()) { @@ -1394,9 +1393,9 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal else { if (FFlag::LuauReportTypeMismatchForTypePackUnificationFailure) - reportError(TypeError{location, TypePackMismatch{subTp, superTp}}); + reportError(location, TypePackMismatch{subTp, superTp}); else - reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); + reportError(location, GenericError{"Failed to unify type packs"}); } } @@ -1408,7 +1407,7 @@ void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy) ice("passed non primitive types to unifyPrimitives"); if (superPrim->type != subPrim->type) - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); } void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) @@ -1429,7 +1428,7 @@ void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) if (superPrim && superPrim->type == PrimitiveTypeVar::String && get(subSingleton) && variance == Covariant) return; - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); } void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall) @@ -1465,21 +1464,21 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal } else { - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); } } else if (numGenerics != subFunction->generics.size()) { numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); - reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); + reportError(location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}); } if (numGenericPacks != subFunction->genericPacks.size()) { numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); - reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); + reportError(location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}); } for (size_t i = 0; i < numGenerics; i++) @@ -1506,11 +1505,10 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) - reportError( - TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}); else if (!innerState.errors.empty()) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}); innerState.ctx = CountMismatch::FunctionResult; innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes); @@ -1520,13 +1518,12 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) - reportError( - TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}); else if (!innerState.errors.empty()) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}); } log.concat(std::move(innerState.log)); @@ -1608,7 +1605,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } else { - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); } } } @@ -1626,7 +1623,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!missingProperties.empty()) { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); + reportError(location, MissingProperties{superTy, subTy, std::move(missingProperties)}); return; } } @@ -1644,7 +1641,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!extraProperties.empty()) { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); + reportError(location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}); return; } } @@ -1703,8 +1700,20 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // Recursive unification can change the txn log, and invalidate the old // table. If we detect that this has happened, we start over, with the updated // txn log. - TableTypeVar* newSuperTable = log.getMutable(superTy); - TableTypeVar* newSubTable = log.getMutable(subTy); + TypeId superTyNew = FFlag::LuauScalarShapeUnifyToMtOwner ? log.follow(superTy) : superTy; + TypeId subTyNew = FFlag::LuauScalarShapeUnifyToMtOwner ? log.follow(subTy) : subTy; + + if (FFlag::LuauScalarShapeUnifyToMtOwner) + { + // If one of the types stopped being a table altogether, we need to restart from the top + if ((superTy != superTyNew || subTy != subTyNew) && errors.empty()) + return tryUnify(subTy, superTy, false, isIntersection); + } + + // Otherwise, restart only the table unification + TableTypeVar* newSuperTable = log.getMutable(superTyNew); + TableTypeVar* newSubTable = log.getMutable(subTyNew); + if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable)) { if (errors.empty()) @@ -1825,13 +1834,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!missingProperties.empty()) { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); + reportError(location, MissingProperties{superTy, subTy, std::move(missingProperties)}); return; } if (!extraProperties.empty()) { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); + reportError(location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}); return; } @@ -1866,15 +1875,17 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) if (reversed) std::swap(subTy, superTy); - if (auto ttv = log.get(superTy); !ttv || ttv->state != TableState::Free) - return reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}}); + TableTypeVar* superTable = log.getMutable(superTy); + + if (!superTable || superTable->state != TableState::Free) + return reportError(location, TypeMismatch{osuperTy, osubTy}); auto fail = [&](std::optional e) { std::string reason = "The former's metatable does not satisfy the requirements."; if (e) - reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason, *e}}); + reportError(location, TypeMismatch{osuperTy, osubTy, reason, *e}); else - reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason}}); + reportError(location, TypeMismatch{osuperTy, osubTy, reason}); }; // Given t1 where t1 = { lower: (t1) -> (a, b...) } @@ -1891,6 +1902,20 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) Unifier child = makeChildUnifier(); child.tryUnify_(ty, superTy); + if (FFlag::LuauScalarShapeUnifyToMtOwner) + { + // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table + // There is a chance that it was unified with the origial subtype, but then, (subtype's metatable) <: subtype could've failed + // Here we check if we have a new supertype instead of the original free table and try original subtype <: new supertype check + TypeId newSuperTy = child.log.follow(superTy); + + if (superTy != newSuperTy && canUnify(subTy, newSuperTy).empty()) + { + log.replace(superTy, BoundTypeVar{subTy}); + return; + } + } + if (auto e = hasUnificationTooComplex(child.errors)) reportError(*e); else if (!child.errors.empty()) @@ -1898,6 +1923,14 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) log.concat(std::move(child.log)); + if (FFlag::LuauScalarShapeUnifyToMtOwner) + { + // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table + // We return success because subtype <: free table which means that correct unification is to replace free table with the subtype + if (child.errors.empty()) + log.replace(superTy, BoundTypeVar{subTy}); + } + return; } else @@ -1906,7 +1939,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) } } - reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}}); + reportError(location, TypeMismatch{osuperTy, osubTy}); return; } @@ -1947,7 +1980,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty()) - reportError(TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); + reportError(location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}); log.concat(std::move(innerState.log)); } @@ -2024,9 +2057,9 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) auto fail = [&]() { if (!reversed) - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); else - reportError(TypeError{location, TypeMismatch{subTy, superTy}}); + reportError(location, TypeMismatch{subTy, superTy}); }; const ClassTypeVar* superClass = get(superTy); @@ -2071,7 +2104,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) if (!classProp) { ok = false; - reportError(TypeError{location, UnknownProperty{superTy, propName}}); + reportError(location, UnknownProperty{superTy, propName}); } else { @@ -2095,7 +2128,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) { ok = false; std::string msg = "Class " + superClass->name + " does not have an indexer"; - reportError(TypeError{location, GenericError{msg}}); + reportError(location, GenericError{msg}); } if (!ok) @@ -2116,13 +2149,13 @@ void Unifier::tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy) const NormalizedType* subNorm = normalizer->normalize(subTy); const NormalizedType* superNorm = normalizer->normalize(superTy); if (!subNorm || !superNorm) - return reportError(TypeError{location, UnificationTooComplex{}}); + return reportError(location, UnificationTooComplex{}); // T & queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) @@ -2195,7 +2228,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } else if (get(tail)) { - reportError(TypeError{location, GenericError{"Cannot unify variadic and generic packs"}}); + reportError(location, GenericError{"Cannot unify variadic and generic packs"}); } else if (get(tail)) { @@ -2209,7 +2242,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } else { - reportError(TypeError{location, GenericError{"Failed to unify variadic packs"}}); + reportError(location, GenericError{"Failed to unify variadic packs"}); } } @@ -2351,7 +2384,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (needle == haystack) { - reportError(TypeError{location, OccursCheckFailed{}}); + reportError(location, OccursCheckFailed{}); log.replace(needle, *singletonTypes->errorRecoveryType()); return true; @@ -2402,7 +2435,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ { if (needle == haystack) { - reportError(TypeError{location, OccursCheckFailed{}}); + reportError(location, OccursCheckFailed{}); log.replace(needle, *singletonTypes->errorRecoveryTypePack()); return true; @@ -2423,18 +2456,31 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ Unifier Unifier::makeChildUnifier() { Unifier u = Unifier{normalizer, mode, scope, location, variance, &log}; - u.anyIsTop = anyIsTop; u.normalize = normalize; + u.useScopes = useScopes; return u; } // A utility function that appends the given error to the unifier's error log. // This allows setting a breakpoint wherever the unifier reports an error. +// +// Note: report error accepts its arguments by value intentionally to reduce the stack usage of functions which call `reportError`. +void Unifier::reportError(Location location, TypeErrorData data) +{ + errors.emplace_back(std::move(location), std::move(data)); +} + +// A utility function that appends the given error to the unifier's error log. +// This allows setting a breakpoint wherever the unifier reports an error. +// +// Note: to conserve stack space in calling functions it is generally preferred to call `Unifier::reportError(Location location, TypeErrorData data)` +// instead of this method. void Unifier::reportError(TypeError err) { errors.push_back(std::move(err)); } + bool Unifier::isNonstrictMode() const { return (mode == Mode::Nonstrict) || (mode == Mode::NoCheck); @@ -2445,7 +2491,7 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId if (auto e = hasUnificationTooComplex(innerErrors)) reportError(*e); else if (!innerErrors.empty()) - reportError(TypeError{location, TypeMismatch{wantedType, givenType}}); + reportError(location, TypeMismatch{wantedType, givenType}); } void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType) diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index d93f2ccb..66436acd 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -641,8 +641,8 @@ Lexeme Lexer::readInterpolatedStringSection(Position start, Lexeme::Type formatT return brokenDoubleBrace; } - Lexeme lexemeOutput(Location(start, position()), Lexeme::InterpStringBegin, &buffer[startOffset], offset - startOffset); consume(); + Lexeme lexemeOutput(Location(start, position()), Lexeme::InterpStringBegin, &buffer[startOffset], offset - startOffset - 1); return lexemeOutput; } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 85c5f5c6..8338a04a 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -25,6 +25,9 @@ LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false) +LUAU_FASTFLAGVARIABLE(LuauTableConstructorRecovery, false) + +LUAU_FASTFLAGVARIABLE(LuauParserErrorsOnMissingDefaultTypePackArgument, false) bool lua_telemetry_parsed_out_of_range_bin_integer = false; bool lua_telemetry_parsed_out_of_range_hex_integer = false; @@ -2310,9 +2313,13 @@ AstExpr* Parser::parseTableConstructor() MatchLexeme matchBrace = lexer.current(); expectAndConsume('{', "table literal"); + unsigned lastElementIndent = 0; while (lexer.current().type != '}') { + if (FFlag::LuauTableConstructorRecovery) + lastElementIndent = lexer.current().location.begin.column; + if (lexer.current().type == '[') { MatchLexeme matchLocationBracket = lexer.current(); @@ -2357,10 +2364,14 @@ AstExpr* Parser::parseTableConstructor() { nextLexeme(); } - else + else if (FFlag::LuauTableConstructorRecovery && (lexer.current().type == '[' || lexer.current().type == Lexeme::Name) && + lexer.current().location.begin.column == lastElementIndent) { - if (lexer.current().type != '}') - break; + report(lexer.current().location, "Expected ',' after table constructor element"); + } + else if (lexer.current().type != '}') + { + break; } } @@ -2494,7 +2505,7 @@ std::pair, AstArray> Parser::parseG namePacks.push_back({name, nameLocation, typePack}); } - else if (lexer.current().type == '(') + else if (!FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument && lexer.current().type == '(') { auto [type, typePack] = parseTypeOrPackAnnotation(); @@ -2503,6 +2514,15 @@ std::pair, AstArray> Parser::parseG namePacks.push_back({name, nameLocation, typePack}); } + else if (FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument) + { + auto [type, typePack] = parseTypeOrPackAnnotation(); + + if (type) + report(type->location, "Expected type pack after '=', got type"); + + namePacks.push_back({name, nameLocation, typePack}); + } } else { diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 87e19db8..e567725e 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -978,7 +978,8 @@ int replMain(int argc, char** argv) if (compileFormat == CompileFormat::Null) printf("Compiled %d KLOC into %d KB bytecode\n", int(stats.lines / 1000), int(stats.bytecode / 1024)); else if (compileFormat == CompileFormat::CodegenNull) - printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024)); + printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024), + int(stats.codegen / 1024)); return failed ? 1 : 0; } diff --git a/CodeGen/include/Luau/AddressA64.h b/CodeGen/include/Luau/AddressA64.h new file mode 100644 index 00000000..53efd3c3 --- /dev/null +++ b/CodeGen/include/Luau/AddressA64.h @@ -0,0 +1,53 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/RegisterA64.h" + +namespace Luau +{ +namespace CodeGen +{ + +enum class AddressKindA64 : uint8_t +{ + imm, // reg + imm + reg, // reg + reg + + // TODO: + // reg + reg << shift + // reg + sext(reg) << shift + // reg + uext(reg) << shift +}; + +struct AddressA64 +{ + AddressA64(RegisterA64 base, int off = 0) + : kind(AddressKindA64::imm) + , base(base) + , offset(xzr) + , data(off) + { + LUAU_ASSERT(base.kind == KindA64::x || base == sp); + LUAU_ASSERT(off >= -256 && off < 4096); + } + + AddressA64(RegisterA64 base, RegisterA64 offset) + : kind(AddressKindA64::reg) + , base(base) + , offset(offset) + , data(0) + { + LUAU_ASSERT(base.kind == KindA64::x); + LUAU_ASSERT(offset.kind == KindA64::x); + } + + AddressKindA64 kind; + RegisterA64 base; + RegisterA64 offset; + int data; +}; + +using mem = AddressA64; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h new file mode 100644 index 00000000..9e12168a --- /dev/null +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -0,0 +1,161 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/RegisterA64.h" +#include "Luau/AddressA64.h" +#include "Luau/ConditionA64.h" +#include "Luau/Label.h" + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +class AssemblyBuilderA64 +{ +public: + explicit AssemblyBuilderA64(bool logText); + ~AssemblyBuilderA64(); + + // Moves + void mov(RegisterA64 dst, RegisterA64 src); + void mov(RegisterA64 dst, uint16_t src, int shift = 0); + void movk(RegisterA64 dst, uint16_t src, int shift = 0); + + // Arithmetics + void add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + void add(RegisterA64 dst, RegisterA64 src1, int src2); + void sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + void sub(RegisterA64 dst, RegisterA64 src1, int src2); + void neg(RegisterA64 dst, RegisterA64 src); + + // Comparisons + // Note: some arithmetic instructions also have versions that update flags (ADDS etc) but we aren't using them atm + void cmp(RegisterA64 src1, RegisterA64 src2); + void cmp(RegisterA64 src1, int src2); + + // Bitwise + // Note: shifted-register support and bitfield operations are omitted for simplicity + // TODO: support immediate arguments (they have odd encoding and forbid many values) + void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void mvn(RegisterA64 dst, RegisterA64 src); + + // Shifts + void lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void lsr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void asr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void clz(RegisterA64 dst, RegisterA64 src); + void rbit(RegisterA64 dst, RegisterA64 src); + + // Load + // Note: paired loads are currently omitted for simplicity + void ldr(RegisterA64 dst, AddressA64 src); + void ldrb(RegisterA64 dst, AddressA64 src); + void ldrh(RegisterA64 dst, AddressA64 src); + void ldrsb(RegisterA64 dst, AddressA64 src); + void ldrsh(RegisterA64 dst, AddressA64 src); + void ldrsw(RegisterA64 dst, AddressA64 src); + + // Store + void str(RegisterA64 src, AddressA64 dst); + void strb(RegisterA64 src, AddressA64 dst); + void strh(RegisterA64 src, AddressA64 dst); + + // Control flow + // Note: tbz/tbnz are currently not supported because they have 15-bit offsets and we don't support branch thunks + void b(Label& label); + void b(ConditionA64 cond, Label& label); + void cbz(RegisterA64 src, Label& label); + void cbnz(RegisterA64 src, Label& label); + void br(RegisterA64 src); + void blr(RegisterA64 src); + void ret(); + + // Address of embedded data + void adr(RegisterA64 dst, const void* ptr, size_t size); + void adr(RegisterA64 dst, uint64_t value); + void adr(RegisterA64 dst, double value); + + // Run final checks + bool finalize(); + + // Places a label at current location and returns it + Label setLabel(); + + // Assigns label position to the current location + void setLabel(Label& label); + + void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3); + + uint32_t getCodeSize() const; + + // Resulting data and code that need to be copied over one after the other + // The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code' + std::vector data; + std::vector code; + + std::string text; + + const bool logText = false; + +private: + // Instruction archetypes + void place0(const char* name, uint32_t word); + void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0); + void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op, uint8_t op2 = 0); + void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2); + void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op); + void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op); + void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0); + void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size); + void placeBC(const char* name, Label& label, uint8_t op, uint8_t cond); + void placeBCR(const char* name, Label& label, uint8_t op, RegisterA64 cond); + void placeBR(const char* name, RegisterA64 src, uint32_t op); + void placeADR(const char* name, RegisterA64 src, uint8_t op); + + void place(uint32_t word); + + void patchLabel(Label& label); + void patchImm19(uint32_t location, int value); + + void commit(); + LUAU_NOINLINE void extend(); + + // Data + size_t allocateData(size_t size, size_t align); + + // Logging of assembly in text form + LUAU_NOINLINE void log(const char* opcode); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, int src2); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, int src, int shift = 0); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, AddressA64 src); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 src, Label label); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 src); + LUAU_NOINLINE void log(const char* opcode, Label label); + LUAU_NOINLINE void log(Label label); + LUAU_NOINLINE void log(RegisterA64 reg); + LUAU_NOINLINE void log(AddressA64 addr); + + uint32_t nextLabel = 1; + std::vector