diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h new file mode 100644 index 00000000..c62166e2 --- /dev/null +++ b/Analysis/include/Luau/Constraint.h @@ -0,0 +1,82 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Location.h" +#include "Luau/NotNull.h" +#include "Luau/Variant.h" + +#include +#include + +namespace Luau +{ + +struct Scope2; +struct TypeVar; +using TypeId = const TypeVar*; + +struct TypePackVar; +using TypePackId = const TypePackVar*; + +// subType <: superType +struct SubtypeConstraint +{ + TypeId subType; + TypeId superType; +}; + +// subPack <: superPack +struct PackSubtypeConstraint +{ + TypePackId subPack; + TypePackId superPack; +}; + +// subType ~ gen superType +struct GeneralizationConstraint +{ + TypeId generalizedType; + TypeId sourceType; + Scope2* scope; +}; + +// subType ~ inst superType +struct InstantiationConstraint +{ + TypeId subType; + TypeId superType; +}; + +using ConstraintV = Variant; +using ConstraintPtr = std::unique_ptr; + +struct Constraint +{ + Constraint(ConstraintV&& c, Location location); + + Constraint(const Constraint&) = delete; + Constraint& operator=(const Constraint&) = delete; + + ConstraintV c; + Location location; + std::vector> dependencies; +}; + +inline Constraint& asMutable(const Constraint& c) +{ + return const_cast(c); +} + +template +T* getMutable(Constraint& c) +{ + return ::Luau::get_if(&c.c); +} + +template +const T* get(const Constraint& c) +{ + return getMutable(asMutable(c)); +} + +} // namespace Luau diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 4234f2f6..da774a2a 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -4,9 +4,12 @@ #include #include +#include #include "Luau/Ast.h" +#include "Luau/Constraint.h" #include "Luau/Module.h" +#include "Luau/NotNull.h" #include "Luau/Symbol.h" #include "Luau/TypeVar.h" #include "Luau/Variant.h" @@ -14,69 +17,6 @@ namespace Luau { -struct Scope2; - -// subType <: superType -struct SubtypeConstraint -{ - TypeId subType; - TypeId superType; -}; - -// subPack <: superPack -struct PackSubtypeConstraint -{ - TypePackId subPack; - TypePackId superPack; -}; - -// subType ~ gen superType -struct GeneralizationConstraint -{ - TypeId subType; - TypeId superType; - Scope2* scope; -}; - -// subType ~ inst superType -struct InstantiationConstraint -{ - TypeId subType; - TypeId superType; -}; - -using ConstraintV = Variant; -using ConstraintPtr = std::unique_ptr; - -struct Constraint -{ - Constraint(ConstraintV&& c); - Constraint(ConstraintV&& c, std::vector dependencies); - - Constraint(const Constraint&) = delete; - Constraint& operator=(const Constraint&) = delete; - - ConstraintV c; - std::vector dependencies; -}; - -inline Constraint& asMutable(const Constraint& c) -{ - return const_cast(c); -} - -template -T* getMutable(Constraint& c) -{ - return ::Luau::get_if(&c.c); -} - -template -const T* get(const Constraint& c) -{ - return getMutable(asMutable(c)); -} - struct Scope2 { // The parent scope of this scope. Null if there is no parent (i.e. this @@ -102,6 +42,11 @@ struct ConstraintGraphBuilder TypeArena* const arena; // The root scope of the module we're generating constraints for. Scope2* rootScope; + // A mapping of AST node to TypeId. + DenseHashMap astTypes{nullptr}; + // A mapping of AST node to TypePackId. + DenseHashMap astTypePacks{nullptr}; + DenseHashMap astOriginalCallTypes{nullptr}; explicit ConstraintGraphBuilder(TypeArena* arena); @@ -128,8 +73,9 @@ struct ConstraintGraphBuilder * Adds a new constraint with no dependencies to a given scope. * @param scope the scope to add the constraint to. Must not be null. * @param cv the constraint variant to add. + * @param location the location to attribute to the constraint. */ - void addConstraint(Scope2* scope, ConstraintV cv); + void addConstraint(Scope2* scope, ConstraintV cv, Location location); /** * Adds a constraint to a given scope. @@ -148,15 +94,48 @@ struct ConstraintGraphBuilder void visit(Scope2* scope, AstStat* stat); void visit(Scope2* scope, AstStatBlock* block); void visit(Scope2* scope, AstStatLocal* local); - void visit(Scope2* scope, AstStatLocalFunction* local); - void visit(Scope2* scope, AstStatReturn* local); + void visit(Scope2* scope, AstStatLocalFunction* function); + void visit(Scope2* scope, AstStatFunction* function); + void visit(Scope2* scope, AstStatReturn* ret); + void visit(Scope2* scope, AstStatAssign* assign); + void visit(Scope2* scope, AstStatIf* ifStatement); + + TypePackId checkExprList(Scope2* scope, const AstArray& exprs); TypePackId checkPack(Scope2* scope, AstArray exprs); TypePackId checkPack(Scope2* scope, AstExpr* expr); + /** + * Checks an expression that is expected to evaluate to one type. + * @param scope the scope the expression is contained within. + * @param expr the expression to check. + * @return the type of the expression. + */ TypeId check(Scope2* scope, AstExpr* expr); + + TypeId checkExprTable(Scope2* scope, AstExprTable* expr); + TypeId check(Scope2* scope, AstExprIndexName* indexName); + + std::pair checkFunctionSignature(Scope2* parent, AstExprFunction* fn); + + /** + * Checks the body of a function expression. + * @param scope the interior scope of the body of the function. + * @param fn the function expression to check. + */ + void checkFunctionBody(Scope2* scope, AstExprFunction* fn); }; -std::vector collectConstraints(Scope2* rootScope); +/** + * Collects a vector of borrowed constraints from the scope and all its child + * scopes. It is important to only call this function when you're done adding + * constraints to the scope or its descendants, lest the borrowed pointers + * become invalid due to a container reallocation. + * @param rootScope the root scope of the scope graph to collect constraints + * from. + * @return a list of pointers to constraints contained within the scope graph. + * None of these pointers should be null. + */ +std::vector> collectConstraints(Scope2* rootScope); } // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 85006e68..7e6d4461 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -4,7 +4,8 @@ #include "Luau/Error.h" #include "Luau/Variant.h" -#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/Constraint.h" +#include "Luau/ConstraintSolverLogger.h" #include "Luau/TypeVar.h" #include @@ -20,39 +21,81 @@ struct ConstraintSolver { TypeArena* arena; InternalErrorReporter iceReporter; - // The entire set of constraints that the solver is trying to resolve. - std::vector constraints; + // The entire set of constraints that the solver is trying to resolve. It + // is important to not add elements to this vector, lest the underlying + // storage that we retain pointers to be mutated underneath us. + const std::vector> constraints; Scope2* rootScope; - std::vector errors; // This includes every constraint that has not been fully solved. // A constraint can be both blocked and unsolved, for instance. - std::unordered_set unsolvedConstraints; + std::vector> unsolvedConstraints; // A mapping of constraint pointer to how many things the constraint is // blocked on. Can be empty or 0 for constraints that are not blocked on // anything. - std::unordered_map blockedConstraints; + std::unordered_map, size_t> blockedConstraints; // A mapping of type/pack pointers to the constraints they block. - std::unordered_map> blocked; + std::unordered_map>> blocked; + + ConstraintSolverLogger logger; explicit ConstraintSolver(TypeArena* arena, Scope2* rootScope); /** * Attempts to dispatch all pending constraints and reach a type solution - * that satisfies all of the constraints, recording any errors that are - * encountered. + * that satisfies all of the constraints. **/ void run(); bool done(); - bool tryDispatch(const Constraint* c); - bool tryDispatch(const SubtypeConstraint& c); - bool tryDispatch(const PackSubtypeConstraint& c); - bool tryDispatch(const GeneralizationConstraint& c); - bool tryDispatch(const InstantiationConstraint& c, const Constraint* constraint); + bool tryDispatch(NotNull c, bool force); + bool tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force); + void block(NotNull target, NotNull constraint); + /** + * Block a constraint on the resolution of a TypeVar. + * @returns false always. This is just to allow tryDispatch to return the result of block() + */ + bool block(TypeId target, NotNull constraint); + bool block(TypePackId target, NotNull constraint); + + void unblock(NotNull progressed); + void unblock(TypeId progressed); + void unblock(TypePackId progressed); + + /** + * @returns true if the TypeId is in a blocked state. + */ + bool isBlocked(TypeId ty); + + /** + * Returns whether the constraint is blocked on anything. + * @param constraint the constraint to check. + */ + bool isBlocked(NotNull constraint); + + /** + * Creates a new Unifier and performs a single unification operation. Commits + * the result. + * @param subType the sub-type to unify. + * @param superType the super-type to unify. + */ + void unify(TypeId subType, TypeId superType, Location location); + + /** + * Creates a new Unifier and performs a single unification operation. Commits + * the result. + * @param subPack the sub-type pack to unify. + * @param superPack the super-type pack to unify. + */ + void unify(TypePackId subPack, TypePackId superPack, Location location); + +private: /** * Marks a constraint as being blocked on a type or type pack. The constraint * solver will not attempt to dispatch blocked constraints until their @@ -60,10 +103,7 @@ struct ConstraintSolver * @param target the type or type pack pointer that the constraint is blocked on. * @param constraint the constraint to block. **/ - void block_(BlockedConstraintId target, const Constraint* constraint); - void block(const Constraint* target, const Constraint* constraint); - void block(TypeId target, const Constraint* constraint); - void block(TypePackId target, const Constraint* constraint); + void block_(BlockedConstraintId target, NotNull constraint); /** * Informs the solver that progress has been made on a type or type pack. The @@ -72,33 +112,6 @@ struct ConstraintSolver * @param progressed the type or type pack pointer that has progressed. **/ void unblock_(BlockedConstraintId progressed); - void unblock(const Constraint* progressed); - void unblock(TypeId progressed); - void unblock(TypePackId progressed); - - /** - * Returns whether the constraint is blocked on anything. - * @param constraint the constraint to check. - */ - bool isBlocked(const Constraint* constraint); - - void reportErrors(const std::vector& errors); - - /** - * Creates a new Unifier and performs a single unification operation. Commits - * the result and reports errors if necessary. - * @param subType the sub-type to unify. - * @param superType the super-type to unify. - */ - void unify(TypeId subType, TypeId superType); - - /** - * Creates a new Unifier and performs a single unification operation. Commits - * the result and reports errors if necessary. - * @param subPack the sub-type pack to unify. - * @param superPack the super-type pack to unify. - */ - void unify(TypePackId subPack, TypePackId superPack); }; void dump(Scope2* rootScope, struct ToStringOptions& opts); diff --git a/Analysis/include/Luau/ConstraintSolverLogger.h b/Analysis/include/Luau/ConstraintSolverLogger.h new file mode 100644 index 00000000..2b195d71 --- /dev/null +++ b/Analysis/include/Luau/ConstraintSolverLogger.h @@ -0,0 +1,26 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/ToString.h" + +#include +#include +#include + +namespace Luau +{ + +struct ConstraintSolverLogger +{ + std::string compileOutput(); + void captureBoundarySnapshot(const Scope2* rootScope, std::vector>& unsolvedConstraints); + void prepareStepSnapshot(const Scope2* rootScope, NotNull current, std::vector>& unsolvedConstraints); + void commitPreparedStepSnapshot(); + +private: + std::vector snapshots; + std::optional preparedSnapshot; + ToStringOptions opts; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 58be0ffe..f4226cc1 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -66,7 +66,7 @@ struct SourceNode } ModuleName name; - std::unordered_set requires; + std::unordered_set requireSet; std::vector> requireLocations; bool dirtySourceModule = true; bool dirtyModule = true; @@ -186,7 +186,7 @@ public: std::unordered_map sourceNodes; std::unordered_map sourceModules; - std::unordered_map requires; + std::unordered_map requireTrace; Stats stats = {}; }; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index f6e077dc..e979b3f0 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -69,6 +69,7 @@ struct Module std::vector>> scope2s; // never empty DenseHashMap astTypes{nullptr}; + DenseHashMap astTypePacks{nullptr}; DenseHashMap astExpectedTypes{nullptr}; DenseHashMap astOriginalCallTypes{nullptr}; DenseHashMap astOverloadResolvedTypes{nullptr}; diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 262b54b2..d4c7698b 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -10,6 +10,7 @@ namespace Luau struct InternalErrorReporter; bool isSubtype(TypeId superTy, TypeId subTy, InternalErrorReporter& ice); +bool isSubtype(TypePackId superTy, TypePackId subTy, InternalErrorReporter& ice); std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice); std::pair normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice); diff --git a/Analysis/include/Luau/NotNull.h b/Analysis/include/Luau/NotNull.h index 3d05fdea..f6043e9c 100644 --- a/Analysis/include/Luau/NotNull.h +++ b/Analysis/include/Luau/NotNull.h @@ -9,20 +9,22 @@ namespace Luau { /** A non-owning, non-null pointer to a T. - * - * A NotNull is notionally identical to a T* with the added restriction that it - * can never store nullptr. - * - * The sole conversion rule from T* to NotNull is the single-argument constructor, which - * is intentionally marked explicit. This constructor performs a runtime test to verify - * that the passed pointer is never nullptr. - * - * Pointer arithmetic, increment, decrement, and array indexing are all forbidden. - * - * An implicit coersion from NotNull to T* is afforded, as are the pointer indirection and member - * access operators. (*p and p->prop) * - * The explicit delete statement is permitted on a NotNull through this implicit conversion. + * A NotNull is notionally identical to a T* with the added restriction that + * it can never store nullptr. + * + * The sole conversion rule from T* to NotNull is the single-argument + * constructor, which is intentionally marked explicit. This constructor + * performs a runtime test to verify that the passed pointer is never nullptr. + * + * Pointer arithmetic, increment, decrement, and array indexing are all + * forbidden. + * + * An implicit coersion from NotNull to T* is afforded, as are the pointer + * indirection and member access operators. (*p and p->prop) + * + * The explicit delete statement is permitted (but not recommended) on a + * NotNull through this implicit conversion. */ template struct NotNull @@ -36,6 +38,11 @@ struct NotNull explicit NotNull(std::nullptr_t) = delete; void operator=(std::nullptr_t) = delete; + template + NotNull(NotNull other) + : ptr(other.get()) + {} + operator T*() const noexcept { return ptr; @@ -56,6 +63,12 @@ struct NotNull T& operator+(int) = delete; T& operator-(int) = delete; + T* get() const noexcept + { + return ptr; + } + +private: T* ptr; }; @@ -68,7 +81,7 @@ template struct hash> { size_t operator()(const Luau::NotNull& p) const { - return std::hash()(p.ptr); + return std::hash()(p.get()); } }; diff --git a/Analysis/include/Luau/Quantify.h b/Analysis/include/Luau/Quantify.h index b32d684e..f46f0cb5 100644 --- a/Analysis/include/Luau/Quantify.h +++ b/Analysis/include/Luau/Quantify.h @@ -6,9 +6,10 @@ namespace Luau { +struct TypeArena; struct Scope2; void quantify(TypeId ty, TypeLevel level); -void quantify(TypeId ty, Scope2* scope); +TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope); } // namespace Luau diff --git a/Analysis/include/Luau/RequireTracer.h b/Analysis/include/Luau/RequireTracer.h index c25545f5..f69d133e 100644 --- a/Analysis/include/Luau/RequireTracer.h +++ b/Analysis/include/Luau/RequireTracer.h @@ -19,7 +19,7 @@ struct RequireTraceResult { DenseHashMap exprs{nullptr}; - std::vector> requires; + std::vector> requireList; }; RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName); diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h new file mode 100644 index 00000000..a6c7a3e3 --- /dev/null +++ b/Analysis/include/Luau/TypeChecker2.h @@ -0,0 +1,13 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Module.h" + +namespace Luau +{ + +void check(const SourceModule& sourceModule, Module* module); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 183cc053..28adc9d9 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -138,25 +138,25 @@ struct TypeChecker void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted); - ExprResult checkExpr( + WithPredicate checkExpr( const ScopePtr& scope, const AstExpr& expr, std::optional expectedType = std::nullopt, bool forceSingleton = false); - ExprResult checkExpr(const ScopePtr& scope, const AstExprLocal& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprGlobal& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprVarargs& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprCall& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIndexName& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType = std::nullopt); - ExprResult checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); - ExprResult checkExpr(const ScopePtr& scope, const AstExprUnary& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprLocal& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprGlobal& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprVarargs& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprCall& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprIndexName& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType = std::nullopt); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprUnary& expr); TypeId checkRelationalOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); TypeId checkBinaryOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); - ExprResult checkExpr(const ScopePtr& scope, const AstExprBinary& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprError& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprBinary& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprError& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector>& fieldTypes, std::optional expectedType); @@ -179,11 +179,11 @@ struct TypeChecker void checkArgumentList( const ScopePtr& scope, Unifier& state, TypePackId paramPack, TypePackId argPack, const std::vector& argLocations); - ExprResult checkExprPack(const ScopePtr& scope, const AstExpr& expr); - ExprResult checkExprPack(const ScopePtr& scope, const AstExprCall& expr); + WithPredicate checkExprPack(const ScopePtr& scope, const AstExpr& expr); + WithPredicate checkExprPack(const ScopePtr& scope, const AstExprCall& expr); std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); - std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector* argLocations, const ExprResult& argListResult, + std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, + TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, const std::vector& errors); @@ -191,7 +191,7 @@ struct TypeChecker const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, const std::vector& errors); - ExprResult checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, + WithPredicate checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, bool substituteFreeForNil = false, const std::vector& lhsAnnotations = {}, const std::vector>& expectedTypes = {}); @@ -234,7 +234,7 @@ struct TypeChecker ErrorVec canUnify(TypeId subTy, TypeId superTy, const Location& location); ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const Location& location); - void unifyLowerBound(TypePackId subTy, TypePackId superTy, const Location& location); + void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const Location& location); std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location); @@ -412,7 +412,6 @@ public: const TypeId booleanType; const TypeId threadType; const TypeId anyType; - const TypeId optionalNumberType; const TypePackId anyTypePack; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index b59e7c64..ff7708d4 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -84,6 +84,24 @@ using Tags = std::vector; using ModuleName = std::string; +/** A TypeVar that cannot be computed. + * + * BlockedTypeVars essentially serve as a way to encode partial ordering on the + * constraint graph. Until a BlockedTypeVar is unblocked by its owning + * constraint, nothing at all can be said about it. Constraints that need to + * process a BlockedTypeVar cannot be dispatched. + * + * Whenever a BlockedTypeVar is added to the graph, we also record a constraint + * that will eventually unblock it. + */ +struct BlockedTypeVar +{ + BlockedTypeVar(); + int index; + + static int nextIndex; +}; + struct PrimitiveTypeVar { enum Type @@ -231,29 +249,29 @@ struct FunctionDefinition // TODO: Do we actually need this? We'll find out later if we can delete this. // Does not exactly belong in TypeVar.h, but this is the only way to appease the compiler. template -struct ExprResult +struct WithPredicate { T type; PredicateVec predicates; }; -using MagicFunction = std::function>( - struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, ExprResult)>; +using MagicFunction = std::function>( + struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate)>; struct FunctionTypeVar { // Global monomorphic function - FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn = {}, bool hasSelf = false); + FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Global polymorphic function - FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retType, + FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Local monomorphic function - FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retType, std::optional defn = {}, bool hasSelf = false); + FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Local polymorphic function - FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retType, + FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); TypeLevel level; @@ -263,7 +281,7 @@ struct FunctionTypeVar std::vector genericPacks; TypePackId argTypes; std::vector> argNames; - TypePackId retType; + TypePackId retTypes; std::optional definition; MagicFunction magicFunction = nullptr; // Function pointer, can be nullptr. bool hasSelf; @@ -442,7 +460,7 @@ struct LazyTypeVar using ErrorTypeVar = Unifiable::Error; -using TypeVariant = Unifiable::Variant; struct TypeVar final @@ -555,7 +573,6 @@ struct SingletonTypes const TypeId trueType; const TypeId falseType; const TypeId anyType; - const TypeId optionalNumberType; const TypePackId anyTypePack; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 627b52ca..b51a485e 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -110,7 +110,7 @@ private: void tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy); public: - void unifyLowerBound(TypePackId subTy, TypePackId superTy); + void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel); // Report an "infinite type error" if the type "needle" already occurs within "haystack" void occursCheck(TypeId needle, TypeId haystack); diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index f3839915..642522c9 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -209,7 +209,7 @@ struct GenericTypeVarVisitor if (visit(ty, *ftv)) { traverse(ftv->argTypes); - traverse(ftv->retType); + traverse(ftv->retTypes); } } diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index a8319c59..8a63901f 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -13,7 +13,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); LUAU_FASTFLAG(LuauSelfCallAutocompleteFix2) static const std::unordered_set kStatementStartingKeywords = { @@ -268,14 +267,14 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ auto checkFunctionType = [typeArena, &canUnify, &expectedType](const FunctionTypeVar* ftv) { if (FFlag::LuauSelfCallAutocompleteFix2) { - if (std::optional firstRetTy = first(ftv->retType)) + if (std::optional firstRetTy = first(ftv->retTypes)) return checkTypeMatch(typeArena, *firstRetTy, expectedType); return false; } else { - auto [retHead, retTail] = flatten(ftv->retType); + auto [retHead, retTail] = flatten(ftv->retTypes); if (!retHead.empty() && canUnify(retHead.front(), expectedType)) return true; @@ -454,7 +453,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId } else if (auto indexFunction = get(followed)) { - std::optional indexFunctionResult = first(indexFunction->retType); + std::optional indexFunctionResult = first(indexFunction->retTypes); if (indexFunctionResult) autocompleteProps(module, typeArena, rootTy, *indexFunctionResult, indexType, nodes, result, seen); } @@ -493,7 +492,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId autocompleteProps(module, typeArena, rootTy, followed, indexType, nodes, result, seen); else if (auto indexFunction = get(followed)) { - std::optional indexFunctionResult = first(indexFunction->retType); + std::optional indexFunctionResult = first(indexFunction->retTypes); if (indexFunctionResult) autocompleteProps(module, typeArena, rootTy, *indexFunctionResult, indexType, nodes, result, seen); } @@ -742,7 +741,7 @@ static std::optional findTypeElementAt(AstType* astType, TypeId ty, Posi if (auto element = findTypeElementAt(type->argTypes, ftv->argTypes, position)) return element; - if (auto element = findTypeElementAt(type->returnTypes, ftv->retType, position)) + if (auto element = findTypeElementAt(type->returnTypes, ftv->retTypes, position)) return element; } @@ -958,7 +957,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi { if (const FunctionTypeVar* ftv = get(follow(*it))) { - if (auto ty = tryGetTypePackTypeAt(ftv->retType, tailPos)) + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, tailPos)) inferredType = *ty; } } @@ -1050,7 +1049,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi { if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node)) { - if (auto ty = tryGetTypePackTypeAt(ftv->retType, i)) + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, i)) tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); } @@ -1067,7 +1066,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi { if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node)) { - if (auto ty = tryGetTypePackTypeAt(ftv->retType, ~0u)) + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, ~0u)) tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); } } @@ -1266,7 +1265,7 @@ static bool autocompleteIfElseExpression( if (!parent) return false; - if (FFlag::LuauIfElseExprFixCompletionIssue && node->is()) + if (node->is()) { // Don't try to complete when the current node is an if-else expression (i.e. only try to complete when the node is a child of an if-else // expression. diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 98737b43..2f57e23c 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -19,16 +19,16 @@ LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) namespace Luau { -static std::optional> magicFunctionSelect( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionSetMetaTable( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionAssert( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionPack( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionRequire( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); +static std::optional> magicFunctionSelect( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionSetMetaTable( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionAssert( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionPack( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionRequire( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); TypeId makeUnion(TypeArena& arena, std::vector&& types) { @@ -263,10 +263,10 @@ void registerBuiltinTypes(TypeChecker& typeChecker) attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); } -static std::optional> magicFunctionSelect( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionSelect( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, _predicates] = exprResult; + auto [paramPack, _predicates] = withPredicate; (void)scope; @@ -287,10 +287,10 @@ static std::optional> magicFunctionSelect( if (size_t(offset) < v.size()) { std::vector result(v.begin() + offset, v.end()); - return ExprResult{typechecker.currentModule->internalTypes.addTypePack(TypePack{std::move(result), tail})}; + return WithPredicate{typechecker.currentModule->internalTypes.addTypePack(TypePack{std::move(result), tail})}; } else if (tail) - return ExprResult{*tail}; + return WithPredicate{*tail}; } typechecker.reportError(TypeError{arg1->location, GenericError{"bad argument #1 to select (index out of range)"}}); @@ -298,16 +298,16 @@ static std::optional> magicFunctionSelect( else if (AstExprConstantString* str = arg1->as()) { if (str->value.size == 1 && str->value.data[0] == '#') - return ExprResult{typechecker.currentModule->internalTypes.addTypePack({typechecker.numberType})}; + return WithPredicate{typechecker.currentModule->internalTypes.addTypePack({typechecker.numberType})}; } return std::nullopt; } -static std::optional> magicFunctionSetMetaTable( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionSetMetaTable( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, _predicates] = exprResult; + auto [paramPack, _predicates] = withPredicate; TypeArena& arena = typechecker.currentModule->internalTypes; @@ -343,7 +343,7 @@ static std::optional> magicFunctionSetMetaTable( if (FFlag::LuauSetMetaTableArgsCheck && expr.args.size < 1) { - return ExprResult{}; + return WithPredicate{}; } if (!FFlag::LuauSetMetaTableArgsCheck || !expr.self) @@ -356,7 +356,7 @@ static std::optional> magicFunctionSetMetaTable( } } - return ExprResult{arena.addTypePack({mtTy})}; + return WithPredicate{arena.addTypePack({mtTy})}; } } else if (get(target) || get(target) || isTableIntersection(target)) @@ -367,13 +367,13 @@ static std::optional> magicFunctionSetMetaTable( typechecker.reportError(TypeError{expr.location, GenericError{"setmetatable should take a table"}}); } - return ExprResult{arena.addTypePack({target})}; + return WithPredicate{arena.addTypePack({target})}; } -static std::optional> magicFunctionAssert( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionAssert( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, predicates] = exprResult; + auto [paramPack, predicates] = withPredicate; TypeArena& arena = typechecker.currentModule->internalTypes; @@ -382,7 +382,7 @@ static std::optional> magicFunctionAssert( { std::optional fst = first(*tail); if (!fst) - return ExprResult{paramPack}; + return WithPredicate{paramPack}; head.push_back(*fst); } @@ -397,13 +397,13 @@ static std::optional> magicFunctionAssert( head[0] = *newhead; } - return ExprResult{arena.addTypePack(TypePack{std::move(head), tail})}; + return WithPredicate{arena.addTypePack(TypePack{std::move(head), tail})}; } -static std::optional> magicFunctionPack( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionPack( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, _predicates] = exprResult; + auto [paramPack, _predicates] = withPredicate; TypeArena& arena = typechecker.currentModule->internalTypes; @@ -436,7 +436,7 @@ static std::optional> magicFunctionPack( TypeId packedTable = arena.addType( TableTypeVar{{{"n", {typechecker.numberType}}}, TableIndexer(typechecker.numberType, result), scope->level, TableState::Sealed}); - return ExprResult{arena.addTypePack({packedTable})}; + return WithPredicate{arena.addTypePack({packedTable})}; } static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) @@ -461,8 +461,8 @@ static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) return good; } -static std::optional> magicFunctionRequire( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionRequire( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { TypeArena& arena = typechecker.currentModule->internalTypes; @@ -476,7 +476,7 @@ static std::optional> magicFunctionRequire( return std::nullopt; if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, expr)) - return ExprResult{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; + return WithPredicate{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; return std::nullopt; } diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 9180f309..248262ce 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -47,6 +47,7 @@ struct TypeCloner void operator()(const Unifiable::Generic& t); void operator()(const Unifiable::Bound& t); void operator()(const Unifiable::Error& t); + void operator()(const BlockedTypeVar& t); void operator()(const PrimitiveTypeVar& t); void operator()(const ConstrainedTypeVar& t); void operator()(const SingletonTypeVar& t); @@ -158,6 +159,11 @@ void TypeCloner::operator()(const Unifiable::Error& t) defaultClone(t); } +void TypeCloner::operator()(const BlockedTypeVar& t) +{ + defaultClone(t); +} + void TypeCloner::operator()(const PrimitiveTypeVar& t) { defaultClone(t); @@ -200,7 +206,7 @@ void TypeCloner::operator()(const FunctionTypeVar& t) ftv->tags = t.tags; ftv->argTypes = clone(t.argTypes, dest, cloneState); ftv->argNames = t.argNames; - ftv->retType = clone(t.retType, dest, cloneState); + ftv->retTypes = clone(t.retTypes, dest, cloneState); ftv->hasNoGenerics = t.hasNoGenerics; } @@ -391,7 +397,7 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log) if (const FunctionTypeVar* ftv = get(ty)) { - FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; + FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; clone.generics = ftv->generics; clone.genericPacks = ftv->genericPacks; clone.magicFunction = ftv->magicFunction; diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp new file mode 100644 index 00000000..6cb0e4ee --- /dev/null +++ b/Analysis/src/Constraint.cpp @@ -0,0 +1,14 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Constraint.h" + +namespace Luau +{ + +Constraint::Constraint(ConstraintV&& c, Location location) + : c(std::move(c)) + , location(location) +{ +} + +} // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index c8f77ddf..fa627e7a 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -5,16 +5,7 @@ namespace Luau { -Constraint::Constraint(ConstraintV&& c) - : c(std::move(c)) -{ -} - -Constraint::Constraint(ConstraintV&& c, std::vector dependencies) - : c(std::move(c)) - , dependencies(dependencies) -{ -} +const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp std::optional Scope2::lookup(Symbol sym) { @@ -68,10 +59,10 @@ Scope2* ConstraintGraphBuilder::childScope(Location location, Scope2* parent) return borrow; } -void ConstraintGraphBuilder::addConstraint(Scope2* scope, ConstraintV cv) +void ConstraintGraphBuilder::addConstraint(Scope2* scope, ConstraintV cv, Location location) { LUAU_ASSERT(scope); - scope->constraints.emplace_back(new Constraint{std::move(cv)}); + scope->constraints.emplace_back(new Constraint{std::move(cv), location}); } void ConstraintGraphBuilder::addConstraint(Scope2* scope, std::unique_ptr c) @@ -99,10 +90,18 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStat* stat) visit(scope, s); else if (auto s = stat->as()) visit(scope, s); + else if (auto f = stat->as()) + visit(scope, f); else if (auto f = stat->as()) visit(scope, f); else if (auto r = stat->as()) visit(scope, r); + else if (auto a = stat->as()) + visit(scope, a); + else if (auto e = stat->as()) + checkPack(scope, e->expr); + else if (auto i = stat->as()) + visit(scope, i); else LUAU_ASSERT(0); } @@ -121,12 +120,30 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local) scope->bindings[local] = ty; } - for (size_t i = 0; i < local->vars.size; ++i) + for (size_t i = 0; i < local->values.size; ++i) { - if (i < local->values.size) + if (local->values.data[i]->is()) + { + // HACK: we leave nil-initialized things floating under the assumption that they will later be populated. + // See the test TypeInfer/infer_locals_with_nil_value. + // Better flow awareness should make this obsolete. + } + else if (i == local->values.size - 1) + { + TypePackId exprPack = checkPack(scope, local->values.data[i]); + + if (i < local->vars.size) + { + std::vector tailValues{varTypes.begin() + i, varTypes.end()}; + TypePackId tailPack = arena->addTypePack(std::move(tailValues)); + addConstraint(scope, PackSubtypeConstraint{exprPack, tailPack}, local->location); + } + } + else { TypeId exprType = check(scope, local->values.data[i]); - addConstraint(scope, SubtypeConstraint{varTypes[i], exprType}); + if (i < varTypes.size()) + addConstraint(scope, SubtypeConstraint{varTypes[i], exprType}, local->vars.data[i]->location); } } } @@ -138,7 +155,7 @@ void addConstraints(Constraint* constraint, Scope2* scope) scope->constraints.reserve(scope->constraints.size() + scope->constraints.size()); for (const auto& c : scope->constraints) - constraint->dependencies.push_back(c.get()); + constraint->dependencies.push_back(NotNull{c.get()}); for (Scope2* childScope : scope->children) addConstraints(constraint, childScope); @@ -155,31 +172,75 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocalFunction* function TypeId functionType = nullptr; auto ty = scope->lookup(function->name); - LUAU_ASSERT(!ty.has_value()); // The parser ensures that every local function has a distinct Symbol for its name. - - functionType = freshType(scope); - scope->bindings[function->name] = functionType; - - Scope2* innerScope = childScope(function->func->body->location, scope); - TypePackId returnType = freshTypePack(scope); - innerScope->returnType = returnType; - - std::vector argTypes; - - for (AstLocal* local : function->func->args) + if (ty.has_value()) { - TypeId t = freshType(innerScope); - argTypes.push_back(t); - innerScope->bindings[local] = t; // TODO annotations + // TODO: This is duplicate definition of a local function. Is this allowed? + functionType = *ty; + } + else + { + functionType = arena->addType(BlockedTypeVar{}); + scope->bindings[function->name] = functionType; } - for (AstStat* stat : function->func->body->body) - visit(innerScope, stat); + auto [actualFunctionType, innerScope] = checkFunctionSignature(scope, function->func); + innerScope->bindings[function->name] = actualFunctionType; - FunctionTypeVar actualFunction{arena->addTypePack(argTypes), returnType}; - TypeId actualFunctionType = arena->addType(std::move(actualFunction)); + checkFunctionBody(innerScope, function->func); - std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}}; + std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}, function->location}}; + addConstraints(c.get(), innerScope); + + addConstraint(scope, std::move(c)); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function) +{ + // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. + // With or without self + + TypeId functionType = nullptr; + + auto [actualFunctionType, innerScope] = checkFunctionSignature(scope, function->func); + + if (AstExprLocal* localName = function->name->as()) + { + std::optional existingFunctionTy = scope->lookup(localName->local); + if (existingFunctionTy) + { + // Duplicate definition + functionType = *existingFunctionTy; + } + else + { + functionType = arena->addType(BlockedTypeVar{}); + scope->bindings[localName->local] = functionType; + } + innerScope->bindings[localName->local] = actualFunctionType; + } + else if (AstExprGlobal* globalName = function->name->as()) + { + std::optional existingFunctionTy = scope->lookup(globalName->name); + if (existingFunctionTy) + { + // Duplicate definition + functionType = *existingFunctionTy; + } + else + { + functionType = arena->addType(BlockedTypeVar{}); + rootScope->bindings[globalName->name] = functionType; + } + innerScope->bindings[globalName->name] = actualFunctionType; + } + else if (AstExprIndexName* indexName = function->name->as()) + { + LUAU_ASSERT(0); // not yet implemented + } + + checkFunctionBody(innerScope, function->func); + + std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}, function->location}}; addConstraints(c.get(), innerScope); addConstraint(scope, std::move(c)); @@ -190,7 +251,7 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatReturn* ret) LUAU_ASSERT(scope); TypePackId exprTypes = checkPack(scope, ret->list); - addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType}); + addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType}, ret->location); } void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block) @@ -201,6 +262,28 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block) visit(scope, stat); } +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatAssign* assign) +{ + TypePackId varPackId = checkExprList(scope, assign->vars); + TypePackId valuePack = checkPack(scope, assign->values); + + addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId}, assign->location); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatIf* ifStatement) +{ + check(scope, ifStatement->condition); + + Scope2* thenScope = childScope(ifStatement->thenbody->location, scope); + visit(thenScope, ifStatement->thenbody); + + if (ifStatement->elsebody) + { + Scope2* elseScope = childScope(ifStatement->elsebody->location, scope); + visit(elseScope, ifStatement->elsebody); + } +} + TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray exprs) { LUAU_ASSERT(scope); @@ -224,75 +307,256 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray e return arena->addTypePack(TypePack{std::move(types), last}); } +TypePackId ConstraintGraphBuilder::checkExprList(Scope2* scope, const AstArray& exprs) +{ + TypePackId result = arena->addTypePack({}); + TypePack* resultPack = getMutable(result); + LUAU_ASSERT(resultPack); + + for (size_t i = 0; i < exprs.size; ++i) + { + AstExpr* expr = exprs.data[i]; + if (i < exprs.size - 1) + resultPack->head.push_back(check(scope, expr)); + else + resultPack->tail = checkPack(scope, expr); + } + + if (resultPack->head.empty() && resultPack->tail) + return *resultPack->tail; + else + return result; +} + TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr) { LUAU_ASSERT(scope); - // TEMP TEMP TEMP HACK HACK HACK FIXME FIXME - TypeId t = check(scope, expr); - return arena->addTypePack({t}); + TypePackId result = nullptr; + + if (AstExprCall* call = expr->as()) + { + std::vector args; + + for (AstExpr* arg : call->args) + { + args.push_back(check(scope, arg)); + } + + // TODO self + + TypeId fnType = check(scope, call->func); + + astOriginalCallTypes[call->func] = fnType; + + TypeId instantiatedType = freshType(scope); + addConstraint(scope, InstantiationConstraint{instantiatedType, fnType}, expr->location); + + TypePackId rets = freshTypePack(scope); + FunctionTypeVar ftv(arena->addTypePack(TypePack{args, {}}), rets); + TypeId inferredFnType = arena->addType(ftv); + + addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType}, expr->location); + result = rets; + } + else + { + TypeId t = check(scope, expr); + result = arena->addTypePack({t}); + } + + LUAU_ASSERT(result); + astTypePacks[expr] = result; + return result; } TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr) { LUAU_ASSERT(scope); - if (auto a = expr->as()) - return singletonTypes.stringType; - else if (auto a = expr->as()) - return singletonTypes.numberType; - else if (auto a = expr->as()) - return singletonTypes.booleanType; - else if (auto a = expr->as()) - return singletonTypes.nilType; + TypeId result = nullptr; + + if (auto group = expr->as()) + result = check(scope, group->expr); + else if (expr->is()) + result = singletonTypes.stringType; + else if (expr->is()) + result = singletonTypes.numberType; + else if (expr->is()) + result = singletonTypes.booleanType; + else if (expr->is()) + result = singletonTypes.nilType; else if (auto a = expr->as()) { std::optional ty = scope->lookup(a->local); if (ty) - return *ty; + result = *ty; else - return singletonTypes.errorRecoveryType(singletonTypes.anyType); // FIXME? Record an error at this point? + result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point? + } + else if (auto g = expr->as()) + { + std::optional ty = scope->lookup(g->name); + if (ty) + result = *ty; + else + result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point? } else if (auto a = expr->as()) { - std::vector args; - - for (AstExpr* arg : a->args) + TypePackId packResult = checkPack(scope, expr); + if (auto f = first(packResult)) + return *f; + else if (get(packResult)) { - args.push_back(check(scope, arg)); + TypeId typeResult = freshType(scope); + TypePack onePack{{typeResult}, freshTypePack(scope)}; + TypePackId oneTypePack = arena->addTypePack(std::move(onePack)); + + addConstraint(scope, PackSubtypeConstraint{packResult, oneTypePack}, expr->location); + + return typeResult; } - - TypeId fnType = check(scope, a->func); - TypeId instantiatedType = freshType(scope); - addConstraint(scope, InstantiationConstraint{instantiatedType, fnType}); - - TypeId firstRet = freshType(scope); - TypePackId rets = arena->addTypePack(TypePack{{firstRet}, arena->addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})}); - FunctionTypeVar ftv(arena->addTypePack(TypePack{args, {}}), rets); - TypeId inferredFnType = arena->addType(ftv); - - addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType}); - return firstRet; + } + else if (auto a = expr->as()) + { + auto [fnType, functionScope] = checkFunctionSignature(scope, a); + checkFunctionBody(functionScope, a); + return fnType; + } + else if (auto indexName = expr->as()) + { + result = check(scope, indexName); + } + else if (auto table = expr->as()) + { + result = checkExprTable(scope, table); } else { LUAU_ASSERT(0); - return freshType(scope); + result = freshType(scope); + } + + LUAU_ASSERT(result); + astTypes[expr] = result; + return result; +} + +TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExprIndexName* indexName) +{ + TypeId obj = check(scope, indexName->expr); + TypeId result = freshType(scope); + + TableTypeVar::Props props{{indexName->index.value, Property{result}}}; + const std::optional indexer; + TableTypeVar ttv{std::move(props), indexer, TypeLevel{}, TableState::Free}; + + TypeId expectedTableType = arena->addType(std::move(ttv)); + + addConstraint(scope, SubtypeConstraint{obj, expectedTableType}, indexName->location); + + return result; +} + +TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) +{ + TypeId ty = arena->addType(TableTypeVar{}); + TableTypeVar* ttv = getMutable(ty); + LUAU_ASSERT(ttv); + + auto createIndexer = [this, scope, ttv]( + TypeId currentIndexType, TypeId currentResultType, Location itemLocation, std::optional keyLocation) { + if (!ttv->indexer) + { + TypeId indexType = this->freshType(scope); + TypeId resultType = this->freshType(scope); + ttv->indexer = TableIndexer{indexType, resultType}; + } + + addConstraint(scope, SubtypeConstraint{ttv->indexer->indexType, currentIndexType}, keyLocation ? *keyLocation : itemLocation); + addConstraint(scope, SubtypeConstraint{ttv->indexer->indexResultType, currentResultType}, itemLocation); + }; + + for (const AstExprTable::Item& item : expr->items) + { + TypeId itemTy = check(scope, item.value); + + if (item.key) + { + // Even though we don't need to use the type of the item's key if + // it's a string constant, we still want to check it to populate + // astTypes. + TypeId keyTy = check(scope, item.key); + + if (AstExprConstantString* key = item.key->as()) + { + ttv->props[key->value.begin()] = {itemTy}; + } + else + { + createIndexer(keyTy, itemTy, item.value->location, item.key->location); + } + } + else + { + TypeId numberType = singletonTypes.numberType; + createIndexer(numberType, itemTy, item.value->location, std::nullopt); + } + } + + return ty; +} + +std::pair ConstraintGraphBuilder::checkFunctionSignature(Scope2* parent, AstExprFunction* fn) +{ + Scope2* innerScope = childScope(fn->body->location, parent); + TypePackId returnType = freshTypePack(innerScope); + innerScope->returnType = returnType; + + std::vector argTypes; + + for (AstLocal* local : fn->args) + { + TypeId t = freshType(innerScope); + argTypes.push_back(t); + innerScope->bindings[local] = t; // TODO annotations + } + + FunctionTypeVar actualFunction{arena->addTypePack(argTypes), returnType}; + TypeId actualFunctionType = arena->addType(std::move(actualFunction)); + LUAU_ASSERT(actualFunctionType); + astTypes[fn] = actualFunctionType; + + return {actualFunctionType, innerScope}; +} + +void ConstraintGraphBuilder::checkFunctionBody(Scope2* scope, AstExprFunction* fn) +{ + for (AstStat* stat : fn->body->body) + visit(scope, stat); + + // If it is possible for execution to reach the end of the function, the return type must be compatible with () + + if (nullptr != getFallthrough(fn->body)) + { + TypePackId empty = arena->addTypePack({}); // TODO we could have CSG retain one of these forever + addConstraint(scope, PackSubtypeConstraint{scope->returnType, empty}, fn->body->location); } } -static void collectConstraints(std::vector& result, Scope2* scope) +void collectConstraints(std::vector>& result, Scope2* scope) { for (const auto& c : scope->constraints) - result.push_back(c.get()); + result.push_back(NotNull{c.get()}); for (Scope2* child : scope->children) collectConstraints(result, child); } -std::vector collectConstraints(Scope2* rootScope) +std::vector> collectConstraints(Scope2* rootScope) { - std::vector result; + std::vector> result; collectConstraints(result, rootScope); return result; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index f40cd4b3..41dfd892 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -7,6 +7,7 @@ #include "Luau/Unifier.h" LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); namespace Luau { @@ -58,11 +59,11 @@ ConstraintSolver::ConstraintSolver(TypeArena* arena, Scope2* rootScope) , constraints(collectConstraints(rootScope)) , rootScope(rootScope) { - for (const Constraint* c : constraints) + for (NotNull c : constraints) { - unsolvedConstraints.insert(c); + unsolvedConstraints.push_back(c); - for (const Constraint* dep : c->dependencies) + for (NotNull dep : c->dependencies) { block(dep, c); } @@ -74,8 +75,6 @@ void ConstraintSolver::run() if (done()) return; - bool progress = false; - ToStringOptions opts; if (FFlag::DebugLuauLogSolver) @@ -84,44 +83,80 @@ void ConstraintSolver::run() dump(this, opts); } - do + if (FFlag::DebugLuauLogSolverToJson) { - progress = false; + logger.captureBoundarySnapshot(rootScope, unsolvedConstraints); + } - auto it = begin(unsolvedConstraints); - auto endIt = end(unsolvedConstraints); + auto runSolverPass = [&](bool force) { + bool progress = false; - while (it != endIt) + size_t i = 0; + while (i < unsolvedConstraints.size()) { - if (isBlocked(*it)) + NotNull c = unsolvedConstraints[i]; + if (!force && isBlocked(c)) { - ++it; + ++i; continue; } - std::string saveMe = FFlag::DebugLuauLogSolver ? toString(**it, opts) : std::string{}; + std::string saveMe = FFlag::DebugLuauLogSolver ? toString(*c, opts) : std::string{}; - bool success = tryDispatch(*it); - progress = progress || success; + if (FFlag::DebugLuauLogSolverToJson) + { + logger.prepareStepSnapshot(rootScope, c, unsolvedConstraints); + } + + bool success = tryDispatch(c, force); + + progress |= success; - auto saveIt = it; - ++it; if (success) { - unsolvedConstraints.erase(saveIt); + unsolvedConstraints.erase(unsolvedConstraints.begin() + i); + + if (FFlag::DebugLuauLogSolverToJson) + { + logger.commitPreparedStepSnapshot(); + } + if (FFlag::DebugLuauLogSolver) { + if (force) + printf("Force "); printf("Dispatched\n\t%s\n", saveMe.c_str()); dump(this, opts); } } + else + ++i; + + if (force && success) + return true; } + + return progress; + }; + + bool progress = false; + do + { + progress = runSolverPass(false); + if (!progress) + progress |= runSolverPass(true); } while (progress); if (FFlag::DebugLuauLogSolver) + { dumpBindings(rootScope, opts); + } - LUAU_ASSERT(done()); + if (FFlag::DebugLuauLogSolverToJson) + { + logger.captureBoundarySnapshot(rootScope, unsolvedConstraints); + printf("Logger output:\n%s\n", logger.compileOutput().c_str()); + } } bool ConstraintSolver::done() @@ -129,21 +164,21 @@ bool ConstraintSolver::done() return unsolvedConstraints.empty(); } -bool ConstraintSolver::tryDispatch(const Constraint* constraint) +bool ConstraintSolver::tryDispatch(NotNull constraint, bool force) { - if (isBlocked(constraint)) + if (!force && isBlocked(constraint)) return false; bool success = false; if (auto sc = get(*constraint)) - success = tryDispatch(*sc); + success = tryDispatch(*sc, constraint, force); else if (auto psc = get(*constraint)) - success = tryDispatch(*psc); + success = tryDispatch(*psc, constraint, force); else if (auto gc = get(*constraint)) - success = tryDispatch(*gc); + success = tryDispatch(*gc, constraint, force); else if (auto ic = get(*constraint)) - success = tryDispatch(*ic, constraint); + success = tryDispatch(*ic, constraint, force); else LUAU_ASSERT(0); @@ -155,65 +190,66 @@ bool ConstraintSolver::tryDispatch(const Constraint* constraint) return success; } -bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c) +bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) { - unify(c.subType, c.superType); + if (isBlocked(c.subType)) + return block(c.subType, constraint); + else if (isBlocked(c.superType)) + return block(c.superType, constraint); + + unify(c.subType, c.superType, constraint->location); + unblock(c.subType); unblock(c.superType); return true; } -bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c) +bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force) { - unify(c.subPack, c.superPack); + unify(c.subPack, c.superPack, constraint->location); unblock(c.subPack); unblock(c.superPack); return true; } -bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& constraint) +bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force) { - unify(constraint.subType, constraint.superType); + if (isBlocked(c.sourceType)) + return block(c.sourceType, constraint); - quantify(constraint.superType, constraint.scope); - unblock(constraint.subType); - unblock(constraint.superType); + if (isBlocked(c.generalizedType)) + asMutable(c.generalizedType)->ty.emplace(c.sourceType); + else + unify(c.generalizedType, c.sourceType, constraint->location); + + TypeId generalized = quantify(arena, c.sourceType, c.scope); + *asMutable(c.sourceType) = *generalized; + + unblock(c.generalizedType); + unblock(c.sourceType); return true; } -bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, const Constraint* constraint) +bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force) { - TypeId superType = follow(c.superType); - if (const FunctionTypeVar* ftv = get(superType)) - { - if (!ftv->generalized) - { - block(superType, constraint); - return false; - } - } - else if (get(superType)) - { - block(superType, constraint); - return false; - } - // TODO: Error if it's a primitive or something + if (isBlocked(c.superType)) + return block(c.superType, constraint); Instantiation inst(TxnLog::empty(), arena, TypeLevel{}); std::optional instantiated = inst.substitute(c.superType); LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS - unify(c.subType, *instantiated); + unify(c.subType, *instantiated, constraint->location); unblock(c.subType); return true; } -void ConstraintSolver::block_(BlockedConstraintId target, const Constraint* constraint) +void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) { blocked[target].push_back(constraint); @@ -221,19 +257,21 @@ void ConstraintSolver::block_(BlockedConstraintId target, const Constraint* cons count += 1; } -void ConstraintSolver::block(const Constraint* target, const Constraint* constraint) +void ConstraintSolver::block(NotNull target, NotNull constraint) { block_(target, constraint); } -void ConstraintSolver::block(TypeId target, const Constraint* constraint) +bool ConstraintSolver::block(TypeId target, NotNull constraint) { block_(target, constraint); + return false; } -void ConstraintSolver::block(TypePackId target, const Constraint* constraint) +bool ConstraintSolver::block(TypePackId target, NotNull constraint) { block_(target, constraint); + return false; } void ConstraintSolver::unblock_(BlockedConstraintId progressed) @@ -243,7 +281,7 @@ void ConstraintSolver::unblock_(BlockedConstraintId progressed) return; // unblocked should contain a value always, because of the above check - for (const Constraint* unblockedConstraint : it->second) + for (NotNull unblockedConstraint : it->second) { auto& count = blockedConstraints[unblockedConstraint]; // This assertion being hit indicates that `blocked` and @@ -257,7 +295,7 @@ void ConstraintSolver::unblock_(BlockedConstraintId progressed) blocked.erase(it); } -void ConstraintSolver::unblock(const Constraint* progressed) +void ConstraintSolver::unblock(NotNull progressed) { return unblock_(progressed); } @@ -272,35 +310,33 @@ void ConstraintSolver::unblock(TypePackId progressed) return unblock_(progressed); } -bool ConstraintSolver::isBlocked(const Constraint* constraint) +bool ConstraintSolver::isBlocked(TypeId ty) +{ + return nullptr != get(follow(ty)); +} + +bool ConstraintSolver::isBlocked(NotNull constraint) { auto blockedIt = blockedConstraints.find(constraint); return blockedIt != blockedConstraints.end() && blockedIt->second > 0; } -void ConstraintSolver::reportErrors(const std::vector& errors) -{ - this->errors.insert(end(this->errors), begin(errors), end(errors)); -} - -void ConstraintSolver::unify(TypeId subType, TypeId superType) +void ConstraintSolver::unify(TypeId subType, TypeId superType, Location location) { UnifierSharedState sharedState{&iceReporter}; - Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState}; + Unifier u{arena, Mode::Strict, location, Covariant, sharedState}; u.tryUnify(subType, superType); u.log.commit(); - reportErrors(u.errors); } -void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack) +void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, Location location) { UnifierSharedState sharedState{&iceReporter}; - Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState}; + Unifier u{arena, Mode::Strict, location, Covariant, sharedState}; u.tryUnify(subPack, superPack); u.log.commit(); - reportErrors(u.errors); } } // namespace Luau diff --git a/Analysis/src/ConstraintSolverLogger.cpp b/Analysis/src/ConstraintSolverLogger.cpp new file mode 100644 index 00000000..2f93c280 --- /dev/null +++ b/Analysis/src/ConstraintSolverLogger.cpp @@ -0,0 +1,139 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/ConstraintSolverLogger.h" + +namespace Luau +{ + +static std::string dumpScopeAndChildren(const Scope2* scope, ToStringOptions& opts) +{ + std::string output = "{\"bindings\":{"; + + bool comma = false; + for (const auto& [name, type] : scope->bindings) + { + if (comma) + output += ","; + + output += "\""; + output += name.c_str(); + output += "\": \""; + + ToStringResult result = toStringDetailed(type, opts); + opts.nameMap = std::move(result.nameMap); + output += result.name; + output += "\""; + + comma = true; + } + + output += "},\"children\":["; + comma = false; + + for (const Scope2* child : scope->children) + { + if (comma) + output += ","; + + output += dumpScopeAndChildren(child, opts); + comma = true; + } + + output += "]}"; + return output; +} + +static std::string dumpConstraintsToDot(std::vector>& constraints, ToStringOptions& opts) +{ + std::string result = "digraph Constraints {\\n"; + + std::unordered_set> contained; + for (NotNull c : constraints) + { + contained.insert(c); + } + + for (NotNull c : constraints) + { + std::string id = std::to_string(reinterpret_cast(c.get())); + result += id; + result += " [label=\\\""; + result += toString(*c, opts).c_str(); + result += "\\\"];\\n"; + + for (NotNull dep : c->dependencies) + { + if (contained.count(dep) == 0) + continue; + + result += std::to_string(reinterpret_cast(dep.get())); + result += " -> "; + result += id; + result += ";\\n"; + } + } + + result += "}"; + + return result; +} + +std::string ConstraintSolverLogger::compileOutput() +{ + std::string output = "["; + bool comma = false; + + for (const std::string& snapshot : snapshots) + { + if (comma) + output += ","; + output += snapshot; + + comma = true; + } + + output += "]"; + return output; +} + +void ConstraintSolverLogger::captureBoundarySnapshot(const Scope2* rootScope, std::vector>& unsolvedConstraints) +{ + std::string snapshot = "{\"type\":\"boundary\",\"rootScope\":"; + + snapshot += dumpScopeAndChildren(rootScope, opts); + snapshot += ",\"constraintGraph\":\""; + snapshot += dumpConstraintsToDot(unsolvedConstraints, opts); + snapshot += "\"}"; + + snapshots.push_back(std::move(snapshot)); +} + +void ConstraintSolverLogger::prepareStepSnapshot( + const Scope2* rootScope, NotNull current, std::vector>& unsolvedConstraints) +{ + // LUAU_ASSERT(!preparedSnapshot); + + std::string snapshot = "{\"type\":\"step\",\"rootScope\":"; + + snapshot += dumpScopeAndChildren(rootScope, opts); + snapshot += ",\"constraintGraph\":\""; + snapshot += dumpConstraintsToDot(unsolvedConstraints, opts); + snapshot += "\",\"currentId\":\""; + snapshot += std::to_string(reinterpret_cast(current.get())); + snapshot += "\",\"current\":\""; + snapshot += toString(*current, opts); + snapshot += "\"}"; + + preparedSnapshot = std::move(snapshot); +} + +void ConstraintSolverLogger::commitPreparedStepSnapshot() +{ + if (preparedSnapshot) + { + snapshots.push_back(std::move(*preparedSnapshot)); + preparedSnapshot = std::nullopt; + } +} + +} // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 741a35cf..9e025062 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1,16 +1,17 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Frontend.h" -#include "Luau/Common.h" #include "Luau/Clone.h" +#include "Luau/Common.h" #include "Luau/Config.h" -#include "Luau/FileResolver.h" #include "Luau/ConstraintGraphBuilder.h" #include "Luau/ConstraintSolver.h" +#include "Luau/FileResolver.h" #include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/StringUtils.h" #include "Luau/TimeTrace.h" +#include "Luau/TypeChecker2.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" @@ -216,7 +217,7 @@ ErrorVec accumulateErrors( continue; const SourceNode& sourceNode = it->second; - queue.insert(queue.end(), sourceNode.requires.begin(), sourceNode.requires.end()); + queue.insert(queue.end(), sourceNode.requireSet.begin(), sourceNode.requireSet.end()); // FIXME: If a module has a syntax error, we won't be able to re-report it here. // The solution is probably to move errors from Module to SourceNode @@ -586,7 +587,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec path.push_back(top); // push children - for (const ModuleName& dep : top->requires) + for (const ModuleName& dep : top->requireSet) { auto it = sourceNodes.find(dep); if (it != sourceNodes.end()) @@ -738,7 +739,7 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked std::unordered_map> reverseDeps; for (const auto& module : sourceNodes) { - for (const auto& dep : module.second.requires) + for (const auto& dep : module.second.requireSet) reverseDeps[dep].push_back(module.first); } @@ -797,9 +798,14 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const Sco cs.run(); result->scope2s = std::move(cgb.scopes); + result->astTypes = std::move(cgb.astTypes); + result->astTypePacks = std::move(cgb.astTypePacks); + result->astOriginalCallTypes = std::move(cgb.astOriginalCallTypes); result->clonePublicInterface(iceHandler); + Luau::check(sourceModule, result.get()); + return result; } @@ -841,8 +847,8 @@ std::pair Frontend::getSourceNode(CheckResult& check SourceModule result = parse(name, source->source, opts); result.type = source->type; - RequireTraceResult& requireTrace = requires[name]; - requireTrace = traceRequires(fileResolver, result.root, name); + RequireTraceResult& require = requireTrace[name]; + require = traceRequires(fileResolver, result.root, name); SourceNode& sourceNode = sourceNodes[name]; SourceModule& sourceModule = sourceModules[name]; @@ -851,7 +857,7 @@ std::pair Frontend::getSourceNode(CheckResult& check sourceModule.environmentName = environmentName; sourceNode.name = name; - sourceNode.requires.clear(); + sourceNode.requireSet.clear(); sourceNode.requireLocations.clear(); sourceNode.dirtySourceModule = false; @@ -861,10 +867,10 @@ std::pair Frontend::getSourceNode(CheckResult& check sourceNode.dirtyModuleForAutocomplete = true; } - for (const auto& [moduleName, location] : requireTrace.requires) - sourceNode.requires.insert(moduleName); + for (const auto& [moduleName, location] : require.requireList) + sourceNode.requireSet.insert(moduleName); - sourceNode.requireLocations = requireTrace.requires; + sourceNode.requireLocations = require.requireList; return {&sourceNode, &sourceModule}; } @@ -925,8 +931,8 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const std::optional FrontendModuleResolver::resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) { // FIXME I think this can be pushed into the FileResolver. - auto it = frontend->requires.find(currentModuleName); - if (it == frontend->requires.end()) + auto it = frontend->requireTrace.find(currentModuleName); + if (it == frontend->requireTrace.end()) { // CLI-43699 // If we can't find the current module name, that's because we bypassed the frontend's initializer @@ -1025,7 +1031,7 @@ void Frontend::clear() sourceModules.clear(); moduleResolver.modules.clear(); moduleResolverForAutocomplete.modules.clear(); - requires.clear(); + requireTrace.clear(); } } // namespace Luau diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index f145a511..77c62422 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -40,7 +40,7 @@ TypeId Instantiation::clean(TypeId ty) const FunctionTypeVar* ftv = log->getMutable(ty); LUAU_ASSERT(ftv); - FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; + FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; clone.magicFunction = ftv->magicFunction; clone.tags = ftv->tags; clone.argNames = ftv->argNames; diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 200b7d1b..50868e56 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -2282,7 +2282,7 @@ private: size_t getReturnCount(TypeId ty) { if (auto ftv = get(ty)) - return size(ftv->retType); + return size(ftv->retTypes); if (auto itv = get(ty)) { @@ -2291,7 +2291,7 @@ private: for (TypeId part : itv->parts) if (auto ftv = get(follow(part))) - result = std::max(result, size(ftv->retType)); + result = std::max(result, size(ftv->retTypes)); return result; } diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 11403be5..d36665e2 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -17,6 +17,7 @@ LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineEqFix, false); LUAU_FASTFLAGVARIABLE(LuauReplaceReplacer, false); +LUAU_FASTFLAG(LuauQuantifyConstrained) namespace Luau { @@ -273,6 +274,18 @@ bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice) return ok; } +bool isSubtype(TypePackId subPack, TypePackId superPack, InternalErrorReporter& ice) +{ + UnifierSharedState sharedState{&ice}; + TypeArena arena; + Unifier u{&arena, Mode::Strict, Location{}, Covariant, sharedState}; + u.anyIsTop = true; + + u.tryUnify(subPack, superPack); + const bool ok = u.errors.empty() && u.log.empty(); + return ok; +} + template static bool areNormal_(const T& t, const std::unordered_set& seen, InternalErrorReporter& ice) { @@ -390,6 +403,7 @@ struct Normalize final : TypeVarVisitor bool visit(TypeId ty, const ConstrainedTypeVar& ctvRef) override { CHECK_ITERATION_LIMIT(false); + LUAU_ASSERT(!ty->normal); ConstrainedTypeVar* ctv = const_cast(&ctvRef); @@ -401,14 +415,21 @@ struct Normalize final : TypeVarVisitor std::vector newParts = normalizeUnion(parts); - const bool normal = areNormal(newParts, seen, ice); - - if (newParts.size() == 1) - *asMutable(ty) = BoundTypeVar{newParts[0]}; + if (FFlag::LuauQuantifyConstrained) + { + ctv->parts = std::move(newParts); + } else - *asMutable(ty) = UnionTypeVar{std::move(newParts)}; + { + const bool normal = areNormal(newParts, seen, ice); - asMutable(ty)->normal = normal; + if (newParts.size() == 1) + *asMutable(ty) = BoundTypeVar{newParts[0]}; + else + *asMutable(ty) = UnionTypeVar{std::move(newParts)}; + + asMutable(ty)->normal = normal; + } return false; } @@ -421,9 +442,9 @@ struct Normalize final : TypeVarVisitor return false; traverse(ftv.argTypes); - traverse(ftv.retType); + traverse(ftv.retTypes); - asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retType, seen, ice); + asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retTypes, seen, ice); return false; } @@ -465,7 +486,14 @@ struct Normalize final : TypeVarVisitor checkNormal(ttv.indexer->indexResultType); } - asMutable(ty)->normal = normal; + // An unsealed table can never be normal, ditto for free tables iff the type it is bound to is also not normal. + if (FFlag::LuauQuantifyConstrained) + { + if (ttv.state == TableState::Generic || ttv.state == TableState::Sealed || (ttv.state == TableState::Free && follow(ty)->normal)) + asMutable(ty)->normal = normal; + } + else + asMutable(ty)->normal = normal; return false; } diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 21775373..2004d153 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -2,15 +2,32 @@ #include "Luau/Quantify.h" +#include "Luau/ConstraintGraphBuilder.h" // TODO for Scope2; move to separate header +#include "Luau/TxnLog.h" +#include "Luau/Substitution.h" #include "Luau/VisitTypeVar.h" #include "Luau/ConstraintGraphBuilder.h" // TODO for Scope2; move to separate header LUAU_FASTFLAG(LuauAlwaysQuantify); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAGVARIABLE(LuauQuantifyConstrained, false) namespace Luau { +/// @return true if outer encloses inner +static bool subsumes(Scope2* outer, Scope2* inner) +{ + while (inner) + { + if (inner == outer) + return true; + inner = inner->parent; + } + + return false; +} + struct Quantifier final : TypeVarOnceVisitor { TypeLevel level; @@ -62,6 +79,34 @@ struct Quantifier final : TypeVarOnceVisitor return false; } + bool visit(TypeId ty, const ConstrainedTypeVar&) override + { + if (FFlag::LuauQuantifyConstrained) + { + ConstrainedTypeVar* ctv = getMutable(ty); + + seenMutableType = true; + + if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ctv->scope) : !level.subsumes(ctv->level)) + return false; + + std::vector opts = std::move(ctv->parts); + + // We might transmute, so it's not safe to rely on the builtin traversal logic + for (TypeId opt : opts) + traverse(opt); + + if (opts.size() == 1) + *asMutable(ty) = BoundTypeVar{opts[0]}; + else + *asMutable(ty) = UnionTypeVar{std::move(opts)}; + + return false; + } + else + return true; + } + bool visit(TypeId ty, const TableTypeVar&) override { LUAU_ASSERT(getMutable(ty)); @@ -73,8 +118,12 @@ struct Quantifier final : TypeVarOnceVisitor if (ttv.state == TableState::Free) seenMutableType = true; - if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) - return false; + if (!FFlag::LuauQuantifyConstrained) + { + if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) + return false; + } + if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ttv.scope) : !level.subsumes(ttv.level)) { if (ttv.state == TableState::Unsealed) @@ -156,4 +205,104 @@ void quantify(TypeId ty, Scope2* scope) ftv->generalized = true; } +struct PureQuantifier : Substitution +{ + Scope2* scope; + std::vector insertedGenerics; + std::vector insertedGenericPacks; + + PureQuantifier(const TxnLog* log, TypeArena* arena, Scope2* scope) + : Substitution(log, arena) + , scope(scope) + { + } + + bool isDirty(TypeId ty) override + { + LUAU_ASSERT(ty == follow(ty)); + + if (auto ftv = get(ty)) + { + return subsumes(scope, ftv->scope); + } + else if (auto ttv = get(ty)) + { + return ttv->state == TableState::Free && subsumes(scope, ttv->scope); + } + + return false; + } + + bool isDirty(TypePackId tp) override + { + if (auto ftp = get(tp)) + { + return subsumes(scope, ftp->scope); + } + + return false; + } + + TypeId clean(TypeId ty) override + { + if (auto ftv = get(ty)) + { + TypeId result = arena->addType(GenericTypeVar{}); + insertedGenerics.push_back(result); + return result; + } + else if (auto ttv = get(ty)) + { + TypeId result = arena->addType(TableTypeVar{}); + TableTypeVar* resultTable = getMutable(result); + LUAU_ASSERT(resultTable); + + *resultTable = *ttv; + resultTable->scope = nullptr; + resultTable->state = TableState::Generic; + + return result; + } + + return ty; + } + + TypePackId clean(TypePackId tp) override + { + if (auto ftp = get(tp)) + { + TypePackId result = arena->addTypePack(TypePackVar{GenericTypePack{}}); + insertedGenericPacks.push_back(result); + return result; + } + + return tp; + } + + bool ignoreChildren(TypeId ty) override + { + return ty->persistent; + } + bool ignoreChildren(TypePackId ty) override + { + return ty->persistent; + } +}; + +TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope) +{ + PureQuantifier quantifier{TxnLog::empty(), arena, scope}; + std::optional result = quantifier.substitute(ty); + LUAU_ASSERT(result); + + FunctionTypeVar* ftv = getMutable(*result); + LUAU_ASSERT(ftv); + ftv->generics.insert(ftv->generics.end(), quantifier.insertedGenerics.begin(), quantifier.insertedGenerics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), quantifier.insertedGenericPacks.begin(), quantifier.insertedGenericPacks.end()); + + // TODO: Set hasNoGenerics. + + return *result; +} + } // namespace Luau diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index 8ed245fb..c036a7a5 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -28,7 +28,7 @@ struct RequireTracer : AstVisitor AstExprGlobal* global = expr->func->as(); if (global && global->name == "require" && expr->args.size >= 1) - requires.push_back(expr); + requireCalls.push_back(expr); return true; } @@ -84,9 +84,9 @@ struct RequireTracer : AstVisitor ModuleInfo moduleContext{currentModuleName}; // seed worklist with require arguments - work.reserve(requires.size()); + work.reserve(requireCalls.size()); - for (AstExprCall* require : requires) + for (AstExprCall* require : requireCalls) work.push_back(require->args.data[0]); // push all dependent expressions to the work stack; note that the vector is modified during traversal @@ -125,15 +125,15 @@ struct RequireTracer : AstVisitor } // resolve all requires according to their argument - result.requires.reserve(requires.size()); + result.requireList.reserve(requireCalls.size()); - for (AstExprCall* require : requires) + for (AstExprCall* require : requireCalls) { AstExpr* arg = require->args.data[0]; if (const ModuleInfo* info = result.exprs.find(arg)) { - result.requires.push_back({info->name, require->location}); + result.requireList.push_back({info->name, require->location}); ModuleInfo infoCopy = *info; // copy *info out since next line invalidates info! result.exprs[require] = std::move(infoCopy); @@ -151,7 +151,7 @@ struct RequireTracer : AstVisitor DenseHashMap locals; std::vector work; - std::vector requires; + std::vector requireCalls; }; RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName) diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 5a22deeb..9c4ce829 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -27,7 +27,7 @@ void Tarjan::visitChildren(TypeId ty, int index) if (const FunctionTypeVar* ftv = get(ty)) { visitChild(ftv->argTypes); - visitChild(ftv->retType); + visitChild(ftv->retTypes); } else if (const TableTypeVar* ttv = get(ty)) { @@ -442,7 +442,7 @@ void Substitution::replaceChildren(TypeId ty) if (FunctionTypeVar* ftv = getMutable(ty)) { ftv->argTypes = replace(ftv->argTypes); - ftv->retType = replace(ftv->retType); + ftv->retTypes = replace(ftv->retTypes); } else if (TableTypeVar* ttv = getMutable(ty)) { diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 9b396c80..6b677bb8 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -154,7 +154,7 @@ void StateDot::visitChildren(TypeId ty, int index) finishNode(); visitChild(ftv->argTypes, index, "arg"); - visitChild(ftv->retType, index, "ret"); + visitChild(ftv->retTypes, index, "ret"); } else if (const TableTypeVar* ttv = get(ty)) { diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 04d15cf7..81dc0467 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -226,6 +226,11 @@ struct StringifierState result.name += s; } + void emit(int i) + { + emit(std::to_string(i).c_str()); + } + void indent() { indentation += 4; @@ -394,6 +399,13 @@ struct TypeVarStringifier state.emit("]]"); } + void operator()(TypeId, const BlockedTypeVar& btv) + { + state.emit("*blocked-"); + state.emit(btv.index); + state.emit("*"); + } + void operator()(TypeId, const PrimitiveTypeVar& ptv) { switch (ptv.type) @@ -480,8 +492,8 @@ struct TypeVarStringifier if (FFlag::LuauLowerBoundsCalculation) { - auto retBegin = begin(ftv.retType); - auto retEnd = end(ftv.retType); + auto retBegin = begin(ftv.retTypes); + auto retEnd = end(ftv.retTypes); if (retBegin != retEnd) { ++retBegin; @@ -491,7 +503,7 @@ struct TypeVarStringifier } else { - if (auto retPack = get(follow(ftv.retType))) + if (auto retPack = get(follow(ftv.retTypes))) { if (retPack->head.size() == 1 && !retPack->tail) plural = false; @@ -501,7 +513,7 @@ struct TypeVarStringifier if (plural) state.emit("("); - stringify(ftv.retType); + stringify(ftv.retTypes); if (plural) state.emit(")"); @@ -1303,14 +1315,14 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp state.emit("): "); - size_t retSize = size(ftv.retType); - bool hasTail = !finite(ftv.retType); - bool wrap = get(follow(ftv.retType)) && (hasTail ? retSize != 0 : retSize != 1); + size_t retSize = size(ftv.retTypes); + bool hasTail = !finite(ftv.retTypes); + bool wrap = get(follow(ftv.retTypes)) && (hasTail ? retSize != 0 : retSize != 1); if (wrap) state.emit("("); - tvs.stringify(ftv.retType); + tvs.stringify(ftv.retTypes); if (wrap) state.emit(")"); @@ -1385,9 +1397,9 @@ std::string toString(const Constraint& c, ToStringOptions& opts) } else if (const GeneralizationConstraint* gc = Luau::get_if(&c.c)) { - ToStringResult subStr = toStringDetailed(gc->subType, opts); + ToStringResult subStr = toStringDetailed(gc->generalizedType, opts); opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(gc->superType, opts); + ToStringResult superStr = toStringDetailed(gc->sourceType, opts); opts.nameMap = std::move(superStr.nameMap); return subStr.name + " ~ gen " + superStr.name; } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 0f4534b7..6cca7127 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -94,6 +94,11 @@ public: } } + AstType* operator()(const BlockedTypeVar& btv) + { + return allocator->alloc(Location(), std::nullopt, AstName("*blocked*")); + } + AstType* operator()(const ConstrainedTypeVar& ctv) { AstArray types; @@ -271,7 +276,7 @@ public: } AstArray returnTypes; - const auto& [retVector, retTail] = flatten(ftv.retType); + const auto& [retVector, retTail] = flatten(ftv.retTypes); returnTypes.size = retVector.size(); returnTypes.data = static_cast(allocator->allocate(sizeof(AstType*) * returnTypes.size)); for (size_t i = 0; i < returnTypes.size; ++i) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp new file mode 100644 index 00000000..7f5ba683 --- /dev/null +++ b/Analysis/src/TypeChecker2.cpp @@ -0,0 +1,160 @@ + +#include "Luau/TypeChecker2.h" + +#include + +#include "Luau/Ast.h" +#include "Luau/AstQuery.h" +#include "Luau/Clone.h" +#include "Luau/Normalize.h" + +namespace Luau +{ + +struct TypeChecker2 : public AstVisitor +{ + const SourceModule* sourceModule; + Module* module; + InternalErrorReporter ice; // FIXME accept a pointer from Frontend + + TypeChecker2(const SourceModule* sourceModule, Module* module) + : sourceModule(sourceModule) + , module(module) + { + } + + using AstVisitor::visit; + + TypePackId lookupPack(AstExpr* expr) + { + TypePackId* tp = module->astTypePacks.find(expr); + LUAU_ASSERT(tp); + return follow(*tp); + } + + TypeId lookupType(AstExpr* expr) + { + TypeId* ty = module->astTypes.find(expr); + LUAU_ASSERT(ty); + return follow(*ty); + } + + bool visit(AstStatAssign* assign) override + { + size_t count = std::min(assign->vars.size, assign->values.size); + + for (size_t i = 0; i < count; ++i) + { + AstExpr* lhs = assign->vars.data[i]; + TypeId* lhsType = module->astTypes.find(lhs); + LUAU_ASSERT(lhsType); + + AstExpr* rhs = assign->values.data[i]; + TypeId* rhsType = module->astTypes.find(rhs); + LUAU_ASSERT(rhsType); + + if (!isSubtype(*rhsType, *lhsType, ice)) + { + reportError(TypeMismatch{*lhsType, *rhsType}, rhs->location); + } + } + + return true; + } + + bool visit(AstExprCall* call) override + { + TypePackId expectedRetType = lookupPack(call); + TypeId functionType = lookupType(call->func); + + TypeArena arena; + TypePack args; + for (const auto& arg : call->args) + { + TypeId argTy = module->astTypes[arg]; + LUAU_ASSERT(argTy); + args.head.push_back(argTy); + } + + TypePackId argsTp = arena.addTypePack(args); + FunctionTypeVar ftv{argsTp, expectedRetType}; + TypeId expectedType = arena.addType(ftv); + if (!isSubtype(expectedType, functionType, ice)) + { + unfreeze(module->interfaceTypes); + CloneState cloneState; + expectedType = clone(expectedType, module->interfaceTypes, cloneState); + freeze(module->interfaceTypes); + reportError(TypeMismatch{expectedType, functionType}, call->location); + } + + return true; + } + + bool visit(AstExprIndexName* indexName) override + { + TypeId leftType = lookupType(indexName->expr); + TypeId resultType = lookupType(indexName); + + // leftType must have a property called indexName->index + + if (auto ttv = get(leftType)) + { + auto it = ttv->props.find(indexName->index.value); + if (it == ttv->props.end()) + { + reportError(UnknownProperty{leftType, indexName->index.value}, indexName->location); + } + else if (!isSubtype(resultType, it->second.type, ice)) + { + reportError(TypeMismatch{resultType, it->second.type}, indexName->location); + } + } + else + { + reportError(UnknownProperty{leftType, indexName->index.value}, indexName->location); + } + + return true; + } + + bool visit(AstExprConstantNumber* number) override + { + TypeId actualType = lookupType(number); + TypeId numberType = getSingletonTypes().numberType; + + if (!isSubtype(actualType, numberType, ice)) + { + reportError(TypeMismatch{actualType, numberType}, number->location); + } + + return true; + } + + bool visit(AstExprConstantString* string) override + { + TypeId actualType = lookupType(string); + TypeId stringType = getSingletonTypes().stringType; + + if (!isSubtype(actualType, stringType, ice)) + { + reportError(TypeMismatch{actualType, stringType}, string->location); + } + + return true; + } + + void reportError(TypeErrorData&& data, const Location& location) + { + module->errors.emplace_back(location, sourceModule->name, std::move(data)); + } +}; + +void check(const SourceModule& sourceModule, Module* module) +{ + TypeChecker2 typeChecker{&sourceModule, module}; + + sourceModule.root->visit(&typeChecker); +} + +} // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 447cd029..fd1b3b85 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -18,6 +18,7 @@ #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" #include #include @@ -30,7 +31,6 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) -LUAU_FASTFLAGVARIABLE(LuauExpectedPropTypeFromIndexer, false) LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix2, false) @@ -42,9 +42,9 @@ LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); LUAU_FASTFLAGVARIABLE(LuauApplyTypeFunctionFix, false); -LUAU_FASTFLAGVARIABLE(LuauSuccessTypingForEqualityOperations, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false) +LUAU_FASTFLAG(LuauQuantifyConstrained) LUAU_FASTFLAGVARIABLE(LuauFalsyPredicateReturnsNilInstead, false) LUAU_FASTFLAGVARIABLE(LuauNonCopyableTypeVarFields, false) @@ -260,7 +260,6 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , booleanType(getSingletonTypes().booleanType) , threadType(getSingletonTypes().threadType) , anyType(getSingletonTypes().anyType) - , optionalNumberType(getSingletonTypes().optionalNumberType) , anyTypePack(getSingletonTypes().anyTypePack) , duplicateTypeAliases{{false, {}}} { @@ -679,7 +678,7 @@ static std::optional tryGetTypeGuardPredicate(const AstExprBinary& ex void TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) { - ExprResult result = checkExpr(scope, *statement.condition); + WithPredicate result = checkExpr(scope, *statement.condition); ScopePtr ifScope = childScope(scope, statement.thenbody->location); resolve(result.predicates, ifScope, true); @@ -712,7 +711,7 @@ ErrorVec TypeChecker::canUnify(TypePackId subTy, TypePackId superTy, const Locat void TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) { - ExprResult result = checkExpr(scope, *statement.condition); + WithPredicate result = checkExpr(scope, *statement.condition); ScopePtr whileScope = childScope(scope, statement.body->location); resolve(result.predicates, whileScope, true); @@ -728,16 +727,64 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) checkExpr(repScope, *statement.condition); } -void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, const Location& location) +void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const Location& location) { Unifier state = mkUnifier(location); - state.unifyLowerBound(subTy, superTy); + state.unifyLowerBound(subTy, superTy, demotedLevel); state.log.commit(); reportErrors(state.errors); } +struct Demoter : Substitution +{ + Demoter(TypeArena* arena) + : Substitution(TxnLog::empty(), arena) + { + } + + bool isDirty(TypeId ty) override + { + return get(ty); + } + + bool isDirty(TypePackId tp) override + { + return get(tp); + } + + TypeId clean(TypeId ty) override + { + auto ftv = get(ty); + LUAU_ASSERT(ftv); + return addType(FreeTypeVar{demotedLevel(ftv->level)}); + } + + TypePackId clean(TypePackId tp) override + { + auto ftp = get(tp); + LUAU_ASSERT(ftp); + return addTypePack(TypePackVar{FreeTypePack{demotedLevel(ftp->level)}}); + } + + TypeLevel demotedLevel(TypeLevel level) + { + return TypeLevel{level.level + 5000, level.subLevel}; + } + + void demote(std::vector>& expectedTypes) + { + if (!FFlag::LuauQuantifyConstrained) + return; + for (std::optional& ty : expectedTypes) + { + if (ty) + ty = substitute(*ty); + } + } +}; + void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) { std::vector> expectedTypes; @@ -760,11 +807,14 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) } } + Demoter demoter{¤tModule->internalTypes}; + demoter.demote(expectedTypes); + TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type; if (FFlag::LuauReturnTypeInferenceInNonstrict ? FFlag::LuauLowerBoundsCalculation : useConstrainedIntersections()) { - unifyLowerBound(retPack, scope->returnType, return_.location); + unifyLowerBound(retPack, scope->returnType, demoter.demotedLevel(scope->level), return_.location); return; } @@ -1230,7 +1280,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) unify(retPack, varPack, forin.location); } else - unify(iterFunc->retType, varPack, forin.location); + unify(iterFunc->retTypes, varPack, forin.location); check(loopScope, *forin.body); } @@ -1611,7 +1661,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo currentModule->getModuleScope()->bindings[global.name] = Binding{fnType, global.location}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType, bool forceSingleton) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType, bool forceSingleton) { RecursionCounter _rc(&checkRecursionCount); if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) @@ -1620,7 +1670,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& return {errorRecoveryType(scope)}; } - ExprResult result; + WithPredicate result; if (auto a = expr.as()) result = checkExpr(scope, *a->expr, expectedType); @@ -1682,7 +1732,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& return result; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLocal& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLocal& expr) { std::optional lvalue = tryGetLValue(expr); LUAU_ASSERT(lvalue); // Guaranteed to not be nullopt - AstExprLocal is an LValue. @@ -1696,7 +1746,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLo return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) { std::optional lvalue = tryGetLValue(expr); LUAU_ASSERT(lvalue); // Guaranteed to not be nullopt - AstExprGlobal is an LValue. @@ -1708,7 +1758,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGl return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) { TypePackId varargPack = checkExprPack(scope, expr).type; @@ -1738,9 +1788,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa ice("Unknown TypePack type in checkExpr(AstExprVarargs)!"); } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCall& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCall& expr) { - ExprResult result = checkExprPack(scope, expr); + WithPredicate result = checkExprPack(scope, expr); TypePackId retPack = follow(result.type); if (auto pack = get(retPack)) @@ -1770,7 +1820,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa ice("Unknown TypePack type!", expr.location); } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexName& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexName& expr) { Name name = expr.index.value; @@ -2031,7 +2081,7 @@ TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) return ty; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr) { TypeId ty = checkLValue(scope, expr); @@ -2042,7 +2092,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn return {ty}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) { auto [funTy, funScope] = checkFunctionSignature(scope, 0, expr, std::nullopt, expectedType); @@ -2108,8 +2158,7 @@ TypeId TypeChecker::checkExprTable( if (errors.empty()) exprType = expectedProp.type; } - else if (expectedTable->indexer && (FFlag::LuauExpectedPropTypeFromIndexer ? maybeString(expectedTable->indexer->indexType) - : isString(expectedTable->indexer->indexType))) + else if (expectedTable->indexer && maybeString(expectedTable->indexer->indexType)) { ErrorVec errors = tryUnify(exprType, expectedTable->indexer->indexResultType, k->location); if (errors.empty()) @@ -2147,7 +2196,7 @@ TypeId TypeChecker::checkExprTable( return addType(table); } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) { RecursionCounter _rc(&checkRecursionCount); if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) @@ -2201,7 +2250,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa { if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) expectedResultType = prop->second.type; - else if (FFlag::LuauExpectedPropTypeFromIndexer && expectedIndexType && maybeString(*expectedIndexType)) + else if (expectedIndexType && maybeString(*expectedIndexType)) expectedResultType = expectedIndexResultType; } else if (expectedUnion) @@ -2236,9 +2285,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa return {checkExprTable(scope, expr, fieldTypes, expectedType)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUnary& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUnary& expr) { - ExprResult result = checkExpr(scope, *expr.expr); + WithPredicate result = checkExpr(scope, *expr.expr); TypeId operandType = follow(result.type); switch (expr.op) @@ -2466,62 +2515,50 @@ TypeId TypeChecker::checkRelationalOperation( std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType)); std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType)); - if (FFlag::LuauSuccessTypingForEqualityOperations) + if (leftMetatable != rightMetatable) { - if (leftMetatable != rightMetatable) + bool matches = false; + if (isEquality) { - bool matches = false; - if (isEquality) + if (const UnionTypeVar* utv = get(leftType); utv && rightMetatable) { - if (const UnionTypeVar* utv = get(leftType); utv && rightMetatable) + for (TypeId leftOption : utv) { - for (TypeId leftOption : utv) + if (getMetatable(follow(leftOption)) == rightMetatable) { - if (getMetatable(follow(leftOption)) == rightMetatable) + matches = true; + break; + } + } + } + + if (!matches) + { + if (const UnionTypeVar* utv = get(rhsType); utv && leftMetatable) + { + for (TypeId rightOption : utv) + { + if (getMetatable(follow(rightOption)) == leftMetatable) { matches = true; break; } } } - - if (!matches) - { - if (const UnionTypeVar* utv = get(rhsType); utv && leftMetatable) - { - for (TypeId rightOption : utv) - { - if (getMetatable(follow(rightOption)) == leftMetatable) - { - matches = true; - break; - } - } - } - } - } - - - if (!matches) - { - reportError( - expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); - return errorRecoveryType(booleanType); } } - } - else - { - if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) + + + if (!matches) { reportError( expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); return errorRecoveryType(booleanType); } } + if (leftMetatable) { std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location); @@ -2532,7 +2569,7 @@ TypeId TypeChecker::checkRelationalOperation( if (isEquality) { Unifier state = mkUnifier(expr.location); - state.tryUnify(addTypePack({booleanType}), ftv->retType); + state.tryUnify(addTypePack({booleanType}), ftv->retTypes); if (!state.errors.empty()) { @@ -2721,7 +2758,7 @@ TypeId TypeChecker::checkBinaryOperation( } } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBinary& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBinary& expr) { if (expr.op == AstExprBinary::And) { @@ -2752,8 +2789,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; - ExprResult lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); - ExprResult rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); + WithPredicate lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); + WithPredicate rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); PredicateVec predicates; @@ -2770,18 +2807,18 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi } else { - ExprResult lhs = checkExpr(scope, *expr.left); - ExprResult rhs = checkExpr(scope, *expr.right); + WithPredicate lhs = checkExpr(scope, *expr.left); + WithPredicate rhs = checkExpr(scope, *expr.right); // Intentionally discarding predicates with other operators. return {checkBinaryOperation(scope, expr, lhs.type, rhs.type, lhs.predicates)}; } } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr) { TypeId annotationType = resolveType(scope, *expr.annotation); - ExprResult result = checkExpr(scope, *expr.expr, annotationType); + WithPredicate result = checkExpr(scope, *expr.expr, annotationType); // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. if (canUnify(annotationType, result.type, expr.location).empty()) @@ -2794,7 +2831,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy return {errorRecoveryType(annotationType), std::move(result.predicates)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprError& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprError& expr) { const size_t oldSize = currentModule->errors.size(); @@ -2808,17 +2845,17 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType) { - ExprResult result = checkExpr(scope, *expr.condition); + WithPredicate result = checkExpr(scope, *expr.condition); ScopePtr trueScope = childScope(scope, expr.trueExpr->location); resolve(result.predicates, trueScope, true); - ExprResult trueType = checkExpr(trueScope, *expr.trueExpr, expectedType); + WithPredicate trueType = checkExpr(trueScope, *expr.trueExpr, expectedType); ScopePtr falseScope = childScope(scope, expr.falseExpr->location); resolve(result.predicates, falseScope, false); - ExprResult falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); + WithPredicate falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); if (falseType.type == trueType.type) return {trueType.type}; @@ -3170,7 +3207,7 @@ std::pair TypeChecker::checkFunctionSignature( retPack = anyTypePack; else if (expectedFunctionType) { - auto [head, tail] = flatten(expectedFunctionType->retType); + auto [head, tail] = flatten(expectedFunctionType->retTypes); // Do not infer 'nil' as function return type if (!tail && head.size() == 1 && isNil(head[0])) @@ -3354,7 +3391,7 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE if (useConstrainedIntersections()) { - TypePackId retPack = follow(funTy->retType); + TypePackId retPack = follow(funTy->retTypes); // It is possible for a function to have no annotation and no return statement, and yet still have an ascribed return type // if it is expected to conform to some other interface. (eg the function may be a lambda passed as a callback) if (!hasReturn(function.body) && !function.returnAnnotation.has_value() && get(retPack)) @@ -3367,20 +3404,20 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE else { // We explicitly don't follow here to check if we have a 'true' free type instead of bound one - if (get_if(&funTy->retType->ty)) - *asMutable(funTy->retType) = TypePack{{}, std::nullopt}; + if (get_if(&funTy->retTypes->ty)) + *asMutable(funTy->retTypes) = TypePack{{}, std::nullopt}; } bool reachesImplicitReturn = getFallthrough(function.body) != nullptr; - if (reachesImplicitReturn && !allowsNoReturnValues(follow(funTy->retType))) + if (reachesImplicitReturn && !allowsNoReturnValues(follow(funTy->retTypes))) { // If we're in nonstrict mode we want to only report this missing return // statement if there are type annotations on the function. In strict mode // we report it regardless. if (!isNonstrictMode() || function.returnAnnotation) { - reportError(getEndLocation(function), FunctionExitsWithoutReturning{funTy->retType}); + reportError(getEndLocation(function), FunctionExitsWithoutReturning{funTy->retTypes}); } } } @@ -3388,7 +3425,7 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE ice("Checking non functional type"); } -ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const AstExpr& expr) +WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, const AstExpr& expr) { if (auto a = expr.as()) return checkExprPack(scope, *a); @@ -3654,7 +3691,7 @@ void TypeChecker::checkArgumentList( } } -ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const AstExprCall& expr) +WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, const AstExprCall& expr) { // evaluate type of function // decompose an intersection into its component overloads @@ -3722,7 +3759,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A std::vector> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self); - ExprResult argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); + WithPredicate argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); TypePackId argPack = argListResult.type; if (get(argPack)) @@ -3766,7 +3803,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A if (!overload && !overloadsThatDont.empty()) overload = get(overloadsThatDont[0]); if (overload) - return {errorRecoveryTypePack(overload->retType)}; + return {errorRecoveryTypePack(overload->retTypes)}; return {errorRecoveryTypePack(retPack)}; } @@ -3775,7 +3812,7 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st { std::vector> expectedTypes; - auto assignOption = [this, &expectedTypes](size_t index, std::optional ty) { + auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) { if (index == expectedTypes.size()) { expectedTypes.push_back(ty); @@ -3790,7 +3827,7 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st } else { - std::vector result = reduceUnion({*el, *ty}); + std::vector result = reduceUnion({*el, ty}); el = result.size() == 1 ? result[0] : addType(UnionTypeVar{std::move(result)}); } } @@ -3810,7 +3847,8 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st if (argsTail) { - if (const VariadicTypePack* vtp = get(follow(*argsTail))) + argsTail = follow(*argsTail); + if (const VariadicTypePack* vtp = get(*argsTail)) { while (index < argumentCount) assignOption(index++, vtp->ty); @@ -3819,11 +3857,14 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st } } + Demoter demoter{¤tModule->internalTypes}; + demoter.demote(expectedTypes); + return expectedTypes; } -std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector* argLocations, const ExprResult& argListResult, +std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, + TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) { LUAU_ASSERT(argLocations); @@ -3918,14 +3959,14 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (ftv->magicFunction) { // TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458 - if (std::optional> ret = ftv->magicFunction(*this, scope, expr, argListResult)) + if (std::optional> ret = ftv->magicFunction(*this, scope, expr, argListResult)) return *ret; } Unifier state = mkUnifier(expr.location); // Unify return types - checkArgumentList(scope, state, retPack, ftv->retType, /*argLocations*/ {}); + checkArgumentList(scope, state, retPack, ftv->retTypes, /*argLocations*/ {}); if (!state.errors.empty()) { return {}; @@ -3996,7 +4037,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal // we eagerly assume that that's what you actually meant and we commit to it. // This could be incorrect if the function has an additional overload that // actually works. - // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); + // checkArgumentList(scope, editedState, retPack, ftv->retTypes, retLocations, CountMismatch::Return); return true; } } @@ -4027,7 +4068,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal // we eagerly assume that that's what you actually meant and we commit to it. // This could be incorrect if the function has an additional overload that // actually works. - // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); + // checkArgumentList(scope, editedState, retPack, ftv->retTypes, retLocations, CountMismatch::Return); return true; } } @@ -4085,7 +4126,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast // Unify return types if (const FunctionTypeVar* ftv = get(overload)) { - checkArgumentList(scope, state, retPack, ftv->retType, {}); + checkArgumentList(scope, state, retPack, ftv->retTypes, {}); checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); } @@ -4110,7 +4151,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast return; } -ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, +WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, bool substituteFreeForNil, const std::vector& instantiateGenerics, const std::vector>& expectedTypes) { TypePackId pack = addTypePack(TypePack{}); @@ -4401,10 +4442,24 @@ TypeId Anyification::clean(TypeId ty) } else if (auto ctv = get(ty)) { - auto [t, ok] = normalize(ty, *arena, *iceHandler); - if (!ok) - normalizationTooComplex = true; - return t; + if (FFlag::LuauQuantifyConstrained) + { + std::vector copy = ctv->parts; + for (TypeId& ty : copy) + ty = replace(ty); + TypeId res = copy.size() == 1 ? copy[0] : addType(UnionTypeVar{std::move(copy)}); + auto [t, ok] = normalize(res, *arena, *iceHandler); + if (!ok) + normalizationTooComplex = true; + return t; + } + else + { + auto [t, ok] = normalize(ty, *arena, *iceHandler); + if (!ok) + normalizationTooComplex = true; + return t; + } } else return anyType; diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index ba09df5f..3d97e6eb 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -66,7 +66,7 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId t } else if (const auto& itf = get(index)) { - std::optional r = first(follow(itf->retType)); + std::optional r = first(follow(itf->retTypes)); if (!r) return getSingletonTypes().nilType; else diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 33bfe254..57762937 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -29,8 +29,8 @@ LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) namespace Luau { -std::optional> magicFunctionFormat( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); +std::optional> magicFunctionFormat( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); TypeId follow(TypeId t) { @@ -408,41 +408,48 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) return false; } -FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) +BlockedTypeVar::BlockedTypeVar() + : index(++nextIndex) +{ +} + +int BlockedTypeVar::nextIndex = 0; + +FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : argTypes(argTypes) - , retType(retType) + , retTypes(retTypes) , definition(std::move(defn)) , hasSelf(hasSelf) { } -FunctionTypeVar::FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) +FunctionTypeVar::FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : level(level) , argTypes(argTypes) - , retType(retType) + , retTypes(retTypes) , definition(std::move(defn)) , hasSelf(hasSelf) { } -FunctionTypeVar::FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retType, +FunctionTypeVar::FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) - , retType(retType) + , retTypes(retTypes) , definition(std::move(defn)) , hasSelf(hasSelf) { } FunctionTypeVar::FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, - TypePackId retType, std::optional defn, bool hasSelf) + TypePackId retTypes, std::optional defn, bool hasSelf) : level(level) , generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) - , retType(retType) + , retTypes(retTypes) , definition(std::move(defn)) , hasSelf(hasSelf) { @@ -488,7 +495,7 @@ bool areEqual(SeenSet& seen, const FunctionTypeVar& lhs, const FunctionTypeVar& if (!areEqual(seen, *lhs.argTypes, *rhs.argTypes)) return false; - if (!areEqual(seen, *lhs.retType, *rhs.retType)) + if (!areEqual(seen, *lhs.retTypes, *rhs.retTypes)) return false; return true; @@ -678,7 +685,6 @@ static TypeVar trueType_{SingletonTypeVar{BooleanSingleton{true}}, /*persistent* static TypeVar falseType_{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true}; static TypeVar anyType_{AnyTypeVar{}, /*persistent*/ true}; static TypeVar errorType_{ErrorTypeVar{}, /*persistent*/ true}; -static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}, /*persistent*/ true}; static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, true}; static TypePackVar errorTypePack_{Unifiable::Error{}}; @@ -692,7 +698,6 @@ SingletonTypes::SingletonTypes() , trueType(&trueType_) , falseType(&falseType_) , anyType(&anyType_) - , optionalNumberType(&optionalNumberType_) , anyTypePack(&anyTypePack_) , arena(new TypeArena) { @@ -825,7 +830,7 @@ void persist(TypeId ty) else if (auto ftv = get(t)) { persist(ftv->argTypes); - persist(ftv->retType); + persist(ftv->retTypes); } else if (auto ttv = get(t)) { @@ -1100,10 +1105,10 @@ static std::vector parseFormatString(TypeChecker& typechecker, const cha return result; } -std::optional> magicFunctionFormat( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +std::optional> magicFunctionFormat( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, _predicates] = exprResult; + auto [paramPack, _predicates] = withPredicate; TypeArena& arena = typechecker.currentModule->internalTypes; @@ -1142,7 +1147,7 @@ std::optional> magicFunctionFormat( if (expected.size() != actualParamSize && (!tail || expected.size() < actualParamSize)) typechecker.reportError(TypeError{expr.location, CountMismatch{expected.size(), actualParamSize}}); - return ExprResult{arena.addTypePack({typechecker.stringType})}; + return WithPredicate{arena.addTypePack({typechecker.stringType})}; } std::vector filterMap(TypeId type, TypeIdPredicate predicate) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 414b05f4..877663de 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,6 +22,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) +LUAU_FASTFLAG(LuauQuantifyConstrained) namespace Luau { @@ -1288,13 +1289,13 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); innerState.ctx = CountMismatch::Result; - innerState.tryUnify_(subFunction->retType, superFunction->retType); + innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes); if (!reported) { if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); - else if (!innerState.errors.empty() && size(superFunction->retType) == 1 && finite(superFunction->retType)) + else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) reportError(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) reportError( @@ -1312,7 +1313,7 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); ctx = CountMismatch::Result; - tryUnify_(subFunction->retType, superFunction->retType); + tryUnify_(subFunction->retTypes, superFunction->retTypes); } if (FFlag::LuauTxnLogRefreshFunctionPointers) @@ -2177,7 +2178,7 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas else if (auto fun = state.log.getMutable(ty)) { queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); - queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); + queueTypePack(queue, seenTypePacks, state, fun->retTypes, anyTypePack); } else if (auto table = state.log.getMutable(ty)) { @@ -2322,7 +2323,7 @@ void Unifier::tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy) superC->parts.push_back(subTy); } -void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy) +void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel) { // The duplication between this and regular typepack unification is tragic. @@ -2357,7 +2358,7 @@ void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy) if (!freeTailPack) return; - TypeLevel level = freeTailPack->level; + TypeLevel level = FFlag::LuauQuantifyConstrained ? demotedLevel : freeTailPack->level; TypePack* tp = getMutable(log.replace(tailPack, TypePack{})); diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 597b2f0a..a34f7603 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -1075,6 +1075,8 @@ void BytecodeBuilder::validate() const LUAU_ASSERT(i <= insns.size()); } + std::vector openCaptures; + // second pass: validate the rest of the bytecode for (size_t i = 0; i < insns.size();) { @@ -1121,6 +1123,8 @@ void BytecodeBuilder::validate() const case LOP_CLOSEUPVALS: VREG(LUAU_INSN_A(insn)); + while (openCaptures.size() && openCaptures.back() >= LUAU_INSN_A(insn)) + openCaptures.pop_back(); break; case LOP_GETIMPORT: @@ -1388,8 +1392,12 @@ void BytecodeBuilder::validate() const switch (LUAU_INSN_A(insn)) { case LCT_VAL: + VREG(LUAU_INSN_B(insn)); + break; + case LCT_REF: VREG(LUAU_INSN_B(insn)); + openCaptures.push_back(LUAU_INSN_B(insn)); break; case LCT_UPVAL: @@ -1409,6 +1417,12 @@ void BytecodeBuilder::validate() const LUAU_ASSERT(i <= insns.size()); } + // all CAPTURE REF instructions must have a CLOSEUPVALS instruction after them in the bytecode stream + // this doesn't guarantee safety as it doesn't perform basic block based analysis, but if this fails + // then the bytecode is definitely unsafe to run since the compiler won't generate backwards branches + // except for loop edges + LUAU_ASSERT(openCaptures.empty()); + #undef VREG #undef VREGEND #undef VUPVAL diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 7431cde4..52dc9242 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -246,6 +246,14 @@ struct Compiler f.canInline = true; f.stackSize = stackSize; f.costModel = modelCost(func->body, func->args.data, func->args.size); + + // track functions that only ever return a single value so that we can convert multret calls to fixedret calls + if (allPathsEndWithReturn(func->body)) + { + ReturnVisitor returnVisitor(this); + stat->visit(&returnVisitor); + f.returnsOne = returnVisitor.returnsOne; + } } upvals.clear(); // note: instead of std::move above, we copy & clear to preserve capacity for future pushes @@ -260,6 +268,19 @@ struct Compiler { if (AstExprCall* expr = node->as()) { + // Optimization: convert multret calls to functions that always return one value to fixedret calls; this facilitates inlining + if (options.optimizationLevel >= 2) + { + AstExprFunction* func = getFunctionExpr(expr->func); + Function* fi = func ? functions.find(func) : nullptr; + + if (fi && fi->returnsOne) + { + compileExprTemp(node, target); + return false; + } + } + // We temporarily swap out regTop to have targetTop work correctly... // This is a crude hack but it's necessary for correctness :( RegScope rs(this, target); @@ -447,7 +468,9 @@ struct Compiler return false; } - // TODO: we can compile multret functions if all returns of the function are multret as well + // we can't inline multret functions because the caller expects L->top to be adjusted: + // - inlined return compiles to a JUMP, and we don't have an instruction that adjusts L->top arbitrarily + // - even if we did, right now all L->top adjustments are immediately consumed by the next instruction, and for now we want to preserve that if (multRet) { bytecode.addDebugRemark("inlining failed: can't convert fixed returns to multret"); @@ -492,7 +515,7 @@ struct Compiler size_t oldLocals = localStack.size(); // note that we push the frame early; this is needed to block recursive inline attempts - inlineFrames.push_back({func, target, targetCount}); + inlineFrames.push_back({func, oldLocals, target, targetCount}); // evaluate all arguments; note that we don't emit code for constant arguments (relying on constant folding) for (size_t i = 0; i < func->args.size; ++i) @@ -593,6 +616,8 @@ struct Compiler { for (size_t i = 0; i < targetCount; ++i) bytecode.emitABC(LOP_LOADNIL, uint8_t(target + i), 0, 0); + + closeLocals(oldLocals); } popLocals(oldLocals); @@ -2355,6 +2380,8 @@ struct Compiler compileExprListTemp(stat->list, frame.target, frame.targetCount, /* targetTop= */ false); + closeLocals(frame.localOffset); + if (!fallthrough) { size_t jumpLabel = bytecode.emitLabel(); @@ -3316,6 +3343,48 @@ struct Compiler std::vector upvals; }; + struct ReturnVisitor: AstVisitor + { + Compiler* self; + bool returnsOne = true; + + ReturnVisitor(Compiler* self) + : self(self) + { + } + + bool visit(AstExpr* expr) override + { + return false; + } + + bool visit(AstStatReturn* stat) override + { + if (stat->list.size == 1) + { + AstExpr* value = stat->list.data[0]; + + if (AstExprCall* expr = value->as()) + { + AstExprFunction* func = self->getFunctionExpr(expr->func); + Function* fi = func ? self->functions.find(func) : nullptr; + + returnsOne &= fi && fi->returnsOne; + } + else if (value->is()) + { + returnsOne = false; + } + } + else + { + returnsOne = false; + } + + return false; + } + }; + struct RegScope { RegScope(Compiler* self) @@ -3351,6 +3420,7 @@ struct Compiler uint64_t costModel = 0; unsigned int stackSize = 0; bool canInline = false; + bool returnsOne = false; }; struct Local @@ -3384,6 +3454,8 @@ struct Compiler { AstExprFunction* func; + size_t localOffset; + uint8_t target; uint8_t targetCount; diff --git a/Sources.cmake b/Sources.cmake index 99007e89..f261cba6 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -65,12 +65,13 @@ target_sources(Luau.CodeGen PRIVATE target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/AstQuery.h Analysis/include/Luau/Autocomplete.h - Analysis/include/Luau/NotNull.h Analysis/include/Luau/BuiltinDefinitions.h Analysis/include/Luau/Clone.h Analysis/include/Luau/Config.h + Analysis/include/Luau/Constraint.h Analysis/include/Luau/ConstraintGraphBuilder.h Analysis/include/Luau/ConstraintSolver.h + Analysis/include/Luau/ConstraintSolverLogger.h Analysis/include/Luau/Documentation.h Analysis/include/Luau/Error.h Analysis/include/Luau/FileResolver.h @@ -97,6 +98,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/TxnLog.h Analysis/include/Luau/TypeArena.h Analysis/include/Luau/TypeAttach.h + Analysis/include/Luau/TypeChecker2.h Analysis/include/Luau/TypedAllocator.h Analysis/include/Luau/TypeInfer.h Analysis/include/Luau/TypePack.h @@ -113,8 +115,10 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/BuiltinDefinitions.cpp Analysis/src/Clone.cpp Analysis/src/Config.cpp + Analysis/src/Constraint.cpp Analysis/src/ConstraintGraphBuilder.cpp Analysis/src/ConstraintSolver.cpp + Analysis/src/ConstraintSolverLogger.cpp Analysis/src/Error.cpp Analysis/src/Frontend.cpp Analysis/src/Instantiation.cpp @@ -136,6 +140,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/TxnLog.cpp Analysis/src/TypeArena.cpp Analysis/src/TypeAttach.cpp + Analysis/src/TypeChecker2.cpp Analysis/src/TypedAllocator.cpp Analysis/src/TypeInfer.cpp Analysis/src/TypePack.cpp @@ -245,7 +250,6 @@ if(TARGET Luau.UnitTest) tests/AstQuery.test.cpp tests/AstVisitor.test.cpp tests/Autocomplete.test.cpp - tests/NotNull.test.cpp tests/BuiltinDefinitions.test.cpp tests/Compiler.test.cpp tests/Config.test.cpp diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 5e02c2ea..bdcb85cb 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -418,7 +418,7 @@ typedef struct Table CommonHeader; - uint8_t flags; /* 1<

