diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index 217e1cc3..f003c242 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include #include "Luau/TypeArena.h" #include "Luau/TypeVar.h" @@ -26,5 +27,6 @@ TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone = false); +TypeId shallowClone(TypeId ty, NotNull dest); } // namespace Luau diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 0e19f13f..7f092f5b 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -2,9 +2,10 @@ #pragma once #include "Luau/Ast.h" // Used for some of the enumerations +#include "Luau/Def.h" #include "Luau/NotNull.h" -#include "Luau/Variant.h" #include "Luau/TypeVar.h" +#include "Luau/Variant.h" #include #include @@ -131,9 +132,15 @@ struct HasPropConstraint std::string prop; }; -using ConstraintV = - Variant; +struct RefinementConstraint +{ + DefId def; + TypeId discriminantType; +}; + +using ConstraintV = Variant; struct Constraint { @@ -143,7 +150,7 @@ struct Constraint Constraint& operator=(const Constraint&) = delete; NotNull scope; - Location location; + Location location; // TODO: Extract this out into only the constraints that needs a location. Not all constraints needs locations. ConstraintV c; std::vector> dependencies; diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 973c0a8e..dc5d4598 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -1,13 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details - #pragma once -#include -#include -#include - #include "Luau/Ast.h" #include "Luau/Constraint.h" +#include "Luau/DataFlowGraphBuilder.h" #include "Luau/Module.h" #include "Luau/ModuleResolver.h" #include "Luau/NotNull.h" @@ -15,6 +11,10 @@ #include "Luau/TypeVar.h" #include "Luau/Variant.h" +#include +#include +#include + namespace Luau { @@ -48,6 +48,7 @@ struct ConstraintGraphBuilder DenseHashMap astResolvedTypePacks{nullptr}; // Defining scopes for AST nodes. DenseHashMap astTypeAliasDefiningScopes{nullptr}; + NotNull dfg; int recursionCount = 0; @@ -63,7 +64,8 @@ struct ConstraintGraphBuilder DcrLogger* logger; ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull moduleResolver, - NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger); + NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger, + NotNull dfg); /** * Fabricates a new free type belonging to a given scope. @@ -88,15 +90,17 @@ struct ConstraintGraphBuilder * Adds a new constraint with no dependencies to a given scope. * @param scope the scope to add the constraint to. * @param cv the constraint variant to add. + * @return the pointer to the inserted constraint */ - void addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv); + NotNull addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv); /** * Adds a constraint to a given scope. * @param scope the scope to add the constraint to. Must not be null. * @param c the constraint to add. + * @return the pointer to the inserted constraint */ - void addConstraint(const ScopePtr& scope, std::unique_ptr c); + NotNull addConstraint(const ScopePtr& scope, std::unique_ptr c); /** * The entry point to the ConstraintGraphBuilder. This will construct a set @@ -139,13 +143,20 @@ struct ConstraintGraphBuilder */ TypeId check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}); - TypeId check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); + TypeId check(const ScopePtr& scope, AstExprLocal* local); + TypeId check(const ScopePtr& scope, AstExprGlobal* global); TypeId check(const ScopePtr& scope, AstExprIndexName* indexName); TypeId check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); TypeId check(const ScopePtr& scope, AstExprUnary* unary); - TypeId check(const ScopePtr& scope, AstExprBinary* binary); + TypeId check_(const ScopePtr& scope, AstExprUnary* unary); + TypeId check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); TypeId check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); TypeId check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); + TypeId check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); + + TypePackId checkLValues(const ScopePtr& scope, AstArray exprs); + + TypeId checkLValue(const ScopePtr& scope, AstExpr* expr); struct FunctionSignature { diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 0bf6d1bc..5cc63e65 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -110,6 +110,7 @@ struct ConstraintSolver bool tryDispatch(const FunctionCallConstraint& c, NotNull constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); + bool tryDispatch(const RefinementConstraint& c, NotNull constraint); // for a, ... in some_table do // also handles __iter metamethod @@ -215,6 +216,8 @@ private: TypeId errorRecoveryType() const; TypePackId errorRecoveryTypePack() const; + TypeId unionOfTypes(TypeId a, TypeId b, NotNull scope, bool unifyFreeTypes); + ToStringOptions opts; }; diff --git a/Analysis/include/Luau/DataFlowGraphBuilder.h b/Analysis/include/Luau/DataFlowGraphBuilder.h new file mode 100644 index 00000000..3a72403e --- /dev/null +++ b/Analysis/include/Luau/DataFlowGraphBuilder.h @@ -0,0 +1,115 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +// Do not include LValue. It should never be used here. +#include "Luau/Ast.h" +#include "Luau/DenseHash.h" +#include "Luau/Def.h" +#include "Luau/Symbol.h" + +#include + +namespace Luau +{ + +struct DataFlowGraph +{ + DataFlowGraph(DataFlowGraph&&) = default; + DataFlowGraph& operator=(DataFlowGraph&&) = default; + + // TODO: AstExprLocal, AstExprGlobal, and AstLocal* are guaranteed never to return nullopt. + // We leave them to return an optional as we build it out, but the end state is for them to return a non-optional DefId. + std::optional getDef(const AstExpr* expr) const; + std::optional getDef(const AstLocal* local) const; + + /// Retrieve the Def that corresponds to the given Symbol. + /// + /// We do not perform dataflow analysis on globals, so this function always + /// yields nullopt when passed a global Symbol. + std::optional getDef(const Symbol& symbol) const; + +private: + DataFlowGraph() = default; + + DataFlowGraph(const DataFlowGraph&) = delete; + DataFlowGraph& operator=(const DataFlowGraph&) = delete; + + DefArena arena; + DenseHashMap astDefs{nullptr}; + DenseHashMap localDefs{nullptr}; + + friend struct DataFlowGraphBuilder; +}; + +struct DfgScope +{ + DfgScope* parent; + DenseHashMap bindings{Symbol{}}; +}; + +struct ExpressionFlowGraph +{ + std::optional def; +}; + +// Currently unsound. We do not presently track the control flow of the program. +// Additionally, we do not presently track assignments. +struct DataFlowGraphBuilder +{ + static DataFlowGraph build(AstStatBlock* root, NotNull handle); + +private: + DataFlowGraphBuilder() = default; + + DataFlowGraphBuilder(const DataFlowGraphBuilder&) = delete; + DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete; + + DataFlowGraph graph; + NotNull arena{&graph.arena}; + struct InternalErrorReporter* handle; + std::vector> scopes; + + DfgScope* childScope(DfgScope* scope); + + std::optional use(DfgScope* scope, Symbol symbol, AstExpr* e); + + void visit(DfgScope* scope, AstStatBlock* b); + void visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b); + + // TODO: visit type aliases + void visit(DfgScope* scope, AstStat* s); + void visit(DfgScope* scope, AstStatIf* i); + void visit(DfgScope* scope, AstStatWhile* w); + void visit(DfgScope* scope, AstStatRepeat* r); + void visit(DfgScope* scope, AstStatBreak* b); + void visit(DfgScope* scope, AstStatContinue* c); + void visit(DfgScope* scope, AstStatReturn* r); + void visit(DfgScope* scope, AstStatExpr* e); + void visit(DfgScope* scope, AstStatLocal* l); + void visit(DfgScope* scope, AstStatFor* f); + void visit(DfgScope* scope, AstStatForIn* f); + void visit(DfgScope* scope, AstStatAssign* a); + void visit(DfgScope* scope, AstStatCompoundAssign* c); + void visit(DfgScope* scope, AstStatFunction* f); + void visit(DfgScope* scope, AstStatLocalFunction* l); + + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExpr* e); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprLocal* l); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprGlobal* g); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprCall* c); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexName* i); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexExpr* i); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprFunction* f); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTable* t); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprUnary* u); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprBinary* b); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTypeAssertion* t); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIfElse* i); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprInterpString* i); + + // TODO: visitLValue + // TODO: visitTypes (because of typeof which has access to values namespace, needs unreachable scope) + // TODO: visitTypePacks (because of typeof which has access to values namespace, needs unreachable scope) +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Def.h b/Analysis/include/Luau/Def.h new file mode 100644 index 00000000..ac1fa132 --- /dev/null +++ b/Analysis/include/Luau/Def.h @@ -0,0 +1,78 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/NotNull.h" +#include "Luau/TypedAllocator.h" +#include "Luau/Variant.h" + +namespace Luau +{ + +using Def = Variant; + +/** + * We statically approximate a value at runtime using a symbolic value, which we call a Def. + * + * DataFlowGraphBuilder will allocate these defs as a stand-in for some Luau values, and bind them to places that + * can hold a Luau value, and then observes how those defs will commute as it statically evaluate the program. + * + * It must also be noted that defs are a cyclic graph, so it is not safe to recursively traverse into it expecting it to terminate. + */ +using DefId = NotNull; + +/** + * A "single-object" value. + * + * Leaky implementation note: sometimes "multiple-object" values, but none of which were interesting enough to warrant creating a phi node instead. + * That can happen because there's no point in creating a phi node that points to either resultant in `if math.random() > 0.5 then 5 else "hello"`. + * This might become of utmost importance if we wanted to do some backward reasoning, e.g. if `5` is taken, then `cond` must be `truthy`. + */ +struct Undefined +{ +}; + +/** + * A phi node is a union of defs. + * + * We need this because we're statically evaluating a program, and sometimes a place may be assigned with + * different defs, and when that happens, we need a special data type that merges in all the defs + * that will flow into that specific place. For example, consider this simple program: + * + * ``` + * x-1 + * if cond() then + * x-2 = 5 + * else + * x-3 = "hello" + * end + * x-4 : {x-2, x-3} + * ``` + * + * At x-4, we know for a fact statically that either `5` or `"hello"` can flow into the variable `x` after the branch, but + * we cannot make any definitive decisions about which one, so we just take in both. + */ +struct Phi +{ + std::vector operands; +}; + +template +T* getMutable(DefId def) +{ + return get_if(def.get()); +} + +template +const T* get(DefId def) +{ + return getMutable(def); +} + +struct DefArena +{ + TypedAllocator allocator; + + DefId freshDef(); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/LValue.h b/Analysis/include/Luau/LValue.h index 1a92d52d..518cbfaf 100644 --- a/Analysis/include/Luau/LValue.h +++ b/Analysis/include/Luau/LValue.h @@ -14,6 +14,8 @@ struct TypeVar; using TypeId = const TypeVar*; struct Field; + +// Deprecated. Do not use in new work. using LValue = Variant; struct Field diff --git a/Analysis/include/Luau/Metamethods.h b/Analysis/include/Luau/Metamethods.h new file mode 100644 index 00000000..84b0092f --- /dev/null +++ b/Analysis/include/Luau/Metamethods.h @@ -0,0 +1,32 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" + +#include + +namespace Luau +{ + +static const std::unordered_map kBinaryOpMetamethods{ + {AstExprBinary::Op::CompareEq, "__eq"}, + {AstExprBinary::Op::CompareNe, "__eq"}, + {AstExprBinary::Op::CompareGe, "__lt"}, + {AstExprBinary::Op::CompareGt, "__le"}, + {AstExprBinary::Op::CompareLe, "__le"}, + {AstExprBinary::Op::CompareLt, "__lt"}, + {AstExprBinary::Op::Add, "__add"}, + {AstExprBinary::Op::Sub, "__sub"}, + {AstExprBinary::Op::Mul, "__mul"}, + {AstExprBinary::Op::Div, "__div"}, + {AstExprBinary::Op::Pow, "__pow"}, + {AstExprBinary::Op::Mod, "__mod"}, + {AstExprBinary::Op::Concat, "__concat"}, +}; + +static const std::unordered_map kUnaryOpMetamethods{ + {AstExprUnary::Op::Minus, "__unm"}, + {AstExprUnary::Op::Len, "__len"}, +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 72ea9558..a23d0fda 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -22,15 +22,6 @@ bool isSubtype( bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop = true); -std::pair normalize( - TypeId ty, NotNull scope, TypeArena& arena, NotNull singletonTypes, InternalErrorReporter& ice); -std::pair normalize(TypeId ty, NotNull module, NotNull singletonTypes, InternalErrorReporter& ice); -std::pair normalize(TypeId ty, const ModulePtr& module, NotNull singletonTypes, InternalErrorReporter& ice); -std::pair normalize( - TypePackId ty, NotNull scope, TypeArena& arena, NotNull singletonTypes, InternalErrorReporter& ice); -std::pair normalize(TypePackId ty, NotNull module, NotNull singletonTypes, InternalErrorReporter& ice); -std::pair normalize(TypePackId ty, const ModulePtr& module, NotNull singletonTypes, InternalErrorReporter& ice); - class TypeIds { private: diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index b2da7bc0..ccf2964c 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -54,7 +54,9 @@ struct Scope DenseHashSet builtinTypeNames{""}; void addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun); - std::optional lookup(Symbol sym); + std::optional lookup(Symbol sym) const; + std::optional lookup(DefId def) const; + std::optional> lookupEx(Symbol sym); std::optional lookupType(const Name& name); std::optional lookupImportedType(const Name& moduleAlias, const Name& name); @@ -66,6 +68,7 @@ struct Scope std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true) const; RefinementMap refinements; + DenseHashMap dcrRefinements{nullptr}; // For mutually recursive type aliases, it's important that // they use the same types for the same names. diff --git a/Analysis/include/Luau/Symbol.h b/Analysis/include/Luau/Symbol.h index 1fe037e5..0432946c 100644 --- a/Analysis/include/Luau/Symbol.h +++ b/Analysis/include/Luau/Symbol.h @@ -6,10 +6,11 @@ #include +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + namespace Luau { -// TODO Rename this to Name once the old type alias is gone. struct Symbol { Symbol() @@ -40,9 +41,12 @@ struct Symbol { if (local) return local == rhs.local; - if (global.value) + else if (global.value) return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. - return false; + else if (FFlag::DebugLuauDeferredConstraintResolution) + return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is. + else + return false; } bool operator!=(const Symbol& rhs) const @@ -58,8 +62,8 @@ struct Symbol return global < rhs.global; else if (local) return true; - else - return false; + + return false; } AstName astName() const diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 1c4d1cb4..384637bb 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -192,18 +192,12 @@ struct TypeChecker ErrorVec canUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location); ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location); - void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const ScopePtr& scope, const Location& location); - std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors); std::optional getIndexTypeFromType(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); std::optional getIndexTypeFromTypeImpl(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); - // Reduces the union to its simplest possible shape. - // (A | B) | B | C yields A | B | C - std::vector reduceUnion(const std::vector& types); - std::optional tryStripUnionFromNil(TypeId ty); TypeId stripFromNilAndReport(TypeId ty, const Location& location); diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index e5a205ba..7409dbe7 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -29,4 +29,23 @@ std::pair> getParameterExtents(const TxnLog* log, // various other things to get there. std::vector flatten(TypeArena& arena, NotNull singletonTypes, TypePackId pack, size_t length); +/** + * Reduces a union by decomposing to the any/error type if it appears in the + * type list, and by merging child unions. Also strips out duplicate (by pointer + * identity) types. + * @param types the input type list to reduce. + * @returns the reduced type list. +*/ +std::vector reduceUnion(const std::vector& types); + +/** + * Tries to remove nil from a union type, if there's another option. T | nil + * reduces to T, but nil itself does not reduce. + * @param singletonTypes the singleton types to use + * @param arena the type arena to allocate the new type in, if necessary + * @param ty the type to remove nil from + * @returns a type with nil removed, or nil itself if that were the only option. +*/ +TypeId stripNil(NotNull singletonTypes, TypeArena& arena, TypeId ty); + } // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 1d587ffe..70c12cb9 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -2,22 +2,23 @@ #pragma once #include "Luau/Ast.h" +#include "Luau/Common.h" #include "Luau/DenseHash.h" +#include "Luau/Def.h" +#include "Luau/NotNull.h" #include "Luau/Predicate.h" #include "Luau/Unifiable.h" #include "Luau/Variant.h" -#include "Luau/Common.h" -#include "Luau/NotNull.h" -#include -#include -#include -#include -#include -#include #include +#include #include #include +#include +#include +#include +#include +#include LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeMaximumStringifierLength) @@ -131,24 +132,6 @@ struct PrimitiveTypeVar } }; -struct ConstrainedTypeVar -{ - explicit ConstrainedTypeVar(TypeLevel level) - : level(level) - { - } - - explicit ConstrainedTypeVar(TypeLevel level, const std::vector& parts) - : parts(parts) - , level(level) - { - } - - std::vector parts; - TypeLevel level; - Scope* scope = nullptr; -}; - // Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md // Types for true and false struct BooleanSingleton @@ -496,11 +479,13 @@ struct AnyTypeVar { }; +// T | U struct UnionTypeVar { std::vector options; }; +// T & U struct IntersectionTypeVar { std::vector parts; @@ -519,12 +504,27 @@ struct NeverTypeVar { }; +// Invariant 1: there should never be a reason why such UseTypeVar exists without it mapping to another type. +// Invariant 2: UseTypeVar should always disappear across modules. +struct UseTypeVar +{ + DefId def; + NotNull scope; +}; + +// ~T +// TODO: Some simplification step that overwrites the type graph to make sure negation +// types disappear from the user's view, and (?) a debug flag to disable that +struct NegationTypeVar +{ + TypeId ty; +}; + using ErrorTypeVar = Unifiable::Error; -using TypeVariant = - Unifiable::Variant; - +using TypeVariant = Unifiable::Variant; struct TypeVar final { @@ -541,7 +541,6 @@ struct TypeVar final TypeVar(const TypeVariant& ty, bool persistent) : ty(ty) , persistent(persistent) - , normal(persistent) // We assume that all persistent types are irreducable. { } @@ -549,7 +548,6 @@ struct TypeVar final void reassign(const TypeVar& rhs) { ty = rhs.ty; - normal = rhs.normal; documentationSymbol = rhs.documentationSymbol; } @@ -560,10 +558,6 @@ struct TypeVar final // Persistent TypeVars do not get cloned. bool persistent = false; - // Normalization sets this for types that are fully normalized. - // This implies that they are transitively immutable. - bool normal = false; - std::optional documentationSymbol; // Pointer to the type arena that allocated this type. @@ -656,6 +650,8 @@ public: const TypeId unknownType; const TypeId neverType; const TypeId errorType; + const TypeId falsyType; // No type binding! + const TypeId truthyType; // No type binding! const TypePackId anyTypePack; const TypePackId neverTypePack; @@ -703,7 +699,6 @@ T* getMutable(TypeId tv) const std::vector& getTypes(const UnionTypeVar* utv); const std::vector& getTypes(const IntersectionTypeVar* itv); -const std::vector& getTypes(const ConstrainedTypeVar* ctv); template struct TypeIterator; @@ -716,10 +711,6 @@ using IntersectionTypeVarIterator = TypeIterator; IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv); IntersectionTypeVarIterator end(const IntersectionTypeVar* itv); -using ConstrainedTypeVarIterator = TypeIterator; -ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv); -ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv); - /* Traverses the type T yielding each TypeId. * If the iterator encounters a nested type T, it will instead yield each TypeId within. */ @@ -793,7 +784,6 @@ struct TypeIterator // with templates portability in this area, so not worth it. Thanks MSVC. friend UnionTypeVarIterator end(const UnionTypeVar*); friend IntersectionTypeVarIterator end(const IntersectionTypeVar*); - friend ConstrainedTypeVarIterator end(const ConstrainedTypeVar*); private: TypeIterator() = default; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 10f3f48c..c15cae31 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -119,12 +119,7 @@ private: std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name); - void tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy); - void tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy); - public: - void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel); - // Returns true if the type "needle" already occurs within "haystack" and reports an "infinite type error" bool occursCheck(TypeId needle, TypeId haystack); bool occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack); diff --git a/Analysis/include/Luau/Variant.h b/Analysis/include/Luau/Variant.h index f637222e..76812c9b 100644 --- a/Analysis/include/Luau/Variant.h +++ b/Analysis/include/Luau/Variant.h @@ -105,7 +105,7 @@ public: tableDtor[typeId](&storage); typeId = tid; - new (&storage) TT(std::forward(args)...); + new (&storage) TT{std::forward(args)...}; return *reinterpret_cast(&storage); } diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 315e5992..d4f8528f 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -103,10 +103,6 @@ struct GenericTypeVarVisitor { return visit(ty); } - virtual bool visit(TypeId ty, const ConstrainedTypeVar& ctv) - { - return visit(ty); - } virtual bool visit(TypeId ty, const PrimitiveTypeVar& ptv) { return visit(ty); @@ -159,6 +155,14 @@ struct GenericTypeVarVisitor { return visit(ty); } + virtual bool visit(TypeId ty, const UseTypeVar& utv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const NegationTypeVar& ntv) + { + return visit(ty); + } virtual bool visit(TypePackId tp) { @@ -216,14 +220,6 @@ struct GenericTypeVarVisitor visit(ty, *gtv); else if (auto etv = get(ty)) visit(ty, *etv); - else if (auto ctv = get(ty)) - { - if (visit(ty, *ctv)) - { - for (TypeId part : ctv->parts) - traverse(part); - } - } else if (auto ptv = get(ty)) visit(ty, *ptv); else if (auto ftv = get(ty)) @@ -325,6 +321,10 @@ struct GenericTypeVarVisitor traverse(a); } } + else if (auto utv = get(ty)) + visit(ty, *utv); + else if (auto ntv = get(ty)) + visit(ty, *ntv); else if (!FFlag::LuauCompleteVisitor) return visit_detail::unsee(seen, ty); else diff --git a/Analysis/src/Anyification.cpp b/Analysis/src/Anyification.cpp index cc9796ee..5dd761c2 100644 --- a/Analysis/src/Anyification.cpp +++ b/Analysis/src/Anyification.cpp @@ -37,8 +37,6 @@ bool Anyification::isDirty(TypeId ty) return (ttv->state == TableState::Free || ttv->state == TableState::Unsealed); else if (log->getMutable(ty)) return true; - else if (get(ty)) - return true; else return false; } @@ -65,20 +63,8 @@ TypeId Anyification::clean(TypeId ty) clone.syntheticName = ttv->syntheticName; clone.tags = ttv->tags; TypeId res = addType(std::move(clone)); - asMutable(res)->normal = ty->normal; return res; } - else if (auto ctv = get(ty)) - { - std::vector copy = ctv->parts; - for (TypeId& ty : copy) - ty = replace(ty); - TypeId res = copy.size() == 1 ? copy[0] : addType(UnionTypeVar{std::move(copy)}); - auto [t, ok] = normalize(res, scope, *arena, singletonTypes, *iceHandler); - if (!ok) - normalizationTooComplex = true; - return t; - } else return anyType; } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index c5250a6d..6051e117 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -10,6 +10,7 @@ #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" +#include "Luau/TypeUtils.h" #include @@ -41,6 +42,7 @@ static std::optional> magicFunctionRequire( static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); static bool dcrMagicFunctionRequire(MagicFunctionCallContext context); +static bool dcrMagicFunctionPack(MagicFunctionCallContext context); TypeId makeUnion(TypeArena& arena, std::vector&& types) { @@ -333,6 +335,7 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); + attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack); } attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); @@ -660,7 +663,7 @@ static std::optional> magicFunctionPack( options.push_back(vtp->ty); } - options = typechecker.reduceUnion(options); + options = reduceUnion(options); // table.pack() -> {| n: number, [number]: nil |} // table.pack(1) -> {| n: number, [number]: number |} @@ -679,6 +682,47 @@ static std::optional> magicFunctionPack( return WithPredicate{arena.addTypePack({packedTable})}; } +static bool dcrMagicFunctionPack(MagicFunctionCallContext context) +{ + + TypeArena* arena = context.solver->arena; + + const auto& [paramTypes, paramTail] = flatten(context.arguments); + + std::vector options; + options.reserve(paramTypes.size()); + for (auto type : paramTypes) + options.push_back(type); + + if (paramTail) + { + if (const VariadicTypePack* vtp = get(*paramTail)) + options.push_back(vtp->ty); + } + + options = reduceUnion(options); + + // table.pack() -> {| n: number, [number]: nil |} + // table.pack(1) -> {| n: number, [number]: number |} + // table.pack(1, "foo") -> {| n: number, [number]: number | string |} + TypeId result = nullptr; + if (options.empty()) + result = context.solver->singletonTypes->nilType; + else if (options.size() == 1) + result = options[0]; + else + result = arena->addType(UnionTypeVar{std::move(options)}); + + TypeId numberType = context.solver->singletonTypes->numberType; + TypeId packedTable = arena->addType( + TableTypeVar{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed}); + + TypePackId tableTypePack = arena->addTypePack({packedTable}); + asMutable(context.result)->ty.emplace(tableTypePack); + + return true; +} + static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) { // require(foo.parent.bar) will technically work, but it depends on legacy goop that diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index fd3a089b..85408919 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -1,6 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details - #include "Luau/Clone.h" + #include "Luau/RecursionCounter.h" #include "Luau/TxnLog.h" #include "Luau/TypePack.h" @@ -51,7 +51,6 @@ struct TypeCloner void operator()(const BlockedTypeVar& t); void operator()(const PendingExpansionTypeVar& t); void operator()(const PrimitiveTypeVar& t); - void operator()(const ConstrainedTypeVar& t); void operator()(const SingletonTypeVar& t); void operator()(const FunctionTypeVar& t); void operator()(const TableTypeVar& t); @@ -63,6 +62,8 @@ struct TypeCloner void operator()(const LazyTypeVar& t); void operator()(const UnknownTypeVar& t); void operator()(const NeverTypeVar& t); + void operator()(const UseTypeVar& t); + void operator()(const NegationTypeVar& t); }; struct TypePackCloner @@ -198,21 +199,6 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t) defaultClone(t); } -void TypeCloner::operator()(const ConstrainedTypeVar& t) -{ - TypeId res = dest.addType(ConstrainedTypeVar{t.level}); - ConstrainedTypeVar* ctv = getMutable(res); - LUAU_ASSERT(ctv); - - seenTypes[typeId] = res; - - std::vector parts; - for (TypeId part : t.parts) - parts.push_back(clone(part, dest, cloneState)); - - ctv->parts = std::move(parts); -} - void TypeCloner::operator()(const SingletonTypeVar& t) { defaultClone(t); @@ -352,6 +338,21 @@ void TypeCloner::operator()(const NeverTypeVar& t) defaultClone(t); } +void TypeCloner::operator()(const UseTypeVar& t) +{ + TypeId result = dest.addType(BoundTypeVar{follow(typeId)}); + seenTypes[typeId] = result; +} + +void TypeCloner::operator()(const NegationTypeVar& t) +{ + TypeId result = dest.addType(AnyTypeVar{}); + seenTypes[typeId] = result; + + TypeId ty = clone(t.ty, dest, cloneState); + asMutable(result)->ty = NegationTypeVar{ty}; +} + } // anonymous namespace TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) @@ -390,7 +391,6 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) if (!res->persistent) { asMutable(res)->documentationSymbol = typeId->documentationSymbol; - asMutable(res)->normal = typeId->normal; } } @@ -478,11 +478,6 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl clone.parts = itv->parts; result = dest.addType(std::move(clone)); } - else if (const ConstrainedTypeVar* ctv = get(ty)) - { - ConstrainedTypeVar clone{ctv->level, ctv->parts}; - result = dest.addType(std::move(clone)); - } else if (const PendingExpansionTypeVar* petv = get(ty)) { PendingExpansionTypeVar clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments}; @@ -497,6 +492,10 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl { result = dest.addType(*ty); } + else if (const NegationTypeVar* ntv = get(ty)) + { + result = dest.addType(NegationTypeVar{ntv->ty}); + } else return result; @@ -504,4 +503,9 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl return result; } +TypeId shallowClone(TypeId ty, NotNull dest) +{ + return shallowClone(ty, *dest, TxnLog::empty()); +} + } // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 8436fb30..de2b0a4e 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -1,20 +1,21 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details - #include "Luau/ConstraintGraphBuilder.h" + #include "Luau/Ast.h" +#include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/Constraint.h" +#include "Luau/DcrLogger.h" #include "Luau/ModuleResolver.h" #include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" #include "Luau/ToString.h" -#include "Luau/DcrLogger.h" +#include "Luau/TypeUtils.h" LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); -#include "Luau/Scope.h" - namespace Luau { @@ -53,12 +54,13 @@ static bool matchSetmetatable(const AstExprCall& call) ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull moduleResolver, NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, - DcrLogger* logger) + DcrLogger* logger, NotNull dfg) : moduleName(moduleName) , module(module) , singletonTypes(singletonTypes) , arena(arena) , rootScope(nullptr) + , dfg(dfg) , moduleResolver(moduleResolver) , ice(ice) , globalScope(globalScope) @@ -95,14 +97,14 @@ ScopePtr ConstraintGraphBuilder::childScope(AstNode* node, const ScopePtr& paren return scope; } -void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv) +NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv) { - scope->constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}); + return NotNull{scope->constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}).get()}; } -void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr c) +NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr c) { - scope->constraints.emplace_back(std::move(c)); + return NotNull{scope->constraints.emplace_back(std::move(c)).get()}; } void ConstraintGraphBuilder::visit(AstStatBlock* block) @@ -229,22 +231,16 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) { std::vector varTypes; + varTypes.reserve(local->vars.size); for (AstLocal* local : local->vars) { TypeId ty = nullptr; - Location location = local->location; if (local->annotation) - { - location = local->annotation->location; ty = resolveType(scope, local->annotation, /* topLevel */ true); - } - else - ty = freshType(scope); varTypes.push_back(ty); - scope->bindings[local] = Binding{ty, location}; } for (size_t i = 0; i < local->values.size; ++i) @@ -257,6 +253,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) // HACK: we leave nil-initialized things floating under the assumption that they will later be populated. // See the test TypeInfer/infer_locals_with_nil_value. // Better flow awareness should make this obsolete. + + if (!varTypes[i]) + varTypes[i] = freshType(scope); } else if (i == local->values.size - 1) { @@ -268,6 +267,20 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (i < local->vars.size) { + std::vector packTypes = flatten(*arena, singletonTypes, exprPack, varTypes.size() - i); + + // fill out missing values in varTypes with values from exprPack + for (size_t j = i; j < varTypes.size(); ++j) + { + if (!varTypes[j]) + { + if (j - i < packTypes.size()) + varTypes[j] = packTypes[j - i]; + else + varTypes[j] = freshType(scope); + } + } + std::vector tailValues{varTypes.begin() + i, varTypes.end()}; TypePackId tailPack = arena->addTypePack(std::move(tailValues)); addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack}); @@ -281,10 +294,31 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) TypeId exprType = check(scope, value, expectedType); if (i < varTypes.size()) - addConstraint(scope, local->location, SubtypeConstraint{varTypes[i], exprType}); + { + if (varTypes[i]) + addConstraint(scope, local->location, SubtypeConstraint{varTypes[i], exprType}); + else + varTypes[i] = exprType; + } } } + for (size_t i = 0; i < local->vars.size; ++i) + { + AstLocal* l = local->vars.data[i]; + Location location = l->location; + + if (!varTypes[i]) + varTypes[i] = freshType(scope); + + scope->bindings[l] = Binding{varTypes[i], location}; + + // HACK: In the greedy solver, we say the type state of a variable is the type annotation itself, but + // the actual type state is the corresponding initializer expression (if it exists) or nil otherwise. + if (auto def = dfg->getDef(l)) + scope->dcrRefinements[*def] = varTypes[i]; + } + if (local->values.size > 0) { // To correctly handle 'require', we need to import the exported type bindings into the variable 'namespace'. @@ -510,7 +544,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) { - TypePackId varPackId = checkPack(scope, assign->vars); + TypePackId varPackId = checkLValues(scope, assign->vars); TypePackId valuePack = checkPack(scope, assign->values); addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId}); @@ -532,7 +566,10 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) { - check(scope, ifStatement->condition); + // TODO: Optimization opportunity, the interior scope of the condition could be + // reused for the then body, so we don't need to refine twice. + ScopePtr condScope = childScope(ifStatement->condition, scope); + check(condScope, ifStatement->condition, std::nullopt); ScopePtr thenScope = childScope(ifStatement->thenbody, scope); visit(thenScope, ifStatement->thenbody); @@ -893,7 +930,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std:: TypeId result = nullptr; if (auto group = expr->as()) - result = check(scope, group->expr); + result = check(scope, group->expr, expectedType); else if (auto stringExpr = expr->as()) { if (expectedType) @@ -937,32 +974,14 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std:: } else if (expr->is()) result = singletonTypes->nilType; - else if (auto a = expr->as()) - { - std::optional ty = scope->lookup(a->local); - if (ty) - result = *ty; - else - result = singletonTypes->errorRecoveryType(); // FIXME? Record an error at this point? - } - else if (auto g = expr->as()) - { - std::optional ty = scope->lookup(g->name); - if (ty) - result = *ty; - else - { - /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any - * global that is not already in-scope is definitely an unknown symbol. - */ - reportError(g->location, UnknownSymbol{g->name.value}); - result = singletonTypes->errorRecoveryType(); // FIXME? Record an error at this point? - } - } + else if (auto local = expr->as()) + result = check(scope, local); + else if (auto global = expr->as()) + result = check(scope, global); else if (expr->is()) result = flattenPack(scope, expr->location, checkPack(scope, expr)); else if (expr->is()) - result = flattenPack(scope, expr->location, checkPack(scope, expr)); + result = flattenPack(scope, expr->location, checkPack(scope, expr)); // TODO: needs predicates too else if (auto a = expr->as()) { FunctionSignature sig = checkFunctionSignature(scope, a); @@ -978,7 +997,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std:: else if (auto unary = expr->as()) result = check(scope, unary); else if (auto binary = expr->as()) - result = check(scope, binary); + result = check(scope, binary, expectedType); else if (auto ifElse = expr->as()) result = check(scope, ifElse, expectedType); else if (auto typeAssert = expr->as()) @@ -1002,6 +1021,37 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std:: return result; } +TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) +{ + std::optional resultTy; + + if (auto def = dfg->getDef(local)) + resultTy = scope->lookup(*def); + + if (!resultTy) + { + if (auto ty = scope->lookup(local->local)) + resultTy = *ty; + } + + if (!resultTy) + return singletonTypes->errorRecoveryType(); // TODO: replace with ice, locals should never exist before its definition. + + return *resultTy; +} + +TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global) +{ + if (std::optional ty = scope->lookup(global->name)) + return *ty; + + /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any + * global that is not already in-scope is definitely an unknown symbol. + */ + reportError(global->location, UnknownSymbol{global->name.value}); + return singletonTypes->errorRecoveryType(); +} + TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) { TypeId obj = check(scope, indexName->expr); @@ -1036,54 +1086,32 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* in TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) { - TypeId operandType = check(scope, unary->expr); - + TypeId operandType = check_(scope, unary); TypeId resultType = arena->addType(BlockedTypeVar{}); addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType}); return resultType; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary) +TypeId ConstraintGraphBuilder::check_(const ScopePtr& scope, AstExprUnary* unary) { - TypeId leftType = check(scope, binary->left); - TypeId rightType = check(scope, binary->right); - switch (binary->op) + if (unary->op == AstExprUnary::Not) { - case AstExprBinary::And: - case AstExprBinary::Or: - { - addConstraint(scope, binary->location, SubtypeConstraint{leftType, rightType}); - return leftType; - } - case AstExprBinary::Add: - case AstExprBinary::Sub: - case AstExprBinary::Mul: - case AstExprBinary::Div: - case AstExprBinary::Mod: - case AstExprBinary::Pow: - case AstExprBinary::CompareNe: - case AstExprBinary::CompareEq: - case AstExprBinary::CompareLt: - case AstExprBinary::CompareLe: - case AstExprBinary::CompareGt: - case AstExprBinary::CompareGe: - { - TypeId resultType = arena->addType(BlockedTypeVar{}); - addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType}); - return resultType; - } - case AstExprBinary::Concat: - { - addConstraint(scope, binary->left->location, SubtypeConstraint{leftType, singletonTypes->stringType}); - addConstraint(scope, binary->right->location, SubtypeConstraint{rightType, singletonTypes->stringType}); - return singletonTypes->stringType; - } - default: - LUAU_ASSERT(0); + TypeId ty = check(scope, unary->expr, std::nullopt); + + return ty; } - LUAU_ASSERT(0); - return nullptr; + return check(scope, unary->expr); +} + +TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) +{ + TypeId leftType = check(scope, binary->left, expectedType); + TypeId rightType = check(scope, binary->right, expectedType); + + TypeId resultType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType}); + return resultType; } TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) @@ -1106,10 +1134,182 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifEls TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) { - check(scope, typeAssert->expr); + check(scope, typeAssert->expr, std::nullopt); return resolveType(scope, typeAssert->annotation); } +TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray exprs) +{ + std::vector types; + types.reserve(exprs.size); + + for (size_t i = 0; i < exprs.size; ++i) + { + AstExpr* const expr = exprs.data[i]; + types.push_back(checkLValue(scope, expr)); + } + + return arena->addTypePack(std::move(types)); +} + +static bool isUnsealedTable(TypeId ty) +{ + ty = follow(ty); + const TableTypeVar* ttv = get(ty); + return ttv && ttv->state == TableState::Unsealed; +}; + +/** + * If the expr is a dotted set of names, and if the root symbol refers to an + * unsealed table, return that table type, plus the indeces that follow as a + * vector. + */ +static std::optional>> extractDottedName(AstExpr* expr) +{ + std::vector names; + + while (expr) + { + if (auto global = expr->as()) + { + std::reverse(begin(names), end(names)); + return std::pair{global->name, std::move(names)}; + } + else if (auto local = expr->as()) + { + std::reverse(begin(names), end(names)); + return std::pair{local->local, std::move(names)}; + } + else if (auto indexName = expr->as()) + { + names.push_back(indexName->index.value); + expr = indexName->expr; + } + else + return std::nullopt; + } + + return std::nullopt; +} + +/** + * Create a shallow copy of `ty` and its properties along `path`. Insert a new + * property (the last segment of `path`) into the tail table with the value `t`. + * + * On success, returns the new outermost table type. If the root table or any + * of its subkeys are not unsealed tables, the function fails and returns + * std::nullopt. + * + * TODO: Prove that we completely give up in the face of indexers and + * metatables. + */ +static std::optional updateTheTableType(NotNull arena, TypeId ty, const std::vector& path, TypeId replaceTy) +{ + if (path.empty()) + return std::nullopt; + + // First walk the path and ensure that it's unsealed tables all the way + // to the end. + { + TypeId t = ty; + for (size_t i = 0; i < path.size() - 1; ++i) + { + if (!isUnsealedTable(t)) + return std::nullopt; + + const TableTypeVar* tbl = get(t); + auto it = tbl->props.find(path[i]); + if (it == tbl->props.end()) + return std::nullopt; + + t = it->second.type; + } + + // The last path segment should not be a property of the table at all. + // We are not changing property types. We are only admitting this one + // new property to be appended. + if (!isUnsealedTable(t)) + return std::nullopt; + const TableTypeVar* tbl = get(t); + auto it = tbl->props.find(path.back()); + if (it != tbl->props.end()) + return std::nullopt; + } + + const TypeId res = shallowClone(ty, arena); + TypeId t = res; + + for (size_t i = 0; i < path.size() - 1; ++i) + { + const std::string segment = path[i]; + + TableTypeVar* ttv = getMutable(t); + LUAU_ASSERT(ttv); + + auto propIt = ttv->props.find(segment); + if (propIt != ttv->props.end()) + { + LUAU_ASSERT(isUnsealedTable(propIt->second.type)); + t = shallowClone(follow(propIt->second.type), arena); + ttv->props[segment].type = t; + } + else + return std::nullopt; + } + + TableTypeVar* ttv = getMutable(t); + LUAU_ASSERT(ttv); + + const std::string lastSegment = path.back(); + LUAU_ASSERT(0 == ttv->props.count(lastSegment)); + ttv->props[lastSegment] = Property{replaceTy}; + return res; +} + +/** + * This function is mostly about identifying properties that are being inserted into unsealed tables. + * + * If expr has the form name.a.b.c + */ +TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) +{ + if (auto indexExpr = expr->as()) + { + if (auto constantString = indexExpr->index->as()) + { + AstName syntheticIndex{constantString->value.data}; + AstExprIndexName synthetic{ + indexExpr->location, indexExpr->expr, syntheticIndex, constantString->location, indexExpr->expr->location.end, '.'}; + return checkLValue(scope, &synthetic); + } + } + + auto dottedPath = extractDottedName(expr); + if (!dottedPath) + return check(scope, expr); + const auto [sym, segments] = std::move(*dottedPath); + + if (!sym.local) + return check(scope, expr); + + auto lookupResult = scope->lookupEx(sym); + if (!lookupResult) + return check(scope, expr); + const auto [ty, symbolScope] = std::move(*lookupResult); + + TypeId replaceTy = arena->freshType(scope.get()); + + std::optional updatedType = updateTheTableType(arena, ty, segments, replaceTy); + if (!updatedType) + return check(scope, expr); + + std::optional def = dfg->getDef(sym); + LUAU_ASSERT(def); + symbolScope->bindings[sym].typeId = *updatedType; + symbolScope->dcrRefinements[*def] = *updatedType; + return replaceTy; +} + TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) { TypeId ty = arena->addType(TableTypeVar{}); @@ -1275,6 +1475,9 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS argTypes.push_back(t); signatureScope->bindings[local] = Binding{t, local->location}; + if (auto def = dfg->getDef(local)) + signatureScope->dcrRefinements[*def] = t; + if (local->annotation) { TypeId argAnnotation = resolveType(signatureScope, local->annotation, /* topLevel */ true); diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index e29eeaaa..60f4666a 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -3,14 +3,16 @@ #include "Luau/Anyification.h" #include "Luau/ApplyTypeFunction.h" #include "Luau/ConstraintSolver.h" +#include "Luau/DcrLogger.h" #include "Luau/Instantiation.h" #include "Luau/Location.h" +#include "Luau/Metamethods.h" #include "Luau/ModuleResolver.h" #include "Luau/Quantify.h" #include "Luau/ToString.h" +#include "Luau/TypeUtils.h" #include "Luau/TypeVar.h" #include "Luau/Unifier.h" -#include "Luau/DcrLogger.h" #include "Luau/VisitTypeVar.h" #include "Luau/TypeUtils.h" @@ -438,6 +440,8 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*fcc, constraint); else if (auto hpc = get(*constraint)) success = tryDispatch(*hpc, constraint); + else if (auto rc = get(*constraint)) + success = tryDispatch(*rc, constraint); else LUAU_ASSERT(false); @@ -564,44 +568,192 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull - * - * This constraint is the one that is meant to unblock A, so it doesn't - * make any sense to stop and wait for someone else to do it. - */ - if (leftType != resultType && rightType != resultType) - { - block(c.leftType, constraint); - block(c.rightType, constraint); - return false; - } - } + bool isLogical = c.op == AstExprBinary::Op::And || c.op == AstExprBinary::Op::Or; - if (isNumber(leftType)) - { - unify(leftType, rightType, constraint->scope); - asMutable(resultType)->ty.emplace(leftType); - return true; - } + /* Compound assignments create constraints of the form + * + * A <: Binary + * + * This constraint is the one that is meant to unblock A, so it doesn't + * make any sense to stop and wait for someone else to do it. + */ + + if (isBlocked(leftType) && leftType != resultType) + return block(c.leftType, constraint); + + if (isBlocked(rightType) && rightType != resultType) + return block(c.rightType, constraint); if (!force) { - if (get(leftType)) + // Logical expressions may proceed if the LHS is free. + if (get(leftType) && !isLogical) return block(leftType, constraint); } - if (isBlocked(leftType)) + // Logical expressions may proceed if the LHS is free. + if (isBlocked(leftType) || (get(leftType) && !isLogical)) { asMutable(resultType)->ty.emplace(errorRecoveryType()); - // reportError(constraint->location, CannotInferBinaryOperation{c.op, std::nullopt, CannotInferBinaryOperation::Operation}); + unblock(resultType); return true; } - // TODO metatables, classes + // For or expressions, the LHS will never have nil as a possible output. + // Consider: + // local foo = nil or 2 + // `foo` will always be 2. + if (c.op == AstExprBinary::Op::Or) + leftType = stripNil(singletonTypes, *arena, leftType); + + // Metatables go first, even if there is primitive behavior. + if (auto it = kBinaryOpMetamethods.find(c.op); it != kBinaryOpMetamethods.end()) + { + // Metatables are not the same. The metamethod will not be invoked. + if ((c.op == AstExprBinary::Op::CompareEq || c.op == AstExprBinary::Op::CompareNe) && + getMetatable(leftType, singletonTypes) != getMetatable(rightType, singletonTypes)) + { + // TODO: Boolean singleton false? The result is _always_ boolean false. + asMutable(resultType)->ty.emplace(singletonTypes->booleanType); + unblock(resultType); + return true; + } + + std::optional mm; + + // The LHS metatable takes priority over the RHS metatable, where + // present. + if (std::optional leftMm = findMetatableEntry(singletonTypes, errors, leftType, it->second, constraint->location)) + mm = leftMm; + else if (std::optional rightMm = findMetatableEntry(singletonTypes, errors, rightType, it->second, constraint->location)) + mm = rightMm; + + if (mm) + { + // TODO: Is a table with __call legal here? + // TODO: Overloads + if (const FunctionTypeVar* ftv = get(follow(*mm))) + { + TypePackId inferredArgs; + // For >= and > we invoke __lt and __le respectively with + // swapped argument ordering. + if (c.op == AstExprBinary::Op::CompareGe || c.op == AstExprBinary::Op::CompareGt) + { + inferredArgs = arena->addTypePack({rightType, leftType}); + } + else + { + inferredArgs = arena->addTypePack({leftType, rightType}); + } + + unify(inferredArgs, ftv->argTypes, constraint->scope); + + TypeId mmResult; + + // Comparison operations always evaluate to a boolean, + // regardless of what the metamethod returns. + switch (c.op) + { + case AstExprBinary::Op::CompareEq: + case AstExprBinary::Op::CompareNe: + case AstExprBinary::Op::CompareGe: + case AstExprBinary::Op::CompareGt: + case AstExprBinary::Op::CompareLe: + case AstExprBinary::Op::CompareLt: + mmResult = singletonTypes->booleanType; + break; + default: + mmResult = first(ftv->retTypes).value_or(errorRecoveryType()); + } + + asMutable(resultType)->ty.emplace(mmResult); + unblock(resultType); + return true; + } + } + + // If there's no metamethod available, fall back to primitive behavior. + } + + // If any is present, the expression must evaluate to any as well. + bool leftAny = get(leftType) || get(leftType); + bool rightAny = get(rightType) || get(rightType); + bool anyPresent = leftAny || rightAny; + + switch (c.op) + { + // For arithmetic operators, if the LHS is a number, the RHS must be a + // number as well. The result will also be a number. + case AstExprBinary::Op::Add: + case AstExprBinary::Op::Sub: + case AstExprBinary::Op::Mul: + case AstExprBinary::Op::Div: + case AstExprBinary::Op::Pow: + case AstExprBinary::Op::Mod: + if (isNumber(leftType)) + { + unify(leftType, rightType, constraint->scope); + asMutable(resultType)->ty.emplace(anyPresent ? singletonTypes->anyType : leftType); + unblock(resultType); + return true; + } + + break; + // For concatenation, if the LHS is a string, the RHS must be a string as + // well. The result will also be a string. + case AstExprBinary::Op::Concat: + if (isString(leftType)) + { + unify(leftType, rightType, constraint->scope); + asMutable(resultType)->ty.emplace(anyPresent ? singletonTypes->anyType : leftType); + unblock(resultType); + return true; + } + + break; + // Inexact comparisons require that the types be both numbers or both + // strings, and evaluate to a boolean. + case AstExprBinary::Op::CompareGe: + case AstExprBinary::Op::CompareGt: + case AstExprBinary::Op::CompareLe: + case AstExprBinary::Op::CompareLt: + if ((isNumber(leftType) && isNumber(rightType)) || (isString(leftType) && isString(rightType))) + { + asMutable(resultType)->ty.emplace(singletonTypes->booleanType); + unblock(resultType); + return true; + } + + break; + // == and ~= always evaluate to a boolean, and impose no other constraints + // on their parameters. + case AstExprBinary::Op::CompareEq: + case AstExprBinary::Op::CompareNe: + asMutable(resultType)->ty.emplace(singletonTypes->booleanType); + unblock(resultType); + return true; + // And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is + // truthy. + case AstExprBinary::Op::And: + asMutable(resultType)->ty.emplace(unionOfTypes(rightType, singletonTypes->booleanType, constraint->scope, false)); + unblock(resultType); + return true; + // Or evaluates to the LHS type if the LHS is truthy, and the RHS type if + // LHS is falsey. + case AstExprBinary::Op::Or: + asMutable(resultType)->ty.emplace(unionOfTypes(rightType, leftType, constraint->scope, true)); + unblock(resultType); + return true; + default: + iceReporter.ice("Unhandled AstExprBinary::Op for binary operation", constraint->location); + break; + } + + // We failed to either evaluate a metamethod or invoke primitive behavior. + unify(leftType, errorRecoveryType(), constraint->scope); + unify(rightType, errorRecoveryType(), constraint->scope); + asMutable(resultType)->ty.emplace(errorRecoveryType()); + unblock(resultType); return true; } @@ -943,6 +1095,31 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull callMm = findMetatableEntry(singletonTypes, errors, fn, "__call", constraint->location)) + { + std::vector args{fn}; + + for (TypeId arg : c.argsPack) + args.push_back(arg); + + TypeId instantiatedType = arena->addType(BlockedTypeVar{}); + TypeId inferredFnType = + arena->addType(FunctionTypeVar(TypeLevel{}, constraint->scope.get(), arena->addTypePack(TypePack{args, {}}), c.result)); + + // Alter the inner constraints. + LUAU_ASSERT(c.innerConstraints.size() == 2); + + asMutable(*c.innerConstraints.at(0)).c = InstantiationConstraint{instantiatedType, *callMm}; + asMutable(*c.innerConstraints.at(1)).c = SubtypeConstraint{inferredFnType, instantiatedType}; + + unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints)); + + asMutable(c.result)->ty.emplace(constraint->scope); + unblock(c.result); + return true; + } + const FunctionTypeVar* ftv = get(fn); bool usedMagic = false; @@ -1059,6 +1236,29 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull constraint) +{ + // TODO: Figure out exact details on when refinements need to be blocked. + // It's possible that it never needs to be, since we can just use intersection types with the discriminant type? + + if (!constraint->scope->parent) + iceReporter.ice("No parent scope"); + + std::optional previousTy = constraint->scope->parent->lookup(c.def); + if (!previousTy) + iceReporter.ice("No previous type"); + + std::optional useTy = constraint->scope->lookup(c.def); + if (!useTy) + iceReporter.ice("The def is not bound to a type"); + + TypeId resultTy = follow(*useTy); + std::vector parts{*previousTy, c.discriminantType}; + asMutable(resultTy)->ty.emplace(std::move(parts)); + + return true; +} + bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force) { auto block_ = [&](auto&& t) { @@ -1502,4 +1702,39 @@ TypePackId ConstraintSolver::errorRecoveryTypePack() const return singletonTypes->errorRecoveryTypePack(); } +TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull scope, bool unifyFreeTypes) +{ + a = follow(a); + b = follow(b); + + if (unifyFreeTypes && (get(a) || get(b))) + { + Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; + u.useScopes = true; + u.tryUnify(b, a); + + if (u.errors.empty()) + { + u.log.commit(); + return a; + } + else + { + return singletonTypes->errorRecoveryType(singletonTypes->anyType); + } + } + + if (*a == *b) + return a; + + std::vector types = reduceUnion({a, b}); + if (types.empty()) + return singletonTypes->neverType; + + if (types.size() == 1) + return types[0]; + + return arena->addType(UnionTypeVar{types}); +} + } // namespace Luau diff --git a/Analysis/src/DataFlowGraphBuilder.cpp b/Analysis/src/DataFlowGraphBuilder.cpp new file mode 100644 index 00000000..e2c4c285 --- /dev/null +++ b/Analysis/src/DataFlowGraphBuilder.cpp @@ -0,0 +1,440 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/DataFlowGraphBuilder.h" + +#include "Luau/Error.h" + +LUAU_FASTFLAG(DebugLuauFreezeArena) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + +namespace Luau +{ + +std::optional DataFlowGraph::getDef(const AstExpr* expr) const +{ + if (auto def = astDefs.find(expr)) + return NotNull{*def}; + return std::nullopt; +} + +std::optional DataFlowGraph::getDef(const AstLocal* local) const +{ + if (auto def = localDefs.find(local)) + return NotNull{*def}; + return std::nullopt; +} + +std::optional DataFlowGraph::getDef(const Symbol& symbol) const +{ + if (symbol.local) + return getDef(symbol.local); + else + return std::nullopt; +} + +DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull handle) +{ + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + + DataFlowGraphBuilder builder; + builder.handle = handle; + builder.visit(nullptr, block); // nullptr is the root DFG scope. + if (FFlag::DebugLuauFreezeArena) + builder.arena->allocator.freeze(); + return std::move(builder.graph); +} + +DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope) +{ + return scopes.emplace_back(new DfgScope{scope}).get(); +} + +std::optional DataFlowGraphBuilder::use(DfgScope* scope, Symbol symbol, AstExpr* e) +{ + for (DfgScope* current = scope; current; current = current->parent) + { + if (auto loc = current->bindings.find(symbol)) + { + graph.astDefs[e] = *loc; + return NotNull{*loc}; + } + } + + return std::nullopt; +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBlock* b) +{ + DfgScope* child = childScope(scope); + return visitBlockWithoutChildScope(child, b); +} + +void DataFlowGraphBuilder::visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b) +{ + for (AstStat* s : b->body) + visit(scope, s); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s) +{ + if (auto b = s->as()) + return visit(scope, b); + else if (auto i = s->as()) + return visit(scope, i); + else if (auto w = s->as()) + return visit(scope, w); + else if (auto r = s->as()) + return visit(scope, r); + else if (auto b = s->as()) + return visit(scope, b); + else if (auto c = s->as()) + return visit(scope, c); + else if (auto r = s->as()) + return visit(scope, r); + else if (auto e = s->as()) + return visit(scope, e); + else if (auto l = s->as()) + return visit(scope, l); + else if (auto f = s->as()) + return visit(scope, f); + else if (auto f = s->as()) + return visit(scope, f); + else if (auto a = s->as()) + return visit(scope, a); + else if (auto c = s->as()) + return visit(scope, c); + else if (auto f = s->as()) + return visit(scope, f); + else if (auto l = s->as()) + return visit(scope, l); + else if (auto t = s->as()) + return; // ok + else if (auto d = s->as()) + return; // ok + else if (auto d = s->as()) + return; // ok + else if (auto d = s->as()) + return; // ok + else if (auto d = s->as()) + return; // ok + else if (auto _ = s->as()) + return; // ok + else + handle->ice("Unknown AstStat in DataFlowGraphBuilder"); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i) +{ + DfgScope* condScope = childScope(scope); + visitExpr(condScope, i->condition); + visit(condScope, i->thenbody); + + if (i->elsebody) + visit(scope, i->elsebody); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatWhile* w) +{ + // TODO(controlflow): entry point has a back edge from exit point + DfgScope* whileScope = childScope(scope); + visitExpr(whileScope, w->condition); + visit(whileScope, w->body); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatRepeat* r) +{ + // TODO(controlflow): entry point has a back edge from exit point + DfgScope* repeatScope = childScope(scope); // TODO: loop scope. + visitBlockWithoutChildScope(repeatScope, r->body); + visitExpr(repeatScope, r->condition); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBreak* b) +{ + // TODO: Control flow analysis + return; // ok +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatContinue* c) +{ + // TODO: Control flow analysis + return; // ok +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatReturn* r) +{ + // TODO: Control flow analysis + for (AstExpr* e : r->list) + visitExpr(scope, e); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatExpr* e) +{ + visitExpr(scope, e->expr); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) +{ + // TODO: alias tracking + for (AstExpr* e : l->values) + visitExpr(scope, e); + + for (AstLocal* local : l->vars) + { + DefId def = arena->freshDef(); + graph.localDefs[local] = def; + scope->bindings[local] = def; + } +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f) +{ + DfgScope* forScope = childScope(scope); // TODO: loop scope. + DefId def = arena->freshDef(); + graph.localDefs[f->var] = def; + scope->bindings[f->var] = def; + + // TODO(controlflow): entry point has a back edge from exit point + visit(forScope, f->body); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f) +{ + DfgScope* forScope = childScope(scope); // TODO: loop scope. + + for (AstLocal* local : f->vars) + { + DefId def = arena->freshDef(); + graph.localDefs[local] = def; + forScope->bindings[local] = def; + } + + // TODO(controlflow): entry point has a back edge from exit point + for (AstExpr* e : f->values) + visitExpr(forScope, e); + + visit(forScope, f->body); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a) +{ + for (AstExpr* r : a->values) + visitExpr(scope, r); + + for (AstExpr* l : a->vars) + { + AstExpr* root = l; + + bool isUpdatable = true; + while (true) + { + if (root->is() || root->is()) + break; + + AstExprIndexName* indexName = root->as(); + if (!indexName) + { + isUpdatable = false; + break; + } + + root = indexName->expr; + } + + if (isUpdatable) + { + // TODO global? + if (auto exprLocal = root->as()) + { + DefId def = arena->freshDef(); + graph.astDefs[exprLocal] = def; + + // Update the def in the scope that introduced the local. Not + // the current scope. + AstLocal* local = exprLocal->local; + DfgScope* s = scope; + while (s && !s->bindings.find(local)) + s = s->parent; + LUAU_ASSERT(s && s->bindings.find(local)); + s->bindings[local] = def; + } + } + + visitExpr(scope, l); // TODO: they point to a new def!! + } +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c) +{ + // TODO(typestates): The lhs is being read and written to. This might or might not be annoying. + visitExpr(scope, c->value); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) +{ + visitExpr(scope, f->name); + visitExpr(scope, f->func); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l) +{ + DefId def = arena->freshDef(); + graph.localDefs[l->name] = def; + scope->bindings[l->name] = def; + + visitExpr(scope, l->func); +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) +{ + if (auto g = e->as()) + return visitExpr(scope, g->expr); + else if (auto c = e->as()) + return {}; // ok + else if (auto c = e->as()) + return {}; // ok + else if (auto c = e->as()) + return {}; // ok + else if (auto c = e->as()) + return {}; // ok + else if (auto l = e->as()) + return visitExpr(scope, l); + else if (auto g = e->as()) + return visitExpr(scope, g); + else if (auto v = e->as()) + return {}; // ok + else if (auto c = e->as()) + return visitExpr(scope, c); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto f = e->as()) + return visitExpr(scope, f); + else if (auto t = e->as()) + return visitExpr(scope, t); + else if (auto u = e->as()) + return visitExpr(scope, u); + else if (auto b = e->as()) + return visitExpr(scope, b); + else if (auto t = e->as()) + return visitExpr(scope, t); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto _ = e->as()) + return {}; // ok + else + handle->ice("Unknown AstExpr in DataFlowGraphBuilder"); +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l) +{ + return {use(scope, l->local, l)}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g) +{ + return {use(scope, g->name, g)}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c) +{ + visitExpr(scope, c->func); + + for (AstExpr* arg : c->args) + visitExpr(scope, arg); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i) +{ + std::optional def = visitExpr(scope, i->expr).def; + if (!def) + return {}; + + // TODO: properties for the above def. + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i) +{ + visitExpr(scope, i->expr); + visitExpr(scope, i->expr); + + if (i->index->as()) + { + // TODO: properties for the def + } + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f) +{ + if (AstLocal* self = f->self) + { + DefId def = arena->freshDef(); + graph.localDefs[self] = def; + scope->bindings[self] = def; + } + + for (AstLocal* param : f->args) + { + DefId def = arena->freshDef(); + graph.localDefs[param] = def; + scope->bindings[param] = def; + } + + visit(scope, f->body); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t) +{ + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u) +{ + visitExpr(scope, u->expr); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b) +{ + visitExpr(scope, b->left); + visitExpr(scope, b->right); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t) +{ + ExpressionFlowGraph result = visitExpr(scope, t->expr); + // TODO: visit type + return result; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i) +{ + DfgScope* condScope = childScope(scope); + visitExpr(condScope, i->condition); + visitExpr(condScope, i->trueExpr); + + visitExpr(scope, i->falseExpr); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i) +{ + for (AstExpr* e : i->expressions) + visitExpr(scope, e); + return {}; +} + +} // namespace Luau diff --git a/Analysis/src/Def.cpp b/Analysis/src/Def.cpp new file mode 100644 index 00000000..935301c8 --- /dev/null +++ b/Analysis/src/Def.cpp @@ -0,0 +1,12 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Def.h" + +namespace Luau +{ + +DefId DefArena::freshDef() +{ + return NotNull{allocator.allocate(Undefined{})}; +} + +} // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 4e9b6882..e5553003 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,8 +7,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleNameResolution, false) - static std::string wrongNumberOfArgsString( size_t expectedCount, std::optional maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { @@ -70,7 +68,7 @@ struct ErrorConverter { if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType)) { - if (FFlag::LuauTypeMismatchModuleNameResolution && fileResolver != nullptr) + if (fileResolver != nullptr) { std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule); std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule); @@ -96,14 +94,7 @@ struct ErrorConverter if (!tm.reason.empty()) result += tm.reason + " "; - if (FFlag::LuauTypeMismatchModuleNameResolution) - { - result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver}); - } - else - { - result += Luau::toString(*tm.error); - } + result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver}); } else if (!tm.reason.empty()) { diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index cfe710d9..e059a463 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1,11 +1,13 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Frontend.h" +#include "Luau/BuiltinDefinitions.h" #include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/Config.h" #include "Luau/ConstraintGraphBuilder.h" #include "Luau/ConstraintSolver.h" +#include "Luau/DataFlowGraphBuilder.h" #include "Luau/DcrLogger.h" #include "Luau/FileResolver.h" #include "Luau/Parser.h" @@ -15,7 +17,6 @@ #include "Luau/TypeChecker2.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" -#include "Luau/BuiltinDefinitions.h" #include #include @@ -26,7 +27,6 @@ LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAG(LuauNoMoreGlobalSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAG(DebugLuauLogSolverToJson); @@ -489,23 +489,19 @@ CheckResult Frontend::check(const ModuleName& name, std::optional 0) - typeCheckerForAutocomplete.instantiationChildLimit = - std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt; + // TODO: This is a dirty ad hoc solution for autocomplete timeouts + // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit + // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle + if (FInt::LuauTarjanChildLimit > 0) + typeCheckerForAutocomplete.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt; - if (FInt::LuauTypeInferIterationLimit > 0) - typeCheckerForAutocomplete.unifierIterationLimit = - std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt; - } + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckerForAutocomplete.unifierIterationLimit = + std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt; ModulePtr moduleForAutocomplete = FFlag::DebugLuauDeferredConstraintResolution ? check(sourceModule, mode, environmentScope, requireCycles, /*forAutocomplete*/ true) @@ -519,10 +515,9 @@ CheckResult Frontend::check(const ModuleName& name, std::optional mr{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}; const ScopePtr& globalScope{forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope}; Normalizer normalizer{&result->internalTypes, singletonTypes, NotNull{&typeChecker.unifierState}}; ConstraintGraphBuilder cgb{ - sourceModule.name, result, &result->internalTypes, mr, singletonTypes, NotNull(&iceHandler), globalScope, logger.get()}; + sourceModule.name, + result, + &result->internalTypes, + mr, + singletonTypes, + NotNull(&iceHandler), + globalScope, + logger.get(), + NotNull{&dfg}, + }; + cgb.visit(sourceModule.root); result->errors = std::move(cgb.errors); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 31a089a4..0412f007 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -60,36 +60,6 @@ bool isWithinComment(const SourceModule& sourceModule, Position pos) return contains(pos, *iter); } -struct ForceNormal : TypeVarOnceVisitor -{ - const TypeArena* typeArena = nullptr; - - ForceNormal(const TypeArena* typeArena) - : typeArena(typeArena) - { - } - - bool visit(TypeId ty) override - { - if (ty->owningArena != typeArena) - return false; - - asMutable(ty)->normal = true; - return true; - } - - bool visit(TypeId ty, const FreeTypeVar& ftv) override - { - visit(ty); - return true; - } - - bool visit(TypePackId tp, const FreeTypePack& ftp) override - { - return true; - } -}; - struct ClonePublicInterface : Substitution { NotNull singletonTypes; @@ -241,8 +211,6 @@ void Module::clonePublicInterface(NotNull singletonTypes, Intern moduleScope->varargPack = varargPack; } - ForceNormal forceNormal{&interfaceTypes}; - if (exportedTypeBindings) { for (auto& [name, tf] : *exportedTypeBindings) @@ -262,7 +230,6 @@ void Module::clonePublicInterface(NotNull singletonTypes, Intern { auto t = asMutable(ty); t->ty = AnyTypeVar{}; - t->normal = true; } } } diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 81114b76..cea159c3 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -16,11 +16,11 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); -LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauOverloadedFunctionSubtypingPerf); namespace Luau { @@ -1269,19 +1269,35 @@ std::optional Normalizer::intersectionOfFunctions(TypeId here, TypeId th return std::nullopt; if (hftv->genericPacks != tftv->genericPacks) return std::nullopt; - if (hftv->retTypes != tftv->retTypes) + + TypePackId argTypes; + TypePackId retTypes; + + if (hftv->retTypes == tftv->retTypes) + { + std::optional argTypesOpt = unionOfTypePacks(hftv->argTypes, tftv->argTypes); + if (!argTypesOpt) + return std::nullopt; + argTypes = *argTypesOpt; + retTypes = hftv->retTypes; + } + else if (FFlag::LuauOverloadedFunctionSubtypingPerf && hftv->argTypes == tftv->argTypes) + { + std::optional retTypesOpt = intersectionOfTypePacks(hftv->argTypes, tftv->argTypes); + if (!retTypesOpt) + return std::nullopt; + argTypes = hftv->argTypes; + retTypes = *retTypesOpt; + } + else return std::nullopt; - std::optional argTypes = unionOfTypePacks(hftv->argTypes, tftv->argTypes); - if (!argTypes) - return std::nullopt; - - if (*argTypes == hftv->argTypes) + if (argTypes == hftv->argTypes && retTypes == hftv->retTypes) return here; - if (*argTypes == tftv->argTypes) + if (argTypes == tftv->argTypes && retTypes == tftv->retTypes) return there; - FunctionTypeVar result{*argTypes, hftv->retTypes}; + FunctionTypeVar result{argTypes, retTypes}; result.generics = hftv->generics; result.genericPacks = hftv->genericPacks; return arena->addType(std::move(result)); @@ -1762,610 +1778,4 @@ bool isSubtype( return ok; } -template -static bool areNormal_(const T& t, const std::unordered_set& seen, InternalErrorReporter& ice) -{ - int count = 0; - auto isNormal = [&](TypeId ty) { - ++count; - if (count >= FInt::LuauNormalizeIterationLimit) - ice.ice("Luau::areNormal hit iteration limit"); - - return ty->normal; - }; - - return std::all_of(begin(t), end(t), isNormal); -} - -static bool areNormal(const std::vector& types, const std::unordered_set& seen, InternalErrorReporter& ice) -{ - return areNormal_(types, seen, ice); -} - -static bool areNormal(TypePackId tp, const std::unordered_set& seen, InternalErrorReporter& ice) -{ - tp = follow(tp); - if (get(tp)) - return false; - - auto [head, tail] = flatten(tp); - - if (!areNormal_(head, seen, ice)) - return false; - - if (!tail) - return true; - - if (auto vtp = get(*tail)) - return vtp->ty->normal || follow(vtp->ty)->normal || seen.find(asMutable(vtp->ty)) != seen.end(); - - return true; -} - -#define CHECK_ITERATION_LIMIT(...) \ - do \ - { \ - if (iterationLimit > FInt::LuauNormalizeIterationLimit) \ - { \ - limitExceeded = true; \ - return __VA_ARGS__; \ - } \ - ++iterationLimit; \ - } while (false) - -struct Normalize final : TypeVarVisitor -{ - using TypeVarVisitor::Set; - - Normalize(TypeArena& arena, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) - : arena(arena) - , scope(scope) - , singletonTypes(singletonTypes) - , ice(ice) - { - } - - TypeArena& arena; - NotNull scope; - NotNull singletonTypes; - InternalErrorReporter& ice; - - int iterationLimit = 0; - bool limitExceeded = false; - - bool visit(TypeId ty, const FreeTypeVar&) override - { - LUAU_ASSERT(!ty->normal); - return false; - } - - bool visit(TypeId ty, const BoundTypeVar& btv) override - { - // A type could be considered normal when it is in the stack, but we will eventually find out it is not normal as normalization progresses. - // So we need to avoid eagerly saying that this bound type is normal if the thing it is bound to is in the stack. - if (seen.find(asMutable(btv.boundTo)) != seen.end()) - return false; - - // It should never be the case that this TypeVar is normal, but is bound to a non-normal type, except in nontrivial cases. - LUAU_ASSERT(!ty->normal || ty->normal == btv.boundTo->normal); - - if (!ty->normal) - asMutable(ty)->normal = btv.boundTo->normal; - return !ty->normal; - } - - bool visit(TypeId ty, const PrimitiveTypeVar&) override - { - LUAU_ASSERT(ty->normal); - return false; - } - - bool visit(TypeId ty, const GenericTypeVar&) override - { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; - } - - bool visit(TypeId ty, const ErrorTypeVar&) override - { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; - } - - bool visit(TypeId ty, const UnknownTypeVar&) override - { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; - } - - bool visit(TypeId ty, const NeverTypeVar&) override - { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; - } - - bool visit(TypeId ty, const ConstrainedTypeVar& ctvRef) override - { - CHECK_ITERATION_LIMIT(false); - LUAU_ASSERT(!ty->normal); - - ConstrainedTypeVar* ctv = const_cast(&ctvRef); - - std::vector parts = std::move(ctv->parts); - - // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar - for (TypeId part : parts) - traverse(part); - - std::vector newParts = normalizeUnion(parts); - ctv->parts = std::move(newParts); - - return false; - } - - bool visit(TypeId ty, const FunctionTypeVar& ftv) override - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - traverse(ftv.argTypes); - traverse(ftv.retTypes); - - asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retTypes, seen, ice); - - return false; - } - - bool visit(TypeId ty, const TableTypeVar& ttv) override - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - bool normal = true; - - auto checkNormal = [&](TypeId t) { - // if t is on the stack, it is possible that this type is normal. - // If t is not normal and it is not on the stack, this type is definitely not normal. - if (!t->normal && seen.find(asMutable(t)) == seen.end()) - normal = false; - }; - - if (ttv.boundTo) - { - traverse(*ttv.boundTo); - asMutable(ty)->normal = (*ttv.boundTo)->normal; - return false; - } - - for (const auto& [_name, prop] : ttv.props) - { - traverse(prop.type); - checkNormal(prop.type); - } - - if (ttv.indexer) - { - traverse(ttv.indexer->indexType); - checkNormal(ttv.indexer->indexType); - traverse(ttv.indexer->indexResultType); - checkNormal(ttv.indexer->indexResultType); - } - - // An unsealed table can never be normal, ditto for free tables iff the type it is bound to is also not normal. - if (ttv.state == TableState::Generic || ttv.state == TableState::Sealed || (ttv.state == TableState::Free && follow(ty)->normal)) - asMutable(ty)->normal = normal; - - return false; - } - - bool visit(TypeId ty, const MetatableTypeVar& mtv) override - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - traverse(mtv.table); - traverse(mtv.metatable); - - asMutable(ty)->normal = mtv.table->normal && mtv.metatable->normal; - - return false; - } - - bool visit(TypeId ty, const ClassTypeVar& ctv) override - { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; - } - - bool visit(TypeId ty, const AnyTypeVar&) override - { - LUAU_ASSERT(ty->normal); - return false; - } - - bool visit(TypeId ty, const UnionTypeVar& utvRef) override - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - UnionTypeVar* utv = &const_cast(utvRef); - - // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar - for (TypeId option : utv->options) - traverse(option); - - std::vector newOptions = normalizeUnion(utv->options); - - const bool normal = areNormal(newOptions, seen, ice); - - LUAU_ASSERT(!newOptions.empty()); - - if (newOptions.size() == 1) - *asMutable(ty) = BoundTypeVar{newOptions[0]}; - else - utv->options = std::move(newOptions); - - asMutable(ty)->normal = normal; - - return false; - } - - bool visit(TypeId ty, const IntersectionTypeVar& itvRef) override - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - IntersectionTypeVar* itv = &const_cast(itvRef); - - std::vector oldParts = itv->parts; - IntersectionTypeVar newIntersection; - - for (TypeId part : oldParts) - traverse(part); - - std::vector tables; - for (TypeId part : oldParts) - { - part = follow(part); - if (get(part)) - tables.push_back(part); - else - { - Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD - combineIntoIntersection(replacer, &newIntersection, part); - } - } - - // Don't allocate a new table if there's just one in the intersection. - if (tables.size() == 1) - newIntersection.parts.push_back(tables[0]); - else if (!tables.empty()) - { - const TableTypeVar* first = get(tables[0]); - LUAU_ASSERT(first); - - TypeId newTable = arena.addType(TableTypeVar{first->state, first->level}); - TableTypeVar* ttv = getMutable(newTable); - for (TypeId part : tables) - { - // Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need - // to be rewritten to point at 'newTable' in the clone. - Replacer replacer{&arena, part, newTable}; - combineIntoTable(replacer, ttv, part); - } - - newIntersection.parts.push_back(newTable); - } - - itv->parts = std::move(newIntersection.parts); - - asMutable(ty)->normal = areNormal(itv->parts, seen, ice); - - if (itv->parts.size() == 1) - { - TypeId part = itv->parts[0]; - *asMutable(ty) = BoundTypeVar{part}; - } - - return false; - } - - std::vector normalizeUnion(const std::vector& options) - { - if (options.size() == 1) - return options; - - std::vector result; - - for (TypeId part : options) - { - // AnyTypeVar always win the battle no matter what we do, so we're done. - if (FFlag::LuauUnknownAndNeverType && get(follow(part))) - return {part}; - - combineIntoUnion(result, part); - } - - return result; - } - - void combineIntoUnion(std::vector& result, TypeId ty) - { - ty = follow(ty); - if (auto utv = get(ty)) - { - for (TypeId t : utv) - { - // AnyTypeVar always win the battle no matter what we do, so we're done. - if (FFlag::LuauUnknownAndNeverType && get(t)) - { - result = {t}; - return; - } - - combineIntoUnion(result, t); - } - - return; - } - - for (TypeId& part : result) - { - if (isSubtype(ty, part, scope, singletonTypes, ice)) - return; // no need to do anything - else if (isSubtype(part, ty, scope, singletonTypes, ice)) - { - part = ty; // replace the less general type by the more general one - return; - } - } - - result.push_back(ty); - } - - /** - * @param replacer knows how to clone a type such that any recursive references point at the new containing type. - * @param result is an intersection that is safe for us to mutate in-place. - */ - void combineIntoIntersection(Replacer& replacer, IntersectionTypeVar* result, TypeId ty) - { - // Note: this check guards against running out of stack space - // so if you increase the size of a stack frame, you'll need to decrease the limit. - CHECK_ITERATION_LIMIT(); - - ty = follow(ty); - if (auto itv = get(ty)) - { - for (TypeId part : itv->parts) - combineIntoIntersection(replacer, result, part); - return; - } - - // Let's say that the last part of our result intersection is always a table, if any table is part of this intersection - if (get(ty)) - { - if (result->parts.empty()) - result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}})); - - TypeId theTable = result->parts.back(); - - if (!get(follow(theTable))) - { - result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}})); - theTable = result->parts.back(); - } - - TypeId newTable = replacer.smartClone(theTable); - result->parts.back() = newTable; - - combineIntoTable(replacer, getMutable(newTable), ty); - } - else if (auto ftv = get(ty)) - { - bool merged = false; - for (TypeId& part : result->parts) - { - if (isSubtype(part, ty, scope, singletonTypes, ice)) - { - merged = true; - break; // no need to do anything - } - else if (isSubtype(ty, part, scope, singletonTypes, ice)) - { - merged = true; - part = ty; // replace the less general type by the more general one - break; - } - } - - if (!merged) - result->parts.push_back(ty); - } - else - result->parts.push_back(ty); - } - - TableState combineTableStates(TableState lhs, TableState rhs) - { - if (lhs == rhs) - return lhs; - - if (lhs == TableState::Free || rhs == TableState::Free) - return TableState::Free; - - if (lhs == TableState::Unsealed || rhs == TableState::Unsealed) - return TableState::Unsealed; - - return lhs; - } - - /** - * @param replacer gives us a way to clone a type such that recursive references are rewritten to the new - * "containing" type. - * @param table always points into a table that is safe for us to mutate. - */ - void combineIntoTable(Replacer& replacer, TableTypeVar* table, TypeId ty) - { - // Note: this check guards against running out of stack space - // so if you increase the size of a stack frame, you'll need to decrease the limit. - CHECK_ITERATION_LIMIT(); - - LUAU_ASSERT(table); - - ty = follow(ty); - - TableTypeVar* tyTable = getMutable(ty); - LUAU_ASSERT(tyTable); - - for (const auto& [propName, prop] : tyTable->props) - { - if (auto it = table->props.find(propName); it != table->props.end()) - { - /** - * If we are going to recursively merge intersections of tables, we need to ensure that we never mutate - * a table that comes from somewhere else in the type graph. - * - * smarClone() does some nice things for us: It will perform a clone that is as shallow as possible - * while still rewriting any cyclic references back to the new 'root' table. - * - * replacer also keeps a mapping of types that have previously been copied, so we have the added - * advantage here of knowing that, whether or not a new copy was actually made, the resulting TypeVar is - * safe for us to mutate in-place. - */ - TypeId clone = replacer.smartClone(it->second.type); - it->second.type = combine(replacer, clone, prop.type); - } - else - table->props.insert({propName, prop}); - } - - if (tyTable->indexer) - { - if (table->indexer) - { - table->indexer->indexType = combine(replacer, replacer.smartClone(tyTable->indexer->indexType), table->indexer->indexType); - table->indexer->indexResultType = - combine(replacer, replacer.smartClone(tyTable->indexer->indexResultType), table->indexer->indexResultType); - } - else - { - table->indexer = - TableIndexer{replacer.smartClone(tyTable->indexer->indexType), replacer.smartClone(tyTable->indexer->indexResultType)}; - } - } - - table->state = combineTableStates(table->state, tyTable->state); - table->level = max(table->level, tyTable->level); - } - - /** - * @param a is always cloned by the caller. It is safe to mutate in-place. - * @param b will never be mutated. - */ - TypeId combine(Replacer& replacer, TypeId a, TypeId b) - { - b = follow(b); - - if (FFlag::LuauNormalizeCombineTableFix && a == b) - return a; - - if (!get(a) && !get(a)) - { - if (!FFlag::LuauNormalizeCombineTableFix && a == b) - return a; - else - return arena.addType(IntersectionTypeVar{{a, b}}); - } - - if (auto itv = getMutable(a)) - { - combineIntoIntersection(replacer, itv, b); - return a; - } - else if (auto ttv = getMutable(a)) - { - if (FFlag::LuauNormalizeCombineTableFix && !get(b)) - return arena.addType(IntersectionTypeVar{{a, b}}); - combineIntoTable(replacer, ttv, b); - return a; - } - - LUAU_ASSERT(!"Impossible"); - LUAU_UNREACHABLE(); - } -}; - -#undef CHECK_ITERATION_LIMIT - -/** - * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) - */ -std::pair normalize( - TypeId ty, NotNull scope, TypeArena& arena, NotNull singletonTypes, InternalErrorReporter& ice) -{ - CloneState state; - if (FFlag::DebugLuauCopyBeforeNormalizing) - (void)clone(ty, arena, state); - - Normalize n{arena, scope, singletonTypes, ice}; - n.traverse(ty); - - return {ty, !n.limitExceeded}; -} - -// TODO: Think about using a temporary arena and cloning types out of it so that we -// reclaim memory used by wantonly allocated intermediate types here. -// The main wrinkle here is that we don't want clone() to copy a type if the source and dest -// arena are the same. -std::pair normalize(TypeId ty, NotNull module, NotNull singletonTypes, InternalErrorReporter& ice) -{ - return normalize(ty, NotNull{module->getModuleScope().get()}, module->internalTypes, singletonTypes, ice); -} - -std::pair normalize(TypeId ty, const ModulePtr& module, NotNull singletonTypes, InternalErrorReporter& ice) -{ - return normalize(ty, NotNull{module.get()}, singletonTypes, ice); -} - -/** - * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) - */ -std::pair normalize( - TypePackId tp, NotNull scope, TypeArena& arena, NotNull singletonTypes, InternalErrorReporter& ice) -{ - CloneState state; - if (FFlag::DebugLuauCopyBeforeNormalizing) - (void)clone(tp, arena, state); - - Normalize n{arena, scope, singletonTypes, ice}; - n.traverse(tp); - - return {tp, !n.limitExceeded}; -} - -std::pair normalize(TypePackId tp, NotNull module, NotNull singletonTypes, InternalErrorReporter& ice) -{ - return normalize(tp, NotNull{module->getModuleScope().get()}, module->internalTypes, singletonTypes, ice); -} - -std::pair normalize(TypePackId tp, const ModulePtr& module, NotNull singletonTypes, InternalErrorReporter& ice) -{ - return normalize(tp, NotNull{module.get()}, singletonTypes, ice); -} - } // namespace Luau diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index e4c069bd..e9de094b 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -57,29 +57,6 @@ struct Quantifier final : TypeVarOnceVisitor return false; } - bool visit(TypeId ty, const ConstrainedTypeVar&) override - { - ConstrainedTypeVar* ctv = getMutable(ty); - - seenMutableType = true; - - if (!level.subsumes(ctv->level)) - return false; - - std::vector opts = std::move(ctv->parts); - - // We might transmute, so it's not safe to rely on the builtin traversal logic - for (TypeId opt : opts) - traverse(opt); - - if (opts.size() == 1) - *asMutable(ty) = BoundTypeVar{opts[0]}; - else - *asMutable(ty) = UnionTypeVar{std::move(opts)}; - - return false; - } - bool visit(TypeId ty, const TableTypeVar&) override { LUAU_ASSERT(getMutable(ty)); diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 9a7d3609..84925f79 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -27,6 +27,44 @@ void Scope::addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun) builtinTypeNames.insert(name); } +std::optional Scope::lookup(Symbol sym) const +{ + auto r = const_cast(this)->lookupEx(sym); + if (r) + return r->first; + else + return std::nullopt; +} + +std::optional> Scope::lookupEx(Symbol sym) +{ + Scope* s = this; + + while (true) + { + auto it = s->bindings.find(sym); + if (it != s->bindings.end()) + return std::pair{it->second.typeId, s}; + + if (s->parent) + s = s->parent.get(); + else + return std::nullopt; + } +} + +// TODO: We might kill Scope::lookup(Symbol) once data flow is fully fleshed out with type states and control flow analysis. +std::optional Scope::lookup(DefId def) const +{ + for (const Scope* current = this; current; current = current->parent.get()) + { + if (auto ty = current->dcrRefinements.find(def)) + return *ty; + } + + return std::nullopt; +} + std::optional Scope::lookupType(const Name& name) { const Scope* scope = this; @@ -111,23 +149,6 @@ std::optional Scope::linearSearchForBinding(const std::string& name, bo return std::nullopt; } -std::optional Scope::lookup(Symbol sym) -{ - Scope* s = this; - - while (true) - { - auto it = s->bindings.find(sym); - if (it != s->bindings.end()) - return it->second.typeId; - - if (s->parent) - s = s->parent.get(); - else - return std::nullopt; - } -} - bool subsumesStrict(Scope* left, Scope* right) { while (right) diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 2137d73e..20ed34f6 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -73,11 +73,6 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypeId part : itv->parts) visitChild(part); } - else if (const ConstrainedTypeVar* ctv = get(ty)) - { - for (TypeId part : ctv->parts) - visitChild(part); - } else if (const PendingExpansionTypeVar* petv = get(ty)) { for (TypeId a : petv->typeArguments) @@ -97,6 +92,10 @@ void Tarjan::visitChildren(TypeId ty, int index) if (ctv->metatable) visitChild(*ctv->metatable); } + else if (const NegationTypeVar* ntv = get(ty)) + { + visitChild(ntv->ty); + } } void Tarjan::visitChildren(TypePackId tp, int index) @@ -605,11 +604,6 @@ void Substitution::replaceChildren(TypeId ty) for (TypeId& part : itv->parts) part = replace(part); } - else if (ConstrainedTypeVar* ctv = getMutable(ty)) - { - for (TypeId& part : ctv->parts) - part = replace(part); - } else if (PendingExpansionTypeVar* petv = getMutable(ty)) { for (TypeId& a : petv->typeArguments) @@ -629,6 +623,10 @@ void Substitution::replaceChildren(TypeId ty) if (ctv->metatable) ctv->metatable = replace(*ctv->metatable); } + else if (NegationTypeVar* ntv = getMutable(ty)) + { + ntv->ty = replace(ntv->ty); + } } void Substitution::replaceChildren(TypePackId tp) diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 0d989ca0..68fa5393 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -237,15 +237,6 @@ void StateDot::visitChildren(TypeId ty, int index) finishNodeLabel(ty); finishNode(); } - else if (const ConstrainedTypeVar* ctv = get(ty)) - { - formatAppend(result, "ConstrainedTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - for (TypeId part : ctv->parts) - visitChild(part, index); - } else if (get(ty)) { formatAppend(result, "ErrorTypeVar %d", index); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 9572ef19..31641af8 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -400,29 +400,6 @@ struct TypeVarStringifier state.emit(state.getName(ty)); } - void operator()(TypeId, const ConstrainedTypeVar& ctv) - { - state.result.invalid = true; - - state.emit("["); - if (FFlag::DebugLuauVerboseTypeNames) - state.emit(ctv.level); - state.emit("["); - - bool first = true; - for (TypeId ty : ctv.parts) - { - if (first) - first = false; - else - state.emit("|"); - - stringify(ty); - } - - state.emit("]]"); - } - void operator()(TypeId, const BlockedTypeVar& btv) { state.emit("*blocked-"); @@ -871,6 +848,28 @@ struct TypeVarStringifier { state.emit("never"); } + + void operator()(TypeId ty, const UseTypeVar&) + { + stringify(follow(ty)); + } + + void operator()(TypeId, const NegationTypeVar& ntv) + { + state.emit("~"); + + // The precedence of `~` should be less than `|` and `&`. + TypeId followed = follow(ntv.ty); + bool parens = get(followed) || get(followed); + + if (parens) + state.emit("("); + + stringify(ntv.ty); + + if (parens) + state.emit(")"); + } }; struct TypePackStringifier @@ -1442,7 +1441,7 @@ std::string generateName(size_t i) std::string toString(const Constraint& constraint, ToStringOptions& opts) { - auto go = [&opts](auto&& c) { + auto go = [&opts](auto&& c) -> std::string { using T = std::decay_t; // TODO: Inline and delete this function when clipping FFlag::LuauFixNameMaps @@ -1526,6 +1525,10 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) { return tos(c.resultType, opts) + " ~ hasProp " + tos(c.subjectType, opts) + ", \"" + c.prop + "\""; } + else if constexpr (std::is_same_v) + { + return "TODO"; + } else static_assert(always_false_v, "Non-exhaustive constraint switch"); }; diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 06bde195..034aeaec 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -251,7 +251,7 @@ PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) { - LUAU_ASSERT(get(ty) || get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); PendingType* newTy = queue(ty); if (FreeTypeVar* ftv = Luau::getMutable(newTy)) @@ -267,11 +267,6 @@ PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) { ftv->level = newLevel; } - else if (ConstrainedTypeVar* ctv = Luau::getMutable(newTy)) - { - if (FFlag::LuauUnknownAndNeverType) - ctv->level = newLevel; - } return newTy; } @@ -291,7 +286,7 @@ PendingTypePack* TxnLog::changeLevel(TypePackId tp, TypeLevel newLevel) PendingType* TxnLog::changeScope(TypeId ty, NotNull newScope) { - LUAU_ASSERT(get(ty) || get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); PendingType* newTy = queue(ty); if (FreeTypeVar* ftv = Luau::getMutable(newTy)) @@ -307,10 +302,6 @@ PendingType* TxnLog::changeScope(TypeId ty, NotNull newScope) { ftv->scope = newScope; } - else if (ConstrainedTypeVar* ctv = Luau::getMutable(newTy)) - { - ctv->scope = newScope; - } return newTy; } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 84494083..f2613cae 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -104,16 +104,6 @@ public: return allocator->alloc(Location(), std::nullopt, AstName("*pending-expansion*")); } - AstType* operator()(const ConstrainedTypeVar& ctv) - { - AstArray types; - types.size = ctv.parts.size(); - types.data = static_cast(allocator->allocate(sizeof(AstType*) * ctv.parts.size())); - for (size_t i = 0; i < ctv.parts.size(); ++i) - types.data[i] = Luau::visit(*this, ctv.parts[i]->ty); - return allocator->alloc(Location(), types); - } - AstType* operator()(const SingletonTypeVar& stv) { if (const BooleanSingleton* bs = get(&stv)) @@ -348,6 +338,17 @@ public: { return allocator->alloc(Location(), std::nullopt, AstName{"never"}); } + AstType* operator()(const UseTypeVar& utv) + { + std::optional ty = utv.scope->lookup(utv.def); + LUAU_ASSERT(ty); + return Luau::visit(*this, (*ty)->ty); + } + AstType* operator()(const NegationTypeVar& ntv) + { + // FIXME: do the same thing we do with ErrorTypeVar + throw std::runtime_error("Cannot convert NegationTypeVar into AstNode"); + } private: Allocator* allocator; diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 4753a7c2..bd220e9c 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -5,6 +5,7 @@ #include "Luau/AstQuery.h" #include "Luau/Clone.h" #include "Luau/Instantiation.h" +#include "Luau/Metamethods.h" #include "Luau/Normalize.h" #include "Luau/ToString.h" #include "Luau/TxnLog.h" @@ -62,6 +63,23 @@ struct StackPusher } }; +static std::optional getIdentifierOfBaseVar(AstExpr* node) +{ + if (AstExprGlobal* expr = node->as()) + return expr->name.value; + + if (AstExprLocal* expr = node->as()) + return expr->local->name.value; + + if (AstExprIndexExpr* expr = node->as()) + return getIdentifierOfBaseVar(expr->expr); + + if (AstExprIndexName* expr = node->as()) + return getIdentifierOfBaseVar(expr->expr); + + return std::nullopt; +} + struct TypeChecker2 { NotNull singletonTypes; @@ -750,7 +768,7 @@ struct TypeChecker2 TypeId actualType = lookupType(string); TypeId stringType = singletonTypes->stringType; - if (!isSubtype(stringType, actualType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(actualType, stringType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) { reportError(TypeMismatch{actualType, stringType}, string->location); } @@ -783,26 +801,55 @@ struct TypeChecker2 TypePackId expectedRetType = lookupPack(call); TypeId functionType = lookupType(call->func); - LUAU_ASSERT(functionType); + TypeId testFunctionType = functionType; + TypePack args; if (get(functionType) || get(functionType)) return; - - // TODO: Lots of other types are callable: intersections of functions - // and things with the __call metamethod. - if (!get(functionType)) + else if (std::optional callMm = findMetatableEntry(singletonTypes, module->errors, functionType, "__call", call->func->location)) + { + if (get(follow(*callMm))) + { + if (std::optional instantiatedCallMm = instantiation.substitute(*callMm)) + { + args.head.push_back(functionType); + testFunctionType = follow(*instantiatedCallMm); + } + else + { + reportError(UnificationTooComplex{}, call->func->location); + return; + } + } + else + { + // TODO: This doesn't flag the __call metamethod as the problem + // very clearly. + reportError(CannotCallNonFunction{*callMm}, call->func->location); + return; + } + } + else if (get(functionType)) + { + if (std::optional instantiatedFunctionType = instantiation.substitute(functionType)) + { + testFunctionType = *instantiatedFunctionType; + } + else + { + reportError(UnificationTooComplex{}, call->func->location); + return; + } + } + else { reportError(CannotCallNonFunction{functionType}, call->func->location); return; } - TypeId instantiatedFunctionType = follow(instantiation.substitute(functionType).value_or(nullptr)); - - TypePack args; for (AstExpr* arg : call->args) { - TypeId argTy = module->astTypes[arg]; - LUAU_ASSERT(argTy); + TypeId argTy = lookupType(arg); args.head.push_back(argTy); } @@ -810,7 +857,7 @@ struct TypeChecker2 FunctionTypeVar ftv{argsTp, expectedRetType}; TypeId expectedType = arena.addType(ftv); - if (!isSubtype(instantiatedFunctionType, expectedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(testFunctionType, expectedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) { CloneState cloneState; expectedType = clone(expectedType, module->internalTypes, cloneState); @@ -893,9 +940,204 @@ struct TypeChecker2 void visit(AstExprBinary* expr) { - // TODO! visit(expr->left); visit(expr->right); + + NotNull scope = stack.back(); + + bool isEquality = expr->op == AstExprBinary::Op::CompareEq || expr->op == AstExprBinary::Op::CompareNe; + bool isComparison = expr->op >= AstExprBinary::Op::CompareEq && expr->op <= AstExprBinary::Op::CompareGe; + bool isLogical = expr->op == AstExprBinary::Op::And || expr->op == AstExprBinary::Op::Or; + + TypeId leftType = lookupType(expr->left); + TypeId rightType = lookupType(expr->right); + + if (expr->op == AstExprBinary::Op::Or) + { + leftType = stripNil(singletonTypes, module->internalTypes, leftType); + } + + bool isStringOperation = isString(leftType) && isString(rightType); + + if (get(leftType) || get(leftType) || get(rightType) || get(rightType)) + return; + + if ((get(leftType) || get(leftType)) && !isEquality && !isLogical) + { + auto name = getIdentifierOfBaseVar(expr->left); + reportError(CannotInferBinaryOperation{expr->op, name, + isComparison ? CannotInferBinaryOperation::OpKind::Comparison : CannotInferBinaryOperation::OpKind::Operation}, + expr->location); + return; + } + + if (auto it = kBinaryOpMetamethods.find(expr->op); it != kBinaryOpMetamethods.end()) + { + std::optional leftMt = getMetatable(leftType, singletonTypes); + std::optional rightMt = getMetatable(rightType, singletonTypes); + + bool matches = leftMt == rightMt; + if (isEquality && !matches) + { + auto testUnion = [&matches, singletonTypes = this->singletonTypes](const UnionTypeVar* utv, std::optional otherMt) { + for (TypeId option : utv) + { + if (getMetatable(follow(option), singletonTypes) == otherMt) + { + matches = true; + break; + } + } + }; + + if (const UnionTypeVar* utv = get(leftType); utv && rightMt) + { + testUnion(utv, rightMt); + } + + if (const UnionTypeVar* utv = get(rightType); utv && leftMt && !matches) + { + testUnion(utv, leftMt); + } + } + + if (!matches && isComparison) + { + reportError(GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())}, + expr->location); + + return; + } + + std::optional mm; + if (std::optional leftMm = findMetatableEntry(singletonTypes, module->errors, leftType, it->second, expr->left->location)) + mm = leftMm; + else if (std::optional rightMm = findMetatableEntry(singletonTypes, module->errors, rightType, it->second, expr->right->location)) + mm = rightMm; + + if (mm) + { + if (const FunctionTypeVar* ftv = get(*mm)) + { + TypePackId expectedArgs; + // For >= and > we invoke __lt and __le respectively with + // swapped argument ordering. + if (expr->op == AstExprBinary::Op::CompareGe || expr->op == AstExprBinary::Op::CompareGt) + { + expectedArgs = module->internalTypes.addTypePack({rightType, leftType}); + } + else + { + expectedArgs = module->internalTypes.addTypePack({leftType, rightType}); + } + + reportErrors(tryUnify(scope, expr->location, ftv->argTypes, expectedArgs)); + + if (expr->op == AstExprBinary::CompareEq || expr->op == AstExprBinary::CompareNe || expr->op == AstExprBinary::CompareGe || + expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::Op::CompareLe || expr->op == AstExprBinary::Op::CompareLt) + { + TypePackId expectedRets = module->internalTypes.addTypePack({singletonTypes->booleanType}); + if (!isSubtype(ftv->retTypes, expectedRets, scope, singletonTypes, ice)) + { + reportError(GenericError{format("Metamethod '%s' must return type 'boolean'", it->second)}, expr->location); + } + } + else if (!first(ftv->retTypes)) + { + reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location); + } + } + else + { + reportError(CannotCallNonFunction{*mm}, expr->location); + } + + return; + } + // If this is a string comparison, or a concatenation of strings, we + // want to fall through to primitive behavior. + else if (!isEquality && !(isStringOperation && (expr->op == AstExprBinary::Op::Concat || isComparison))) + { + if (leftMt || rightMt) + { + if (isComparison) + { + reportError(GenericError{format( + "Types '%s' and '%s' cannot be compared with %s because neither type's metatable has a '%s' metamethod", + toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str(), it->second)}, + expr->location); + } + else + { + reportError(GenericError{format( + "Operator %s is not applicable for '%s' and '%s' because neither type's metatable has a '%s' metamethod", + toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str(), it->second)}, + expr->location); + } + + return; + } + else if (!leftMt && !rightMt && (get(leftType) || get(rightType))) + { + if (isComparison) + { + reportError(GenericError{format("Types '%s' and '%s' cannot be compared with %s because neither type has a metatable", + toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())}, + expr->location); + } + else + { + reportError(GenericError{format("Operator %s is not applicable for '%s' and '%s' because neither type has a metatable", + toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str())}, + expr->location); + } + + return; + } + } + } + + switch (expr->op) + { + case AstExprBinary::Op::Add: + case AstExprBinary::Op::Sub: + case AstExprBinary::Op::Mul: + case AstExprBinary::Op::Div: + case AstExprBinary::Op::Pow: + case AstExprBinary::Op::Mod: + reportErrors(tryUnify(scope, expr->left->location, leftType, singletonTypes->numberType)); + reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->numberType)); + + break; + case AstExprBinary::Op::Concat: + reportErrors(tryUnify(scope, expr->left->location, leftType, singletonTypes->stringType)); + reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->stringType)); + + break; + case AstExprBinary::Op::CompareGe: + case AstExprBinary::Op::CompareGt: + case AstExprBinary::Op::CompareLe: + case AstExprBinary::Op::CompareLt: + if (isNumber(leftType)) + reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->numberType)); + else if (isString(leftType)) + reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->stringType)); + else + reportError(GenericError{format("Types '%s' and '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), + toString(rightType).c_str(), toString(expr->op).c_str())}, + expr->location); + + break; + case AstExprBinary::Op::And: + case AstExprBinary::Op::Or: + case AstExprBinary::Op::CompareEq: + case AstExprBinary::Op::CompareNe: + break; + default: + // Unhandled AstExprBinary::Op possibility. + LUAU_ASSERT(false); + } } void visit(AstExprTypeAssertion* expr) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index b806edb7..d5c6b2c4 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -31,7 +31,6 @@ LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) -LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAG(LuauTypeNormalization2) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. @@ -280,11 +279,8 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo iceHandler->moduleName = module.name; normalizer.arena = ¤tModule->internalTypes; - if (FFlag::LuauAutocompleteDynamicLimits) - { - unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; - unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit; - } + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit; ScopePtr parentScope = environmentScope.value_or(globalScope); ScopePtr moduleScope = std::make_shared(parentScope); @@ -773,16 +769,6 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) checkExpr(repScope, *statement.condition); } -void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const ScopePtr& scope, const Location& location) -{ - Unifier state = mkUnifier(scope, location); - state.unifyLowerBound(subTy, superTy, demotedLevel); - - state.log.commit(); - - reportErrors(state.errors); -} - struct Demoter : Substitution { Demoter(TypeArena* arena) @@ -2091,39 +2077,6 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( return std::nullopt; } -std::vector TypeChecker::reduceUnion(const std::vector& types) -{ - std::vector result; - for (TypeId t : types) - { - t = follow(t); - if (get(t)) - continue; - - if (get(t) || get(t)) - return {t}; - - if (const UnionTypeVar* utv = get(t)) - { - for (TypeId ty : utv) - { - ty = follow(ty); - if (get(ty)) - continue; - if (get(ty) || get(ty)) - return {ty}; - - if (result.end() == std::find(result.begin(), result.end(), ty)) - result.push_back(ty); - } - } - else if (std::find(result.begin(), result.end(), t) == result.end()) - result.push_back(t); - } - - return result; -} - std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) { if (const UnionTypeVar* utv = get(ty)) @@ -4597,7 +4550,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat Instantiation instantiation{log, ¤tModule->internalTypes, scope->level, /*scope*/ nullptr}; - if (FFlag::LuauAutocompleteDynamicLimits && instantiationChildLimit) + if (instantiationChildLimit) instantiation.childLimit = *instantiationChildLimit; std::optional instantiated = instantiation.substitute(ty); diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 688c8767..72597c4a 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -6,6 +6,8 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" +#include + namespace Luau { @@ -146,18 +148,15 @@ std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro return std::nullopt; } + goodOptions = reduceUnion(goodOptions); + if (goodOptions.empty()) return singletonTypes->neverType; if (goodOptions.size() == 1) return goodOptions[0]; - // TODO: inefficient. - TypeId result = arena->addType(UnionTypeVar{std::move(goodOptions)}); - auto [ty, ok] = normalize(result, NotNull{scope.get()}, *arena, singletonTypes, handle); - if (!ok && addErrors) - errors.push_back(TypeError{location, NormalizationTooComplex{}}); - return ok ? ty : singletonTypes->anyType; + return arena->addType(UnionTypeVar{std::move(goodOptions)}); } else if (const IntersectionTypeVar* itv = get(type)) { @@ -264,4 +263,79 @@ std::vector flatten(TypeArena& arena, NotNull singletonT return result; } +std::vector reduceUnion(const std::vector& types) +{ + std::vector result; + for (TypeId t : types) + { + t = follow(t); + if (get(t)) + continue; + + if (get(t) || get(t)) + return {t}; + + if (const UnionTypeVar* utv = get(t)) + { + for (TypeId ty : utv) + { + ty = follow(ty); + if (get(ty)) + continue; + if (get(ty) || get(ty)) + return {ty}; + + if (result.end() == std::find(result.begin(), result.end(), ty)) + result.push_back(ty); + } + } + else if (std::find(result.begin(), result.end(), t) == result.end()) + result.push_back(t); + } + + return result; +} + +static std::optional tryStripUnionFromNil(TypeArena& arena, TypeId ty) +{ + if (const UnionTypeVar* utv = get(ty)) + { + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; + + std::vector result; + + for (TypeId option : utv) + { + if (!isNil(option)) + result.push_back(option); + } + + if (result.empty()) + return std::nullopt; + + return result.size() == 1 ? result[0] : arena.addType(UnionTypeVar{std::move(result)}); + } + + return std::nullopt; +} + +TypeId stripNil(NotNull singletonTypes, TypeArena& arena, TypeId ty) +{ + ty = follow(ty); + + if (get(ty)) + { + std::optional cleaned = tryStripUnionFromNil(arena, ty); + + // If there is no union option without 'nil' + if (!cleaned) + return singletonTypes->nilType; + + return follow(*cleaned); + } + + return follow(ty); +} + } // namespace Luau diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index bcdaff7d..19d3d266 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -57,6 +57,13 @@ TypeId follow(TypeId t, std::function mapper) return btv->boundTo; else if (auto ttv = get(mapper(ty))) return ttv->boundTo; + else if (auto utv = get(mapper(ty))) + { + std::optional ty = utv->scope->lookup(utv->def); + if (!ty) + throw std::runtime_error("UseTypeVar must map to another TypeId"); + return *ty; + } else return std::nullopt; }; @@ -760,6 +767,8 @@ SingletonTypes::SingletonTypes() , unknownType(arena->addType(TypeVar{UnknownTypeVar{}, /*persistent*/ true})) , neverType(arena->addType(TypeVar{NeverTypeVar{}, /*persistent*/ true})) , errorType(arena->addType(TypeVar{ErrorTypeVar{}, /*persistent*/ true})) + , falsyType(arena->addType(TypeVar{UnionTypeVar{{falseType, nilType}}, /*persistent*/ true})) + , truthyType(arena->addType(TypeVar{NegationTypeVar{falsyType}, /*persistent*/ true})) , anyTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, /*persistent*/ true})) , neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true})) , uninhabitableTypePack(arena->addTypePack({neverType}, neverTypePack)) @@ -896,7 +905,6 @@ void persist(TypeId ty) continue; asMutable(t)->persistent = true; - asMutable(t)->normal = true; // all persistent types are assumed to be normal if (auto btv = get(t)) queue.push_back(btv->boundTo); @@ -933,11 +941,6 @@ void persist(TypeId ty) for (TypeId opt : itv->parts) queue.push_back(opt); } - else if (auto ctv = get(t)) - { - for (TypeId opt : ctv->parts) - queue.push_back(opt); - } else if (auto mtv = get(t)) { queue.push_back(mtv->table); @@ -990,8 +993,6 @@ const TypeLevel* getLevel(TypeId ty) return &ttv->level; else if (auto ftv = get(ty)) return &ftv->level; - else if (auto ctv = get(ty)) - return &ctv->level; else return nullptr; } @@ -1056,11 +1057,6 @@ const std::vector& getTypes(const IntersectionTypeVar* itv) return itv->parts; } -const std::vector& getTypes(const ConstrainedTypeVar* ctv) -{ - return ctv->parts; -} - UnionTypeVarIterator begin(const UnionTypeVar* utv) { return UnionTypeVarIterator{utv}; @@ -1081,17 +1077,6 @@ IntersectionTypeVarIterator end(const IntersectionTypeVar* itv) return IntersectionTypeVarIterator{}; } -ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv) -{ - return ConstrainedTypeVarIterator{ctv}; -} - -ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv) -{ - return ConstrainedTypeVarIterator{}; -} - - static std::vector parseFormatString(TypeChecker& typechecker, const char* data, size_t size) { const char* options = "cdiouxXeEfgGqs*"; diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 42fcd2fd..e23e6161 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -13,16 +13,13 @@ #include -LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); -LUAU_FASTINT(LuauTypeInferIterationLimit); -LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) -LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false); LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) +LUAU_FASTFLAGVARIABLE(LuauOverloadedFunctionSubtypingPerf, false); LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) @@ -95,15 +92,6 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor return true; } - bool visit(TypeId ty, const ConstrainedTypeVar&) override - { - if (!FFlag::LuauUnknownAndNeverType) - return visit(ty); - - promote(ty, log.getMutable(ty)); - return true; - } - bool visit(TypeId ty, const FunctionTypeVar&) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside @@ -368,26 +356,14 @@ void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool i void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) { - RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); ++sharedState.counters.iterationCount; - if (FFlag::LuauAutocompleteDynamicLimits) + if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) { - if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) - { - reportError(TypeError{location, UnificationTooComplex{}}); - return; - } - } - else - { - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) - { - reportError(TypeError{location, UnificationTooComplex{}}); - return; - } + reportError(TypeError{location, UnificationTooComplex{}}); + return; } superTy = log.follow(superTy); @@ -396,9 +372,6 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superTy == subTy) return; - if (log.get(superTy)) - return tryUnifyWithConstrainedSuperTypeVar(subTy, superTy); - auto superFree = log.getMutable(superTy); auto subFree = log.getMutable(subTy); @@ -520,9 +493,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool size_t errorCount = errors.size(); - if (log.get(subTy)) - tryUnifyWithConstrainedSubTypeVar(subTy, superTy); - else if (const UnionTypeVar* subUnion = log.getMutable(subTy)) + if (const UnionTypeVar* subUnion = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, subUnion, superTy); } @@ -1011,10 +982,17 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized log.concat(std::move(innerState.log)); if (result) { + if (FFlag::LuauOverloadedFunctionSubtypingPerf) + { + innerState.log.clear(); + innerState.tryUnify_(*result, ftv->retTypes); + } + if (FFlag::LuauOverloadedFunctionSubtypingPerf && innerState.errors.empty()) + log.concat(std::move(innerState.log)); // Annoyingly, since we don't support intersection of generic type packs, // the intersection may fail. We rather arbitrarily use the first matching overload // in that case. - if (std::optional intersect = normalizer->intersectionOfTypePacks(*result, ftv->retTypes)) + else if (std::optional intersect = normalizer->intersectionOfTypePacks(*result, ftv->retTypes)) result = intersect; } else @@ -1214,26 +1192,14 @@ void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall */ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall) { - RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); ++sharedState.counters.iterationCount; - if (FFlag::LuauAutocompleteDynamicLimits) + if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) { - if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) - { - reportError(TypeError{location, UnificationTooComplex{}}); - return; - } - } - else - { - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) - { - reportError(TypeError{location, UnificationTooComplex{}}); - return; - } + reportError(TypeError{location, UnificationTooComplex{}}); + return; } superTp = log.follow(superTp); @@ -2314,186 +2280,6 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N return Luau::findTablePropertyRespectingMeta(singletonTypes, errors, lhsType, name, location); } -void Unifier::tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy) -{ - const ConstrainedTypeVar* subConstrained = get(subTy); - if (!subConstrained) - ice("tryUnifyWithConstrainedSubTypeVar received non-ConstrainedTypeVar subTy!"); - - const std::vector& subTyParts = subConstrained->parts; - - // A | B <: T if A <: T and B <: T - bool failed = false; - std::optional unificationTooComplex; - - const size_t count = subTyParts.size(); - - for (size_t i = 0; i < count; ++i) - { - TypeId type = subTyParts[i]; - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, superTy); - - if (i == count - 1) - log.concat(std::move(innerState.log)); - - ++i; - - if (auto e = hasUnificationTooComplex(innerState.errors)) - unificationTooComplex = e; - - if (!innerState.errors.empty()) - { - failed = true; - break; - } - } - - if (unificationTooComplex) - reportError(*unificationTooComplex); - else if (failed) - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - else - log.replace(subTy, BoundTypeVar{superTy}); -} - -void Unifier::tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy) -{ - ConstrainedTypeVar* superC = log.getMutable(superTy); - if (!superC) - ice("tryUnifyWithConstrainedSuperTypeVar received non-ConstrainedTypeVar superTy!"); - - // subTy could be a - // table - // metatable - // class - // function - // primitive - // free - // generic - // intersection - // union - // Do we really just tack it on? I think we might! - // We can certainly do some deduplication. - // Is there any point to deducing Player|Instance when we could just reduce to Instance? - // Is it actually ok to have multiple free types in a single intersection? What if they are later unified into the same type? - // Maybe we do a simplification step during quantification. - - auto it = std::find(superC->parts.begin(), superC->parts.end(), subTy); - if (it != superC->parts.end()) - return; - - superC->parts.push_back(subTy); -} - -void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel) -{ - // The duplication between this and regular typepack unification is tragic. - - auto superIter = begin(superTy, &log); - auto superEndIter = end(superTy); - - auto subIter = begin(subTy, &log); - auto subEndIter = end(subTy); - - int count = FInt::LuauTypeInferLowerBoundsIterationLimit; - - for (; subIter != subEndIter; ++subIter) - { - if (0 >= --count) - ice("Internal recursion counter limit exceeded in Unifier::unifyLowerBound"); - - if (superIter != superEndIter) - { - tryUnify_(*subIter, *superIter); - ++superIter; - continue; - } - - if (auto t = superIter.tail()) - { - TypePackId tailPack = follow(*t); - - if (log.get(tailPack) && occursCheck(tailPack, subTy)) - return; - - FreeTypePack* freeTailPack = log.getMutable(tailPack); - if (!freeTailPack) - return; - - TypePack* tp = getMutable(log.replace(tailPack, TypePack{})); - - for (; subIter != subEndIter; ++subIter) - { - tp->head.push_back(types->addType(ConstrainedTypeVar{demotedLevel, {follow(*subIter)}})); - } - - tp->tail = subIter.tail(); - } - - return; - } - - if (superIter != superEndIter) - { - if (auto subTail = subIter.tail()) - { - TypePackId subTailPack = follow(*subTail); - if (get(subTailPack)) - { - TypePack* tp = getMutable(log.replace(subTailPack, TypePack{})); - - for (; superIter != superEndIter; ++superIter) - tp->head.push_back(*superIter); - } - else if (const VariadicTypePack* subVariadic = log.getMutable(subTailPack)) - { - while (superIter != superEndIter) - { - tryUnify_(subVariadic->ty, *superIter); - ++superIter; - } - } - } - else - { - while (superIter != superEndIter) - { - if (!isOptional(*superIter)) - { - errors.push_back(TypeError{location, CountMismatch{size(superTy), std::nullopt, size(subTy), CountMismatch::Return}}); - return; - } - ++superIter; - } - } - - return; - } - - // Both iters are at their respective tails - auto subTail = subIter.tail(); - auto superTail = superIter.tail(); - if (subTail && superTail) - tryUnify(*subTail, *superTail); - else if (subTail) - { - const FreeTypePack* freeSubTail = log.getMutable(*subTail); - if (freeSubTail) - { - log.replace(*subTail, TypePack{}); - } - } - else if (superTail) - { - const FreeTypePack* freeSuperTail = log.getMutable(*superTail); - if (freeSuperTail) - { - log.replace(*superTail, TypePack{}); - } - } -} - bool Unifier::occursCheck(TypeId needle, TypeId haystack) { sharedState.tempSeenTy.clear(); @@ -2503,8 +2289,7 @@ bool Unifier::occursCheck(TypeId needle, TypeId haystack) bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack) { - RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); bool occurrence = false; @@ -2547,11 +2332,6 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays for (TypeId ty : a->parts) check(ty); } - else if (auto a = log.getMutable(haystack)) - { - for (TypeId ty : a->parts) - check(ty); - } return occurrence; } @@ -2579,8 +2359,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ if (!log.getMutable(needle)) ice("Expected needle pack to be free"); - RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); while (!log.getMutable(haystack)) { diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 7e4c5691..6257e2f3 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -14,7 +14,6 @@ #endif LUAU_FASTFLAG(DebugLuauTimeTracing) -LUAU_FASTFLAG(LuauTypeMismatchModuleNameResolution) enum class ReportFormat { @@ -55,11 +54,9 @@ static void reportError(const Luau::Frontend& frontend, ReportFormat format, con if (const Luau::SyntaxError* syntaxError = Luau::get_if(&error.data)) report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str()); - else if (FFlag::LuauTypeMismatchModuleNameResolution) + else report(format, humanReadableName.c_str(), error.location, "TypeError", Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str()); - else - report(format, humanReadableName.c_str(), error.location, "TypeError", Luau::toString(error).c_str()); } static void reportWarning(ReportFormat format, const char* name, const Luau::LintWarning& warning) diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index aecddf38..a9dd8970 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -16,6 +16,7 @@ #include "isocline.h" +#include #include #ifdef _WIN32 @@ -49,6 +50,8 @@ enum class CompileFormat Binary, Remarks, Codegen, + CodegenVerbose, + CodegenNull, Null }; @@ -673,21 +676,33 @@ static void reportError(const char* name, const Luau::CompileError& error) report(name, error.getLocation(), "CompileError", error.what()); } -static std::string getCodegenAssembly(const char* name, const std::string& bytecode) +static std::string getCodegenAssembly(const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options) { std::unique_ptr globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); - setupState(L); - if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0) - return Luau::CodeGen::getAssemblyText(L, -1); + return Luau::CodeGen::getAssembly(L, -1, options); fprintf(stderr, "Error loading bytecode %s\n", name); return ""; } -static bool compileFile(const char* name, CompileFormat format) +static void annotateInstruction(void* context, std::string& text, int fid, int instid) +{ + Luau::BytecodeBuilder& bcb = *(Luau::BytecodeBuilder*)context; + + bcb.annotateInstruction(text, fid, instid); +} + +struct CompileStats +{ + size_t lines; + size_t bytecode; + size_t codegen; +}; + +static bool compileFile(const char* name, CompileFormat format, CompileStats& stats) { std::optional source = readFile(name); if (!source) @@ -696,9 +711,12 @@ static bool compileFile(const char* name, CompileFormat format) return false; } + stats.lines += std::count(source->begin(), source->end(), '\n'); + try { Luau::BytecodeBuilder bcb; + Luau::CodeGen::AssemblyOptions options = {format == CompileFormat::CodegenNull, format == CompileFormat::Codegen, annotateInstruction, &bcb}; if (format == CompileFormat::Text) { @@ -711,8 +729,15 @@ static bool compileFile(const char* name, CompileFormat format) bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks); bcb.setDumpSource(*source); } + else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenVerbose) + { + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | + Luau::BytecodeBuilder::Dump_Remarks); + bcb.setDumpSource(*source); + } Luau::compileOrThrow(bcb, *source, copts()); + stats.bytecode += bcb.getBytecode().size(); switch (format) { @@ -726,7 +751,11 @@ static bool compileFile(const char* name, CompileFormat format) fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); break; case CompileFormat::Codegen: - printf("%s", getCodegenAssembly(name, bcb.getBytecode()).c_str()); + case CompileFormat::CodegenVerbose: + printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options).c_str()); + break; + case CompileFormat::CodegenNull: + stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options).size(); break; case CompileFormat::Null: break; @@ -755,7 +784,7 @@ static void displayHelp(const char* argv0) printf("\n"); printf("Available modes:\n"); printf(" omitted: compile and run input files one by one\n"); - printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary, text, remarks, codegen or null)\n"); + printf(" --compile[=format]: compile input files and output resulting bytecode/assembly (binary, text, remarks, codegen)\n"); printf("\n"); printf("Available options:\n"); printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); @@ -812,6 +841,14 @@ int replMain(int argc, char** argv) { compileFormat = CompileFormat::Codegen; } + else if (strcmp(argv[1], "--compile=codegenverbose") == 0) + { + compileFormat = CompileFormat::CodegenVerbose; + } + else if (strcmp(argv[1], "--compile=codegennull") == 0) + { + compileFormat = CompileFormat::CodegenNull; + } else if (strcmp(argv[1], "--compile=null") == 0) { compileFormat = CompileFormat::Null; @@ -923,10 +960,16 @@ int replMain(int argc, char** argv) _setmode(_fileno(stdout), _O_BINARY); #endif + CompileStats stats = {}; int failed = 0; for (const std::string& path : files) - failed += !compileFile(path.c_str(), compileFormat); + failed += !compileFile(path.c_str(), compileFormat, stats); + + if (compileFormat == CompileFormat::Null) + printf("Compiled %d KLOC into %d KB bytecode\n", int(stats.lines / 1000), int(stats.bytecode / 1024)); + else if (compileFormat == CompileFormat::CodegenNull) + printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024)); return failed ? 1 : 0; } diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index 1c755017..e48388c5 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -23,6 +23,19 @@ enum class RoundingModeX64 RoundToZero = 0b11, }; +enum class AlignmentDataX64 +{ + Nop, + Int3, + Ud2, // int3 will be used as a fall-back if it doesn't fit +}; + +enum class ABIX64 +{ + Windows, + SystemV, +}; + class AssemblyBuilderX64 { public: @@ -80,6 +93,10 @@ public: void int3(); + // Code alignment + void nop(uint32_t length = 1); + void align(uint32_t alignment, AlignmentDataX64 data = AlignmentDataX64::Nop); + // AVX void vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vaddps(OperandX64 dst, OperandX64 src1, OperandX64 src2); @@ -131,6 +148,8 @@ public: void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3); + uint32_t getCodeSize() const; + // Resulting data and code that need to be copied over one after the other // The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code' std::vector data; @@ -140,6 +159,8 @@ public: const bool logText = false; + const ABIX64 abi; + private: // Instruction archetypes void placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8, uint8_t code8rev, @@ -177,7 +198,6 @@ private: void commit(); LUAU_NOINLINE void extend(); - uint32_t getCodeSize(); // Data size_t allocateData(size_t size, size_t align); @@ -192,8 +212,8 @@ private: LUAU_NOINLINE void log(const char* opcode, Label label); void log(OperandX64 op); - const char* getSizeName(SizeX64 size); - const char* getRegisterName(RegisterX64 reg); + const char* getSizeName(SizeX64 size) const; + const char* getRegisterName(RegisterX64 reg) const; uint32_t nextLabel = 1; std::vector