flags = 0 +#define invalidateTMcache(t) t->tmcache = 0 // empty hash data points to dummynode so that we can always dereference it const LuaNode luaH_dummynode = { @@ -479,7 +479,7 @@ Table* luaH_new(lua_State* L, int narray, int nhash) Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); luaC_init(L, t, LUA_TTABLE); t->metatable = NULL; - t->flags = cast_byte(~0); + t->tmcache = cast_byte(~0); t->array = NULL; t->sizearray = 0; t->lastfree = 0; @@ -778,7 +778,7 @@ Table* luaH_clone(lua_State* L, Table* tt) Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); luaC_init(L, t, LUA_TTABLE); t->metatable = tt->metatable; - t->flags = tt->flags; + t->tmcache = tt->tmcache; t->array = NULL; t->sizearray = 0; t->lsizenode = 0; @@ -835,5 +835,5 @@ void luaH_clear(Table* tt) } /* back to empty -> no tag methods present */ - tt->flags = cast_byte(~0); + tt->tmcache = cast_byte(~0); } diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index 9b99506b..e7df4e53 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -88,8 +88,8 @@ const TValue* luaT_gettm(Table* events, TMS event, TString* ename) const TValue* tm = luaH_getstr(events, ename); LUAU_ASSERT(event <= TM_EQ); if (ttisnil(tm)) - { /* no tag method? */ - events->flags |= cast_byte(1u << event); /* cache this fact */ + { /* no tag method? */ + events->tmcache |= cast_byte(1u << event); /* cache this fact */ return NULL; } else diff --git a/VM/src/ltm.h b/VM/src/ltm.h index e1b95c21..a5223941 100644 --- a/VM/src/ltm.h +++ b/VM/src/ltm.h @@ -41,10 +41,10 @@ typedef enum } TMS; // clang-format on -#define gfasttm(g, et, e) ((et) == NULL ? NULL : ((et)->flags & (1u << (e))) ? NULL : luaT_gettm(et, e, (g)->tmname[e])) +#define gfasttm(g, et, e) ((et) == NULL ? NULL : ((et)->tmcache & (1u << (e))) ? NULL : luaT_gettm(et, e, (g)->tmname[e])) #define fasttm(l, et, e) gfasttm(l->global, et, e) -#define fastnotm(et, e) ((et) == NULL || ((et)->flags & (1u << (e)))) +#define fastnotm(et, e) ((et) == NULL || ((et)->tmcache & (1u << (e)))) LUAI_DATA const char* const luaT_typenames[]; LUAI_DATA const char* const luaT_eventname[]; diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index dea1ab19..f3b0bcad 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1992,6 +1992,7 @@ local fp: @1= f auto ac = autocomplete('1'); + REQUIRE_EQ("({| x: number, y: number |}) -> number", toString(requireType("f"))); CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } @@ -2620,7 +2621,6 @@ a = if temp then even elseif true then temp else e@9 TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_else_regression") { - ScopedFastFlag FFlagLuauIfElseExprFixCompletionIssue("LuauIfElseExprFixCompletionIssue", true); check(R"( local abcdef = 0; local temp = false diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 6eee254e..036bf124 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -4992,6 +4992,147 @@ RETURN R1 1 )"); } +TEST_CASE("InlineCapture") +{ + // if the argument is captured by a nested closure, normally we can rely on capture by value + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return function() return a end +end + +local x = ... +local y = foo(x) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +NEWCLOSURE R2 P1 +CAPTURE VAL R1 +RETURN R2 1 +)"); + + // if the argument is a constant, we move it to a register so that capture by value can happen + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return function() return a end +end + +local y = foo(42) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R2 42 +NEWCLOSURE R1 P1 +CAPTURE VAL R2 +RETURN R1 1 +)"); + + // if the argument is an externally mutated variable, we copy it to an argument and capture it by value + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return function() return a end +end + +local x x = 42 +local y = foo(x) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R1 +LOADN R1 42 +MOVE R3 R1 +NEWCLOSURE R2 P1 +CAPTURE VAL R3 +RETURN R2 1 +)"); + + // finally, if the argument is mutated internally, we must capture it by reference and close the upvalue + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = a or 42 + return function() return a end +end + +local y = foo() +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R2 +ORK R2 R2 K1 +NEWCLOSURE R1 P1 +CAPTURE REF R2 +CLOSEUPVALS R2 +RETURN R1 1 +)"); + + // note that capture might need to be performed during the fallthrough block + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = a or 42 + print(function() return a end) +end + +local x = ... +local y = foo(x) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +MOVE R3 R1 +ORK R3 R3 K1 +GETIMPORT R4 3 +NEWCLOSURE R5 P1 +CAPTURE REF R3 +CALL R4 1 0 +LOADNIL R2 +CLOSEUPVALS R3 +RETURN R2 1 +)"); + + // note that mutation and capture might be inside internal control flow + // TODO: this has an oddly redundant CLOSEUPVALS after JUMP; it's not due to inlining, and is an artifact of how StatBlock/StatReturn interact + // fixing this would reduce the number of redundant CLOSEUPVALS a bit but it only affects bytecode size as these instructions aren't executed + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + if not a then + local b b = 42 + return function() return b end + end +end + +local x = ... +local y = foo(x) +return y, x +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +JUMPIF R1 L0 +LOADNIL R3 +LOADN R3 42 +NEWCLOSURE R2 P1 +CAPTURE REF R3 +CLOSEUPVALS R3 +JUMP L1 +CLOSEUPVALS R3 +L0: LOADNIL R2 +L1: MOVE R3 R2 +MOVE R4 R1 +RETURN R3 2 +)"); +} + TEST_CASE("InlineFallthrough") { // if the function doesn't return, we still fill the results with nil @@ -5044,27 +5185,6 @@ RETURN R1 -1 )"); } -TEST_CASE("InlineCapture") -{ - // can't inline function with nested functions that capture locals because they might be constants - CHECK_EQ("\n" + compileFunction(R"( -local function foo(a) - local function bar() - return a - end - return bar() -end -)", - 1, 2), - R"( -NEWCLOSURE R1 P0 -CAPTURE VAL R0 -MOVE R2 R1 -CALL R2 0 -1 -RETURN R2 -1 -)"); -} - TEST_CASE("InlineArgMismatch") { // when inlining a function, we must respect all the usual rules @@ -5491,6 +5611,96 @@ RETURN R2 1 )"); } +TEST_CASE("InlineMultret") +{ + // inlining a function in multret context is prohibited since we can't adjust L->top outside of CALL/GETVARARGS + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a() +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); + + // however, if we can deduce statically that a function always returns a single value, the inlining will work + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R1 42 +RETURN R1 1 +)"); + + // this analysis will also propagate through other functions + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local function bar(a) + return foo(a) +end + +return bar(42) +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +DUPCLOSURE R1 K1 +LOADN R2 42 +RETURN R2 1 +)"); + + // we currently don't do this analysis fully for recursive functions since they can't be inlined anyway + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return foo(a) +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +CAPTURE VAL R0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); + + // and unfortunately we can't do this analysis for builtins or method calls due to getfenv + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return math.abs(a) +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); +} + TEST_CASE("ReturnConsecutive") { // we can return a single local directly diff --git a/tests/ConstraintGraphBuilder.test.cpp b/tests/ConstraintGraphBuilder.test.cpp index ab5af4f6..96b21613 100644 --- a/tests/ConstraintGraphBuilder.test.cpp +++ b/tests/ConstraintGraphBuilder.test.cpp @@ -17,13 +17,13 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello_world") )"); cgb.visit(block); - std::vector constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(cgb.rootScope); REQUIRE(2 == constraints.size()); ToStringOptions opts; - CHECK("a <: string" == toString(*constraints[0], opts)); - CHECK("b <: a" == toString(*constraints[1], opts)); + CHECK("string <: a" == toString(*constraints[0], opts)); + CHECK("a <: b" == toString(*constraints[1], opts)); } TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "primitives") @@ -36,15 +36,34 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "primitives") )"); cgb.visit(block); - std::vector constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(cgb.rootScope); - REQUIRE(4 == constraints.size()); + REQUIRE(3 == constraints.size()); ToStringOptions opts; - CHECK("a <: string" == toString(*constraints[0], opts)); - CHECK("b <: number" == toString(*constraints[1], opts)); - CHECK("c <: boolean" == toString(*constraints[2], opts)); - CHECK("d <: nil" == toString(*constraints[3], opts)); + CHECK("string <: a" == toString(*constraints[0], opts)); + CHECK("number <: b" == toString(*constraints[1], opts)); + CHECK("boolean <: c" == toString(*constraints[2], opts)); +} + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "nil_primitive") +{ + AstStatBlock* block = parse(R"( + local function a() return nil end + local b = a() + )"); + + cgb.visit(block); + auto constraints = collectConstraints(cgb.rootScope); + + ToStringOptions opts; + REQUIRE(5 <= constraints.size()); + + CHECK("*blocked-1* ~ gen () -> (a...)" == toString(*constraints[0], opts)); + CHECK("b ~ inst *blocked-1*" == toString(*constraints[1], opts)); + CHECK("() -> (c...) <: b" == toString(*constraints[2], opts)); + CHECK("c... <: d" == toString(*constraints[3], opts)); + CHECK("nil <: a..." == toString(*constraints[4], opts)); } TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "function_application") @@ -55,15 +74,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "function_application") )"); cgb.visit(block); - std::vector constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(cgb.rootScope); REQUIRE(4 == constraints.size()); ToStringOptions opts; - CHECK("a <: string" == toString(*constraints[0], opts)); + CHECK("string <: a" == toString(*constraints[0], opts)); CHECK("b ~ inst a" == toString(*constraints[1], opts)); - CHECK("(string) -> (c, d...) <: b" == toString(*constraints[2], opts)); - CHECK("e <: c" == toString(*constraints[3], opts)); + CHECK("(string) -> (c...) <: b" == toString(*constraints[2], opts)); + CHECK("c... <: d" == toString(*constraints[3], opts)); } TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "local_function_definition") @@ -75,13 +94,13 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "local_function_definition") )"); cgb.visit(block); - std::vector constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(cgb.rootScope); REQUIRE(2 == constraints.size()); ToStringOptions opts; - CHECK("a ~ gen (b) -> (c...)" == toString(*constraints[0], opts)); - CHECK("b <: c..." == toString(*constraints[1], opts)); + CHECK("*blocked-1* ~ gen (a) -> (b...)" == toString(*constraints[0], opts)); + CHECK("a <: b..." == toString(*constraints[1], opts)); } TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "recursive_function") @@ -93,15 +112,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "recursive_function") )"); cgb.visit(block); - std::vector constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(cgb.rootScope); REQUIRE(4 == constraints.size()); ToStringOptions opts; - CHECK("a ~ gen (b) -> (c...)" == toString(*constraints[0], opts)); - CHECK("d ~ inst a" == toString(*constraints[1], opts)); - CHECK("(b) -> (e, f...) <: d" == toString(*constraints[2], opts)); - CHECK("e <: c..." == toString(*constraints[3], opts)); + CHECK("*blocked-1* ~ gen (a) -> (b...)" == toString(*constraints[0], opts)); + CHECK("c ~ inst (a) -> (b...)" == toString(*constraints[1], opts)); + CHECK("(a) -> (d...) <: c" == toString(*constraints[2], opts)); + CHECK("d... <: b..." == toString(*constraints[3], opts)); } TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 232ec2de..ac22f65b 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -345,7 +345,7 @@ void Fixture::dumpErrors(std::ostream& os, const std::vector& errors) if (error.location.begin.line >= lines.size()) { os << "\tSource not available?" << std::endl; - return; + continue; } std::string_view theLine = lines[error.location.begin.line]; @@ -430,6 +430,7 @@ ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() : Fixture() , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} { + BlockedTypeVar::nextIndex = 0; } ModuleName fromString(std::string_view name) diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index c0554669..b9c24704 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -97,8 +97,8 @@ TEST_CASE_FIXTURE(FrontendFixture, "find_a_require") NaiveFileResolver naiveFileResolver; auto res = traceRequires(&naiveFileResolver, program, ""); - CHECK_EQ(1, res.requires.size()); - CHECK_EQ(res.requires[0].first, "Modules/Foo/Bar"); + CHECK_EQ(1, res.requireList.size()); + CHECK_EQ(res.requireList[0].first, "Modules/Foo/Bar"); } // It could be argued that this should not work. @@ -113,7 +113,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "find_a_require_inside_a_function") NaiveFileResolver naiveFileResolver; auto res = traceRequires(&naiveFileResolver, program, ""); - CHECK_EQ(1, res.requires.size()); + CHECK_EQ(1, res.requireList.size()); } TEST_CASE_FIXTURE(FrontendFixture, "real_source") @@ -138,7 +138,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "real_source") NaiveFileResolver naiveFileResolver; auto res = traceRequires(&naiveFileResolver, program, ""); - CHECK_EQ(8, res.requires.size()); + CHECK_EQ(8, res.requireList.size()); } TEST_CASE_FIXTURE(FrontendFixture, "automatically_check_dependent_scripts") diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 89b13ab1..d585b731 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -102,7 +102,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") const FunctionTypeVar* ftv = get(methodType); REQUIRE(ftv != nullptr); - std::optional methodReturnType = first(ftv->retType); + std::optional methodReturnType = first(ftv->retTypes); REQUIRE(methodReturnType); CHECK_EQ(methodReturnType, counterCopy); diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index c0556103..50dcbad0 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -13,6 +13,57 @@ using namespace Luau; TEST_SUITE_BEGIN("NonstrictModeTests"); +TEST_CASE_FIXTURE(Fixture, "globals") +{ + CheckResult result = check(R"( + --!nonstrict + foo = true + foo = "now i'm a string!" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("any", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "globals2") +{ + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + --!nonstrict + foo = function() return 1 end + foo = "now i'm a string!" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("() -> number", toString(tm->wantedType)); + CHECK_EQ("string", toString(tm->givenType)); + CHECK_EQ("() -> number", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "globals_everywhere") +{ + CheckResult result = check(R"( + --!nonstrict + foo = 1 + + if true then + bar = 2 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("foo"))); + CHECK_EQ("any", toString(requireType("bar"))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "function_returns_number_or_string") { ScopedFastFlag sff[]{{"LuauReturnTypeInferenceInNonstrict", true}, {"LuauLowerBoundsCalculation", true}}; @@ -51,7 +102,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_nullary_function") REQUIRE_EQ("any", toString(args[0])); REQUIRE_EQ("any", toString(args[1])); - auto rets = flatten(ftv->retType).first; + auto rets = flatten(ftv->retTypes).first; REQUIRE_EQ(0, rets.size()); } diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 2876175d..284230c9 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -837,6 +837,7 @@ TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersect { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, }; // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. @@ -867,16 +868,17 @@ TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersect CHECK("{+ y: number +}" == toString(args[2])); CHECK("{+ z: string +}" == toString(args[3])); - std::vector ret = flatten(ftv->retType).first; + std::vector ret = flatten(ftv->retTypes).first; REQUIRE(1 == ret.size()); - CHECK("{| x: a & {- w: boolean, y: number, z: string -} |}" == toString(ret[0])); + CHECK("{| x: a & {+ w: boolean, y: number, z: string +} |}" == toString(ret[0])); } TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_3") { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, }; // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. @@ -906,16 +908,17 @@ TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersect CHECK("t1 where t1 = {+ y: t1 +}" == toString(args[1])); CHECK("{+ z: string +}" == toString(args[2])); - std::vector ret = flatten(ftv->retType).first; + std::vector ret = flatten(ftv->retTypes).first; REQUIRE(1 == ret.size()); - CHECK("{| x: {- x: boolean, y: t1, z: string -} |} where t1 = {+ y: t1 +}" == toString(ret[0])); + CHECK("{| x: {+ x: boolean, y: t1, z: string +} |} where t1 = {+ y: t1 +}" == toString(ret[0])); } TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_4") { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, }; // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. @@ -944,13 +947,13 @@ TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersect REQUIRE(3 == args.size()); CHECK("{+ x: boolean +}" == toString(args[0])); - CHECK("{+ y: t1 +} where t1 = {| x: {- x: boolean, y: t1, z: string -} |}" == toString(args[1])); + CHECK("{+ y: t1 +} where t1 = {| x: {+ x: boolean, y: t1, z: string +} |}" == toString(args[1])); CHECK("{+ z: string +}" == toString(args[2])); - std::vector ret = flatten(ftv->retType).first; + std::vector ret = flatten(ftv->retTypes).first; REQUIRE(1 == ret.size()); - CHECK("t1 where t1 = {| x: {- x: boolean, y: t1, z: string -} |}" == toString(ret[0])); + CHECK("t1 where t1 = {| x: {+ x: boolean, y: t1, z: string +} |}" == toString(ret[0])); } TEST_CASE_FIXTURE(Fixture, "nested_table_normalization_with_non_table__no_ice") @@ -1062,4 +1065,29 @@ export type t0 = (((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,}))&(((any)&({_:l0.t LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "normalization_does_not_convert_ever") +{ + ScopedFastFlag sff[]{ + {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, + }; + + CheckResult result = check(R"( + --!strict + local function f() + if math.random() > 0.5 then + return true + end + type Ret = typeof(f()) + if math.random() > 0.5 then + return "something" + end + return "something" :: Ret + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("() -> boolean | string", toString(requireType("f"))); +} + TEST_SUITE_END(); diff --git a/tests/NotNull.test.cpp b/tests/NotNull.test.cpp index 1a323c85..ed1c25ec 100644 --- a/tests/NotNull.test.cpp +++ b/tests/NotNull.test.cpp @@ -75,9 +75,9 @@ TEST_CASE("basic_stuff") t->y = 3.14f; const NotNull u = t; - // u->x = 44; // nope + u->x = 44; int v = u->x; - CHECK(v == 5); + CHECK(v == 44); bar(a); @@ -96,8 +96,11 @@ TEST_CASE("basic_stuff") TEST_CASE("hashable") { std::unordered_map, const char*> map; - NotNull a{new int(8)}; - NotNull b{new int(10)}; + int a_ = 8; + int b_ = 10; + + NotNull a{&a_}; + NotNull b{&b_}; std::string hello = "hello"; std::string world = "world"; @@ -108,9 +111,47 @@ TEST_CASE("hashable") CHECK_EQ(2, map.size()); CHECK_EQ(hello.c_str(), map[a]); CHECK_EQ(world.c_str(), map[b]); +} - delete a; - delete b; +TEST_CASE("const") +{ + int p = 0; + int q = 0; + + NotNull n{&p}; + + *n = 123; + + NotNull m = n; // Conversion from NotNull to NotNull is allowed + + CHECK(123 == *m); // readonly access of m is ok + + // *m = 321; // nope. m points at const data. + + // NotNull o = m; // nope. Conversion from NotNull to NotNull is forbidden + + NotNull n2{&q}; + m = n2; // ok. m points to const data, but is not itself const + + const NotNull m2 = n; + // m2 = n2; // nope. m2 is const. + *m2 = 321; // ok. m2 is const, but points to mutable data + + CHECK(321 == *n); +} + +TEST_CASE("const_compatibility") +{ + int* raw = new int(8); + + NotNull a(raw); + NotNull b(raw); + NotNull c = a; + // NotNull d = c; // nope - no conversion from const to non-const + + CHECK_EQ(*c, 8); + + delete raw; } TEST_SUITE_END(); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index b9e1ae96..ccdd2b37 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -70,7 +70,7 @@ TEST_CASE_FIXTURE(Fixture, "function_return_annotations_are_checked") const FunctionTypeVar* ftv = get(fiftyType); REQUIRE(ftv != nullptr); - TypePackId retPack = ftv->retType; + TypePackId retPack = ftv->retTypes; const TypePack* tp = get(retPack); REQUIRE(tp != nullptr); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index a28ba49e..036a667a 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -45,7 +45,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_return_type") const FunctionTypeVar* takeFiveType = get(requireType("take_five")); REQUIRE(takeFiveType != nullptr); - std::vector retVec = flatten(takeFiveType->retType).first; + std::vector retVec = flatten(takeFiveType->retTypes).first; REQUIRE(!retVec.empty()); REQUIRE_EQ(*follow(retVec[0]), *typeChecker.numberType); @@ -345,7 +345,7 @@ TEST_CASE_FIXTURE(Fixture, "local_function") const FunctionTypeVar* ftv = get(h); REQUIRE(ftv != nullptr); - std::optional rt = first(ftv->retType); + std::optional rt = first(ftv->retTypes); REQUIRE(bool(rt)); TypeId retType = follow(*rt); @@ -361,7 +361,7 @@ TEST_CASE_FIXTURE(Fixture, "func_expr_doesnt_leak_free") LUAU_REQUIRE_NO_ERRORS(result); const Luau::FunctionTypeVar* fn = get(requireType("p")); REQUIRE(fn); - auto ret = first(fn->retType); + auto ret = first(fn->retTypes); REQUIRE(ret); REQUIRE(get(follow(*ret))); } @@ -460,7 +460,7 @@ TEST_CASE_FIXTURE(Fixture, "complicated_return_types_require_an_explicit_annotat const FunctionTypeVar* functionType = get(requireType("most_of_the_natural_numbers")); - std::optional retType = first(functionType->retType); + std::optional retType = first(functionType->retTypes); REQUIRE(retType); CHECK(get(*retType)); } @@ -1619,4 +1619,56 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "quantify_constrained_types") +{ + ScopedFastFlag sff[]{ + {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, + }; + + CheckResult result = check(R"( + --!strict + local function foo(f) + f(5) + f("hi") + local function g() + return f + end + local h = g() + h(true) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("((boolean | number | string) -> (a...)) -> ()", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "call_o_with_another_argument_after_foo_was_quantified") +{ + ScopedFastFlag sff[]{ + {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, + }; + + CheckResult result = check(R"( + local function f(o) + local t = {} + t[o] = true + + local function foo(o) + o.m1(5) + t[o] = nil + end + + o.m1("hi") + + return t + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + // TODO: check the normalized type of f +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index fbda8bec..edb5adcf 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -224,7 +224,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function") const FunctionTypeVar* idFun = get(idType); REQUIRE(idFun); auto [args, varargs] = flatten(idFun->argTypes); - auto [rets, varrets] = flatten(idFun->retType); + auto [rets, varrets] = flatten(idFun->retTypes); CHECK_EQ(idFun->generics.size(), 1); CHECK_EQ(idFun->genericPacks.size(), 0); @@ -247,7 +247,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") const FunctionTypeVar* idFun = get(idType); REQUIRE(idFun); auto [args, varargs] = flatten(idFun->argTypes); - auto [rets, varrets] = flatten(idFun->retType); + auto [rets, varrets] = flatten(idFun->retTypes); CHECK_EQ(idFun->generics.size(), 1); CHECK_EQ(idFun->genericPacks.size(), 0); @@ -882,7 +882,7 @@ TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") const FunctionTypeVar* foo = get(follow(fooProp->type)); REQUIRE(bool(foo)); - std::optional ret_ = first(foo->retType); + std::optional ret_ = first(foo->retTypes); REQUIRE(bool(ret_)); TypeId ret = follow(*ret_); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 03614938..fd9b1dd4 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -90,7 +90,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "primitive_arith_no_metatable") const FunctionTypeVar* functionType = get(requireType("add")); - std::optional retType = first(functionType->retType); + std::optional retType = first(functionType->retTypes); REQUIRE(retType.has_value()); CHECK_EQ(typeChecker.numberType, follow(*retType)); CHECK_EQ(requireType("n"), typeChecker.numberType); @@ -777,8 +777,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown") TEST_CASE_FIXTURE(BuiltinsFixture, "equality_operations_succeed_if_any_union_branch_succeeds") { - ScopedFastFlag sff("LuauSuccessTypingForEqualityOperations", true); - CheckResult result = check(R"( local mm = {} type Foo = typeof(setmetatable({}, mm)) diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 22fb3b69..487e5979 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -472,6 +472,7 @@ TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent") ScopedFastFlag sff[]{ {"LuauLowerBoundsCalculation", true}, {"LuauNormalizeFlagIsConservative", true}, + {"LuauQuantifyConstrained", true}, }; CheckResult result = check(R"( @@ -494,8 +495,8 @@ TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent") )"); LUAU_REQUIRE_NO_ERRORS(result); - // TODO: We're missing generics a... and b... - CHECK_EQ("(t1) -> {| [t1]: boolean |} where t1 = t2 ; t2 = {+ m1: (t1) -> (a...), m2: (t2) -> (b...) +}", toString(requireType("f"))); + // TODO: We're missing generics b... + CHECK_EQ("(t1) -> {| [t1]: boolean |} where t1 = t2 ; t2 = {+ m1: (t1) -> (a...), m2: (t2) -> (b...) +}", toString(requireType("f"))); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 207b3cff..cefba4b2 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -13,8 +13,8 @@ using namespace Luau; namespace { -std::optional> magicFunctionInstanceIsA( - TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +std::optional> magicFunctionInstanceIsA( + TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { if (expr.args.size != 1) return std::nullopt; @@ -32,7 +32,7 @@ std::optional> magicFunctionInstanceIsA( unfreeze(typeChecker.globalTypes); TypePackId booleanPack = typeChecker.globalTypes.addTypePack({typeChecker.booleanType}); freeze(typeChecker.globalTypes); - return ExprResult{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; + return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; } struct RefinementClassFixture : Fixture diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index d622d4af..87d49651 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -642,7 +642,7 @@ TEST_CASE_FIXTURE(Fixture, "indexers_quantification_2") const TableTypeVar* argType = get(follow(argVec[0])); REQUIRE(argType != nullptr); - std::vector retVec = flatten(ftv->retType).first; + std::vector retVec = flatten(ftv->retTypes).first; const TableTypeVar* retType = get(follow(retVec[0])); REQUIRE(retType != nullptr); @@ -691,7 +691,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_value_property_in_literal") const FunctionTypeVar* fType = get(requireType("f")); REQUIRE(fType != nullptr); - auto retType_ = first(fType->retType); + auto retType_ = first(fType->retTypes); REQUIRE(bool(retType_)); auto retType = get(follow(*retType_)); @@ -1881,7 +1881,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "quantifying_a_bound_var_works") REQUIRE(prop.type); const FunctionTypeVar* ftv = get(follow(prop.type)); REQUIRE(ftv); - const TypePack* res = get(follow(ftv->retType)); + const TypePack* res = get(follow(ftv->retTypes)); REQUIRE(res); REQUIRE(res->head.size() == 1); const MetatableTypeVar* mtv = get(follow(res->head[0])); @@ -2584,7 +2584,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_quantify_table_that_belongs_to_outer_sc const FunctionTypeVar* newType = get(follow(counterType->props["new"].type)); REQUIRE(newType); - std::optional newRetType = *first(newType->retType); + std::optional newRetType = *first(newType->retTypes); REQUIRE(newRetType); const MetatableTypeVar* newRet = get(follow(*newRetType)); @@ -2977,7 +2977,6 @@ TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra") { - ScopedFastFlag luauExpectedPropTypeFromIndexer{"LuauExpectedPropTypeFromIndexer", true}; ScopedFastFlag luauSubtypingAddOptPropsToUnsealedTables{"LuauSubtypingAddOptPropsToUnsealedTables", true}; CheckResult result = check(R"( @@ -2992,8 +2991,6 @@ TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra") TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra_2") { - ScopedFastFlag luauExpectedPropTypeFromIndexer{"LuauExpectedPropTypeFromIndexer", true}; - CheckResult result = check(R"( type X = {[any]: string | boolean} diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index cf0c9881..6257cda6 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -13,8 +13,9 @@ #include -LUAU_FASTFLAG(LuauLowerBoundsCalculation) -LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) +LUAU_FASTFLAG(LuauLowerBoundsCalculation); +LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); using namespace Luau; @@ -43,10 +44,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_error") CheckResult result = check("local a = 7 local b = 'hi' a = b"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 35}, Position{0, 36}}, TypeMismatch{ - requireType("a"), - requireType("b"), - }})); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 35}, Position{0, 36}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); } TEST_CASE_FIXTURE(Fixture, "tc_error_2") @@ -86,6 +84,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_locals_via_assignment_from_its_call_site") TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") { ScopedFastFlag sff[]{ + {"DebugLuauDeferredConstraintResolution", false}, {"LuauReturnTypeInferenceInNonstrict", true}, {"LuauLowerBoundsCalculation", true}, }; @@ -236,10 +235,14 @@ TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") CHECK_EQ("boolean", toString(err->table)); CHECK_EQ("x", err->key); - CHECK_EQ("*unknown*", toString(requireType("c"))); - CHECK_EQ("*unknown*", toString(requireType("d"))); - CHECK_EQ("*unknown*", toString(requireType("e"))); - CHECK_EQ("*unknown*", toString(requireType("f"))); + // TODO: Should we assert anything about these tests when DCR is being used? + if (!FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("*unknown*", toString(requireType("c"))); + CHECK_EQ("*unknown*", toString(requireType("d"))); + CHECK_EQ("*unknown*", toString(requireType("e"))); + CHECK_EQ("*unknown*", toString(requireType("f"))); + } } TEST_CASE_FIXTURE(Fixture, "should_be_able_to_infer_this_without_stack_overflowing") @@ -352,40 +355,6 @@ TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") CHECK(nullptr != get(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "globals") -{ - CheckResult result = check(R"( - --!nonstrict - foo = true - foo = "now i'm a string!" - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("foo"))); -} - -TEST_CASE_FIXTURE(Fixture, "globals2") -{ - ScopedFastFlag sff[]{ - {"LuauReturnTypeInferenceInNonstrict", true}, - {"LuauLowerBoundsCalculation", true}, - }; - - CheckResult result = check(R"( - --!nonstrict - foo = function() return 1 end - foo = "now i'm a string!" - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("() -> number", toString(tm->wantedType)); - CHECK_EQ("string", toString(tm->givenType)); - CHECK_EQ("() -> number", toString(requireType("foo"))); -} - TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode") { CheckResult result = check(R"( @@ -400,23 +369,6 @@ TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode") CHECK_EQ("foo", us->name); } -TEST_CASE_FIXTURE(Fixture, "globals_everywhere") -{ - CheckResult result = check(R"( - --!nonstrict - foo = 1 - - if true then - bar = 2 - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("foo"))); - CHECK_EQ("any", toString(requireType("bar"))); -} - TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_scope_locals_do") { CheckResult result = check(R"( @@ -447,21 +399,6 @@ TEST_CASE_FIXTURE(Fixture, "checking_should_not_ice") CHECK_EQ("any", toString(requireType("value"))); } -// TEST_CASE_FIXTURE(Fixture, "infer_method_signature_of_argument") -// { -// CheckResult result = check(R"( -// function f(a) -// if a.cond then -// return a.method() -// end -// end -// )"); - -// LUAU_REQUIRE_NO_ERRORS(result); - -// CHECK_EQ("A", toString(requireType("f"))); -// } - TEST_CASE_FIXTURE(Fixture, "cyclic_follow") { check(R"( diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 118863fe..bcd30498 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -26,7 +26,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_multi_return") const FunctionTypeVar* takeTwoType = get(requireType("take_two")); REQUIRE(takeTwoType != nullptr); - const auto& [returns, tail] = flatten(takeTwoType->retType); + const auto& [returns, tail] = flatten(takeTwoType->retTypes); CHECK_EQ(2, returns.size()); CHECK_EQ(typeChecker.numberType, follow(returns[0])); @@ -73,7 +73,7 @@ TEST_CASE_FIXTURE(Fixture, "last_element_of_return_statement_can_itself_be_a_pac const FunctionTypeVar* takeOneMoreType = get(requireType("take_three")); REQUIRE(takeOneMoreType != nullptr); - const auto& [rets, tail] = flatten(takeOneMoreType->retType); + const auto& [rets, tail] = flatten(takeOneMoreType->retTypes); REQUIRE_EQ(3, rets.size()); CHECK_EQ(typeChecker.numberType, follow(rets[0])); @@ -105,10 +105,10 @@ TEST_CASE_FIXTURE(Fixture, "return_type_should_be_empty_if_nothing_is_returned") LUAU_REQUIRE_NO_ERRORS(result); const FunctionTypeVar* fTy = get(requireType("f")); REQUIRE(fTy != nullptr); - CHECK_EQ(0, size(fTy->retType)); + CHECK_EQ(0, size(fTy->retTypes)); const FunctionTypeVar* gTy = get(requireType("g")); REQUIRE(gTy != nullptr); - CHECK_EQ(0, size(gTy->retType)); + CHECK_EQ(0, size(gTy->retTypes)); } TEST_CASE_FIXTURE(Fixture, "no_return_size_should_be_zero") @@ -125,15 +125,15 @@ TEST_CASE_FIXTURE(Fixture, "no_return_size_should_be_zero") const FunctionTypeVar* fTy = get(requireType("f")); REQUIRE(fTy != nullptr); - CHECK_EQ(1, size(follow(fTy->retType))); + CHECK_EQ(1, size(follow(fTy->retTypes))); const FunctionTypeVar* gTy = get(requireType("g")); REQUIRE(gTy != nullptr); - CHECK_EQ(0, size(gTy->retType)); + CHECK_EQ(0, size(gTy->retTypes)); const FunctionTypeVar* hTy = get(requireType("h")); REQUIRE(hTy != nullptr); - CHECK_EQ(0, size(hTy->retType)); + CHECK_EQ(0, size(hTy->retTypes)); } TEST_CASE_FIXTURE(Fixture, "varargs_inference_through_multiple_scopes") diff --git a/tools/natvis/Analysis.natvis b/tools/natvis/Analysis.natvis index 5de0140e..b9ea3141 100644 --- a/tools/natvis/Analysis.natvis +++ b/tools/natvis/Analysis.natvis @@ -6,40 +6,40 @@ - {{ index=0, value={*($T1*)storage} }} - {{ index=1, value={*($T2*)storage} }} - {{ index=2, value={*($T3*)storage} }} - {{ index=3, value={*($T4*)storage} }} - {{ index=4, value={*($T5*)storage} }} - {{ index=5, value={*($T6*)storage} }} - {{ index=6, value={*($T7*)storage} }} - {{ index=7, value={*($T8*)storage} }} - {{ index=8, value={*($T9*)storage} }} - {{ index=9, value={*($T10*)storage} }} - {{ index=10, value={*($T11*)storage} }} - {{ index=11, value={*($T12*)storage} }} - {{ index=12, value={*($T13*)storage} }} - {{ index=13, value={*($T14*)storage} }} - {{ index=14, value={*($T15*)storage} }} - {{ index=15, value={*($T16*)storage} }} - {{ index=16, value={*($T17*)storage} }} - {{ index=17, value={*($T18*)storage} }} - {{ index=18, value={*($T19*)storage} }} - {{ index=19, value={*($T20*)storage} }} - {{ index=20, value={*($T21*)storage} }} - {{ index=21, value={*($T22*)storage} }} - {{ index=22, value={*($T23*)storage} }} - {{ index=23, value={*($T24*)storage} }} - {{ index=24, value={*($T25*)storage} }} - {{ index=25, value={*($T26*)storage} }} - {{ index=26, value={*($T27*)storage} }} - {{ index=27, value={*($T28*)storage} }} - {{ index=28, value={*($T29*)storage} }} - {{ index=29, value={*($T30*)storage} }} - {{ index=30, value={*($T31*)storage} }} - {{ index=31, value={*($T32*)storage} }} + {{ typeId=0, value={*($T1*)storage} }} + {{ typeId=1, value={*($T2*)storage} }} + {{ typeId=2, value={*($T3*)storage} }} + {{ typeId=3, value={*($T4*)storage} }} + {{ typeId=4, value={*($T5*)storage} }} + {{ typeId=5, value={*($T6*)storage} }} + {{ typeId=6, value={*($T7*)storage} }} + {{ typeId=7, value={*($T8*)storage} }} + {{ typeId=8, value={*($T9*)storage} }} + {{ typeId=9, value={*($T10*)storage} }} + {{ typeId=10, value={*($T11*)storage} }} + {{ typeId=11, value={*($T12*)storage} }} + {{ typeId=12, value={*($T13*)storage} }} + {{ typeId=13, value={*($T14*)storage} }} + {{ typeId=14, value={*($T15*)storage} }} + {{ typeId=15, value={*($T16*)storage} }} + {{ typeId=16, value={*($T17*)storage} }} + {{ typeId=17, value={*($T18*)storage} }} + {{ typeId=18, value={*($T19*)storage} }} + {{ typeId=19, value={*($T20*)storage} }} + {{ typeId=20, value={*($T21*)storage} }} + {{ typeId=21, value={*($T22*)storage} }} + {{ typeId=22, value={*($T23*)storage} }} + {{ typeId=23, value={*($T24*)storage} }} + {{ typeId=24, value={*($T25*)storage} }} + {{ typeId=25, value={*($T26*)storage} }} + {{ typeId=26, value={*($T27*)storage} }} + {{ typeId=27, value={*($T28*)storage} }} + {{ typeId=28, value={*($T29*)storage} }} + {{ typeId=29, value={*($T30*)storage} }} + {{ typeId=30, value={*($T31*)storage} }} + {{ typeId=31, value={*($T32*)storage} }} - typeId + typeId *($T1*)storage *($T2*)storage *($T3*)storage