diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 4db40a62..7fb88e21 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -73,9 +73,11 @@ jobs: valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O0 | tee -a compile-output.txt valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O1 | tee -a compile-output.txt valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=codegennull -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-codegen | tee -a compile-output.txt valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/regex.lua 2>&1 | filter regex-O0 | tee -a compile-output.txt valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/regex.lua 2>&1 | filter regex-O1 | tee -a compile-output.txt valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/regex.lua 2>&1 | filter regex-O2 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=codegennull -O2 bench/other/regex.lua 2>&1 | filter regex-O2-codegen | tee -a compile-output.txt - name: Checkout benchmark results uses: actions/checkout@v3 diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index 217e1cc3..f003c242 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include #include "Luau/TypeArena.h" #include "Luau/TypeVar.h" @@ -26,5 +27,6 @@ TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone = false); +TypeId shallowClone(TypeId ty, NotNull dest); } // namespace Luau diff --git a/Analysis/include/Luau/Connective.h b/Analysis/include/Luau/Connective.h new file mode 100644 index 00000000..c9daa0f9 --- /dev/null +++ b/Analysis/include/Luau/Connective.h @@ -0,0 +1,68 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Def.h" +#include "Luau/TypedAllocator.h" +#include "Luau/TypeVar.h" +#include "Luau/Variant.h" + +#include + +namespace Luau +{ + +struct Negation; +struct Conjunction; +struct Disjunction; +struct Equivalence; +struct Proposition; +using Connective = Variant; +using ConnectiveId = Connective*; // Can and most likely is nullptr. + +struct Negation +{ + ConnectiveId connective; +}; + +struct Conjunction +{ + ConnectiveId lhs; + ConnectiveId rhs; +}; + +struct Disjunction +{ + ConnectiveId lhs; + ConnectiveId rhs; +}; + +struct Equivalence +{ + ConnectiveId lhs; + ConnectiveId rhs; +}; + +struct Proposition +{ + DefId def; + TypeId discriminantTy; +}; + +template +const T* get(ConnectiveId connective) +{ + return get_if(connective); +} + +struct ConnectiveArena +{ + TypedAllocator allocator; + + ConnectiveId negation(ConnectiveId connective); + ConnectiveId conjunction(ConnectiveId lhs, ConnectiveId rhs); + ConnectiveId disjunction(ConnectiveId lhs, ConnectiveId rhs); + ConnectiveId equivalence(ConnectiveId lhs, ConnectiveId rhs); + ConnectiveId proposition(DefId def, TypeId discriminantTy); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 0e19f13f..4370d0cf 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -2,9 +2,10 @@ #pragma once #include "Luau/Ast.h" // Used for some of the enumerations +#include "Luau/Def.h" #include "Luau/NotNull.h" -#include "Luau/Variant.h" #include "Luau/TypeVar.h" +#include "Luau/Variant.h" #include #include @@ -131,9 +132,16 @@ struct HasPropConstraint std::string prop; }; -using ConstraintV = - Variant; +// result ~ if isSingleton D then ~D else unknown where D = discriminantType +struct SingletonOrTopTypeConstraint +{ + TypeId resultType; + TypeId discriminantType; +}; + +using ConstraintV = Variant; struct Constraint { @@ -143,7 +151,7 @@ struct Constraint Constraint& operator=(const Constraint&) = delete; NotNull scope; - Location location; + Location location; // TODO: Extract this out into only the constraints that needs a location. Not all constraints needs locations. ConstraintV c; std::vector> dependencies; diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 973c0a8e..cb5900ea 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -1,13 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details - #pragma once -#include -#include -#include - #include "Luau/Ast.h" +#include "Luau/Connective.h" #include "Luau/Constraint.h" +#include "Luau/DataFlowGraphBuilder.h" #include "Luau/Module.h" #include "Luau/ModuleResolver.h" #include "Luau/NotNull.h" @@ -15,6 +12,10 @@ #include "Luau/TypeVar.h" #include "Luau/Variant.h" +#include +#include +#include + namespace Luau { @@ -23,6 +24,34 @@ using ScopePtr = std::shared_ptr; struct DcrLogger; +struct Inference +{ + TypeId ty = nullptr; + ConnectiveId connective = nullptr; + + Inference() = default; + + explicit Inference(TypeId ty, ConnectiveId connective = nullptr) + : ty(ty) + , connective(connective) + { + } +}; + +struct InferencePack +{ + TypePackId tp = nullptr; + std::vector connectives; + + InferencePack() = default; + + explicit InferencePack(TypePackId tp, const std::vector& connectives = {}) + : tp(tp) + , connectives(connectives) + { + } +}; + struct ConstraintGraphBuilder { // A list of all the scopes in the module. This vector holds ownership of the @@ -48,6 +77,8 @@ struct ConstraintGraphBuilder DenseHashMap astResolvedTypePacks{nullptr}; // Defining scopes for AST nodes. DenseHashMap astTypeAliasDefiningScopes{nullptr}; + NotNull dfg; + ConnectiveArena connectiveArena; int recursionCount = 0; @@ -63,7 +94,8 @@ struct ConstraintGraphBuilder DcrLogger* logger; ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull moduleResolver, - NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger); + NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger, + NotNull dfg); /** * Fabricates a new free type belonging to a given scope. @@ -88,15 +120,19 @@ struct ConstraintGraphBuilder * Adds a new constraint with no dependencies to a given scope. * @param scope the scope to add the constraint to. * @param cv the constraint variant to add. + * @return the pointer to the inserted constraint */ - void addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv); + NotNull addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv); /** * Adds a constraint to a given scope. * @param scope the scope to add the constraint to. Must not be null. * @param c the constraint to add. + * @return the pointer to the inserted constraint */ - void addConstraint(const ScopePtr& scope, std::unique_ptr c); + NotNull addConstraint(const ScopePtr& scope, std::unique_ptr c); + + void applyRefinements(const ScopePtr& scope, Location location, ConnectiveId connective); /** * The entry point to the ConstraintGraphBuilder. This will construct a set @@ -126,8 +162,10 @@ struct ConstraintGraphBuilder void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); void visit(const ScopePtr& scope, AstStatError* error); - TypePackId checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes = {}); - TypePackId checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes = {}); + InferencePack checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes = {}); + InferencePack checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes = {}); + + InferencePack checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector& expectedTypes); /** * Checks an expression that is expected to evaluate to one type. @@ -137,15 +175,24 @@ struct ConstraintGraphBuilder * surrounding context. Used to implement bidirectional type checking. * @return the type of the expression. */ - TypeId check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}); + Inference check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}, bool forceSingleton = false); - TypeId check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); - TypeId check(const ScopePtr& scope, AstExprIndexName* indexName); - TypeId check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); - TypeId check(const ScopePtr& scope, AstExprUnary* unary); - TypeId check(const ScopePtr& scope, AstExprBinary* binary); - TypeId check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); - TypeId check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); + Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton); + Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional expectedType, bool forceSingleton); + Inference check(const ScopePtr& scope, AstExprLocal* local); + Inference check(const ScopePtr& scope, AstExprGlobal* global); + Inference check(const ScopePtr& scope, AstExprIndexName* indexName); + Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); + Inference check(const ScopePtr& scope, AstExprUnary* unary); + Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); + Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); + Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); + Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); + std::tuple checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); + + TypePackId checkLValues(const ScopePtr& scope, AstArray exprs); + + TypeId checkLValue(const ScopePtr& scope, AstExpr* expr); struct FunctionSignature { @@ -191,7 +238,7 @@ struct ConstraintGraphBuilder std::vector> createGenerics(const ScopePtr& scope, AstArray generics); std::vector> createGenericPacks(const ScopePtr& scope, AstArray packs); - TypeId flattenPack(const ScopePtr& scope, Location location, TypePackId tp); + Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack); void reportError(Location location, TypeErrorData err); void reportCodeTooComplex(Location location); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 0bf6d1bc..07f027ad 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -110,6 +110,7 @@ struct ConstraintSolver bool tryDispatch(const FunctionCallConstraint& c, NotNull constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); + bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint); // for a, ... in some_table do // also handles __iter metamethod @@ -215,6 +216,8 @@ private: TypeId errorRecoveryType() const; TypePackId errorRecoveryTypePack() const; + TypeId unionOfTypes(TypeId a, TypeId b, NotNull scope, bool unifyFreeTypes); + ToStringOptions opts; }; diff --git a/Analysis/include/Luau/DataFlowGraphBuilder.h b/Analysis/include/Luau/DataFlowGraphBuilder.h new file mode 100644 index 00000000..3a72403e --- /dev/null +++ b/Analysis/include/Luau/DataFlowGraphBuilder.h @@ -0,0 +1,115 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +// Do not include LValue. It should never be used here. +#include "Luau/Ast.h" +#include "Luau/DenseHash.h" +#include "Luau/Def.h" +#include "Luau/Symbol.h" + +#include + +namespace Luau +{ + +struct DataFlowGraph +{ + DataFlowGraph(DataFlowGraph&&) = default; + DataFlowGraph& operator=(DataFlowGraph&&) = default; + + // TODO: AstExprLocal, AstExprGlobal, and AstLocal* are guaranteed never to return nullopt. + // We leave them to return an optional as we build it out, but the end state is for them to return a non-optional DefId. + std::optional getDef(const AstExpr* expr) const; + std::optional getDef(const AstLocal* local) const; + + /// Retrieve the Def that corresponds to the given Symbol. + /// + /// We do not perform dataflow analysis on globals, so this function always + /// yields nullopt when passed a global Symbol. + std::optional getDef(const Symbol& symbol) const; + +private: + DataFlowGraph() = default; + + DataFlowGraph(const DataFlowGraph&) = delete; + DataFlowGraph& operator=(const DataFlowGraph&) = delete; + + DefArena arena; + DenseHashMap astDefs{nullptr}; + DenseHashMap localDefs{nullptr}; + + friend struct DataFlowGraphBuilder; +}; + +struct DfgScope +{ + DfgScope* parent; + DenseHashMap bindings{Symbol{}}; +}; + +struct ExpressionFlowGraph +{ + std::optional def; +}; + +// Currently unsound. We do not presently track the control flow of the program. +// Additionally, we do not presently track assignments. +struct DataFlowGraphBuilder +{ + static DataFlowGraph build(AstStatBlock* root, NotNull handle); + +private: + DataFlowGraphBuilder() = default; + + DataFlowGraphBuilder(const DataFlowGraphBuilder&) = delete; + DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete; + + DataFlowGraph graph; + NotNull arena{&graph.arena}; + struct InternalErrorReporter* handle; + std::vector> scopes; + + DfgScope* childScope(DfgScope* scope); + + std::optional use(DfgScope* scope, Symbol symbol, AstExpr* e); + + void visit(DfgScope* scope, AstStatBlock* b); + void visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b); + + // TODO: visit type aliases + void visit(DfgScope* scope, AstStat* s); + void visit(DfgScope* scope, AstStatIf* i); + void visit(DfgScope* scope, AstStatWhile* w); + void visit(DfgScope* scope, AstStatRepeat* r); + void visit(DfgScope* scope, AstStatBreak* b); + void visit(DfgScope* scope, AstStatContinue* c); + void visit(DfgScope* scope, AstStatReturn* r); + void visit(DfgScope* scope, AstStatExpr* e); + void visit(DfgScope* scope, AstStatLocal* l); + void visit(DfgScope* scope, AstStatFor* f); + void visit(DfgScope* scope, AstStatForIn* f); + void visit(DfgScope* scope, AstStatAssign* a); + void visit(DfgScope* scope, AstStatCompoundAssign* c); + void visit(DfgScope* scope, AstStatFunction* f); + void visit(DfgScope* scope, AstStatLocalFunction* l); + + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExpr* e); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprLocal* l); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprGlobal* g); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprCall* c); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexName* i); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexExpr* i); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprFunction* f); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTable* t); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprUnary* u); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprBinary* b); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTypeAssertion* t); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIfElse* i); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprInterpString* i); + + // TODO: visitLValue + // TODO: visitTypes (because of typeof which has access to values namespace, needs unreachable scope) + // TODO: visitTypePacks (because of typeof which has access to values namespace, needs unreachable scope) +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Def.h b/Analysis/include/Luau/Def.h new file mode 100644 index 00000000..ac1fa132 --- /dev/null +++ b/Analysis/include/Luau/Def.h @@ -0,0 +1,78 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/NotNull.h" +#include "Luau/TypedAllocator.h" +#include "Luau/Variant.h" + +namespace Luau +{ + +using Def = Variant; + +/** + * We statically approximate a value at runtime using a symbolic value, which we call a Def. + * + * DataFlowGraphBuilder will allocate these defs as a stand-in for some Luau values, and bind them to places that + * can hold a Luau value, and then observes how those defs will commute as it statically evaluate the program. + * + * It must also be noted that defs are a cyclic graph, so it is not safe to recursively traverse into it expecting it to terminate. + */ +using DefId = NotNull; + +/** + * A "single-object" value. + * + * Leaky implementation note: sometimes "multiple-object" values, but none of which were interesting enough to warrant creating a phi node instead. + * That can happen because there's no point in creating a phi node that points to either resultant in `if math.random() > 0.5 then 5 else "hello"`. + * This might become of utmost importance if we wanted to do some backward reasoning, e.g. if `5` is taken, then `cond` must be `truthy`. + */ +struct Undefined +{ +}; + +/** + * A phi node is a union of defs. + * + * We need this because we're statically evaluating a program, and sometimes a place may be assigned with + * different defs, and when that happens, we need a special data type that merges in all the defs + * that will flow into that specific place. For example, consider this simple program: + * + * ``` + * x-1 + * if cond() then + * x-2 = 5 + * else + * x-3 = "hello" + * end + * x-4 : {x-2, x-3} + * ``` + * + * At x-4, we know for a fact statically that either `5` or `"hello"` can flow into the variable `x` after the branch, but + * we cannot make any definitive decisions about which one, so we just take in both. + */ +struct Phi +{ + std::vector operands; +}; + +template +T* getMutable(DefId def) +{ + return get_if(def.get()); +} + +template +const T* get(DefId def) +{ + return getMutable(def); +} + +struct DefArena +{ + TypedAllocator allocator; + + DefId freshDef(); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 7338627c..f7bd9d50 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -7,6 +7,8 @@ #include "Luau/Variant.h" #include "Luau/TypeArena.h" +LUAU_FASTFLAG(LuauIceExceptionInheritanceChange) + namespace Luau { struct TypeError; @@ -302,12 +304,20 @@ struct NormalizationTooComplex } }; +struct TypePackMismatch +{ + TypePackId wantedTp; + TypePackId givenTp; + + bool operator==(const TypePackMismatch& rhs) const; +}; + using TypeErrorData = Variant; + TypesAreUnrelated, NormalizationTooComplex, TypePackMismatch>; struct TypeError { @@ -374,6 +384,10 @@ struct InternalErrorReporter class InternalCompilerError : public std::exception { public: + explicit InternalCompilerError(const std::string& message) + : message(message) + { + } explicit InternalCompilerError(const std::string& message, const std::string& moduleName) : message(message) , moduleName(moduleName) @@ -388,8 +402,14 @@ public: virtual const char* what() const throw(); const std::string message; - const std::string moduleName; + const std::optional moduleName; const std::optional location; }; +// These two function overloads only exist to facilitate fast flagging a change to InternalCompilerError +// Both functions can be removed when FFlagLuauIceExceptionInheritanceChange is removed and calling code +// can directly throw InternalCompilerError. +[[noreturn]] void throwRuntimeError(const std::string& message); +[[noreturn]] void throwRuntimeError(const std::string& message, const std::string& moduleName); + } // namespace Luau diff --git a/Analysis/include/Luau/LValue.h b/Analysis/include/Luau/LValue.h index 1a92d52d..518cbfaf 100644 --- a/Analysis/include/Luau/LValue.h +++ b/Analysis/include/Luau/LValue.h @@ -14,6 +14,8 @@ struct TypeVar; using TypeId = const TypeVar*; struct Field; + +// Deprecated. Do not use in new work. using LValue = Variant; struct Field diff --git a/Analysis/include/Luau/Metamethods.h b/Analysis/include/Luau/Metamethods.h new file mode 100644 index 00000000..84b0092f --- /dev/null +++ b/Analysis/include/Luau/Metamethods.h @@ -0,0 +1,32 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" + +#include + +namespace Luau +{ + +static const std::unordered_map kBinaryOpMetamethods{ + {AstExprBinary::Op::CompareEq, "__eq"}, + {AstExprBinary::Op::CompareNe, "__eq"}, + {AstExprBinary::Op::CompareGe, "__lt"}, + {AstExprBinary::Op::CompareGt, "__le"}, + {AstExprBinary::Op::CompareLe, "__le"}, + {AstExprBinary::Op::CompareLt, "__lt"}, + {AstExprBinary::Op::Add, "__add"}, + {AstExprBinary::Op::Sub, "__sub"}, + {AstExprBinary::Op::Mul, "__mul"}, + {AstExprBinary::Op::Div, "__div"}, + {AstExprBinary::Op::Pow, "__pow"}, + {AstExprBinary::Op::Mod, "__mod"}, + {AstExprBinary::Op::Concat, "__concat"}, +}; + +static const std::unordered_map kUnaryOpMetamethods{ + {AstExprUnary::Op::Minus, "__unm"}, + {AstExprUnary::Op::Len, "__len"}, +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 72ea9558..b28c06a5 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -17,19 +17,8 @@ struct SingletonTypes; using ModulePtr = std::shared_ptr; -bool isSubtype( - TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop = true); -bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, - bool anyIsTop = true); - -std::pair normalize( - TypeId ty, NotNull scope, TypeArena& arena, NotNull singletonTypes, InternalErrorReporter& ice); -std::pair normalize(TypeId ty, NotNull module, NotNull singletonTypes, InternalErrorReporter& ice); -std::pair normalize(TypeId ty, const ModulePtr& module, NotNull singletonTypes, InternalErrorReporter& ice); -std::pair normalize( - TypePackId ty, NotNull scope, TypeArena& arena, NotNull singletonTypes, InternalErrorReporter& ice); -std::pair normalize(TypePackId ty, NotNull module, NotNull singletonTypes, InternalErrorReporter& ice); -std::pair normalize(TypePackId ty, const ModulePtr& module, NotNull singletonTypes, InternalErrorReporter& ice); +bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice); +bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice); class TypeIds { @@ -115,16 +104,89 @@ struct std::equal_to namespace Luau { -// A normalized string type is either `string` (represented by `nullopt`) -// or a union of string singletons. -using NormalizedStringType = std::optional>; +/** A normalized string type is either `string` (represented by `nullopt`) or a + * union of string singletons. + * + * When FFlagLuauNegatedStringSingletons is unset, the representation is as + * follows: + * + * * The `string` data type is represented by the option `singletons` having the + * value `std::nullopt`. + * * The type `never` is represented by `singletons` being populated with an + * empty map. + * * A union of string singletons is represented by a map populated by the names + * and TypeIds of the singletons contained therein. + * + * When FFlagLuauNegatedStringSingletons is set, the representation is as + * follows: + * + * * A union of string singletons is finite and includes the singletons named by + * the `singletons` field. + * * An intersection of negated string singletons is cofinite and includes the + * singletons excluded by the `singletons` field. It is implied that cofinite + * values are exclusions from `string` itself. + * * The `string` data type is a cofinite set minus zero elements. + * * The `never` data type is a finite set plus zero elements. + */ +struct NormalizedStringType +{ + // When false, this type represents a union of singleton string types. + // eg "a" | "b" | "c" + // + // When true, this type represents string intersected with negated string + // singleton types. + // eg string & ~"a" & ~"b" & ... + bool isCofinite = false; -// A normalized function type is either `never` (represented by `nullopt`) + // TODO: This field cannot be nullopt when FFlagLuauNegatedStringSingletons + // is set. When clipping that flag, we can remove the wrapping optional. + std::optional> singletons; + + void resetToString(); + void resetToNever(); + + bool isNever() const; + bool isString() const; + + /// Returns true if the string has finite domain. + /// + /// Important subtlety: This method returns true for `never`. The empty set + /// is indeed an empty set. + bool isUnion() const; + + /// Returns true if the string has infinite domain. + bool isIntersection() const; + + bool includes(const std::string& str) const; + + static const NormalizedStringType never; + + NormalizedStringType() = default; + NormalizedStringType(bool isCofinite, std::optional> singletons); +}; + +bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr); + +// A normalized function type can be `never`, the top function type `function`, // or an intersection of function types. -// NOTE: type normalization can fail on function types with generics -// (e.g. because we do not support unions and intersections of generic type packs), -// so this type may contain `error`. -using NormalizedFunctionType = std::optional; +// +// NOTE: type normalization can fail on function types with generics (e.g. +// because we do not support unions and intersections of generic type packs), so +// this type may contain `error`. +struct NormalizedFunctionType +{ + NormalizedFunctionType(); + + bool isTop = false; + // TODO: Remove this wrapping optional when clipping + // FFlagLuauNegatedFunctionTypes. + std::optional parts; + + void resetToNever(); + void resetToTop(); + + bool isNever() const; +}; // A normalized generic/free type is a union, where each option is of the form (X & T) where // * X is either a free type or a generic @@ -166,7 +228,7 @@ struct NormalizedType // The string part of the type. // This may be the `string` type, or a union of singletons. - NormalizedStringType strings = std::map{}; + NormalizedStringType strings; // The thread part of the type. // This type is either never or thread. @@ -184,12 +246,14 @@ struct NormalizedType NormalizedType(NotNull singletonTypes); - NormalizedType(const NormalizedType&) = delete; - NormalizedType(NormalizedType&&) = default; NormalizedType() = delete; ~NormalizedType() = default; + + NormalizedType(const NormalizedType&) = delete; + NormalizedType& operator=(const NormalizedType&) = delete; + + NormalizedType(NormalizedType&&) = default; NormalizedType& operator=(NormalizedType&&) = default; - NormalizedType& operator=(NormalizedType&) = delete; }; class Normalizer @@ -240,8 +304,14 @@ public: bool unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1); + // ------- Negations + std::optional negateNormal(const NormalizedType& here); + TypeIds negateAll(const TypeIds& theres); + TypeId negate(TypeId there); + void subtractPrimitive(NormalizedType& here, TypeId ty); + void subtractSingleton(NormalizedType& here, TypeId ty); + // ------- Normalizing intersections - void intersectTysWithTy(TypeIds& here, TypeId there); TypeId intersectionOfTops(TypeId here, TypeId there); TypeId intersectionOfBools(TypeId here, TypeId there); void intersectClasses(TypeIds& heres, const TypeIds& theres); diff --git a/Analysis/include/Luau/RecursionCounter.h b/Analysis/include/Luau/RecursionCounter.h index f964dbfe..632afd19 100644 --- a/Analysis/include/Luau/RecursionCounter.h +++ b/Analysis/include/Luau/RecursionCounter.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Common.h" +#include "Luau/Error.h" #include #include @@ -9,10 +10,20 @@ namespace Luau { -struct RecursionLimitException : public std::exception +struct RecursionLimitException : public InternalCompilerError +{ + RecursionLimitException() + : InternalCompilerError("Internal recursion counter limit exceeded") + { + LUAU_ASSERT(FFlag::LuauIceExceptionInheritanceChange); + } +}; + +struct RecursionLimitException_DEPRECATED : public std::exception { const char* what() const noexcept { + LUAU_ASSERT(!FFlag::LuauIceExceptionInheritanceChange); return "Internal recursion counter limit exceeded"; } }; @@ -42,7 +53,14 @@ struct RecursionLimiter : RecursionCounter { if (limit > 0 && *count > limit) { - throw RecursionLimitException(); + if (FFlag::LuauIceExceptionInheritanceChange) + { + throw RecursionLimitException(); + } + else + { + throw RecursionLimitException_DEPRECATED(); + } } } }; diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index b2da7bc0..ccf2964c 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -54,7 +54,9 @@ struct Scope DenseHashSet builtinTypeNames{""}; void addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun); - std::optional lookup(Symbol sym); + std::optional lookup(Symbol sym) const; + std::optional lookup(DefId def) const; + std::optional> lookupEx(Symbol sym); std::optional lookupType(const Name& name); std::optional lookupImportedType(const Name& moduleAlias, const Name& name); @@ -66,6 +68,7 @@ struct Scope std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true) const; RefinementMap refinements; + DenseHashMap dcrRefinements{nullptr}; // For mutually recursive type aliases, it's important that // they use the same types for the same names. diff --git a/Analysis/include/Luau/Symbol.h b/Analysis/include/Luau/Symbol.h index 1fe037e5..0432946c 100644 --- a/Analysis/include/Luau/Symbol.h +++ b/Analysis/include/Luau/Symbol.h @@ -6,10 +6,11 @@ #include +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + namespace Luau { -// TODO Rename this to Name once the old type alias is gone. struct Symbol { Symbol() @@ -40,9 +41,12 @@ struct Symbol { if (local) return local == rhs.local; - if (global.value) + else if (global.value) return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. - return false; + else if (FFlag::DebugLuauDeferredConstraintResolution) + return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is. + else + return false; } bool operator!=(const Symbol& rhs) const @@ -58,8 +62,8 @@ struct Symbol return global < rhs.global; else if (local) return true; - else - return false; + + return false; } AstName astName() const diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index dd2d709b..ff2561e6 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -117,6 +117,8 @@ inline std::string toStringNamedFunction(const std::string& funcName, const Func return toStringNamedFunction(funcName, ftv, opts); } +std::optional getFunctionNameAsString(const AstExpr& expr); + // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression std::string dump(TypeId ty); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 1c4d1cb4..c5d7501d 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -48,7 +48,17 @@ struct HashBoolNamePair size_t operator()(const std::pair& pair) const; }; -class TimeLimitError : public std::exception +class TimeLimitError : public InternalCompilerError +{ +public: + explicit TimeLimitError(const std::string& moduleName) + : InternalCompilerError("Typeinfer failed to complete in allotted time", moduleName) + { + LUAU_ASSERT(FFlag::LuauIceExceptionInheritanceChange); + } +}; + +class TimeLimitError_DEPRECATED : public std::exception { public: virtual const char* what() const throw(); @@ -192,18 +202,12 @@ struct TypeChecker ErrorVec canUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location); ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location); - void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const ScopePtr& scope, const Location& location); - std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors); std::optional getIndexTypeFromType(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); std::optional getIndexTypeFromTypeImpl(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); - // Reduces the union to its simplest possible shape. - // (A | B) | B | C yields A | B | C - std::vector reduceUnion(const std::vector& types); - std::optional tryStripUnionFromNil(TypeId ty); TypeId stripFromNilAndReport(TypeId ty, const Location& location); @@ -242,6 +246,7 @@ public: [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); + [[noreturn]] void throwTimeLimitError(); ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0); ScopePtr childScope(const ScopePtr& parent, const Location& location); diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index e5a205ba..085ee21b 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -29,4 +29,23 @@ std::pair> getParameterExtents(const TxnLog* log, // various other things to get there. std::vector flatten(TypeArena& arena, NotNull singletonTypes, TypePackId pack, size_t length); +/** + * Reduces a union by decomposing to the any/error type if it appears in the + * type list, and by merging child unions. Also strips out duplicate (by pointer + * identity) types. + * @param types the input type list to reduce. + * @returns the reduced type list. + */ +std::vector reduceUnion(const std::vector& types); + +/** + * Tries to remove nil from a union type, if there's another option. T | nil + * reduces to T, but nil itself does not reduce. + * @param singletonTypes the singleton types to use + * @param arena the type arena to allocate the new type in, if necessary + * @param ty the type to remove nil from + * @returns a type with nil removed, or nil itself if that were the only option. + */ +TypeId stripNil(NotNull singletonTypes, TypeArena& arena, TypeId ty); + } // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 1d587ffe..0ab4d474 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -2,22 +2,23 @@ #pragma once #include "Luau/Ast.h" +#include "Luau/Common.h" #include "Luau/DenseHash.h" +#include "Luau/Def.h" +#include "Luau/NotNull.h" #include "Luau/Predicate.h" #include "Luau/Unifiable.h" #include "Luau/Variant.h" -#include "Luau/Common.h" -#include "Luau/NotNull.h" -#include -#include -#include -#include -#include -#include #include +#include #include #include +#include +#include +#include +#include +#include LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeMaximumStringifierLength) @@ -114,6 +115,7 @@ struct PrimitiveTypeVar Number, String, Thread, + Function, }; Type type; @@ -131,24 +133,6 @@ struct PrimitiveTypeVar } }; -struct ConstrainedTypeVar -{ - explicit ConstrainedTypeVar(TypeLevel level) - : level(level) - { - } - - explicit ConstrainedTypeVar(TypeLevel level, const std::vector& parts) - : parts(parts) - , level(level) - { - } - - std::vector parts; - TypeLevel level; - Scope* scope = nullptr; -}; - // Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md // Types for true and false struct BooleanSingleton @@ -496,11 +480,13 @@ struct AnyTypeVar { }; +// T | U struct UnionTypeVar { std::vector options; }; +// T & U struct IntersectionTypeVar { std::vector parts; @@ -519,12 +505,19 @@ struct NeverTypeVar { }; +// ~T +// TODO: Some simplification step that overwrites the type graph to make sure negation +// types disappear from the user's view, and (?) a debug flag to disable that +struct NegationTypeVar +{ + TypeId ty; +}; + using ErrorTypeVar = Unifiable::Error; using TypeVariant = - Unifiable::Variant; - + Unifiable::Variant; struct TypeVar final { @@ -541,7 +534,6 @@ struct TypeVar final TypeVar(const TypeVariant& ty, bool persistent) : ty(ty) , persistent(persistent) - , normal(persistent) // We assume that all persistent types are irreducable. { } @@ -549,7 +541,6 @@ struct TypeVar final void reassign(const TypeVar& rhs) { ty = rhs.ty; - normal = rhs.normal; documentationSymbol = rhs.documentationSymbol; } @@ -560,10 +551,6 @@ struct TypeVar final // Persistent TypeVars do not get cloned. bool persistent = false; - // Normalization sets this for types that are fully normalized. - // This implies that they are transitively immutable. - bool normal = false; - std::optional documentationSymbol; // Pointer to the type arena that allocated this type. @@ -650,12 +637,15 @@ public: const TypeId stringType; const TypeId booleanType; const TypeId threadType; + const TypeId functionType; const TypeId trueType; const TypeId falseType; const TypeId anyType; const TypeId unknownType; const TypeId neverType; const TypeId errorType; + const TypeId falsyType; // No type binding! + const TypeId truthyType; // No type binding! const TypePackId anyTypePack; const TypePackId neverTypePack; @@ -703,7 +693,6 @@ T* getMutable(TypeId tv) const std::vector& getTypes(const UnionTypeVar* utv); const std::vector& getTypes(const IntersectionTypeVar* itv); -const std::vector& getTypes(const ConstrainedTypeVar* ctv); template struct TypeIterator; @@ -716,10 +705,6 @@ using IntersectionTypeVarIterator = TypeIterator; IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv); IntersectionTypeVarIterator end(const IntersectionTypeVar* itv); -using ConstrainedTypeVarIterator = TypeIterator; -ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv); -ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv); - /* Traverses the type T yielding each TypeId. * If the iterator encounters a nested type T, it will instead yield each TypeId within. */ @@ -793,7 +778,6 @@ struct TypeIterator // with templates portability in this area, so not worth it. Thanks MSVC. friend UnionTypeVarIterator end(const UnionTypeVar*); friend IntersectionTypeVarIterator end(const IntersectionTypeVar*); - friend ConstrainedTypeVarIterator end(const ConstrainedTypeVar*); private: TypeIterator() = default; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 10f3f48c..b5f58d3c 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -61,7 +61,6 @@ struct Unifier ErrorVec errors; Location location; Variance variance = Covariant; - bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once. bool normalize; // Normalize unions and intersections if necessary bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels CountMismatch::Context ctx = CountMismatch::Arg; @@ -96,6 +95,8 @@ private: void tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); + void tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy); + void tryUnifyNegationWithType(TypeId subTy, TypeId superTy); TypePackId tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args); @@ -119,12 +120,7 @@ private: std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name); - void tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy); - void tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy); - public: - void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel); - // Returns true if the type "needle" already occurs within "haystack" and reports an "infinite type error" bool occursCheck(TypeId needle, TypeId haystack); bool occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack); @@ -134,6 +130,7 @@ public: Unifier makeChildUnifier(); void reportError(TypeError err); + LUAU_NOINLINE void reportError(Location location, TypeErrorData data); private: bool isNonstrictMode() const; diff --git a/Analysis/include/Luau/Variant.h b/Analysis/include/Luau/Variant.h index f637222e..016c51f6 100644 --- a/Analysis/include/Luau/Variant.h +++ b/Analysis/include/Luau/Variant.h @@ -58,13 +58,15 @@ public: constexpr int tid = getTypeId(); typeId = tid; - new (&storage) TT(value); + new (&storage) TT(std::forward(value)); } Variant(const Variant& other) { + static constexpr FnCopy table[sizeof...(Ts)] = {&fnCopy...}; + typeId = other.typeId; - tableCopy[typeId](&storage, &other.storage); + table[typeId](&storage, &other.storage); } Variant(Variant&& other) @@ -105,7 +107,7 @@ public: tableDtor[typeId](&storage); typeId = tid; - new (&storage) TT(std::forward(args)...); + new (&storage) TT{std::forward(args)...}; return *reinterpret_cast(&storage); } @@ -192,7 +194,6 @@ private: return *static_cast(lhs) == *static_cast(rhs); } - static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy...}; static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove...}; static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor...}; diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 315e5992..3dcddba1 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -103,10 +103,6 @@ struct GenericTypeVarVisitor { return visit(ty); } - virtual bool visit(TypeId ty, const ConstrainedTypeVar& ctv) - { - return visit(ty); - } virtual bool visit(TypeId ty, const PrimitiveTypeVar& ptv) { return visit(ty); @@ -159,6 +155,10 @@ struct GenericTypeVarVisitor { return visit(ty); } + virtual bool visit(TypeId ty, const NegationTypeVar& ntv) + { + return visit(ty); + } virtual bool visit(TypePackId tp) { @@ -216,14 +216,6 @@ struct GenericTypeVarVisitor visit(ty, *gtv); else if (auto etv = get(ty)) visit(ty, *etv); - else if (auto ctv = get(ty)) - { - if (visit(ty, *ctv)) - { - for (TypeId part : ctv->parts) - traverse(part); - } - } else if (auto ptv = get(ty)) visit(ty, *ptv); else if (auto ftv = get(ty)) @@ -325,6 +317,8 @@ struct GenericTypeVarVisitor traverse(a); } } + else if (auto ntv = get(ty)) + visit(ty, *ntv); else if (!FFlag::LuauCompleteVisitor) return visit_detail::unsee(seen, ty); else diff --git a/Analysis/src/Anyification.cpp b/Analysis/src/Anyification.cpp index cc9796ee..5dd761c2 100644 --- a/Analysis/src/Anyification.cpp +++ b/Analysis/src/Anyification.cpp @@ -37,8 +37,6 @@ bool Anyification::isDirty(TypeId ty) return (ttv->state == TableState::Free || ttv->state == TableState::Unsealed); else if (log->getMutable(ty)) return true; - else if (get(ty)) - return true; else return false; } @@ -65,20 +63,8 @@ TypeId Anyification::clean(TypeId ty) clone.syntheticName = ttv->syntheticName; clone.tags = ttv->tags; TypeId res = addType(std::move(clone)); - asMutable(res)->normal = ty->normal; return res; } - else if (auto ctv = get(ty)) - { - std::vector copy = ctv->parts; - for (TypeId& ty : copy) - ty = replace(ty); - TypeId res = copy.size() == 1 ? copy[0] : addType(UnionTypeVar{std::move(copy)}); - auto [t, ok] = normalize(res, scope, *arena, singletonTypes, *iceHandler); - if (!ok) - normalizationTooComplex = true; - return t; - } else return anyType; } diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 50299704..b93c2cc2 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -11,6 +11,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauCheckOverloadedDocSymbol, false) + namespace Luau { @@ -427,6 +429,38 @@ ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos) return findVisitor.result; } +static std::optional checkOverloadedDocumentationSymbol( + const Module& module, const TypeId ty, const AstExpr* parentExpr, const std::optional documentationSymbol) +{ + LUAU_ASSERT(FFlag::LuauCheckOverloadedDocSymbol); + + if (!documentationSymbol) + return std::nullopt; + + // This might be an overloaded function. + if (get(follow(ty))) + { + TypeId matchingOverload = nullptr; + if (parentExpr && parentExpr->is()) + { + if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) + { + matchingOverload = *it; + } + } + + if (matchingOverload) + { + std::string overloadSymbol = *documentationSymbol + "/overload/"; + // Default toString options are fine for this purpose. + overloadSymbol += toString(matchingOverload); + return overloadSymbol; + } + } + + return documentationSymbol; +} + std::optional getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position) { std::vector ancestry = findAstAncestryOfPosition(source, position); @@ -436,31 +470,38 @@ std::optional getDocumentationSymbolAtPosition(const Source if (std::optional binding = findBindingAtPosition(module, source, position)) { - if (binding->documentationSymbol) + if (FFlag::LuauCheckOverloadedDocSymbol) { - // This might be an overloaded function binding. - if (get(follow(binding->typeId))) + return checkOverloadedDocumentationSymbol(module, binding->typeId, parentExpr, binding->documentationSymbol); + } + else + { + if (binding->documentationSymbol) { - TypeId matchingOverload = nullptr; - if (parentExpr && parentExpr->is()) + // This might be an overloaded function binding. + if (get(follow(binding->typeId))) { - if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) + TypeId matchingOverload = nullptr; + if (parentExpr && parentExpr->is()) { - matchingOverload = *it; + if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) + { + matchingOverload = *it; + } + } + + if (matchingOverload) + { + std::string overloadSymbol = *binding->documentationSymbol + "/overload/"; + // Default toString options are fine for this purpose. + overloadSymbol += toString(matchingOverload); + return overloadSymbol; } } - - if (matchingOverload) - { - std::string overloadSymbol = *binding->documentationSymbol + "/overload/"; - // Default toString options are fine for this purpose. - overloadSymbol += toString(matchingOverload); - return overloadSymbol; - } } - } - return binding->documentationSymbol; + return binding->documentationSymbol; + } } if (targetExpr) @@ -474,14 +515,20 @@ std::optional getDocumentationSymbolAtPosition(const Source { if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) { - return propIt->second.documentationSymbol; + if (FFlag::LuauCheckOverloadedDocSymbol) + return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); + else + return propIt->second.documentationSymbol; } } else if (const ClassTypeVar* ctv = get(parentTy)) { if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) { - return propIt->second.documentationSymbol; + if (FFlag::LuauCheckOverloadedDocSymbol) + return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); + else + return propIt->second.documentationSymbol; } } } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index c5250a6d..ee53ae6b 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -10,6 +10,7 @@ #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" +#include "Luau/TypeUtils.h" #include @@ -41,6 +42,7 @@ static std::optional> magicFunctionRequire( static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); static bool dcrMagicFunctionRequire(MagicFunctionCallContext context); +static bool dcrMagicFunctionPack(MagicFunctionCallContext context); TypeId makeUnion(TypeArena& arena, std::vector&& types) { @@ -333,6 +335,7 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); + attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack); } attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); @@ -660,7 +663,7 @@ static std::optional> magicFunctionPack( options.push_back(vtp->ty); } - options = typechecker.reduceUnion(options); + options = reduceUnion(options); // table.pack() -> {| n: number, [number]: nil |} // table.pack(1) -> {| n: number, [number]: number |} @@ -679,6 +682,46 @@ static std::optional> magicFunctionPack( return WithPredicate{arena.addTypePack({packedTable})}; } +static bool dcrMagicFunctionPack(MagicFunctionCallContext context) +{ + + TypeArena* arena = context.solver->arena; + + const auto& [paramTypes, paramTail] = flatten(context.arguments); + + std::vector options; + options.reserve(paramTypes.size()); + for (auto type : paramTypes) + options.push_back(type); + + if (paramTail) + { + if (const VariadicTypePack* vtp = get(*paramTail)) + options.push_back(vtp->ty); + } + + options = reduceUnion(options); + + // table.pack() -> {| n: number, [number]: nil |} + // table.pack(1) -> {| n: number, [number]: number |} + // table.pack(1, "foo") -> {| n: number, [number]: number | string |} + TypeId result = nullptr; + if (options.empty()) + result = context.solver->singletonTypes->nilType; + else if (options.size() == 1) + result = options[0]; + else + result = arena->addType(UnionTypeVar{std::move(options)}); + + TypeId numberType = context.solver->singletonTypes->numberType; + TypeId packedTable = arena->addType(TableTypeVar{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed}); + + TypePackId tableTypePack = arena->addTypePack({packedTable}); + asMutable(context.result)->ty.emplace(tableTypePack); + + return true; +} + static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) { // require(foo.parent.bar) will technically work, but it depends on legacy goop that diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index fd3a089b..86e1c7fc 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -1,6 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details - #include "Luau/Clone.h" + #include "Luau/RecursionCounter.h" #include "Luau/TxnLog.h" #include "Luau/TypePack.h" @@ -51,7 +51,6 @@ struct TypeCloner void operator()(const BlockedTypeVar& t); void operator()(const PendingExpansionTypeVar& t); void operator()(const PrimitiveTypeVar& t); - void operator()(const ConstrainedTypeVar& t); void operator()(const SingletonTypeVar& t); void operator()(const FunctionTypeVar& t); void operator()(const TableTypeVar& t); @@ -63,6 +62,7 @@ struct TypeCloner void operator()(const LazyTypeVar& t); void operator()(const UnknownTypeVar& t); void operator()(const NeverTypeVar& t); + void operator()(const NegationTypeVar& t); }; struct TypePackCloner @@ -198,21 +198,6 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t) defaultClone(t); } -void TypeCloner::operator()(const ConstrainedTypeVar& t) -{ - TypeId res = dest.addType(ConstrainedTypeVar{t.level}); - ConstrainedTypeVar* ctv = getMutable(res); - LUAU_ASSERT(ctv); - - seenTypes[typeId] = res; - - std::vector parts; - for (TypeId part : t.parts) - parts.push_back(clone(part, dest, cloneState)); - - ctv->parts = std::move(parts); -} - void TypeCloner::operator()(const SingletonTypeVar& t) { defaultClone(t); @@ -352,6 +337,15 @@ void TypeCloner::operator()(const NeverTypeVar& t) defaultClone(t); } +void TypeCloner::operator()(const NegationTypeVar& t) +{ + TypeId result = dest.addType(AnyTypeVar{}); + seenTypes[typeId] = result; + + TypeId ty = clone(t.ty, dest, cloneState); + asMutable(result)->ty = NegationTypeVar{ty}; +} + } // anonymous namespace TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) @@ -390,7 +384,6 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) if (!res->persistent) { asMutable(res)->documentationSymbol = typeId->documentationSymbol; - asMutable(res)->normal = typeId->normal; } } @@ -478,11 +471,6 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl clone.parts = itv->parts; result = dest.addType(std::move(clone)); } - else if (const ConstrainedTypeVar* ctv = get(ty)) - { - ConstrainedTypeVar clone{ctv->level, ctv->parts}; - result = dest.addType(std::move(clone)); - } else if (const PendingExpansionTypeVar* petv = get(ty)) { PendingExpansionTypeVar clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments}; @@ -497,6 +485,10 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl { result = dest.addType(*ty); } + else if (const NegationTypeVar* ntv = get(ty)) + { + result = dest.addType(NegationTypeVar{ntv->ty}); + } else return result; @@ -504,4 +496,9 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl return result; } +TypeId shallowClone(TypeId ty, NotNull dest) +{ + return shallowClone(ty, *dest, TxnLog::empty()); +} + } // namespace Luau diff --git a/Analysis/src/Connective.cpp b/Analysis/src/Connective.cpp new file mode 100644 index 00000000..114b5f2f --- /dev/null +++ b/Analysis/src/Connective.cpp @@ -0,0 +1,32 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Connective.h" + +namespace Luau +{ + +ConnectiveId ConnectiveArena::negation(ConnectiveId connective) +{ + return NotNull{allocator.allocate(Negation{connective})}; +} + +ConnectiveId ConnectiveArena::conjunction(ConnectiveId lhs, ConnectiveId rhs) +{ + return NotNull{allocator.allocate(Conjunction{lhs, rhs})}; +} + +ConnectiveId ConnectiveArena::disjunction(ConnectiveId lhs, ConnectiveId rhs) +{ + return NotNull{allocator.allocate(Disjunction{lhs, rhs})}; +} + +ConnectiveId ConnectiveArena::equivalence(ConnectiveId lhs, ConnectiveId rhs) +{ + return NotNull{allocator.allocate(Equivalence{lhs, rhs})}; +} + +ConnectiveId ConnectiveArena::proposition(DefId def, TypeId discriminantTy) +{ + return NotNull{allocator.allocate(Proposition{def, discriminantTy})}; +} + +} // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 8436fb30..79a69ca4 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -1,20 +1,21 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details - #include "Luau/ConstraintGraphBuilder.h" + #include "Luau/Ast.h" +#include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/Constraint.h" +#include "Luau/DcrLogger.h" #include "Luau/ModuleResolver.h" #include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" #include "Luau/ToString.h" -#include "Luau/DcrLogger.h" +#include "Luau/TypeUtils.h" LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); -#include "Luau/Scope.h" - namespace Luau { @@ -53,12 +54,13 @@ static bool matchSetmetatable(const AstExprCall& call) ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull moduleResolver, NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, - DcrLogger* logger) + DcrLogger* logger, NotNull dfg) : moduleName(moduleName) , module(module) , singletonTypes(singletonTypes) , arena(arena) , rootScope(nullptr) + , dfg(dfg) , moduleResolver(moduleResolver) , ice(ice) , globalScope(globalScope) @@ -95,14 +97,109 @@ ScopePtr ConstraintGraphBuilder::childScope(AstNode* node, const ScopePtr& paren return scope; } -void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv) +NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv) { - scope->constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}); + return NotNull{scope->constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}).get()}; } -void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr c) +NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr c) { - scope->constraints.emplace_back(std::move(c)); + return NotNull{scope->constraints.emplace_back(std::move(c)).get()}; +} + +static void unionRefinements(const std::unordered_map& lhs, const std::unordered_map& rhs, + std::unordered_map& dest, NotNull arena) +{ + for (auto [def, ty] : lhs) + { + auto rhsIt = rhs.find(def); + if (rhsIt == rhs.end()) + continue; + + std::vector discriminants{{ty, rhsIt->second}}; + + if (auto destIt = dest.find(def); destIt != dest.end()) + discriminants.push_back(destIt->second); + + dest[def] = arena->addType(UnionTypeVar{std::move(discriminants)}); + } +} + +static void computeRefinement(const ScopePtr& scope, ConnectiveId connective, std::unordered_map* refis, bool sense, + NotNull arena, bool eq, std::vector* constraints) +{ + using RefinementMap = std::unordered_map; + + if (!connective) + return; + else if (auto negation = get(connective)) + return computeRefinement(scope, negation->connective, refis, !sense, arena, eq, constraints); + else if (auto conjunction = get(connective)) + { + RefinementMap lhsRefis; + RefinementMap rhsRefis; + + computeRefinement(scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, arena, eq, constraints); + computeRefinement(scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, arena, eq, constraints); + + if (!sense) + unionRefinements(lhsRefis, rhsRefis, *refis, arena); + } + else if (auto disjunction = get(connective)) + { + RefinementMap lhsRefis; + RefinementMap rhsRefis; + + computeRefinement(scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, arena, eq, constraints); + computeRefinement(scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, arena, eq, constraints); + + if (sense) + unionRefinements(lhsRefis, rhsRefis, *refis, arena); + } + else if (auto equivalence = get(connective)) + { + computeRefinement(scope, equivalence->lhs, refis, sense, arena, true, constraints); + computeRefinement(scope, equivalence->rhs, refis, sense, arena, true, constraints); + } + else if (auto proposition = get(connective)) + { + TypeId discriminantTy = proposition->discriminantTy; + if (!sense && !eq) + discriminantTy = arena->addType(NegationTypeVar{proposition->discriminantTy}); + else if (!sense && eq) + { + discriminantTy = arena->addType(BlockedTypeVar{}); + constraints->push_back(SingletonOrTopTypeConstraint{discriminantTy, proposition->discriminantTy}); + } + + if (auto it = refis->find(proposition->def); it != refis->end()) + (*refis)[proposition->def] = arena->addType(IntersectionTypeVar{{discriminantTy, it->second}}); + else + (*refis)[proposition->def] = discriminantTy; + } +} + +void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location location, ConnectiveId connective) +{ + if (!connective) + return; + + std::unordered_map refinements; + std::vector constraints; + computeRefinement(scope, connective, &refinements, /*sense*/ true, arena, /*eq*/ false, &constraints); + + for (auto [def, discriminantTy] : refinements) + { + std::optional defTy = scope->lookup(def); + if (!defTy) + ice->ice("Every DefId must map to a type!"); + + TypeId resultTy = arena->addType(IntersectionTypeVar{{*defTy, discriminantTy}}); + scope->dcrRefinements[def] = resultTy; + } + + for (auto& c : constraints) + addConstraint(scope, location, c); } void ConstraintGraphBuilder::visit(AstStatBlock* block) @@ -229,22 +326,16 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) { std::vector varTypes; + varTypes.reserve(local->vars.size); for (AstLocal* local : local->vars) { TypeId ty = nullptr; - Location location = local->location; if (local->annotation) - { - location = local->annotation->location; ty = resolveType(scope, local->annotation, /* topLevel */ true); - } - else - ty = freshType(scope); varTypes.push_back(ty); - scope->bindings[local] = Binding{ty, location}; } for (size_t i = 0; i < local->values.size; ++i) @@ -254,35 +345,77 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (value->is()) { - // HACK: we leave nil-initialized things floating under the assumption that they will later be populated. - // See the test TypeInfer/infer_locals_with_nil_value. - // Better flow awareness should make this obsolete. - } - else if (i == local->values.size - 1) - { - std::vector expectedTypes; - if (hasAnnotation) - expectedTypes.insert(begin(expectedTypes), begin(varTypes) + i, end(varTypes)); + // HACK: we leave nil-initialized things floating under the + // assumption that they will later be populated. + // + // See the test TypeInfer/infer_locals_with_nil_value. Better flow + // awareness should make this obsolete. - TypePackId exprPack = checkPack(scope, value, expectedTypes); - - if (i < local->vars.size) - { - std::vector tailValues{varTypes.begin() + i, varTypes.end()}; - TypePackId tailPack = arena->addTypePack(std::move(tailValues)); - addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack}); - } + if (!varTypes[i]) + varTypes[i] = freshType(scope); } - else + // Only function calls and vararg expressions can produce packs. All + // other expressions produce exactly one value. + else if (i != local->values.size - 1 || (!value->is() && !value->is())) { std::optional expectedType; if (hasAnnotation) expectedType = varTypes.at(i); - TypeId exprType = check(scope, value, expectedType); + TypeId exprType = check(scope, value, expectedType).ty; if (i < varTypes.size()) - addConstraint(scope, local->location, SubtypeConstraint{varTypes[i], exprType}); + { + if (varTypes[i]) + addConstraint(scope, local->location, SubtypeConstraint{exprType, varTypes[i]}); + else + varTypes[i] = exprType; + } } + else + { + std::vector expectedTypes; + if (hasAnnotation) + expectedTypes.insert(begin(expectedTypes), begin(varTypes) + i, end(varTypes)); + + TypePackId exprPack = checkPack(scope, value, expectedTypes).tp; + + if (i < local->vars.size) + { + std::vector packTypes = flatten(*arena, singletonTypes, exprPack, varTypes.size() - i); + + // fill out missing values in varTypes with values from exprPack + for (size_t j = i; j < varTypes.size(); ++j) + { + if (!varTypes[j]) + { + if (j - i < packTypes.size()) + varTypes[j] = packTypes[j - i]; + else + varTypes[j] = freshType(scope); + } + } + + std::vector tailValues{varTypes.begin() + i, varTypes.end()}; + TypePackId tailPack = arena->addTypePack(std::move(tailValues)); + addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack}); + } + } + } + + for (size_t i = 0; i < local->vars.size; ++i) + { + AstLocal* l = local->vars.data[i]; + Location location = l->location; + + if (!varTypes[i]) + varTypes[i] = freshType(scope); + + scope->bindings[l] = Binding{varTypes[i], location}; + + // HACK: In the greedy solver, we say the type state of a variable is the type annotation itself, but + // the actual type state is the corresponding initializer expression (if it exists) or nil otherwise. + if (auto def = dfg->getDef(l)) + scope->dcrRefinements[*def] = varTypes[i]; } if (local->values.size > 0) @@ -316,7 +449,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) if (!expr) return; - TypeId t = check(scope, expr); + TypeId t = check(scope, expr).ty; addConstraint(scope, expr->location, SubtypeConstraint{t, singletonTypes->numberType}); }; @@ -334,7 +467,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn) { ScopePtr loopScope = childScope(forIn, scope); - TypePackId iterator = checkPack(scope, forIn->values); + TypePackId iterator = checkPack(scope, forIn->values).tp; std::vector variableTypes; variableTypes.reserve(forIn->vars.size); @@ -455,7 +588,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct } else if (AstExprIndexName* indexName = function->name->as()) { - TypeId containingTableType = check(scope, indexName->expr); + TypeId containingTableType = check(scope, indexName->expr).ty; functionType = arena->addType(BlockedTypeVar{}); @@ -497,7 +630,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret) for (TypeId ty : scope->returnType) expectedTypes.push_back(ty); - TypePackId exprTypes = checkPack(scope, ret->list, expectedTypes); + TypePackId exprTypes = checkPack(scope, ret->list, expectedTypes).tp; addConstraint(scope, ret->location, PackSubtypeConstraint{exprTypes, scope->returnType}); } @@ -510,8 +643,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) { - TypePackId varPackId = checkPack(scope, assign->vars); - TypePackId valuePack = checkPack(scope, assign->values); + TypePackId varPackId = checkLValues(scope, assign->vars); + TypePackId valuePack = checkPack(scope, assign->values).tp; addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId}); } @@ -532,14 +665,19 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) { - check(scope, ifStatement->condition); + // TODO: Optimization opportunity, the interior scope of the condition could be + // reused for the then body, so we don't need to refine twice. + ScopePtr condScope = childScope(ifStatement->condition, scope); + auto [_, connective] = check(condScope, ifStatement->condition, std::nullopt); ScopePtr thenScope = childScope(ifStatement->thenbody, scope); + applyRefinements(thenScope, Location{}, connective); visit(thenScope, ifStatement->thenbody); if (ifStatement->elsebody) { ScopePtr elseScope = childScope(ifStatement->elsebody, scope); + applyRefinements(elseScope, Location{}, connectiveArena.negation(connective)); visit(elseScope, ifStatement->elsebody); } } @@ -695,7 +833,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* d void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction* global) { - std::vector> generics = createGenerics(scope, global->generics); std::vector> genericPacks = createGenericPacks(scope, global->genericPacks); @@ -742,7 +879,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatError* error) check(scope, expr); } -TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes) +InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes) { std::vector head; std::optional tail; @@ -755,219 +892,180 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray expectedType; if (i < expectedTypes.size()) expectedType = expectedTypes[i]; - head.push_back(check(scope, expr)); + head.push_back(check(scope, expr).ty); } else { std::vector expectedTailTypes; if (i < expectedTypes.size()) expectedTailTypes.assign(begin(expectedTypes) + i, end(expectedTypes)); - tail = checkPack(scope, expr, expectedTailTypes); + tail = checkPack(scope, expr, expectedTailTypes).tp; } } if (head.empty() && tail) - return *tail; + return InferencePack{*tail}; else - return arena->addTypePack(TypePack{std::move(head), tail}); + return InferencePack{arena->addTypePack(TypePack{std::move(head), tail})}; } -TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes) +InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes) { RecursionCounter counter{&recursionCount}; if (recursionCount >= FInt::LuauCheckRecursionLimit) { reportCodeTooComplex(expr->location); - return singletonTypes->errorRecoveryTypePack(); + return InferencePack{singletonTypes->errorRecoveryTypePack()}; } - TypePackId result = nullptr; + InferencePack result; if (AstExprCall* call = expr->as()) - { - TypeId fnType = check(scope, call->func); - const size_t constraintIndex = scope->constraints.size(); - const size_t scopeIndex = scopes.size(); - - std::vector args; - - for (AstExpr* arg : call->args) - { - args.push_back(check(scope, arg)); - } - - // TODO self - - if (matchSetmetatable(*call)) - { - LUAU_ASSERT(args.size() == 2); - TypeId target = args[0]; - TypeId mt = args[1]; - - MetatableTypeVar mtv{target, mt}; - TypeId resultTy = arena->addType(mtv); - result = arena->addTypePack({resultTy}); - } - else - { - const size_t constraintEndIndex = scope->constraints.size(); - const size_t scopeEndIndex = scopes.size(); - - astOriginalCallTypes[call->func] = fnType; - - TypeId instantiatedType = arena->addType(BlockedTypeVar{}); - // TODO: How do expectedTypes play into this? Do they? - TypePackId rets = arena->addTypePack(BlockedTypePack{}); - TypePackId argPack = arena->addTypePack(TypePack{args, {}}); - FunctionTypeVar ftv(TypeLevel{}, scope.get(), argPack, rets); - TypeId inferredFnType = arena->addType(ftv); - - scope->unqueuedConstraints.push_back( - std::make_unique(NotNull{scope.get()}, call->func->location, InstantiationConstraint{instantiatedType, fnType})); - NotNull ic(scope->unqueuedConstraints.back().get()); - - scope->unqueuedConstraints.push_back( - std::make_unique(NotNull{scope.get()}, call->func->location, SubtypeConstraint{inferredFnType, instantiatedType})); - NotNull sc(scope->unqueuedConstraints.back().get()); - - // We force constraints produced by checking function arguments to wait - // until after we have resolved the constraint on the function itself. - // This ensures, for instance, that we start inferring the contents of - // lambdas under the assumption that their arguments and return types - // will be compatible with the enclosing function call. - for (size_t ci = constraintIndex; ci < constraintEndIndex; ++ci) - scope->constraints[ci]->dependencies.push_back(sc); - - for (size_t si = scopeIndex; si < scopeEndIndex; ++si) - { - for (auto& c : scopes[si].second->constraints) - { - c->dependencies.push_back(sc); - } - } - - addConstraint(scope, call->func->location, - FunctionCallConstraint{ - {ic, sc}, - fnType, - argPack, - rets, - call, - }); - - result = rets; - } - } + result = {checkPack(scope, call, expectedTypes)}; else if (AstExprVarargs* varargs = expr->as()) { if (scope->varargPack) - result = *scope->varargPack; + result = InferencePack{*scope->varargPack}; else - result = singletonTypes->errorRecoveryTypePack(); + result = InferencePack{singletonTypes->errorRecoveryTypePack()}; } else { std::optional expectedType; if (!expectedTypes.empty()) expectedType = expectedTypes[0]; - TypeId t = check(scope, expr, expectedType); - result = arena->addTypePack({t}); + TypeId t = check(scope, expr, expectedType).ty; + result = InferencePack{arena->addTypePack({t})}; } - LUAU_ASSERT(result); - astTypePacks[expr] = result; + LUAU_ASSERT(result.tp); + astTypePacks[expr] = result.tp; return result; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType) +InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector& expectedTypes) +{ + TypeId fnType = check(scope, call->func).ty; + const size_t constraintIndex = scope->constraints.size(); + const size_t scopeIndex = scopes.size(); + + std::vector args; + + for (AstExpr* arg : call->args) + { + args.push_back(check(scope, arg).ty); + } + + // TODO self + + if (matchSetmetatable(*call)) + { + LUAU_ASSERT(args.size() == 2); + TypeId target = args[0]; + TypeId mt = args[1]; + + AstExpr* targetExpr = call->args.data[0]; + + MetatableTypeVar mtv{target, mt}; + TypeId resultTy = arena->addType(mtv); + + if (AstExprLocal* targetLocal = targetExpr->as()) + scope->bindings[targetLocal->local].typeId = resultTy; + + return InferencePack{arena->addTypePack({resultTy})}; + } + else + { + const size_t constraintEndIndex = scope->constraints.size(); + const size_t scopeEndIndex = scopes.size(); + + astOriginalCallTypes[call->func] = fnType; + + TypeId instantiatedType = arena->addType(BlockedTypeVar{}); + // TODO: How do expectedTypes play into this? Do they? + TypePackId rets = arena->addTypePack(BlockedTypePack{}); + TypePackId argPack = arena->addTypePack(TypePack{args, {}}); + FunctionTypeVar ftv(TypeLevel{}, scope.get(), argPack, rets); + TypeId inferredFnType = arena->addType(ftv); + + scope->unqueuedConstraints.push_back( + std::make_unique(NotNull{scope.get()}, call->func->location, InstantiationConstraint{instantiatedType, fnType})); + NotNull ic(scope->unqueuedConstraints.back().get()); + + scope->unqueuedConstraints.push_back( + std::make_unique(NotNull{scope.get()}, call->func->location, SubtypeConstraint{inferredFnType, instantiatedType})); + NotNull sc(scope->unqueuedConstraints.back().get()); + + // We force constraints produced by checking function arguments to wait + // until after we have resolved the constraint on the function itself. + // This ensures, for instance, that we start inferring the contents of + // lambdas under the assumption that their arguments and return types + // will be compatible with the enclosing function call. + for (size_t ci = constraintIndex; ci < constraintEndIndex; ++ci) + scope->constraints[ci]->dependencies.push_back(sc); + + for (size_t si = scopeIndex; si < scopeEndIndex; ++si) + { + for (auto& c : scopes[si].second->constraints) + { + c->dependencies.push_back(sc); + } + } + + addConstraint(scope, call->func->location, + FunctionCallConstraint{ + {ic, sc}, + fnType, + argPack, + rets, + call, + }); + + return InferencePack{rets}; + } +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType, bool forceSingleton) { RecursionCounter counter{&recursionCount}; if (recursionCount >= FInt::LuauCheckRecursionLimit) { reportCodeTooComplex(expr->location); - return singletonTypes->errorRecoveryType(); + return Inference{singletonTypes->errorRecoveryType()}; } - TypeId result = nullptr; + Inference result; if (auto group = expr->as()) - result = check(scope, group->expr); + result = check(scope, group->expr, expectedType, forceSingleton); else if (auto stringExpr = expr->as()) - { - if (expectedType) - { - const TypeId expectedTy = follow(*expectedType); - if (get(expectedTy) || get(expectedTy)) - { - result = arena->addType(BlockedTypeVar{}); - TypeId singletonType = arena->addType(SingletonTypeVar(StringSingleton{std::string(stringExpr->value.data, stringExpr->value.size)})); - addConstraint(scope, expr->location, PrimitiveTypeConstraint{result, expectedTy, singletonType, singletonTypes->stringType}); - } - else if (maybeSingleton(expectedTy)) - result = arena->addType(SingletonTypeVar{StringSingleton{std::string{stringExpr->value.data, stringExpr->value.size}}}); - else - result = singletonTypes->stringType; - } - else - result = singletonTypes->stringType; - } + result = check(scope, stringExpr, expectedType, forceSingleton); else if (expr->is()) - result = singletonTypes->numberType; + result = Inference{singletonTypes->numberType}; else if (auto boolExpr = expr->as()) - { - if (expectedType) - { - const TypeId expectedTy = follow(*expectedType); - const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType; - - if (get(expectedTy) || get(expectedTy)) - { - result = arena->addType(BlockedTypeVar{}); - addConstraint(scope, expr->location, PrimitiveTypeConstraint{result, expectedTy, singletonType, singletonTypes->booleanType}); - } - else if (maybeSingleton(expectedTy)) - result = singletonType; - else - result = singletonTypes->booleanType; - } - else - result = singletonTypes->booleanType; - } + result = check(scope, boolExpr, expectedType, forceSingleton); else if (expr->is()) - result = singletonTypes->nilType; - else if (auto a = expr->as()) - { - std::optional ty = scope->lookup(a->local); - if (ty) - result = *ty; - else - result = singletonTypes->errorRecoveryType(); // FIXME? Record an error at this point? - } - else if (auto g = expr->as()) - { - std::optional ty = scope->lookup(g->name); - if (ty) - result = *ty; - else - { - /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any - * global that is not already in-scope is definitely an unknown symbol. - */ - reportError(g->location, UnknownSymbol{g->name.value}); - result = singletonTypes->errorRecoveryType(); // FIXME? Record an error at this point? - } - } + result = Inference{singletonTypes->nilType}; + else if (auto local = expr->as()) + result = check(scope, local); + else if (auto global = expr->as()) + result = check(scope, global); else if (expr->is()) result = flattenPack(scope, expr->location, checkPack(scope, expr)); - else if (expr->is()) - result = flattenPack(scope, expr->location, checkPack(scope, expr)); + else if (auto call = expr->as()) + { + std::vector expectedTypes; + if (expectedType) + expectedTypes.push_back(*expectedType); + result = flattenPack(scope, expr->location, checkPack(scope, call, expectedTypes)); // TODO: needs predicates too + } else if (auto a = expr->as()) { FunctionSignature sig = checkFunctionSignature(scope, a); checkFunctionBody(sig.bodyScope, a); - return sig.signature; + return Inference{sig.signature}; } else if (auto indexName = expr->as()) result = check(scope, indexName); @@ -978,7 +1076,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std:: else if (auto unary = expr->as()) result = check(scope, unary); else if (auto binary = expr->as()) - result = check(scope, binary); + result = check(scope, binary, expectedType); else if (auto ifElse = expr->as()) result = check(scope, ifElse, expectedType); else if (auto typeAssert = expr->as()) @@ -989,22 +1087,105 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std:: for (AstExpr* subExpr : err->expressions) check(scope, subExpr); - result = singletonTypes->errorRecoveryType(); + result = Inference{singletonTypes->errorRecoveryType()}; } else { LUAU_ASSERT(0); - result = freshType(scope); + result = Inference{freshType(scope)}; } - LUAU_ASSERT(result); - astTypes[expr] = result; + LUAU_ASSERT(result.ty); + astTypes[expr] = result.ty; return result; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton) { - TypeId obj = check(scope, indexName->expr); + if (forceSingleton) + return Inference{arena->addType(SingletonTypeVar{StringSingleton{std::string{string->value.data, string->value.size}}})}; + + if (expectedType) + { + const TypeId expectedTy = follow(*expectedType); + if (get(expectedTy) || get(expectedTy)) + { + TypeId ty = arena->addType(BlockedTypeVar{}); + TypeId singletonType = arena->addType(SingletonTypeVar(StringSingleton{std::string(string->value.data, string->value.size)})); + addConstraint(scope, string->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, singletonTypes->stringType}); + return Inference{ty}; + } + else if (maybeSingleton(expectedTy)) + return Inference{arena->addType(SingletonTypeVar{StringSingleton{std::string{string->value.data, string->value.size}}})}; + + return Inference{singletonTypes->stringType}; + } + + return Inference{singletonTypes->stringType}; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional expectedType, bool forceSingleton) +{ + const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType; + if (forceSingleton) + return Inference{singletonType}; + + if (expectedType) + { + const TypeId expectedTy = follow(*expectedType); + + if (get(expectedTy) || get(expectedTy)) + { + TypeId ty = arena->addType(BlockedTypeVar{}); + addConstraint(scope, boolExpr->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, singletonTypes->booleanType}); + return Inference{ty}; + } + else if (maybeSingleton(expectedTy)) + return Inference{singletonType}; + + return Inference{singletonTypes->booleanType}; + } + + return Inference{singletonTypes->booleanType}; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) +{ + std::optional resultTy; + auto def = dfg->getDef(local); + if (def) + resultTy = scope->lookup(*def); + + if (!resultTy) + { + if (auto ty = scope->lookup(local->local)) + resultTy = *ty; + } + + if (!resultTy) + return Inference{singletonTypes->errorRecoveryType()}; // TODO: replace with ice, locals should never exist before its definition. + + if (def) + return Inference{*resultTy, connectiveArena.proposition(*def, singletonTypes->truthyType)}; + else + return Inference{*resultTy}; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global) +{ + if (std::optional ty = scope->lookup(global->name)) + return Inference{*ty}; + + /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any + * global that is not already in-scope is definitely an unknown symbol. + */ + reportError(global->location, UnknownSymbol{global->name.value}); + return Inference{singletonTypes->errorRecoveryType()}; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) +{ + TypeId obj = check(scope, indexName->expr).ty; TypeId result = freshType(scope); TableTypeVar::Props props{{indexName->index.value, Property{result}}}; @@ -1015,13 +1196,13 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* in addConstraint(scope, indexName->expr->location, SubtypeConstraint{obj, expectedTableType}); - return result; + return Inference{result}; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* indexExpr) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* indexExpr) { - TypeId obj = check(scope, indexExpr->expr); - TypeId indexType = check(scope, indexExpr->index); + TypeId obj = check(scope, indexExpr->expr).ty; + TypeId indexType = check(scope, indexExpr->index).ty; TypeId result = freshType(scope); @@ -1031,86 +1212,279 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* in addConstraint(scope, indexExpr->expr->location, SubtypeConstraint{obj, tableType}); - return result; + return Inference{result}; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) { - TypeId operandType = check(scope, unary->expr); - + auto [operandType, connective] = check(scope, unary->expr); TypeId resultType = arena->addType(BlockedTypeVar{}); addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType}); - return resultType; + + if (unary->op == AstExprUnary::Not) + return Inference{resultType, connectiveArena.negation(connective)}; + else + return Inference{resultType}; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) { - TypeId leftType = check(scope, binary->left); - TypeId rightType = check(scope, binary->right); - switch (binary->op) - { - case AstExprBinary::And: - case AstExprBinary::Or: - { - addConstraint(scope, binary->location, SubtypeConstraint{leftType, rightType}); - return leftType; - } - case AstExprBinary::Add: - case AstExprBinary::Sub: - case AstExprBinary::Mul: - case AstExprBinary::Div: - case AstExprBinary::Mod: - case AstExprBinary::Pow: - case AstExprBinary::CompareNe: - case AstExprBinary::CompareEq: - case AstExprBinary::CompareLt: - case AstExprBinary::CompareLe: - case AstExprBinary::CompareGt: - case AstExprBinary::CompareGe: - { - TypeId resultType = arena->addType(BlockedTypeVar{}); - addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType}); - return resultType; - } - case AstExprBinary::Concat: - { - addConstraint(scope, binary->left->location, SubtypeConstraint{leftType, singletonTypes->stringType}); - addConstraint(scope, binary->right->location, SubtypeConstraint{rightType, singletonTypes->stringType}); - return singletonTypes->stringType; - } - default: - LUAU_ASSERT(0); - } + auto [leftType, rightType, connective] = checkBinary(scope, binary, expectedType); - LUAU_ASSERT(0); - return nullptr; + TypeId resultType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType}); + return Inference{resultType, std::move(connective)}; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) { check(scope, ifElse->condition); - TypeId thenType = check(scope, ifElse->trueExpr, expectedType); - TypeId elseType = check(scope, ifElse->falseExpr, expectedType); + TypeId thenType = check(scope, ifElse->trueExpr, expectedType).ty; + TypeId elseType = check(scope, ifElse->falseExpr, expectedType).ty; if (ifElse->hasElse) { TypeId resultType = expectedType ? *expectedType : freshType(scope); addConstraint(scope, ifElse->trueExpr->location, SubtypeConstraint{thenType, resultType}); addConstraint(scope, ifElse->falseExpr->location, SubtypeConstraint{elseType, resultType}); - return resultType; + return Inference{resultType}; } - return thenType; + return Inference{thenType}; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) { - check(scope, typeAssert->expr); - return resolveType(scope, typeAssert->annotation); + check(scope, typeAssert->expr, std::nullopt); + return Inference{resolveType(scope, typeAssert->annotation)}; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) +std::tuple ConstraintGraphBuilder::checkBinary( + const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) +{ + if (binary->op == AstExprBinary::And) + { + auto [leftType, leftConnective] = check(scope, binary->left, expectedType); + + ScopePtr rightScope = childScope(binary->right, scope); + applyRefinements(rightScope, binary->right->location, leftConnective); + auto [rightType, rightConnective] = check(rightScope, binary->right, expectedType); + + return {leftType, rightType, connectiveArena.conjunction(leftConnective, rightConnective)}; + } + else if (binary->op == AstExprBinary::Or) + { + auto [leftType, leftConnective] = check(scope, binary->left, expectedType); + + ScopePtr rightScope = childScope(binary->right, scope); + applyRefinements(rightScope, binary->right->location, connectiveArena.negation(leftConnective)); + auto [rightType, rightConnective] = check(rightScope, binary->right, expectedType); + + return {leftType, rightType, connectiveArena.disjunction(leftConnective, rightConnective)}; + } + else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe) + { + TypeId leftType = check(scope, binary->left, expectedType, true).ty; + TypeId rightType = check(scope, binary->right, expectedType, true).ty; + + ConnectiveId leftConnective = nullptr; + if (auto def = dfg->getDef(binary->left)) + leftConnective = connectiveArena.proposition(*def, rightType); + + ConnectiveId rightConnective = nullptr; + if (auto def = dfg->getDef(binary->right)) + rightConnective = connectiveArena.proposition(*def, leftType); + + if (binary->op == AstExprBinary::CompareNe) + { + leftConnective = connectiveArena.negation(leftConnective); + rightConnective = connectiveArena.negation(rightConnective); + } + + return {leftType, rightType, connectiveArena.equivalence(leftConnective, rightConnective)}; + } + else + { + TypeId leftType = check(scope, binary->left, expectedType).ty; + TypeId rightType = check(scope, binary->right, expectedType).ty; + return {leftType, rightType, nullptr}; + } +} + +TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray exprs) +{ + std::vector types; + types.reserve(exprs.size); + + for (size_t i = 0; i < exprs.size; ++i) + { + AstExpr* const expr = exprs.data[i]; + types.push_back(checkLValue(scope, expr)); + } + + return arena->addTypePack(std::move(types)); +} + +static bool isUnsealedTable(TypeId ty) +{ + ty = follow(ty); + const TableTypeVar* ttv = get(ty); + return ttv && ttv->state == TableState::Unsealed; +}; + +/** + * If the expr is a dotted set of names, and if the root symbol refers to an + * unsealed table, return that table type, plus the indeces that follow as a + * vector. + */ +static std::optional>> extractDottedName(AstExpr* expr) +{ + std::vector names; + + while (expr) + { + if (auto global = expr->as()) + { + std::reverse(begin(names), end(names)); + return std::pair{global->name, std::move(names)}; + } + else if (auto local = expr->as()) + { + std::reverse(begin(names), end(names)); + return std::pair{local->local, std::move(names)}; + } + else if (auto indexName = expr->as()) + { + names.push_back(indexName->index.value); + expr = indexName->expr; + } + else + return std::nullopt; + } + + return std::nullopt; +} + +/** + * Create a shallow copy of `ty` and its properties along `path`. Insert a new + * property (the last segment of `path`) into the tail table with the value `t`. + * + * On success, returns the new outermost table type. If the root table or any + * of its subkeys are not unsealed tables, the function fails and returns + * std::nullopt. + * + * TODO: Prove that we completely give up in the face of indexers and + * metatables. + */ +static std::optional updateTheTableType(NotNull arena, TypeId ty, const std::vector& path, TypeId replaceTy) +{ + if (path.empty()) + return std::nullopt; + + // First walk the path and ensure that it's unsealed tables all the way + // to the end. + { + TypeId t = ty; + for (size_t i = 0; i < path.size() - 1; ++i) + { + if (!isUnsealedTable(t)) + return std::nullopt; + + const TableTypeVar* tbl = get(t); + auto it = tbl->props.find(path[i]); + if (it == tbl->props.end()) + return std::nullopt; + + t = it->second.type; + } + + // The last path segment should not be a property of the table at all. + // We are not changing property types. We are only admitting this one + // new property to be appended. + if (!isUnsealedTable(t)) + return std::nullopt; + const TableTypeVar* tbl = get(t); + auto it = tbl->props.find(path.back()); + if (it != tbl->props.end()) + return std::nullopt; + } + + const TypeId res = shallowClone(ty, arena); + TypeId t = res; + + for (size_t i = 0; i < path.size() - 1; ++i) + { + const std::string segment = path[i]; + + TableTypeVar* ttv = getMutable(t); + LUAU_ASSERT(ttv); + + auto propIt = ttv->props.find(segment); + if (propIt != ttv->props.end()) + { + LUAU_ASSERT(isUnsealedTable(propIt->second.type)); + t = shallowClone(follow(propIt->second.type), arena); + ttv->props[segment].type = t; + } + else + return std::nullopt; + } + + TableTypeVar* ttv = getMutable(t); + LUAU_ASSERT(ttv); + + const std::string lastSegment = path.back(); + LUAU_ASSERT(0 == ttv->props.count(lastSegment)); + ttv->props[lastSegment] = Property{replaceTy}; + return res; +} + +/** + * This function is mostly about identifying properties that are being inserted into unsealed tables. + * + * If expr has the form name.a.b.c + */ +TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) +{ + if (auto indexExpr = expr->as()) + { + if (auto constantString = indexExpr->index->as()) + { + AstName syntheticIndex{constantString->value.data}; + AstExprIndexName synthetic{ + indexExpr->location, indexExpr->expr, syntheticIndex, constantString->location, indexExpr->expr->location.end, '.'}; + return checkLValue(scope, &synthetic); + } + } + + auto dottedPath = extractDottedName(expr); + if (!dottedPath) + return check(scope, expr).ty; + const auto [sym, segments] = std::move(*dottedPath); + + if (!sym.local) + return check(scope, expr).ty; + + auto lookupResult = scope->lookupEx(sym); + if (!lookupResult) + return check(scope, expr).ty; + const auto [ty, symbolScope] = std::move(*lookupResult); + + TypeId replaceTy = arena->freshType(scope.get()); + + std::optional updatedType = updateTheTableType(arena, ty, segments, replaceTy); + if (!updatedType) + return check(scope, expr).ty; + + std::optional def = dfg->getDef(sym); + LUAU_ASSERT(def); + symbolScope->bindings[sym].typeId = *updatedType; + symbolScope->dcrRefinements[*def] = *updatedType; + return replaceTy; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) { TypeId ty = arena->addType(TableTypeVar{}); TableTypeVar* ttv = getMutable(ty); @@ -1144,16 +1518,14 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, } } - TypeId itemTy = check(scope, item.value, expectedValueType); - if (get(follow(itemTy))) - return ty; + TypeId itemTy = check(scope, item.value, expectedValueType).ty; if (item.key) { // Even though we don't need to use the type of the item's key if // it's a string constant, we still want to check it to populate // astTypes. - TypeId keyTy = check(scope, item.key); + TypeId keyTy = check(scope, item.key).ty; if (AstExprConstantString* key = item.key->as()) { @@ -1173,7 +1545,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, } } - return ty; + return Inference{ty}; } ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature(const ScopePtr& parent, AstExprFunction* fn) @@ -1275,6 +1647,9 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS argTypes.push_back(t); signatureScope->bindings[local] = Binding{t, local->location}; + if (auto def = dfg->getDef(local)) + signatureScope->dcrRefinements[*def] = t; + if (local->annotation) { TypeId argAnnotation = resolveType(signatureScope, local->annotation, /* topLevel */ true); @@ -1338,9 +1713,18 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b } } - std::optional alias = scope->lookupType(ref->name.value); + std::optional alias; - if (alias.has_value() || ref->prefix.has_value()) + if (ref->prefix.has_value()) + { + alias = scope->lookupImportedType(ref->prefix->value, ref->name.value); + } + else + { + alias = scope->lookupType(ref->name.value); + } + + if (alias.has_value()) { // If the alias is not generic, we don't need to set up a blocked // type and an instantiation constraint. @@ -1383,7 +1767,11 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b } else { - reportError(ty->location, UnknownSymbol{ref->name.value, UnknownSymbol::Context::Type}); + std::string typeName; + if (ref->prefix) + typeName = std::string(ref->prefix->value) + "."; + typeName += ref->name.value; + result = singletonTypes->errorRecoveryType(); } } @@ -1482,7 +1870,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b else if (auto tof = ty->as()) { // TODO: Recursion limit. - TypeId exprType = check(scope, tof->expr); + TypeId exprType = check(scope, tof->expr).ty; result = exprType; } else if (auto unionAnnotation = ty->as()) @@ -1491,7 +1879,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b for (AstType* part : unionAnnotation->types) { // TODO: Recursion limit. - parts.push_back(resolveType(scope, part)); + parts.push_back(resolveType(scope, part, topLevel)); } result = arena->addType(UnionTypeVar{parts}); @@ -1502,7 +1890,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b for (AstType* part : intersectionAnnotation->types) { // TODO: Recursion limit. - parts.push_back(resolveType(scope, part)); + parts.push_back(resolveType(scope, part, topLevel)); } result = arena->addType(IntersectionTypeVar{parts}); @@ -1592,10 +1980,7 @@ std::vector> ConstraintGraphBuilder::crea if (generic.defaultValue) defaultTy = resolveType(scope, generic.defaultValue); - result.push_back({generic.name.value, GenericTypeDefinition{ - genericTy, - defaultTy, - }}); + result.push_back({generic.name.value, GenericTypeDefinition{genericTy, defaultTy}}); } return result; @@ -1613,19 +1998,21 @@ std::vector> ConstraintGraphBuilder:: if (generic.defaultValue) defaultTy = resolveTypePack(scope, generic.defaultValue); - result.push_back({generic.name.value, GenericTypePackDefinition{ - genericTy, - defaultTy, - }}); + result.push_back({generic.name.value, GenericTypePackDefinition{genericTy, defaultTy}}); } return result; } -TypeId ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, TypePackId tp) +Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, InferencePack pack) { + const auto& [tp, connectives] = pack; + ConnectiveId connective = nullptr; + if (!connectives.empty()) + connective = connectives[0]; + if (auto f = first(tp)) - return *f; + return Inference{*f, connective}; TypeId typeResult = freshType(scope); TypePack onePack{{typeResult}, freshTypePack(scope)}; @@ -1633,7 +2020,7 @@ TypeId ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location locat addConstraint(scope, location, PackSubtypeConstraint{tp, oneTypePack}); - return typeResult; + return Inference{typeResult, connective}; } void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index e29eeaaa..c53ac659 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -3,14 +3,16 @@ #include "Luau/Anyification.h" #include "Luau/ApplyTypeFunction.h" #include "Luau/ConstraintSolver.h" +#include "Luau/DcrLogger.h" #include "Luau/Instantiation.h" #include "Luau/Location.h" +#include "Luau/Metamethods.h" #include "Luau/ModuleResolver.h" #include "Luau/Quantify.h" #include "Luau/ToString.h" +#include "Luau/TypeUtils.h" #include "Luau/TypeVar.h" #include "Luau/Unifier.h" -#include "Luau/DcrLogger.h" #include "Luau/VisitTypeVar.h" #include "Luau/TypeUtils.h" @@ -438,6 +440,8 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*fcc, constraint); else if (auto hpc = get(*constraint)) success = tryDispatch(*hpc, constraint); + else if (auto sottc = get(*constraint)) + success = tryDispatch(*sottc, constraint); else LUAU_ASSERT(false); @@ -540,6 +544,7 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNullty.emplace(singletonTypes->numberType); return true; } @@ -548,13 +553,46 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull(operandType) || get(operandType)) { asMutable(c.resultType)->ty.emplace(c.operandType); - return true; } - break; + else if (std::optional mm = findMetatableEntry(singletonTypes, errors, operandType, "__unm", constraint->location)) + { + const FunctionTypeVar* ftv = get(follow(*mm)); + + if (!ftv) + { + if (std::optional callMm = findMetatableEntry(singletonTypes, errors, follow(*mm), "__call", constraint->location)) + { + ftv = get(follow(*callMm)); + } + } + + if (!ftv) + { + asMutable(c.resultType)->ty.emplace(singletonTypes->errorRecoveryType()); + return true; + } + + TypePackId argsPack = arena->addTypePack({operandType}); + unify(ftv->argTypes, argsPack, constraint->scope); + + TypeId result = singletonTypes->errorRecoveryType(); + if (ftv) + { + result = first(ftv->retTypes).value_or(singletonTypes->errorRecoveryType()); + } + + asMutable(c.resultType)->ty.emplace(result); + } + else + { + asMutable(c.resultType)->ty.emplace(singletonTypes->errorRecoveryType()); + } + + return true; } } - LUAU_ASSERT(false); // TODO metatable handling + LUAU_ASSERT(false); return false; } @@ -564,44 +602,192 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull - * - * This constraint is the one that is meant to unblock A, so it doesn't - * make any sense to stop and wait for someone else to do it. - */ - if (leftType != resultType && rightType != resultType) - { - block(c.leftType, constraint); - block(c.rightType, constraint); - return false; - } - } + bool isLogical = c.op == AstExprBinary::Op::And || c.op == AstExprBinary::Op::Or; - if (isNumber(leftType)) - { - unify(leftType, rightType, constraint->scope); - asMutable(resultType)->ty.emplace(leftType); - return true; - } + /* Compound assignments create constraints of the form + * + * A <: Binary + * + * This constraint is the one that is meant to unblock A, so it doesn't + * make any sense to stop and wait for someone else to do it. + */ + + if (isBlocked(leftType) && leftType != resultType) + return block(c.leftType, constraint); + + if (isBlocked(rightType) && rightType != resultType) + return block(c.rightType, constraint); if (!force) { - if (get(leftType)) + // Logical expressions may proceed if the LHS is free. + if (get(leftType) && !isLogical) return block(leftType, constraint); } - if (isBlocked(leftType)) + // Logical expressions may proceed if the LHS is free. + if (isBlocked(leftType) || (get(leftType) && !isLogical)) { asMutable(resultType)->ty.emplace(errorRecoveryType()); - // reportError(constraint->location, CannotInferBinaryOperation{c.op, std::nullopt, CannotInferBinaryOperation::Operation}); + unblock(resultType); return true; } - // TODO metatables, classes + // For or expressions, the LHS will never have nil as a possible output. + // Consider: + // local foo = nil or 2 + // `foo` will always be 2. + if (c.op == AstExprBinary::Op::Or) + leftType = stripNil(singletonTypes, *arena, leftType); + + // Metatables go first, even if there is primitive behavior. + if (auto it = kBinaryOpMetamethods.find(c.op); it != kBinaryOpMetamethods.end()) + { + // Metatables are not the same. The metamethod will not be invoked. + if ((c.op == AstExprBinary::Op::CompareEq || c.op == AstExprBinary::Op::CompareNe) && + getMetatable(leftType, singletonTypes) != getMetatable(rightType, singletonTypes)) + { + // TODO: Boolean singleton false? The result is _always_ boolean false. + asMutable(resultType)->ty.emplace(singletonTypes->booleanType); + unblock(resultType); + return true; + } + + std::optional mm; + + // The LHS metatable takes priority over the RHS metatable, where + // present. + if (std::optional leftMm = findMetatableEntry(singletonTypes, errors, leftType, it->second, constraint->location)) + mm = leftMm; + else if (std::optional rightMm = findMetatableEntry(singletonTypes, errors, rightType, it->second, constraint->location)) + mm = rightMm; + + if (mm) + { + // TODO: Is a table with __call legal here? + // TODO: Overloads + if (const FunctionTypeVar* ftv = get(follow(*mm))) + { + TypePackId inferredArgs; + // For >= and > we invoke __lt and __le respectively with + // swapped argument ordering. + if (c.op == AstExprBinary::Op::CompareGe || c.op == AstExprBinary::Op::CompareGt) + { + inferredArgs = arena->addTypePack({rightType, leftType}); + } + else + { + inferredArgs = arena->addTypePack({leftType, rightType}); + } + + unify(inferredArgs, ftv->argTypes, constraint->scope); + + TypeId mmResult; + + // Comparison operations always evaluate to a boolean, + // regardless of what the metamethod returns. + switch (c.op) + { + case AstExprBinary::Op::CompareEq: + case AstExprBinary::Op::CompareNe: + case AstExprBinary::Op::CompareGe: + case AstExprBinary::Op::CompareGt: + case AstExprBinary::Op::CompareLe: + case AstExprBinary::Op::CompareLt: + mmResult = singletonTypes->booleanType; + break; + default: + mmResult = first(ftv->retTypes).value_or(errorRecoveryType()); + } + + asMutable(resultType)->ty.emplace(mmResult); + unblock(resultType); + return true; + } + } + + // If there's no metamethod available, fall back to primitive behavior. + } + + // If any is present, the expression must evaluate to any as well. + bool leftAny = get(leftType) || get(leftType); + bool rightAny = get(rightType) || get(rightType); + bool anyPresent = leftAny || rightAny; + + switch (c.op) + { + // For arithmetic operators, if the LHS is a number, the RHS must be a + // number as well. The result will also be a number. + case AstExprBinary::Op::Add: + case AstExprBinary::Op::Sub: + case AstExprBinary::Op::Mul: + case AstExprBinary::Op::Div: + case AstExprBinary::Op::Pow: + case AstExprBinary::Op::Mod: + if (isNumber(leftType)) + { + unify(leftType, rightType, constraint->scope); + asMutable(resultType)->ty.emplace(anyPresent ? singletonTypes->anyType : leftType); + unblock(resultType); + return true; + } + + break; + // For concatenation, if the LHS is a string, the RHS must be a string as + // well. The result will also be a string. + case AstExprBinary::Op::Concat: + if (isString(leftType)) + { + unify(leftType, rightType, constraint->scope); + asMutable(resultType)->ty.emplace(anyPresent ? singletonTypes->anyType : leftType); + unblock(resultType); + return true; + } + + break; + // Inexact comparisons require that the types be both numbers or both + // strings, and evaluate to a boolean. + case AstExprBinary::Op::CompareGe: + case AstExprBinary::Op::CompareGt: + case AstExprBinary::Op::CompareLe: + case AstExprBinary::Op::CompareLt: + if ((isNumber(leftType) && isNumber(rightType)) || (isString(leftType) && isString(rightType))) + { + asMutable(resultType)->ty.emplace(singletonTypes->booleanType); + unblock(resultType); + return true; + } + + break; + // == and ~= always evaluate to a boolean, and impose no other constraints + // on their parameters. + case AstExprBinary::Op::CompareEq: + case AstExprBinary::Op::CompareNe: + asMutable(resultType)->ty.emplace(singletonTypes->booleanType); + unblock(resultType); + return true; + // And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is + // truthy. + case AstExprBinary::Op::And: + asMutable(resultType)->ty.emplace(unionOfTypes(rightType, singletonTypes->booleanType, constraint->scope, false)); + unblock(resultType); + return true; + // Or evaluates to the LHS type if the LHS is truthy, and the RHS type if + // LHS is falsey. + case AstExprBinary::Op::Or: + asMutable(resultType)->ty.emplace(unionOfTypes(rightType, leftType, constraint->scope, true)); + unblock(resultType); + return true; + default: + iceReporter.ice("Unhandled AstExprBinary::Op for binary operation", constraint->location); + break; + } + + // We failed to either evaluate a metamethod or invoke primitive behavior. + unify(leftType, errorRecoveryType(), constraint->scope); + unify(rightType, errorRecoveryType(), constraint->scope); + asMutable(resultType)->ty.emplace(errorRecoveryType()); + unblock(resultType); return true; } @@ -710,6 +896,10 @@ bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNullname = c.name; else if (MetatableTypeVar* mtv = getMutable(target)) mtv->syntheticName = c.name; + else if (get(target) || get(target)) + { + // nothing (yet) + } else return block(c.namedType, constraint); @@ -943,6 +1133,31 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull callMm = findMetatableEntry(singletonTypes, errors, fn, "__call", constraint->location)) + { + std::vector args{fn}; + + for (TypeId arg : c.argsPack) + args.push_back(arg); + + TypeId instantiatedType = arena->addType(BlockedTypeVar{}); + TypeId inferredFnType = + arena->addType(FunctionTypeVar(TypeLevel{}, constraint->scope.get(), arena->addTypePack(TypePack{args, {}}), c.result)); + + // Alter the inner constraints. + LUAU_ASSERT(c.innerConstraints.size() == 2); + + asMutable(*c.innerConstraints.at(0)).c = InstantiationConstraint{instantiatedType, *callMm}; + asMutable(*c.innerConstraints.at(1)).c = SubtypeConstraint{inferredFnType, instantiatedType}; + + unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints)); + + asMutable(c.result)->ty.emplace(constraint->scope); + unblock(c.result); + return true; + } + const FunctionTypeVar* ftv = get(fn); bool usedMagic = false; @@ -1059,6 +1274,22 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull constraint) +{ + if (isBlocked(c.discriminantType)) + return false; + + TypeId followed = follow(c.discriminantType); + + // `nil` is a singleton type too! There's only one value of type `nil`. + if (get(followed) || isNil(followed)) + *asMutable(c.resultType) = NegationTypeVar{c.discriminantType}; + else + *asMutable(c.resultType) = BoundTypeVar{singletonTypes->unknownType}; + + return true; +} + bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force) { auto block_ = [&](auto&& t) { @@ -1502,4 +1733,39 @@ TypePackId ConstraintSolver::errorRecoveryTypePack() const return singletonTypes->errorRecoveryTypePack(); } +TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull scope, bool unifyFreeTypes) +{ + a = follow(a); + b = follow(b); + + if (unifyFreeTypes && (get(a) || get(b))) + { + Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; + u.useScopes = true; + u.tryUnify(b, a); + + if (u.errors.empty()) + { + u.log.commit(); + return a; + } + else + { + return singletonTypes->errorRecoveryType(singletonTypes->anyType); + } + } + + if (*a == *b) + return a; + + std::vector types = reduceUnion({a, b}); + if (types.empty()) + return singletonTypes->neverType; + + if (types.size() == 1) + return types[0]; + + return arena->addType(UnionTypeVar{types}); +} + } // namespace Luau diff --git a/Analysis/src/DataFlowGraphBuilder.cpp b/Analysis/src/DataFlowGraphBuilder.cpp new file mode 100644 index 00000000..e2c4c285 --- /dev/null +++ b/Analysis/src/DataFlowGraphBuilder.cpp @@ -0,0 +1,440 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/DataFlowGraphBuilder.h" + +#include "Luau/Error.h" + +LUAU_FASTFLAG(DebugLuauFreezeArena) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + +namespace Luau +{ + +std::optional DataFlowGraph::getDef(const AstExpr* expr) const +{ + if (auto def = astDefs.find(expr)) + return NotNull{*def}; + return std::nullopt; +} + +std::optional DataFlowGraph::getDef(const AstLocal* local) const +{ + if (auto def = localDefs.find(local)) + return NotNull{*def}; + return std::nullopt; +} + +std::optional DataFlowGraph::getDef(const Symbol& symbol) const +{ + if (symbol.local) + return getDef(symbol.local); + else + return std::nullopt; +} + +DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull handle) +{ + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + + DataFlowGraphBuilder builder; + builder.handle = handle; + builder.visit(nullptr, block); // nullptr is the root DFG scope. + if (FFlag::DebugLuauFreezeArena) + builder.arena->allocator.freeze(); + return std::move(builder.graph); +} + +DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope) +{ + return scopes.emplace_back(new DfgScope{scope}).get(); +} + +std::optional DataFlowGraphBuilder::use(DfgScope* scope, Symbol symbol, AstExpr* e) +{ + for (DfgScope* current = scope; current; current = current->parent) + { + if (auto loc = current->bindings.find(symbol)) + { + graph.astDefs[e] = *loc; + return NotNull{*loc}; + } + } + + return std::nullopt; +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBlock* b) +{ + DfgScope* child = childScope(scope); + return visitBlockWithoutChildScope(child, b); +} + +void DataFlowGraphBuilder::visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b) +{ + for (AstStat* s : b->body) + visit(scope, s); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s) +{ + if (auto b = s->as()) + return visit(scope, b); + else if (auto i = s->as()) + return visit(scope, i); + else if (auto w = s->as()) + return visit(scope, w); + else if (auto r = s->as()) + return visit(scope, r); + else if (auto b = s->as()) + return visit(scope, b); + else if (auto c = s->as()) + return visit(scope, c); + else if (auto r = s->as()) + return visit(scope, r); + else if (auto e = s->as()) + return visit(scope, e); + else if (auto l = s->as()) + return visit(scope, l); + else if (auto f = s->as()) + return visit(scope, f); + else if (auto f = s->as()) + return visit(scope, f); + else if (auto a = s->as()) + return visit(scope, a); + else if (auto c = s->as()) + return visit(scope, c); + else if (auto f = s->as()) + return visit(scope, f); + else if (auto l = s->as()) + return visit(scope, l); + else if (auto t = s->as()) + return; // ok + else if (auto d = s->as()) + return; // ok + else if (auto d = s->as()) + return; // ok + else if (auto d = s->as()) + return; // ok + else if (auto d = s->as()) + return; // ok + else if (auto _ = s->as()) + return; // ok + else + handle->ice("Unknown AstStat in DataFlowGraphBuilder"); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i) +{ + DfgScope* condScope = childScope(scope); + visitExpr(condScope, i->condition); + visit(condScope, i->thenbody); + + if (i->elsebody) + visit(scope, i->elsebody); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatWhile* w) +{ + // TODO(controlflow): entry point has a back edge from exit point + DfgScope* whileScope = childScope(scope); + visitExpr(whileScope, w->condition); + visit(whileScope, w->body); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatRepeat* r) +{ + // TODO(controlflow): entry point has a back edge from exit point + DfgScope* repeatScope = childScope(scope); // TODO: loop scope. + visitBlockWithoutChildScope(repeatScope, r->body); + visitExpr(repeatScope, r->condition); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBreak* b) +{ + // TODO: Control flow analysis + return; // ok +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatContinue* c) +{ + // TODO: Control flow analysis + return; // ok +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatReturn* r) +{ + // TODO: Control flow analysis + for (AstExpr* e : r->list) + visitExpr(scope, e); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatExpr* e) +{ + visitExpr(scope, e->expr); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) +{ + // TODO: alias tracking + for (AstExpr* e : l->values) + visitExpr(scope, e); + + for (AstLocal* local : l->vars) + { + DefId def = arena->freshDef(); + graph.localDefs[local] = def; + scope->bindings[local] = def; + } +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f) +{ + DfgScope* forScope = childScope(scope); // TODO: loop scope. + DefId def = arena->freshDef(); + graph.localDefs[f->var] = def; + scope->bindings[f->var] = def; + + // TODO(controlflow): entry point has a back edge from exit point + visit(forScope, f->body); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f) +{ + DfgScope* forScope = childScope(scope); // TODO: loop scope. + + for (AstLocal* local : f->vars) + { + DefId def = arena->freshDef(); + graph.localDefs[local] = def; + forScope->bindings[local] = def; + } + + // TODO(controlflow): entry point has a back edge from exit point + for (AstExpr* e : f->values) + visitExpr(forScope, e); + + visit(forScope, f->body); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a) +{ + for (AstExpr* r : a->values) + visitExpr(scope, r); + + for (AstExpr* l : a->vars) + { + AstExpr* root = l; + + bool isUpdatable = true; + while (true) + { + if (root->is() || root->is()) + break; + + AstExprIndexName* indexName = root->as(); + if (!indexName) + { + isUpdatable = false; + break; + } + + root = indexName->expr; + } + + if (isUpdatable) + { + // TODO global? + if (auto exprLocal = root->as()) + { + DefId def = arena->freshDef(); + graph.astDefs[exprLocal] = def; + + // Update the def in the scope that introduced the local. Not + // the current scope. + AstLocal* local = exprLocal->local; + DfgScope* s = scope; + while (s && !s->bindings.find(local)) + s = s->parent; + LUAU_ASSERT(s && s->bindings.find(local)); + s->bindings[local] = def; + } + } + + visitExpr(scope, l); // TODO: they point to a new def!! + } +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c) +{ + // TODO(typestates): The lhs is being read and written to. This might or might not be annoying. + visitExpr(scope, c->value); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) +{ + visitExpr(scope, f->name); + visitExpr(scope, f->func); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l) +{ + DefId def = arena->freshDef(); + graph.localDefs[l->name] = def; + scope->bindings[l->name] = def; + + visitExpr(scope, l->func); +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) +{ + if (auto g = e->as()) + return visitExpr(scope, g->expr); + else if (auto c = e->as()) + return {}; // ok + else if (auto c = e->as()) + return {}; // ok + else if (auto c = e->as()) + return {}; // ok + else if (auto c = e->as()) + return {}; // ok + else if (auto l = e->as()) + return visitExpr(scope, l); + else if (auto g = e->as()) + return visitExpr(scope, g); + else if (auto v = e->as()) + return {}; // ok + else if (auto c = e->as()) + return visitExpr(scope, c); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto f = e->as()) + return visitExpr(scope, f); + else if (auto t = e->as()) + return visitExpr(scope, t); + else if (auto u = e->as()) + return visitExpr(scope, u); + else if (auto b = e->as()) + return visitExpr(scope, b); + else if (auto t = e->as()) + return visitExpr(scope, t); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto _ = e->as()) + return {}; // ok + else + handle->ice("Unknown AstExpr in DataFlowGraphBuilder"); +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l) +{ + return {use(scope, l->local, l)}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g) +{ + return {use(scope, g->name, g)}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c) +{ + visitExpr(scope, c->func); + + for (AstExpr* arg : c->args) + visitExpr(scope, arg); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i) +{ + std::optional def = visitExpr(scope, i->expr).def; + if (!def) + return {}; + + // TODO: properties for the above def. + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i) +{ + visitExpr(scope, i->expr); + visitExpr(scope, i->expr); + + if (i->index->as()) + { + // TODO: properties for the def + } + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f) +{ + if (AstLocal* self = f->self) + { + DefId def = arena->freshDef(); + graph.localDefs[self] = def; + scope->bindings[self] = def; + } + + for (AstLocal* param : f->args) + { + DefId def = arena->freshDef(); + graph.localDefs[param] = def; + scope->bindings[param] = def; + } + + visit(scope, f->body); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t) +{ + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u) +{ + visitExpr(scope, u->expr); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b) +{ + visitExpr(scope, b->left); + visitExpr(scope, b->right); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t) +{ + ExpressionFlowGraph result = visitExpr(scope, t->expr); + // TODO: visit type + return result; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i) +{ + DfgScope* condScope = childScope(scope); + visitExpr(condScope, i->condition); + visitExpr(condScope, i->trueExpr); + + visitExpr(scope, i->falseExpr); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i) +{ + for (AstExpr* e : i->expressions) + visitExpr(scope, e); + return {}; +} + +} // namespace Luau diff --git a/Analysis/src/Def.cpp b/Analysis/src/Def.cpp new file mode 100644 index 00000000..935301c8 --- /dev/null +++ b/Analysis/src/Def.cpp @@ -0,0 +1,12 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Def.h" + +namespace Luau +{ + +DefId DefArena::freshDef() +{ + return NotNull{allocator.allocate(Undefined{})}; +} + +} // namespace Luau diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 0f04ace0..339de975 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -13,47 +13,47 @@ declare bit32: { bor: (...number) -> number, bxor: (...number) -> number, btest: (number, ...number) -> boolean, - rrotate: (number, number) -> number, - lrotate: (number, number) -> number, - lshift: (number, number) -> number, - arshift: (number, number) -> number, - rshift: (number, number) -> number, - bnot: (number) -> number, - extract: (number, number, number?) -> number, - replace: (number, number, number, number?) -> number, - countlz: (number) -> number, - countrz: (number) -> number, + rrotate: (x: number, disp: number) -> number, + lrotate: (x: number, disp: number) -> number, + lshift: (x: number, disp: number) -> number, + arshift: (x: number, disp: number) -> number, + rshift: (x: number, disp: number) -> number, + bnot: (x: number) -> number, + extract: (n: number, field: number, width: number?) -> number, + replace: (n: number, v: number, field: number, width: number?) -> number, + countlz: (n: number) -> number, + countrz: (n: number) -> number, } declare math: { - frexp: (number) -> (number, number), - ldexp: (number, number) -> number, - fmod: (number, number) -> number, - modf: (number) -> (number, number), - pow: (number, number) -> number, - exp: (number) -> number, + frexp: (n: number) -> (number, number), + ldexp: (s: number, e: number) -> number, + fmod: (x: number, y: number) -> number, + modf: (n: number) -> (number, number), + pow: (x: number, y: number) -> number, + exp: (n: number) -> number, - ceil: (number) -> number, - floor: (number) -> number, - abs: (number) -> number, - sqrt: (number) -> number, + ceil: (n: number) -> number, + floor: (n: number) -> number, + abs: (n: number) -> number, + sqrt: (n: number) -> number, - log: (number, number?) -> number, - log10: (number) -> number, + log: (n: number, base: number?) -> number, + log10: (n: number) -> number, - rad: (number) -> number, - deg: (number) -> number, + rad: (n: number) -> number, + deg: (n: number) -> number, - sin: (number) -> number, - cos: (number) -> number, - tan: (number) -> number, - sinh: (number) -> number, - cosh: (number) -> number, - tanh: (number) -> number, - atan: (number) -> number, - acos: (number) -> number, - asin: (number) -> number, - atan2: (number, number) -> number, + sin: (n: number) -> number, + cos: (n: number) -> number, + tan: (n: number) -> number, + sinh: (n: number) -> number, + cosh: (n: number) -> number, + tanh: (n: number) -> number, + atan: (n: number) -> number, + acos: (n: number) -> number, + asin: (n: number) -> number, + atan2: (y: number, x: number) -> number, min: (number, ...number) -> number, max: (number, ...number) -> number, @@ -61,13 +61,13 @@ declare math: { pi: number, huge: number, - randomseed: (number) -> (), + randomseed: (seed: number) -> (), random: (number?, number?) -> number, - sign: (number) -> number, - clamp: (number, number, number) -> number, - noise: (number, number?, number?) -> number, - round: (number) -> number, + sign: (n: number) -> number, + clamp: (n: number, min: number, max: number) -> number, + noise: (x: number, y: number?, z: number?) -> number, + round: (n: number) -> number, } type DateTypeArg = { @@ -93,9 +93,9 @@ type DateTypeResult = { } declare os: { - time: (DateTypeArg?) -> number, - date: (string?, number?) -> DateTypeResult | string, - difftime: (DateTypeResult | number, DateTypeResult | number) -> number, + time: (time: DateTypeArg?) -> number, + date: (formatString: string?, time: number?) -> DateTypeResult | string, + difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number, clock: () -> number, } @@ -145,51 +145,51 @@ declare function loadstring(src: string, chunkname: string?): (((A...) -> declare function newproxy(mt: boolean?): any declare coroutine: { - create: ((A...) -> R...) -> thread, - resume: (thread, A...) -> (boolean, R...), + create: (f: (A...) -> R...) -> thread, + resume: (co: thread, A...) -> (boolean, R...), running: () -> thread, - status: (thread) -> "dead" | "running" | "normal" | "suspended", + status: (co: thread) -> "dead" | "running" | "normal" | "suspended", -- FIXME: This technically returns a function, but we can't represent this yet. - wrap: ((A...) -> R...) -> any, + wrap: (f: (A...) -> R...) -> any, yield: (A...) -> R..., isyieldable: () -> boolean, - close: (thread) -> (boolean, any) + close: (co: thread) -> (boolean, any) } declare table: { - concat: ({V}, string?, number?, number?) -> string, - insert: (({V}, V) -> ()) & (({V}, number, V) -> ()), - maxn: ({V}) -> number, - remove: ({V}, number?) -> V?, - sort: ({V}, ((V, V) -> boolean)?) -> (), - create: (number, V?) -> {V}, - find: ({V}, V, number?) -> number?, + concat: (t: {V}, sep: string?, i: number?, j: number?) -> string, + insert: ((t: {V}, value: V) -> ()) & ((t: {V}, pos: number, value: V) -> ()), + maxn: (t: {V}) -> number, + remove: (t: {V}, number?) -> V?, + sort: (t: {V}, comp: ((V, V) -> boolean)?) -> (), + create: (count: number, value: V?) -> {V}, + find: (haystack: {V}, needle: V, init: number?) -> number?, - unpack: ({V}, number?, number?) -> ...V, + unpack: (list: {V}, i: number?, j: number?) -> ...V, pack: (...V) -> { n: number, [number]: V }, - getn: ({V}) -> number, - foreach: ({[K]: V}, (K, V) -> ()) -> (), + getn: (t: {V}) -> number, + foreach: (t: {[K]: V}, f: (K, V) -> ()) -> (), foreachi: ({V}, (number, V) -> ()) -> (), - move: ({V}, number, number, number, {V}?) -> {V}, - clear: ({[K]: V}) -> (), + move: (src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V}, + clear: (table: {[K]: V}) -> (), - isfrozen: ({[K]: V}) -> boolean, + isfrozen: (t: {[K]: V}) -> boolean, } declare debug: { - info: ((thread, number, string) -> R...) & ((number, string) -> R...) & (((A...) -> R1..., string) -> R2...), - traceback: ((string?, number?) -> string) & ((thread, string?, number?) -> string), + info: ((thread: thread, level: number, options: string) -> R...) & ((level: number, options: string) -> R...) & ((func: (A...) -> R1..., options: string) -> R2...), + traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string), } declare utf8: { char: (...number) -> string, charpattern: string, - codes: (string) -> ((string, number) -> (number, number), string, number), - codepoint: (string, number?, number?) -> ...number, - len: (string, number?, number?) -> (number?, number?), - offset: (string, number?, number?) -> number, + codes: (str: string) -> ((string, number) -> (number, number), string, number), + codepoint: (str: string, i: number?, j: number?) -> ...number, + len: (s: string, i: number?, j: number?) -> (number?, number?), + offset: (s: string, n: number?, i: number?) -> number, } -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 4e9b6882..ed1a49cd 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,7 +7,7 @@ #include -LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleNameResolution, false) +LUAU_FASTFLAGVARIABLE(LuauIceExceptionInheritanceChange, false) static std::string wrongNumberOfArgsString( size_t expectedCount, std::optional maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) @@ -70,7 +70,7 @@ struct ErrorConverter { if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType)) { - if (FFlag::LuauTypeMismatchModuleNameResolution && fileResolver != nullptr) + if (fileResolver != nullptr) { std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule); std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule); @@ -96,14 +96,7 @@ struct ErrorConverter if (!tm.reason.empty()) result += tm.reason + " "; - if (FFlag::LuauTypeMismatchModuleNameResolution) - { - result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver}); - } - else - { - result += Luau::toString(*tm.error); - } + result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver}); } else if (!tm.reason.empty()) { @@ -469,6 +462,11 @@ struct ErrorConverter { return "Code is too complex to typecheck! Consider simplifying the code around this area"; } + + std::string operator()(const TypePackMismatch& e) const + { + return "Type pack '" + toString(e.givenTp) + "' could not be converted into '" + toString(e.wantedTp) + "'"; + } }; struct InvalidNameChecker @@ -727,6 +725,11 @@ bool TypesAreUnrelated::operator==(const TypesAreUnrelated& rhs) const return left == rhs.left && right == rhs.right; } +bool TypePackMismatch::operator==(const TypePackMismatch& rhs) const +{ + return *wantedTp == *rhs.wantedTp && *givenTp == *rhs.givenTp; +} + std::string toString(const TypeError& error) { return toString(error, TypeErrorToStringOptions{}); @@ -878,6 +881,11 @@ void copyError(T& e, TypeArena& destArena, CloneState cloneState) else if constexpr (std::is_same_v) { } + else if constexpr (std::is_same_v) + { + e.wantedTp = clone(e.wantedTp); + e.givenTp = clone(e.givenTp); + } else static_assert(always_false_v, "Non-exhaustive type switch"); } @@ -922,4 +930,30 @@ const char* InternalCompilerError::what() const throw() return this->message.data(); } +// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted. +void throwRuntimeError(const std::string& message) +{ + if (FFlag::LuauIceExceptionInheritanceChange) + { + throw InternalCompilerError(message); + } + else + { + throw std::runtime_error(message); + } +} + +// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted. +void throwRuntimeError(const std::string& message, const std::string& moduleName) +{ + if (FFlag::LuauIceExceptionInheritanceChange) + { + throw InternalCompilerError(message, moduleName); + } + else + { + throw std::runtime_error(message); + } +} + } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 5705ac17..39e6428d 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1,11 +1,13 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Frontend.h" +#include "Luau/BuiltinDefinitions.h" #include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/Config.h" #include "Luau/ConstraintGraphBuilder.h" #include "Luau/ConstraintSolver.h" +#include "Luau/DataFlowGraphBuilder.h" #include "Luau/DcrLogger.h" #include "Luau/FileResolver.h" #include "Luau/Parser.h" @@ -15,7 +17,6 @@ #include "Luau/TypeChecker2.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" -#include "Luau/BuiltinDefinitions.h" #include #include @@ -26,10 +27,11 @@ LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAG(LuauNoMoreGlobalSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAG(DebugLuauLogSolverToJson); +LUAU_FASTFLAGVARIABLE(LuauFixMarkDirtyReverseDeps, false) +LUAU_FASTFLAGVARIABLE(LuauPersistTypesAfterGeneratingDocSyms, false) namespace Luau { @@ -110,24 +112,57 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c CloneState cloneState; - for (const auto& [name, ty] : checkedModule->declaredGlobals) + if (FFlag::LuauPersistTypesAfterGeneratingDocSyms) { - TypeId globalTy = clone(ty, globalTypes, cloneState); - std::string documentationSymbol = packageName + "/global/" + name; - generateDocumentationSymbols(globalTy, documentationSymbol); - globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + std::vector typesToPersist; + typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size()); - persist(globalTy); + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, globalTypes, cloneState); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + + typesToPersist.push_back(globalTy); + } + + for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, globalTypes, cloneState); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + globalScope->exportedTypeBindings[name] = globalTy; + + typesToPersist.push_back(globalTy.type); + } + + for (TypeId ty : typesToPersist) + { + persist(ty); + } } - - for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + else { - TypeFun globalTy = clone(ty, globalTypes, cloneState); - std::string documentationSymbol = packageName + "/globaltype/" + name; - generateDocumentationSymbols(globalTy.type, documentationSymbol); - globalScope->exportedTypeBindings[name] = globalTy; + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, globalTypes, cloneState); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - persist(globalTy.type); + persist(globalTy); + } + + for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, globalTypes, cloneState); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + globalScope->exportedTypeBindings[name] = globalTy; + + persist(globalTy.type); + } } return LoadDefinitionFileResult{true, parseResult, checkedModule}; @@ -159,24 +194,57 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t CloneState cloneState; - for (const auto& [name, ty] : checkedModule->declaredGlobals) + if (FFlag::LuauPersistTypesAfterGeneratingDocSyms) { - TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/global/" + name; - generateDocumentationSymbols(globalTy, documentationSymbol); - targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + std::vector typesToPersist; + typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size()); - persist(globalTy); + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + + typesToPersist.push_back(globalTy); + } + + for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + targetScope->exportedTypeBindings[name] = globalTy; + + typesToPersist.push_back(globalTy.type); + } + + for (TypeId ty : typesToPersist) + { + persist(ty); + } } - - for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + else { - TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/globaltype/" + name; - generateDocumentationSymbols(globalTy.type, documentationSymbol); - targetScope->exportedTypeBindings[name] = globalTy; + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - persist(globalTy.type); + persist(globalTy); + } + + for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + targetScope->exportedTypeBindings[name] = globalTy; + + persist(globalTy.type); + } } return LoadDefinitionFileResult{true, parseResult, checkedModule}; @@ -425,13 +493,13 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalsecond == nullptr) - throw std::runtime_error("Frontend::modules does not have data for " + name); + throwRuntimeError("Frontend::modules does not have data for " + name, name); } else { auto it2 = moduleResolver.modules.find(name); if (it2 == moduleResolver.modules.end() || it2->second == nullptr) - throw std::runtime_error("Frontend::modules does not have data for " + name); + throwRuntimeError("Frontend::modules does not have data for " + name, name); } return CheckResult{ @@ -488,23 +556,19 @@ CheckResult Frontend::check(const ModuleName& name, std::optional 0) - typeCheckerForAutocomplete.instantiationChildLimit = - std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt; + // TODO: This is a dirty ad hoc solution for autocomplete timeouts + // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit + // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle + if (FInt::LuauTarjanChildLimit > 0) + typeCheckerForAutocomplete.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt; - if (FInt::LuauTypeInferIterationLimit > 0) - typeCheckerForAutocomplete.unifierIterationLimit = - std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt; - } + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckerForAutocomplete.unifierIterationLimit = + std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt; ModulePtr moduleForAutocomplete = FFlag::DebugLuauDeferredConstraintResolution ? check(sourceModule, mode, environmentScope, requireCycles, /*forAutocomplete*/ true) @@ -518,10 +582,9 @@ CheckResult Frontend::check(const ModuleName& name, std::optional* marked sourceNode.dirtyModule = true; sourceNode.dirtyModuleForAutocomplete = true; - if (0 == reverseDeps.count(name)) - continue; + if (FFlag::LuauFixMarkDirtyReverseDeps) + { + if (0 == reverseDeps.count(next)) + continue; - sourceModules.erase(name); + sourceModules.erase(next); - const std::vector& dependents = reverseDeps[name]; - queue.insert(queue.end(), dependents.begin(), dependents.end()); + const std::vector& dependents = reverseDeps[next]; + queue.insert(queue.end(), dependents.begin(), dependents.end()); + } + else + { + if (0 == reverseDeps.count(name)) + continue; + + sourceModules.erase(name); + + const std::vector& dependents = reverseDeps[name]; + queue.insert(queue.end(), dependents.begin(), dependents.end()); + } } } @@ -857,13 +933,25 @@ ModulePtr Frontend::check( } } + DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, NotNull{&iceHandler}); + const NotNull mr{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}; const ScopePtr& globalScope{forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope}; Normalizer normalizer{&result->internalTypes, singletonTypes, NotNull{&typeChecker.unifierState}}; ConstraintGraphBuilder cgb{ - sourceModule.name, result, &result->internalTypes, mr, singletonTypes, NotNull(&iceHandler), globalScope, logger.get()}; + sourceModule.name, + result, + &result->internalTypes, + mr, + singletonTypes, + NotNull(&iceHandler), + globalScope, + logger.get(), + NotNull{&dfg}, + }; + cgb.visit(sourceModule.root); result->errors = std::move(cgb.errors); @@ -986,11 +1074,11 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const double timestamp = getTimestamp(); - auto parseResult = Luau::Parser::parse(src.data(), src.size(), *sourceModule.names, *sourceModule.allocator, parseOptions); + Luau::ParseResult parseResult = Luau::Parser::parse(src.data(), src.size(), *sourceModule.names, *sourceModule.allocator, parseOptions); stats.timeParse += getTimestamp() - timestamp; stats.files++; - stats.lines += std::count(src.begin(), src.end(), '\n') + (src.size() && src.back() != '\n'); + stats.lines += parseResult.lines; if (!parseResult.errors.empty()) sourceModule.parseErrors.insert(sourceModule.parseErrors.end(), parseResult.errors.begin(), parseResult.errors.end()); diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index e4fac455..b47270a0 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -188,6 +188,8 @@ static void errorToString(std::ostream& stream, const T& err) stream << "TypesAreUnrelated { left = '" + toString(err.left) + "', right = '" + toString(err.right) + "' }"; else if constexpr (std::is_same_v) stream << "NormalizationTooComplex { }"; + else if constexpr (std::is_same_v) + stream << "TypePackMismatch { wanted = '" + toString(err.wantedTp) + "', given = '" + toString(err.givenTp) + "' }"; else static_assert(always_false_v, "Non-exhaustive type switch"); } diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 31a089a4..0412f007 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -60,36 +60,6 @@ bool isWithinComment(const SourceModule& sourceModule, Position pos) return contains(pos, *iter); } -struct ForceNormal : TypeVarOnceVisitor -{ - const TypeArena* typeArena = nullptr; - - ForceNormal(const TypeArena* typeArena) - : typeArena(typeArena) - { - } - - bool visit(TypeId ty) override - { - if (ty->owningArena != typeArena) - return false; - - asMutable(ty)->normal = true; - return true; - } - - bool visit(TypeId ty, const FreeTypeVar& ftv) override - { - visit(ty); - return true; - } - - bool visit(TypePackId tp, const FreeTypePack& ftp) override - { - return true; - } -}; - struct ClonePublicInterface : Substitution { NotNull singletonTypes; @@ -241,8 +211,6 @@ void Module::clonePublicInterface(NotNull singletonTypes, Intern moduleScope->varargPack = varargPack; } - ForceNormal forceNormal{&interfaceTypes}; - if (exportedTypeBindings) { for (auto& [name, tf] : *exportedTypeBindings) @@ -262,7 +230,6 @@ void Module::clonePublicInterface(NotNull singletonTypes, Intern { auto t = asMutable(ty); t->ty = AnyTypeVar{}; - t->normal = true; } } } diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 81114b76..21e9f787 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -7,8 +7,9 @@ #include "Luau/Clone.h" #include "Luau/Common.h" +#include "Luau/RecursionCounter.h" +#include "Luau/TypeVar.h" #include "Luau/Unifier.h" -#include "Luau/VisitTypeVar.h" LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) @@ -16,11 +17,13 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); -LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false); +LUAU_FASTFLAGVARIABLE(LuauNegatedStringSingletons, false); +LUAU_FASTFLAGVARIABLE(LuauNegatedFunctionTypes, false); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauOverloadedFunctionSubtypingPerf); namespace Luau { @@ -107,12 +110,132 @@ bool TypeIds::operator==(const TypeIds& there) const return hash == there.hash && types == there.types; } +NormalizedStringType::NormalizedStringType(bool isCofinite, std::optional> singletons) + : isCofinite(isCofinite) + , singletons(std::move(singletons)) +{ + if (!FFlag::LuauNegatedStringSingletons) + LUAU_ASSERT(!isCofinite); +} + +void NormalizedStringType::resetToString() +{ + if (FFlag::LuauNegatedStringSingletons) + { + isCofinite = true; + singletons->clear(); + } + else + singletons.reset(); +} + +void NormalizedStringType::resetToNever() +{ + if (FFlag::LuauNegatedStringSingletons) + { + isCofinite = false; + singletons.emplace(); + } + else + { + if (singletons) + singletons->clear(); + else + singletons.emplace(); + } +} + +bool NormalizedStringType::isNever() const +{ + if (FFlag::LuauNegatedStringSingletons) + return !isCofinite && singletons->empty(); + else + return singletons && singletons->empty(); +} + +bool NormalizedStringType::isString() const +{ + if (FFlag::LuauNegatedStringSingletons) + return isCofinite && singletons->empty(); + else + return !singletons; +} + +bool NormalizedStringType::isUnion() const +{ + if (FFlag::LuauNegatedStringSingletons) + return !isCofinite; + else + return singletons.has_value(); +} + +bool NormalizedStringType::isIntersection() const +{ + if (FFlag::LuauNegatedStringSingletons) + return isCofinite; + else + return false; +} + +bool NormalizedStringType::includes(const std::string& str) const +{ + if (isString()) + return true; + else if (isUnion() && singletons->count(str)) + return true; + else if (isIntersection() && !singletons->count(str)) + return true; + else + return false; +} + +const NormalizedStringType NormalizedStringType::never{false, {{}}}; + +bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr) +{ + if (subStr.isUnion() && superStr.isUnion()) + { + for (auto [name, ty] : *subStr.singletons) + { + if (!superStr.singletons->count(name)) + return false; + } + } + else if (subStr.isString() && superStr.isUnion()) + return false; + + return true; +} + +NormalizedFunctionType::NormalizedFunctionType() + : parts(FFlag::LuauNegatedFunctionTypes ? std::optional{TypeIds{}} : std::nullopt) +{ +} + +void NormalizedFunctionType::resetToTop() +{ + isTop = true; + parts.emplace(); +} + +void NormalizedFunctionType::resetToNever() +{ + isTop = false; + parts.emplace(); +} + +bool NormalizedFunctionType::isNever() const +{ + return !isTop && (!parts || parts->empty()); +} + NormalizedType::NormalizedType(NotNull singletonTypes) : tops(singletonTypes->neverType) , booleans(singletonTypes->neverType) , errors(singletonTypes->neverType) , nils(singletonTypes->neverType) , numbers(singletonTypes->neverType) + , strings{NormalizedStringType::never} , threads(singletonTypes->neverType) { } @@ -120,8 +243,8 @@ NormalizedType::NormalizedType(NotNull singletonTypes) static bool isInhabited(const NormalizedType& norm) { return !get(norm.tops) || !get(norm.booleans) || !norm.classes.empty() || !get(norm.errors) || - !get(norm.nils) || !get(norm.numbers) || !norm.strings || !norm.strings->empty() || - !get(norm.threads) || norm.functions || !norm.tables.empty() || !norm.tyvars.empty(); + !get(norm.nils) || !get(norm.numbers) || !norm.strings.isNever() || !get(norm.threads) || + !norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty(); } static int tyvarIndex(TypeId ty) @@ -183,10 +306,10 @@ static bool isNormalizedNumber(TypeId ty) static bool isNormalizedString(const NormalizedStringType& ty) { - if (!ty) + if (ty.isString()) return true; - for (auto& [str, ty] : *ty) + for (auto& [str, ty] : *ty.singletons) { if (const SingletonTypeVar* stv = get(ty)) { @@ -217,10 +340,14 @@ static bool isNormalizedThread(TypeId ty) static bool areNormalizedFunctions(const NormalizedFunctionType& tys) { - if (tys) - for (TypeId ty : *tys) + if (tys.parts) + { + for (TypeId ty : *tys.parts) + { if (!get(ty) && !get(ty)) return false; + } + } return true; } @@ -317,13 +444,10 @@ void Normalizer::clearNormal(NormalizedType& norm) norm.errors = singletonTypes->neverType; norm.nils = singletonTypes->neverType; norm.numbers = singletonTypes->neverType; - if (norm.strings) - norm.strings->clear(); - else - norm.strings.emplace(); + norm.strings.resetToNever(); norm.threads = singletonTypes->neverType; norm.tables.clear(); - norm.functions = std::nullopt; + norm.functions.resetToNever(); norm.tyvars.clear(); } @@ -495,10 +619,56 @@ void Normalizer::unionClasses(TypeIds& heres, const TypeIds& theres) void Normalizer::unionStrings(NormalizedStringType& here, const NormalizedStringType& there) { - if (!there) - here.reset(); - else if (here) - here->insert(there->begin(), there->end()); + if (FFlag::LuauNegatedStringSingletons) + { + if (there.isString()) + here.resetToString(); + else if (here.isUnion() && there.isUnion()) + here.singletons->insert(there.singletons->begin(), there.singletons->end()); + else if (here.isUnion() && there.isIntersection()) + { + here.isCofinite = true; + for (const auto& pair : *there.singletons) + { + auto it = here.singletons->find(pair.first); + if (it != end(*here.singletons)) + here.singletons->erase(it); + else + here.singletons->insert(pair); + } + } + else if (here.isIntersection() && there.isUnion()) + { + for (const auto& [name, ty] : *there.singletons) + here.singletons->erase(name); + } + else if (here.isIntersection() && there.isIntersection()) + { + auto iter = begin(*here.singletons); + auto endIter = end(*here.singletons); + + while (iter != endIter) + { + if (!there.singletons->count(iter->first)) + { + auto eraseIt = iter; + ++iter; + here.singletons->erase(eraseIt); + } + else + ++iter; + } + } + else + LUAU_ASSERT(!"Unreachable"); + } + else + { + if (there.isString()) + here.resetToString(); + else if (here.isUnion()) + here.singletons->insert(there.singletons->begin(), there.singletons->end()); + } } std::optional Normalizer::unionOfTypePacks(TypePackId here, TypePackId there) @@ -666,20 +836,28 @@ std::optional Normalizer::unionOfFunctions(TypeId here, TypeId there) void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) { - if (!theres) + if (FFlag::LuauNegatedFunctionTypes) + { + if (heres.isTop) + return; + if (theres.isTop) + heres.resetToTop(); + } + + if (theres.isNever()) return; TypeIds tmps; - if (!heres) + if (heres.isNever()) { - tmps.insert(theres->begin(), theres->end()); - heres = std::move(tmps); + tmps.insert(theres.parts->begin(), theres.parts->end()); + heres.parts = std::move(tmps); return; } - for (TypeId here : *heres) - for (TypeId there : *theres) + for (TypeId here : *heres.parts) + for (TypeId there : *theres.parts) { if (std::optional fun = unionOfFunctions(here, there)) tmps.insert(*fun); @@ -687,28 +865,28 @@ void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedF tmps.insert(singletonTypes->errorRecoveryType(there)); } - heres = std::move(tmps); + heres.parts = std::move(tmps); } void Normalizer::unionFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there) { - if (!heres) + if (heres.isNever()) { TypeIds tmps; tmps.insert(there); - heres = std::move(tmps); + heres.parts = std::move(tmps); return; } TypeIds tmps; - for (TypeId here : *heres) + for (TypeId here : *heres.parts) { if (std::optional fun = unionOfFunctions(here, there)) tmps.insert(*fun); else tmps.insert(singletonTypes->errorRecoveryType(there)); } - heres = std::move(tmps); + heres.parts = std::move(tmps); } void Normalizer::unionTablesWithTable(TypeIds& heres, TypeId there) @@ -858,9 +1036,14 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor else if (ptv->type == PrimitiveTypeVar::Number) here.numbers = there; else if (ptv->type == PrimitiveTypeVar::String) - here.strings = std::nullopt; + here.strings.resetToString(); else if (ptv->type == PrimitiveTypeVar::Thread) here.threads = there; + else if (ptv->type == PrimitiveTypeVar::Function) + { + LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); + here.functions.resetToTop(); + } else LUAU_ASSERT(!"Unreachable"); } @@ -870,12 +1053,36 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor here.booleans = unionOfBools(here.booleans, there); else if (const StringSingleton* sstv = get(stv)) { - if (here.strings) - here.strings->insert({sstv->value, there}); + if (FFlag::LuauNegatedStringSingletons) + { + if (here.strings.isCofinite) + { + auto it = here.strings.singletons->find(sstv->value); + if (it != here.strings.singletons->end()) + here.strings.singletons->erase(it); + } + else + here.strings.singletons->insert({sstv->value, there}); + } + else + { + if (here.strings.isUnion()) + here.strings.singletons->insert({sstv->value, there}); + } } else LUAU_ASSERT(!"Unreachable"); } + else if (const NegationTypeVar* ntv = get(there)) + { + const NormalizedType* thereNormal = normalize(ntv->ty); + std::optional tn = negateNormal(*thereNormal); + if (!tn) + return false; + + if (!unionNormals(here, *tn)) + return false; + } else LUAU_ASSERT(!"Unreachable"); @@ -887,6 +1094,177 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor return true; } +// ------- Negations + +std::optional Normalizer::negateNormal(const NormalizedType& here) +{ + NormalizedType result{singletonTypes}; + if (!get(here.tops)) + { + // The negation of unknown or any is never. Easy. + return result; + } + + if (!get(here.errors)) + { + // Negating an error yields the same error. + result.errors = here.errors; + return result; + } + + if (get(here.booleans)) + result.booleans = singletonTypes->booleanType; + else if (get(here.booleans)) + result.booleans = singletonTypes->neverType; + else if (auto stv = get(here.booleans)) + { + auto boolean = get(stv); + LUAU_ASSERT(boolean != nullptr); + if (boolean->value) + result.booleans = singletonTypes->falseType; + else + result.booleans = singletonTypes->trueType; + } + + result.classes = negateAll(here.classes); + result.nils = get(here.nils) ? singletonTypes->nilType : singletonTypes->neverType; + result.numbers = get(here.numbers) ? singletonTypes->numberType : singletonTypes->neverType; + + result.strings = here.strings; + result.strings.isCofinite = !result.strings.isCofinite; + + result.threads = get(here.threads) ? singletonTypes->threadType : singletonTypes->neverType; + + /* + * Things get weird and so, so complicated if we allow negations of + * arbitrary function types. Ordinary code can never form these kinds of + * types, so we decline to negate them. + */ + if (FFlag::LuauNegatedFunctionTypes) + { + if (here.functions.isNever()) + result.functions.resetToTop(); + else if (here.functions.isTop) + result.functions.resetToNever(); + else + return std::nullopt; + } + + // TODO: negating tables + // TODO: negating tyvars? + + return result; +} + +TypeIds Normalizer::negateAll(const TypeIds& theres) +{ + TypeIds tys; + for (TypeId there : theres) + tys.insert(negate(there)); + return tys; +} + +TypeId Normalizer::negate(TypeId there) +{ + there = follow(there); + if (get(there)) + return there; + else if (get(there)) + return singletonTypes->neverType; + else if (get(there)) + return singletonTypes->unknownType; + else if (auto ntv = get(there)) + return ntv->ty; // TODO: do we want to normalize this? + else if (auto utv = get(there)) + { + std::vector parts; + for (TypeId option : utv) + parts.push_back(negate(option)); + return arena->addType(IntersectionTypeVar{std::move(parts)}); + } + else if (auto itv = get(there)) + { + std::vector options; + for (TypeId part : itv) + options.push_back(negate(part)); + return arena->addType(UnionTypeVar{std::move(options)}); + } + else + return there; +} + +void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty) +{ + const PrimitiveTypeVar* ptv = get(follow(ty)); + LUAU_ASSERT(ptv); + switch (ptv->type) + { + case PrimitiveTypeVar::NilType: + here.nils = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Boolean: + here.booleans = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Number: + here.numbers = singletonTypes->neverType; + break; + case PrimitiveTypeVar::String: + here.strings.resetToNever(); + break; + case PrimitiveTypeVar::Thread: + here.threads = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Function: + LUAU_ASSERT(FFlag::LuauNegatedStringSingletons); + here.functions.resetToNever(); + break; + } +} + +void Normalizer::subtractSingleton(NormalizedType& here, TypeId ty) +{ + LUAU_ASSERT(FFlag::LuauNegatedStringSingletons); + + const SingletonTypeVar* stv = get(ty); + LUAU_ASSERT(stv); + + if (const StringSingleton* ss = get(stv)) + { + if (here.strings.isCofinite) + here.strings.singletons->insert({ss->value, ty}); + else + { + auto it = here.strings.singletons->find(ss->value); + if (it != here.strings.singletons->end()) + here.strings.singletons->erase(it); + } + } + else if (const BooleanSingleton* bs = get(stv)) + { + if (get(here.booleans)) + { + // Nothing + } + else if (get(here.booleans)) + here.booleans = bs->value ? singletonTypes->falseType : singletonTypes->trueType; + else if (auto hereSingleton = get(here.booleans)) + { + const BooleanSingleton* hereBooleanSingleton = get(hereSingleton); + LUAU_ASSERT(hereBooleanSingleton); + + // Crucial subtlety: ty (and thus bs) are the value that is being + // negated out. We therefore reduce to never when the values match, + // rather than when they differ. + if (bs->value == hereBooleanSingleton->value) + here.booleans = singletonTypes->neverType; + } + else + LUAU_ASSERT(!"Unreachable"); + } + else + LUAU_ASSERT(!"Unreachable"); +} + // ------- Normalizing intersections TypeId Normalizer::intersectionOfTops(TypeId here, TypeId there) { @@ -971,17 +1349,17 @@ void Normalizer::intersectClassesWithClass(TypeIds& heres, TypeId there) void Normalizer::intersectStrings(NormalizedStringType& here, const NormalizedStringType& there) { - if (!there) + if (there.isString()) return; - if (!here) - here.emplace(); + if (here.isString()) + here.resetToNever(); - for (auto it = here->begin(); it != here->end();) + for (auto it = here.singletons->begin(); it != here.singletons->end();) { - if (there->count(it->first)) + if (there.singletons->count(it->first)) it++; else - it = here->erase(it); + it = here.singletons->erase(it); } } @@ -1269,19 +1647,35 @@ std::optional Normalizer::intersectionOfFunctions(TypeId here, TypeId th return std::nullopt; if (hftv->genericPacks != tftv->genericPacks) return std::nullopt; - if (hftv->retTypes != tftv->retTypes) + + TypePackId argTypes; + TypePackId retTypes; + + if (hftv->retTypes == tftv->retTypes) + { + std::optional argTypesOpt = unionOfTypePacks(hftv->argTypes, tftv->argTypes); + if (!argTypesOpt) + return std::nullopt; + argTypes = *argTypesOpt; + retTypes = hftv->retTypes; + } + else if (FFlag::LuauOverloadedFunctionSubtypingPerf && hftv->argTypes == tftv->argTypes) + { + std::optional retTypesOpt = intersectionOfTypePacks(hftv->argTypes, tftv->argTypes); + if (!retTypesOpt) + return std::nullopt; + argTypes = hftv->argTypes; + retTypes = *retTypesOpt; + } + else return std::nullopt; - std::optional argTypes = unionOfTypePacks(hftv->argTypes, tftv->argTypes); - if (!argTypes) - return std::nullopt; - - if (*argTypes == hftv->argTypes) + if (argTypes == hftv->argTypes && retTypes == hftv->retTypes) return here; - if (*argTypes == tftv->argTypes) + if (argTypes == tftv->argTypes && retTypes == tftv->retTypes) return there; - FunctionTypeVar result{*argTypes, hftv->retTypes}; + FunctionTypeVar result{argTypes, retTypes}; result.generics = hftv->generics; result.genericPacks = hftv->genericPacks; return arena->addType(std::move(result)); @@ -1405,18 +1799,20 @@ std::optional Normalizer::unionSaturatedFunctions(TypeId here, TypeId th void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there) { - if (!heres) + if (heres.isNever()) return; - for (auto it = heres->begin(); it != heres->end();) + heres.isTop = false; + + for (auto it = heres.parts->begin(); it != heres.parts->end();) { TypeId here = *it; if (get(here)) it++; else if (std::optional tmp = intersectionOfFunctions(here, there)) { - heres->erase(it); - heres->insert(*tmp); + heres.parts->erase(it); + heres.parts->insert(*tmp); return; } else @@ -1424,27 +1820,27 @@ void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, T } TypeIds tmps; - for (TypeId here : *heres) + for (TypeId here : *heres.parts) { if (std::optional tmp = unionSaturatedFunctions(here, there)) tmps.insert(*tmp); } - heres->insert(there); - heres->insert(tmps.begin(), tmps.end()); + heres.parts->insert(there); + heres.parts->insert(tmps.begin(), tmps.end()); } void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) { - if (!heres) + if (heres.isNever()) return; - else if (!theres) + else if (theres.isNever()) { - heres = std::nullopt; + heres.resetToNever(); return; } else { - for (TypeId there : *theres) + for (TypeId there : *theres.parts) intersectFunctionsWithFunction(heres, there); } } @@ -1602,6 +1998,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) TypeId nils = here.nils; TypeId numbers = here.numbers; NormalizedStringType strings = std::move(here.strings); + NormalizedFunctionType functions = std::move(here.functions); TypeId threads = here.threads; clearNormal(here); @@ -1616,6 +2013,11 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) here.strings = std::move(strings); else if (ptv->type == PrimitiveTypeVar::Thread) here.threads = threads; + else if (ptv->type == PrimitiveTypeVar::Function) + { + LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); + here.functions = std::move(functions); + } else LUAU_ASSERT(!"Unreachable"); } @@ -1630,12 +2032,37 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) here.booleans = intersectionOfBools(booleans, there); else if (const StringSingleton* sstv = get(stv)) { - if (!strings || strings->count(sstv->value)) - here.strings->insert({sstv->value, there}); + if (strings.includes(sstv->value)) + here.strings.singletons->insert({sstv->value, there}); } else LUAU_ASSERT(!"Unreachable"); } + else if (const NegationTypeVar* ntv = get(there); FFlag::LuauNegatedStringSingletons && ntv) + { + TypeId t = follow(ntv->ty); + if (const PrimitiveTypeVar* ptv = get(t)) + subtractPrimitive(here, ntv->ty); + else if (const SingletonTypeVar* stv = get(t)) + subtractSingleton(here, follow(ntv->ty)); + else if (const UnionTypeVar* itv = get(t)) + { + for (TypeId part : itv->options) + { + const NormalizedType* normalPart = normalize(part); + std::optional negated = negateNormal(*normalPart); + if (!negated) + return false; + intersectNormals(here, *negated); + } + } + else + { + // TODO negated unions, intersections, table, and function. + // Report a TypeError for other types. + LUAU_ASSERT(!"Unimplemented"); + } + } else LUAU_ASSERT(!"Unreachable"); @@ -1660,14 +2087,16 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) result.insert(result.end(), norm.classes.begin(), norm.classes.end()); if (!get(norm.errors)) result.push_back(norm.errors); - if (norm.functions) + if (FFlag::LuauNegatedFunctionTypes && norm.functions.isTop) + result.push_back(singletonTypes->functionType); + else if (!norm.functions.isNever()) { - if (norm.functions->size() == 1) - result.push_back(*norm.functions->begin()); + if (norm.functions.parts->size() == 1) + result.push_back(*norm.functions.parts->begin()); else { std::vector parts; - parts.insert(parts.end(), norm.functions->begin(), norm.functions->end()); + parts.insert(parts.end(), norm.functions.parts->begin(), norm.functions.parts->end()); result.push_back(arena->addType(IntersectionTypeVar{std::move(parts)})); } } @@ -1675,11 +2104,25 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) result.push_back(norm.nils); if (!get(norm.numbers)) result.push_back(norm.numbers); - if (norm.strings) - for (auto& [_, ty] : *norm.strings) - result.push_back(ty); - else + if (norm.strings.isString()) result.push_back(singletonTypes->stringType); + else if (norm.strings.isUnion()) + { + for (auto& [_, ty] : *norm.strings.singletons) + result.push_back(ty); + } + else if (FFlag::LuauNegatedStringSingletons && norm.strings.isIntersection()) + { + std::vector parts; + parts.push_back(singletonTypes->stringType); + for (const auto& [name, ty] : *norm.strings.singletons) + parts.push_back(arena->addType(NegationTypeVar{ty})); + + result.push_back(arena->addType(IntersectionTypeVar{std::move(parts)})); + } + if (!get(norm.threads)) + result.push_back(singletonTypes->threadType); + result.insert(result.end(), norm.tables.begin(), norm.tables.end()); for (auto& [tyvar, intersect] : norm.tyvars) { @@ -1700,672 +2143,28 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) return arena->addType(UnionTypeVar{std::move(result)}); } -namespace -{ - -struct Replacer -{ - TypeArena* arena; - TypeId sourceType; - TypeId replacedType; - DenseHashMap newTypes; - - Replacer(TypeArena* arena, TypeId sourceType, TypeId replacedType) - : arena(arena) - , sourceType(sourceType) - , replacedType(replacedType) - , newTypes(nullptr) - { - } - - TypeId smartClone(TypeId t) - { - t = follow(t); - TypeId* res = newTypes.find(t); - if (res) - return *res; - - TypeId result = shallowClone(t, *arena, TxnLog::empty()); - newTypes[t] = result; - newTypes[result] = result; - - return result; - } -}; - -} // anonymous namespace - -bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop) +bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; - u.anyIsTop = anyIsTop; u.tryUnify(subTy, superTy); const bool ok = u.errors.empty() && u.log.empty(); return ok; } -bool isSubtype( - TypePackId subPack, TypePackId superPack, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop) +bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; - u.anyIsTop = anyIsTop; u.tryUnify(subPack, superPack); const bool ok = u.errors.empty() && u.log.empty(); return ok; } -template -static bool areNormal_(const T& t, const std::unordered_set& seen, InternalErrorReporter& ice) -{ - int count = 0; - auto isNormal = [&](TypeId ty) { - ++count; - if (count >= FInt::LuauNormalizeIterationLimit) - ice.ice("Luau::areNormal hit iteration limit"); - - return ty->normal; - }; - - return std::all_of(begin(t), end(t), isNormal); -} - -static bool areNormal(const std::vector& types, const std::unordered_set& seen, InternalErrorReporter& ice) -{ - return areNormal_(types, seen, ice); -} - -static bool areNormal(TypePackId tp, const std::unordered_set& seen, InternalErrorReporter& ice) -{ - tp = follow(tp); - if (get(tp)) - return false; - - auto [head, tail] = flatten(tp); - - if (!areNormal_(head, seen, ice)) - return false; - - if (!tail) - return true; - - if (auto vtp = get(*tail)) - return vtp->ty->normal || follow(vtp->ty)->normal || seen.find(asMutable(vtp->ty)) != seen.end(); - - return true; -} - -#define CHECK_ITERATION_LIMIT(...) \ - do \ - { \ - if (iterationLimit > FInt::LuauNormalizeIterationLimit) \ - { \ - limitExceeded = true; \ - return __VA_ARGS__; \ - } \ - ++iterationLimit; \ - } while (false) - -struct Normalize final : TypeVarVisitor -{ - using TypeVarVisitor::Set; - - Normalize(TypeArena& arena, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) - : arena(arena) - , scope(scope) - , singletonTypes(singletonTypes) - , ice(ice) - { - } - - TypeArena& arena; - NotNull scope; - NotNull singletonTypes; - InternalErrorReporter& ice; - - int iterationLimit = 0; - bool limitExceeded = false; - - bool visit(TypeId ty, const FreeTypeVar&) override - { - LUAU_ASSERT(!ty->normal); - return false; - } - - bool visit(TypeId ty, const BoundTypeVar& btv) override - { - // A type could be considered normal when it is in the stack, but we will eventually find out it is not normal as normalization progresses. - // So we need to avoid eagerly saying that this bound type is normal if the thing it is bound to is in the stack. - if (seen.find(asMutable(btv.boundTo)) != seen.end()) - return false; - - // It should never be the case that this TypeVar is normal, but is bound to a non-normal type, except in nontrivial cases. - LUAU_ASSERT(!ty->normal || ty->normal == btv.boundTo->normal); - - if (!ty->normal) - asMutable(ty)->normal = btv.boundTo->normal; - return !ty->normal; - } - - bool visit(TypeId ty, const PrimitiveTypeVar&) override - { - LUAU_ASSERT(ty->normal); - return false; - } - - bool visit(TypeId ty, const GenericTypeVar&) override - { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; - } - - bool visit(TypeId ty, const ErrorTypeVar&) override - { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; - } - - bool visit(TypeId ty, const UnknownTypeVar&) override - { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; - } - - bool visit(TypeId ty, const NeverTypeVar&) override - { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; - } - - bool visit(TypeId ty, const ConstrainedTypeVar& ctvRef) override - { - CHECK_ITERATION_LIMIT(false); - LUAU_ASSERT(!ty->normal); - - ConstrainedTypeVar* ctv = const_cast(&ctvRef); - - std::vector parts = std::move(ctv->parts); - - // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar - for (TypeId part : parts) - traverse(part); - - std::vector newParts = normalizeUnion(parts); - ctv->parts = std::move(newParts); - - return false; - } - - bool visit(TypeId ty, const FunctionTypeVar& ftv) override - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - traverse(ftv.argTypes); - traverse(ftv.retTypes); - - asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retTypes, seen, ice); - - return false; - } - - bool visit(TypeId ty, const TableTypeVar& ttv) override - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - bool normal = true; - - auto checkNormal = [&](TypeId t) { - // if t is on the stack, it is possible that this type is normal. - // If t is not normal and it is not on the stack, this type is definitely not normal. - if (!t->normal && seen.find(asMutable(t)) == seen.end()) - normal = false; - }; - - if (ttv.boundTo) - { - traverse(*ttv.boundTo); - asMutable(ty)->normal = (*ttv.boundTo)->normal; - return false; - } - - for (const auto& [_name, prop] : ttv.props) - { - traverse(prop.type); - checkNormal(prop.type); - } - - if (ttv.indexer) - { - traverse(ttv.indexer->indexType); - checkNormal(ttv.indexer->indexType); - traverse(ttv.indexer->indexResultType); - checkNormal(ttv.indexer->indexResultType); - } - - // An unsealed table can never be normal, ditto for free tables iff the type it is bound to is also not normal. - if (ttv.state == TableState::Generic || ttv.state == TableState::Sealed || (ttv.state == TableState::Free && follow(ty)->normal)) - asMutable(ty)->normal = normal; - - return false; - } - - bool visit(TypeId ty, const MetatableTypeVar& mtv) override - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - traverse(mtv.table); - traverse(mtv.metatable); - - asMutable(ty)->normal = mtv.table->normal && mtv.metatable->normal; - - return false; - } - - bool visit(TypeId ty, const ClassTypeVar& ctv) override - { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; - } - - bool visit(TypeId ty, const AnyTypeVar&) override - { - LUAU_ASSERT(ty->normal); - return false; - } - - bool visit(TypeId ty, const UnionTypeVar& utvRef) override - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - UnionTypeVar* utv = &const_cast(utvRef); - - // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar - for (TypeId option : utv->options) - traverse(option); - - std::vector newOptions = normalizeUnion(utv->options); - - const bool normal = areNormal(newOptions, seen, ice); - - LUAU_ASSERT(!newOptions.empty()); - - if (newOptions.size() == 1) - *asMutable(ty) = BoundTypeVar{newOptions[0]}; - else - utv->options = std::move(newOptions); - - asMutable(ty)->normal = normal; - - return false; - } - - bool visit(TypeId ty, const IntersectionTypeVar& itvRef) override - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - IntersectionTypeVar* itv = &const_cast(itvRef); - - std::vector oldParts = itv->parts; - IntersectionTypeVar newIntersection; - - for (TypeId part : oldParts) - traverse(part); - - std::vector tables; - for (TypeId part : oldParts) - { - part = follow(part); - if (get(part)) - tables.push_back(part); - else - { - Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD - combineIntoIntersection(replacer, &newIntersection, part); - } - } - - // Don't allocate a new table if there's just one in the intersection. - if (tables.size() == 1) - newIntersection.parts.push_back(tables[0]); - else if (!tables.empty()) - { - const TableTypeVar* first = get(tables[0]); - LUAU_ASSERT(first); - - TypeId newTable = arena.addType(TableTypeVar{first->state, first->level}); - TableTypeVar* ttv = getMutable(newTable); - for (TypeId part : tables) - { - // Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need - // to be rewritten to point at 'newTable' in the clone. - Replacer replacer{&arena, part, newTable}; - combineIntoTable(replacer, ttv, part); - } - - newIntersection.parts.push_back(newTable); - } - - itv->parts = std::move(newIntersection.parts); - - asMutable(ty)->normal = areNormal(itv->parts, seen, ice); - - if (itv->parts.size() == 1) - { - TypeId part = itv->parts[0]; - *asMutable(ty) = BoundTypeVar{part}; - } - - return false; - } - - std::vector normalizeUnion(const std::vector& options) - { - if (options.size() == 1) - return options; - - std::vector result; - - for (TypeId part : options) - { - // AnyTypeVar always win the battle no matter what we do, so we're done. - if (FFlag::LuauUnknownAndNeverType && get(follow(part))) - return {part}; - - combineIntoUnion(result, part); - } - - return result; - } - - void combineIntoUnion(std::vector& result, TypeId ty) - { - ty = follow(ty); - if (auto utv = get(ty)) - { - for (TypeId t : utv) - { - // AnyTypeVar always win the battle no matter what we do, so we're done. - if (FFlag::LuauUnknownAndNeverType && get(t)) - { - result = {t}; - return; - } - - combineIntoUnion(result, t); - } - - return; - } - - for (TypeId& part : result) - { - if (isSubtype(ty, part, scope, singletonTypes, ice)) - return; // no need to do anything - else if (isSubtype(part, ty, scope, singletonTypes, ice)) - { - part = ty; // replace the less general type by the more general one - return; - } - } - - result.push_back(ty); - } - - /** - * @param replacer knows how to clone a type such that any recursive references point at the new containing type. - * @param result is an intersection that is safe for us to mutate in-place. - */ - void combineIntoIntersection(Replacer& replacer, IntersectionTypeVar* result, TypeId ty) - { - // Note: this check guards against running out of stack space - // so if you increase the size of a stack frame, you'll need to decrease the limit. - CHECK_ITERATION_LIMIT(); - - ty = follow(ty); - if (auto itv = get(ty)) - { - for (TypeId part : itv->parts) - combineIntoIntersection(replacer, result, part); - return; - } - - // Let's say that the last part of our result intersection is always a table, if any table is part of this intersection - if (get(ty)) - { - if (result->parts.empty()) - result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}})); - - TypeId theTable = result->parts.back(); - - if (!get(follow(theTable))) - { - result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}})); - theTable = result->parts.back(); - } - - TypeId newTable = replacer.smartClone(theTable); - result->parts.back() = newTable; - - combineIntoTable(replacer, getMutable(newTable), ty); - } - else if (auto ftv = get(ty)) - { - bool merged = false; - for (TypeId& part : result->parts) - { - if (isSubtype(part, ty, scope, singletonTypes, ice)) - { - merged = true; - break; // no need to do anything - } - else if (isSubtype(ty, part, scope, singletonTypes, ice)) - { - merged = true; - part = ty; // replace the less general type by the more general one - break; - } - } - - if (!merged) - result->parts.push_back(ty); - } - else - result->parts.push_back(ty); - } - - TableState combineTableStates(TableState lhs, TableState rhs) - { - if (lhs == rhs) - return lhs; - - if (lhs == TableState::Free || rhs == TableState::Free) - return TableState::Free; - - if (lhs == TableState::Unsealed || rhs == TableState::Unsealed) - return TableState::Unsealed; - - return lhs; - } - - /** - * @param replacer gives us a way to clone a type such that recursive references are rewritten to the new - * "containing" type. - * @param table always points into a table that is safe for us to mutate. - */ - void combineIntoTable(Replacer& replacer, TableTypeVar* table, TypeId ty) - { - // Note: this check guards against running out of stack space - // so if you increase the size of a stack frame, you'll need to decrease the limit. - CHECK_ITERATION_LIMIT(); - - LUAU_ASSERT(table); - - ty = follow(ty); - - TableTypeVar* tyTable = getMutable(ty); - LUAU_ASSERT(tyTable); - - for (const auto& [propName, prop] : tyTable->props) - { - if (auto it = table->props.find(propName); it != table->props.end()) - { - /** - * If we are going to recursively merge intersections of tables, we need to ensure that we never mutate - * a table that comes from somewhere else in the type graph. - * - * smarClone() does some nice things for us: It will perform a clone that is as shallow as possible - * while still rewriting any cyclic references back to the new 'root' table. - * - * replacer also keeps a mapping of types that have previously been copied, so we have the added - * advantage here of knowing that, whether or not a new copy was actually made, the resulting TypeVar is - * safe for us to mutate in-place. - */ - TypeId clone = replacer.smartClone(it->second.type); - it->second.type = combine(replacer, clone, prop.type); - } - else - table->props.insert({propName, prop}); - } - - if (tyTable->indexer) - { - if (table->indexer) - { - table->indexer->indexType = combine(replacer, replacer.smartClone(tyTable->indexer->indexType), table->indexer->indexType); - table->indexer->indexResultType = - combine(replacer, replacer.smartClone(tyTable->indexer->indexResultType), table->indexer->indexResultType); - } - else - { - table->indexer = - TableIndexer{replacer.smartClone(tyTable->indexer->indexType), replacer.smartClone(tyTable->indexer->indexResultType)}; - } - } - - table->state = combineTableStates(table->state, tyTable->state); - table->level = max(table->level, tyTable->level); - } - - /** - * @param a is always cloned by the caller. It is safe to mutate in-place. - * @param b will never be mutated. - */ - TypeId combine(Replacer& replacer, TypeId a, TypeId b) - { - b = follow(b); - - if (FFlag::LuauNormalizeCombineTableFix && a == b) - return a; - - if (!get(a) && !get(a)) - { - if (!FFlag::LuauNormalizeCombineTableFix && a == b) - return a; - else - return arena.addType(IntersectionTypeVar{{a, b}}); - } - - if (auto itv = getMutable(a)) - { - combineIntoIntersection(replacer, itv, b); - return a; - } - else if (auto ttv = getMutable(a)) - { - if (FFlag::LuauNormalizeCombineTableFix && !get(b)) - return arena.addType(IntersectionTypeVar{{a, b}}); - combineIntoTable(replacer, ttv, b); - return a; - } - - LUAU_ASSERT(!"Impossible"); - LUAU_UNREACHABLE(); - } -}; - -#undef CHECK_ITERATION_LIMIT - -/** - * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) - */ -std::pair normalize( - TypeId ty, NotNull scope, TypeArena& arena, NotNull singletonTypes, InternalErrorReporter& ice) -{ - CloneState state; - if (FFlag::DebugLuauCopyBeforeNormalizing) - (void)clone(ty, arena, state); - - Normalize n{arena, scope, singletonTypes, ice}; - n.traverse(ty); - - return {ty, !n.limitExceeded}; -} - -// TODO: Think about using a temporary arena and cloning types out of it so that we -// reclaim memory used by wantonly allocated intermediate types here. -// The main wrinkle here is that we don't want clone() to copy a type if the source and dest -// arena are the same. -std::pair normalize(TypeId ty, NotNull module, NotNull singletonTypes, InternalErrorReporter& ice) -{ - return normalize(ty, NotNull{module->getModuleScope().get()}, module->internalTypes, singletonTypes, ice); -} - -std::pair normalize(TypeId ty, const ModulePtr& module, NotNull singletonTypes, InternalErrorReporter& ice) -{ - return normalize(ty, NotNull{module.get()}, singletonTypes, ice); -} - -/** - * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) - */ -std::pair normalize( - TypePackId tp, NotNull scope, TypeArena& arena, NotNull singletonTypes, InternalErrorReporter& ice) -{ - CloneState state; - if (FFlag::DebugLuauCopyBeforeNormalizing) - (void)clone(tp, arena, state); - - Normalize n{arena, scope, singletonTypes, ice}; - n.traverse(tp); - - return {tp, !n.limitExceeded}; -} - -std::pair normalize(TypePackId tp, NotNull module, NotNull singletonTypes, InternalErrorReporter& ice) -{ - return normalize(tp, NotNull{module->getModuleScope().get()}, module->internalTypes, singletonTypes, ice); -} - -std::pair normalize(TypePackId tp, const ModulePtr& module, NotNull singletonTypes, InternalErrorReporter& ice) -{ - return normalize(tp, NotNull{module.get()}, singletonTypes, ice); -} - } // namespace Luau diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index e4c069bd..e9de094b 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -57,29 +57,6 @@ struct Quantifier final : TypeVarOnceVisitor return false; } - bool visit(TypeId ty, const ConstrainedTypeVar&) override - { - ConstrainedTypeVar* ctv = getMutable(ty); - - seenMutableType = true; - - if (!level.subsumes(ctv->level)) - return false; - - std::vector opts = std::move(ctv->parts); - - // We might transmute, so it's not safe to rely on the builtin traversal logic - for (TypeId opt : opts) - traverse(opt); - - if (opts.size() == 1) - *asMutable(ty) = BoundTypeVar{opts[0]}; - else - *asMutable(ty) = UnionTypeVar{std::move(opts)}; - - return false; - } - bool visit(TypeId ty, const TableTypeVar&) override { LUAU_ASSERT(getMutable(ty)); diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 9a7d3609..84925f79 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -27,6 +27,44 @@ void Scope::addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun) builtinTypeNames.insert(name); } +std::optional Scope::lookup(Symbol sym) const +{ + auto r = const_cast(this)->lookupEx(sym); + if (r) + return r->first; + else + return std::nullopt; +} + +std::optional> Scope::lookupEx(Symbol sym) +{ + Scope* s = this; + + while (true) + { + auto it = s->bindings.find(sym); + if (it != s->bindings.end()) + return std::pair{it->second.typeId, s}; + + if (s->parent) + s = s->parent.get(); + else + return std::nullopt; + } +} + +// TODO: We might kill Scope::lookup(Symbol) once data flow is fully fleshed out with type states and control flow analysis. +std::optional Scope::lookup(DefId def) const +{ + for (const Scope* current = this; current; current = current->parent.get()) + { + if (auto ty = current->dcrRefinements.find(def)) + return *ty; + } + + return std::nullopt; +} + std::optional Scope::lookupType(const Name& name) { const Scope* scope = this; @@ -111,23 +149,6 @@ std::optional Scope::linearSearchForBinding(const std::string& name, bo return std::nullopt; } -std::optional Scope::lookup(Symbol sym) -{ - Scope* s = this; - - while (true) - { - auto it = s->bindings.find(sym); - if (it != s->bindings.end()) - return it->second.typeId; - - if (s->parent) - s = s->parent.get(); - else - return std::nullopt; - } -} - bool subsumesStrict(Scope* left, Scope* right) { while (right) diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 2137d73e..20ed34f6 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -73,11 +73,6 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypeId part : itv->parts) visitChild(part); } - else if (const ConstrainedTypeVar* ctv = get(ty)) - { - for (TypeId part : ctv->parts) - visitChild(part); - } else if (const PendingExpansionTypeVar* petv = get(ty)) { for (TypeId a : petv->typeArguments) @@ -97,6 +92,10 @@ void Tarjan::visitChildren(TypeId ty, int index) if (ctv->metatable) visitChild(*ctv->metatable); } + else if (const NegationTypeVar* ntv = get(ty)) + { + visitChild(ntv->ty); + } } void Tarjan::visitChildren(TypePackId tp, int index) @@ -605,11 +604,6 @@ void Substitution::replaceChildren(TypeId ty) for (TypeId& part : itv->parts) part = replace(part); } - else if (ConstrainedTypeVar* ctv = getMutable(ty)) - { - for (TypeId& part : ctv->parts) - part = replace(part); - } else if (PendingExpansionTypeVar* petv = getMutable(ty)) { for (TypeId& a : petv->typeArguments) @@ -629,6 +623,10 @@ void Substitution::replaceChildren(TypeId ty) if (ctv->metatable) ctv->metatable = replace(*ctv->metatable); } + else if (NegationTypeVar* ntv = getMutable(ty)) + { + ntv->ty = replace(ntv->ty); + } } void Substitution::replaceChildren(TypePackId tp) diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 0d989ca0..68fa5393 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -237,15 +237,6 @@ void StateDot::visitChildren(TypeId ty, int index) finishNodeLabel(ty); finishNode(); } - else if (const ConstrainedTypeVar* ctv = get(ty)) - { - formatAppend(result, "ConstrainedTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - for (TypeId part : ctv->parts) - visitChild(part, index); - } else if (get(ty)) { formatAppend(result, "ErrorTypeVar %d", index); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 9572ef19..903e156b 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -10,11 +10,12 @@ #include #include +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauLvaluelessPath) LUAU_FASTFLAG(LuauUnknownAndNeverType) -LUAU_FASTFLAGVARIABLE(LuauSpecialTypesAsterisked, false) LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false) -LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) +LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) /* * Prefix generic typenames with gen- @@ -224,6 +225,20 @@ struct StringifierState result.name += s; } + void emitLevel(Scope* scope) + { + size_t count = 0; + for (Scope* s = scope; s; s = s->parent.get()) + ++count; + + emit(count); + emit("-"); + char buffer[16]; + uint32_t s = uint32_t(intptr_t(scope) & 0xFFFFFF); + snprintf(buffer, sizeof(buffer), "0x%x", s); + emit(buffer); + } + void emit(TypeLevel level) { emit(std::to_string(level.level)); @@ -295,10 +310,7 @@ struct TypeVarStringifier if (tv->ty.valueless_by_exception()) { state.result.error = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("* VALUELESS BY EXCEPTION *"); - else - state.emit("< VALUELESS BY EXCEPTION >"); + state.emit("* VALUELESS BY EXCEPTION *"); return; } @@ -376,7 +388,10 @@ struct TypeVarStringifier if (FFlag::DebugLuauVerboseTypeNames) { state.emit("-"); - state.emit(ftv.level); + if (FFlag::DebugLuauDeferredConstraintResolution) + state.emitLevel(ftv.scope); + else + state.emit(ftv.level); } } @@ -398,29 +413,15 @@ struct TypeVarStringifier } else state.emit(state.getName(ty)); - } - void operator()(TypeId, const ConstrainedTypeVar& ctv) - { - state.result.invalid = true; - - state.emit("["); if (FFlag::DebugLuauVerboseTypeNames) - state.emit(ctv.level); - state.emit("["); - - bool first = true; - for (TypeId ty : ctv.parts) { - if (first) - first = false; + state.emit("-"); + if (FFlag::DebugLuauDeferredConstraintResolution) + state.emitLevel(gtv.scope); else - state.emit("|"); - - stringify(ty); + state.emit(gtv.level); } - - state.emit("]]"); } void operator()(TypeId, const BlockedTypeVar& btv) @@ -456,9 +457,12 @@ struct TypeVarStringifier case PrimitiveTypeVar::Thread: state.emit("thread"); return; + case PrimitiveTypeVar::Function: + state.emit("function"); + return; default: LUAU_ASSERT(!"Unknown primitive type"); - throw std::runtime_error("Unknown primitive type " + std::to_string(ptv.type)); + throwRuntimeError("Unknown primitive type " + std::to_string(ptv.type)); } } @@ -475,7 +479,7 @@ struct TypeVarStringifier else { LUAU_ASSERT(!"Unknown singleton type"); - throw std::runtime_error("Unknown singleton type"); + throwRuntimeError("Unknown singleton type"); } } @@ -484,10 +488,7 @@ struct TypeVarStringifier if (state.hasSeen(&ftv)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLE*"); - else - state.emit(""); + state.emit("*CYCLE*"); return; } @@ -595,10 +596,7 @@ struct TypeVarStringifier if (state.hasSeen(&ttv)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLE*"); - else - state.emit(""); + state.emit("*CYCLE*"); return; } @@ -732,10 +730,7 @@ struct TypeVarStringifier if (state.hasSeen(&uv)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLE*"); - else - state.emit(""); + state.emit("*CYCLE*"); return; } @@ -802,10 +797,7 @@ struct TypeVarStringifier if (state.hasSeen(&uv)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLE*"); - else - state.emit(""); + state.emit("*CYCLE*"); return; } @@ -850,10 +842,7 @@ struct TypeVarStringifier void operator()(TypeId, const ErrorTypeVar& tv) { state.result.error = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); - else - state.emit(FFlag::LuauUnknownAndNeverType ? "" : "*unknown*"); + state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); } void operator()(TypeId, const LazyTypeVar& ltv) @@ -871,6 +860,23 @@ struct TypeVarStringifier { state.emit("never"); } + + void operator()(TypeId, const NegationTypeVar& ntv) + { + state.emit("~"); + + // The precedence of `~` should be less than `|` and `&`. + TypeId followed = follow(ntv.ty); + bool parens = get(followed) || get(followed); + + if (parens) + state.emit("("); + + stringify(ntv.ty); + + if (parens) + state.emit(")"); + } }; struct TypePackStringifier @@ -907,10 +913,7 @@ struct TypePackStringifier if (tp->ty.valueless_by_exception()) { state.result.error = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("* VALUELESS TP BY EXCEPTION *"); - else - state.emit("< VALUELESS TP BY EXCEPTION >"); + state.emit("* VALUELESS TP BY EXCEPTION *"); return; } @@ -934,10 +937,7 @@ struct TypePackStringifier if (state.hasSeen(&tp)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLETP*"); - else - state.emit(""); + state.emit("*CYCLETP*"); return; } @@ -982,10 +982,7 @@ struct TypePackStringifier void operator()(TypePackId, const Unifiable::Error& error) { state.result.error = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); - else - state.emit(FFlag::LuauUnknownAndNeverType ? "" : "*unknown*"); + state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); } void operator()(TypePackId, const VariadicTypePack& pack) @@ -993,10 +990,7 @@ struct TypePackStringifier state.emit("..."); if (FFlag::DebugLuauVerboseTypeNames && pack.hidden) { - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*hidden*"); - else - state.emit(""); + state.emit("*hidden*"); } stringify(pack.ty); } @@ -1031,7 +1025,10 @@ struct TypePackStringifier if (FFlag::DebugLuauVerboseTypeNames) { state.emit("-"); - state.emit(pack.level); + if (FFlag::DebugLuauDeferredConstraintResolution) + state.emitLevel(pack.scope); + else + state.emit(pack.level); } state.emit("..."); @@ -1204,10 +1201,7 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) { result.truncated = true; - if (FFlag::LuauSpecialTypesAsterisked) - result.name += "... *TRUNCATED*"; - else - result.name += "... "; + result.name += "... *TRUNCATED*"; } return result; @@ -1280,10 +1274,7 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) { - if (FFlag::LuauSpecialTypesAsterisked) - result.name += "... *TRUNCATED*"; - else - result.name += "... "; + result.name += "... *TRUNCATED*"; } return result; @@ -1442,7 +1433,7 @@ std::string generateName(size_t i) std::string toString(const Constraint& constraint, ToStringOptions& opts) { - auto go = [&opts](auto&& c) { + auto go = [&opts](auto&& c) -> std::string { using T = std::decay_t; // TODO: Inline and delete this function when clipping FFlag::LuauFixNameMaps @@ -1526,6 +1517,13 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) { return tos(c.resultType, opts) + " ~ hasProp " + tos(c.subjectType, opts) + ", \"" + c.prop + "\""; } + else if constexpr (std::is_same_v) + { + std::string result = tos(c.resultType, opts); + std::string discriminant = tos(c.discriminantType, opts); + + return result + " ~ if isSingleton D then ~D else unknown where D = " + discriminant; + } else static_assert(always_false_v, "Non-exhaustive constraint switch"); }; @@ -1545,6 +1543,8 @@ std::string dump(const Constraint& c) std::string toString(const LValue& lvalue) { + LUAU_ASSERT(!FFlag::LuauLvaluelessPath); + std::string s; for (const LValue* current = &lvalue; current; current = baseof(*current)) { @@ -1559,4 +1559,37 @@ std::string toString(const LValue& lvalue) return s; } +std::optional getFunctionNameAsString(const AstExpr& expr) +{ + LUAU_ASSERT(FFlag::LuauLvaluelessPath); + + const AstExpr* curr = &expr; + std::string s; + + for (;;) + { + if (auto local = curr->as()) + return local->local->name.value + s; + + if (auto global = curr->as()) + return global->name.value + s; + + if (auto indexname = curr->as()) + { + curr = indexname->expr; + + s = "." + std::string(indexname->index.value) + s; + } + else if (auto group = curr->as()) + { + curr = group->expr; + } + else + { + return std::nullopt; + } + } + + return s; +} } // namespace Luau diff --git a/Analysis/src/TopoSortStatements.cpp b/Analysis/src/TopoSortStatements.cpp index 1ea2e27d..052c10de 100644 --- a/Analysis/src/TopoSortStatements.cpp +++ b/Analysis/src/TopoSortStatements.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TopoSortStatements.h" +#include "Luau/Error.h" /* Decide the order in which we typecheck Lua statements in a block. * * Algorithm: @@ -149,7 +150,7 @@ Identifier mkName(const AstStatFunction& function) auto name = mkName(*function.name); LUAU_ASSERT(bool(name)); if (!name) - throw std::runtime_error("Internal error: Function declaration has a bad name"); + throwRuntimeError("Internal error: Function declaration has a bad name"); return *name; } @@ -255,7 +256,7 @@ struct ArcCollector : public AstVisitor { auto name = mkName(*node->name); if (!name) - throw std::runtime_error("Internal error: AstStatFunction has a bad name"); + throwRuntimeError("Internal error: AstStatFunction has a bad name"); add(*name); return true; diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 06bde195..034aeaec 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -251,7 +251,7 @@ PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) { - LUAU_ASSERT(get(ty) || get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); PendingType* newTy = queue(ty); if (FreeTypeVar* ftv = Luau::getMutable(newTy)) @@ -267,11 +267,6 @@ PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) { ftv->level = newLevel; } - else if (ConstrainedTypeVar* ctv = Luau::getMutable(newTy)) - { - if (FFlag::LuauUnknownAndNeverType) - ctv->level = newLevel; - } return newTy; } @@ -291,7 +286,7 @@ PendingTypePack* TxnLog::changeLevel(TypePackId tp, TypeLevel newLevel) PendingType* TxnLog::changeScope(TypeId ty, NotNull newScope) { - LUAU_ASSERT(get(ty) || get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); PendingType* newTy = queue(ty); if (FreeTypeVar* ftv = Luau::getMutable(newTy)) @@ -307,10 +302,6 @@ PendingType* TxnLog::changeScope(TypeId ty, NotNull newScope) { ftv->scope = newScope; } - else if (ConstrainedTypeVar* ctv = Luau::getMutable(newTy)) - { - ctv->scope = newScope; - } return newTy; } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 84494083..c97ed05d 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -104,16 +104,6 @@ public: return allocator->alloc(Location(), std::nullopt, AstName("*pending-expansion*")); } - AstType* operator()(const ConstrainedTypeVar& ctv) - { - AstArray types; - types.size = ctv.parts.size(); - types.data = static_cast(allocator->allocate(sizeof(AstType*) * ctv.parts.size())); - for (size_t i = 0; i < ctv.parts.size(); ++i) - types.data[i] = Luau::visit(*this, ctv.parts[i]->ty); - return allocator->alloc(Location(), types); - } - AstType* operator()(const SingletonTypeVar& stv) { if (const BooleanSingleton* bs = get(&stv)) @@ -348,6 +338,11 @@ public: { return allocator->alloc(Location(), std::nullopt, AstName{"never"}); } + AstType* operator()(const NegationTypeVar& ntv) + { + // FIXME: do the same thing we do with ErrorTypeVar + throwRuntimeError("Cannot convert NegationTypeVar into AstNode"); + } private: Allocator* allocator; diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 4753a7c2..dde41a65 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -5,6 +5,7 @@ #include "Luau/AstQuery.h" #include "Luau/Clone.h" #include "Luau/Instantiation.h" +#include "Luau/Metamethods.h" #include "Luau/Normalize.h" #include "Luau/ToString.h" #include "Luau/TxnLog.h" @@ -62,6 +63,23 @@ struct StackPusher } }; +static std::optional getIdentifierOfBaseVar(AstExpr* node) +{ + if (AstExprGlobal* expr = node->as()) + return expr->name.value; + + if (AstExprLocal* expr = node->as()) + return expr->local->name.value; + + if (AstExprIndexExpr* expr = node->as()) + return getIdentifierOfBaseVar(expr->expr); + + if (AstExprIndexName* expr = node->as()) + return getIdentifierOfBaseVar(expr->expr); + + return std::nullopt; +} + struct TypeChecker2 { NotNull singletonTypes; @@ -283,7 +301,6 @@ struct TypeChecker2 UnifierSharedState sharedState{&ice}; Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant}; - u.anyIsTop = true; u.tryUnify(actualRetType, expectedRetType); const bool ok = u.errors.empty() && u.log.empty(); @@ -313,16 +330,21 @@ struct TypeChecker2 if (value) visit(value); - if (i != local->values.size - 1) + TypeId* maybeValueType = value ? module->astTypes.find(value) : nullptr; + if (i != local->values.size - 1 || maybeValueType) { AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr; if (var && var->annotation) { - TypeId varType = lookupAnnotation(var->annotation); + TypeId annotationType = lookupAnnotation(var->annotation); TypeId valueType = value ? lookupType(value) : nullptr; - if (valueType && !isSubtype(varType, valueType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) - reportError(TypeMismatch{varType, valueType}, value->location); + if (valueType) + { + ErrorVec errors = tryUnify(stack.back(), value->location, valueType, annotationType); + if (!errors.empty()) + reportErrors(std::move(errors)); + } } } else @@ -588,7 +610,7 @@ struct TypeChecker2 visit(rhs); TypeId rhsType = lookupType(rhs); - if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{lhsType, rhsType}, rhs->location); } @@ -739,7 +761,7 @@ struct TypeChecker2 TypeId actualType = lookupType(number); TypeId numberType = singletonTypes->numberType; - if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{actualType, numberType}, number->location); } @@ -750,7 +772,7 @@ struct TypeChecker2 TypeId actualType = lookupType(string); TypeId stringType = singletonTypes->stringType; - if (!isSubtype(stringType, actualType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(actualType, stringType, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{actualType, stringType}, string->location); } @@ -783,26 +805,55 @@ struct TypeChecker2 TypePackId expectedRetType = lookupPack(call); TypeId functionType = lookupType(call->func); - LUAU_ASSERT(functionType); + TypeId testFunctionType = functionType; + TypePack args; if (get(functionType) || get(functionType)) return; - - // TODO: Lots of other types are callable: intersections of functions - // and things with the __call metamethod. - if (!get(functionType)) + else if (std::optional callMm = findMetatableEntry(singletonTypes, module->errors, functionType, "__call", call->func->location)) + { + if (get(follow(*callMm))) + { + if (std::optional instantiatedCallMm = instantiation.substitute(*callMm)) + { + args.head.push_back(functionType); + testFunctionType = follow(*instantiatedCallMm); + } + else + { + reportError(UnificationTooComplex{}, call->func->location); + return; + } + } + else + { + // TODO: This doesn't flag the __call metamethod as the problem + // very clearly. + reportError(CannotCallNonFunction{*callMm}, call->func->location); + return; + } + } + else if (get(functionType)) + { + if (std::optional instantiatedFunctionType = instantiation.substitute(functionType)) + { + testFunctionType = *instantiatedFunctionType; + } + else + { + reportError(UnificationTooComplex{}, call->func->location); + return; + } + } + else { reportError(CannotCallNonFunction{functionType}, call->func->location); return; } - TypeId instantiatedFunctionType = follow(instantiation.substitute(functionType).value_or(nullptr)); - - TypePack args; for (AstExpr* arg : call->args) { - TypeId argTy = module->astTypes[arg]; - LUAU_ASSERT(argTy); + TypeId argTy = lookupType(arg); args.head.push_back(argTy); } @@ -810,7 +861,7 @@ struct TypeChecker2 FunctionTypeVar ftv{argsTp, expectedRetType}; TypeId expectedType = arena.addType(ftv); - if (!isSubtype(instantiatedFunctionType, expectedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(testFunctionType, expectedType, stack.back(), singletonTypes, ice)) { CloneState cloneState; expectedType = clone(expectedType, module->internalTypes, cloneState); @@ -829,7 +880,7 @@ struct TypeChecker2 getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true); if (ty) { - if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{resultType, *ty}, indexName->location); } @@ -862,7 +913,7 @@ struct TypeChecker2 TypeId inferredArgTy = *argIt; TypeId annotatedArgTy = lookupAnnotation(arg->annotation); - if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); } @@ -887,15 +938,264 @@ struct TypeChecker2 void visit(AstExprUnary* expr) { - // TODO! visit(expr->expr); + + NotNull scope = stack.back(); + TypeId operandType = lookupType(expr->expr); + + if (get(operandType) || get(operandType) || get(operandType)) + return; + + if (auto it = kUnaryOpMetamethods.find(expr->op); it != kUnaryOpMetamethods.end()) + { + std::optional mm = findMetatableEntry(singletonTypes, module->errors, operandType, it->second, expr->location); + if (mm) + { + if (const FunctionTypeVar* ftv = get(follow(*mm))) + { + TypePackId expectedArgs = module->internalTypes.addTypePack({operandType}); + reportErrors(tryUnify(scope, expr->location, ftv->argTypes, expectedArgs)); + + if (std::optional ret = first(ftv->retTypes)) + { + if (expr->op == AstExprUnary::Op::Len) + { + reportErrors(tryUnify(scope, expr->location, follow(*ret), singletonTypes->numberType)); + } + } + else + { + reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location); + } + } + + return; + } + } + + if (expr->op == AstExprUnary::Op::Len) + { + DenseHashSet seen{nullptr}; + int recursionCount = 0; + + if (!hasLength(operandType, seen, &recursionCount)) + { + reportError(NotATable{operandType}, expr->location); + } + } + else if (expr->op == AstExprUnary::Op::Minus) + { + reportErrors(tryUnify(scope, expr->location, operandType, singletonTypes->numberType)); + } + else if (expr->op == AstExprUnary::Op::Not) + { + } + else + { + LUAU_ASSERT(!"Unhandled unary operator"); + } } void visit(AstExprBinary* expr) { - // TODO! visit(expr->left); visit(expr->right); + + NotNull scope = stack.back(); + + bool isEquality = expr->op == AstExprBinary::Op::CompareEq || expr->op == AstExprBinary::Op::CompareNe; + bool isComparison = expr->op >= AstExprBinary::Op::CompareEq && expr->op <= AstExprBinary::Op::CompareGe; + bool isLogical = expr->op == AstExprBinary::Op::And || expr->op == AstExprBinary::Op::Or; + + TypeId leftType = lookupType(expr->left); + TypeId rightType = lookupType(expr->right); + + if (expr->op == AstExprBinary::Op::Or) + { + leftType = stripNil(singletonTypes, module->internalTypes, leftType); + } + + bool isStringOperation = isString(leftType) && isString(rightType); + + if (get(leftType) || get(leftType) || get(rightType) || get(rightType)) + return; + + if ((get(leftType) || get(leftType)) && !isEquality && !isLogical) + { + auto name = getIdentifierOfBaseVar(expr->left); + reportError(CannotInferBinaryOperation{expr->op, name, + isComparison ? CannotInferBinaryOperation::OpKind::Comparison : CannotInferBinaryOperation::OpKind::Operation}, + expr->location); + return; + } + + if (auto it = kBinaryOpMetamethods.find(expr->op); it != kBinaryOpMetamethods.end()) + { + std::optional leftMt = getMetatable(leftType, singletonTypes); + std::optional rightMt = getMetatable(rightType, singletonTypes); + + bool matches = leftMt == rightMt; + if (isEquality && !matches) + { + auto testUnion = [&matches, singletonTypes = this->singletonTypes](const UnionTypeVar* utv, std::optional otherMt) { + for (TypeId option : utv) + { + if (getMetatable(follow(option), singletonTypes) == otherMt) + { + matches = true; + break; + } + } + }; + + if (const UnionTypeVar* utv = get(leftType); utv && rightMt) + { + testUnion(utv, rightMt); + } + + if (const UnionTypeVar* utv = get(rightType); utv && leftMt && !matches) + { + testUnion(utv, leftMt); + } + } + + if (!matches && isComparison) + { + reportError(GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())}, + expr->location); + + return; + } + + std::optional mm; + if (std::optional leftMm = findMetatableEntry(singletonTypes, module->errors, leftType, it->second, expr->left->location)) + mm = leftMm; + else if (std::optional rightMm = findMetatableEntry(singletonTypes, module->errors, rightType, it->second, expr->right->location)) + mm = rightMm; + + if (mm) + { + if (const FunctionTypeVar* ftv = get(*mm)) + { + TypePackId expectedArgs; + // For >= and > we invoke __lt and __le respectively with + // swapped argument ordering. + if (expr->op == AstExprBinary::Op::CompareGe || expr->op == AstExprBinary::Op::CompareGt) + { + expectedArgs = module->internalTypes.addTypePack({rightType, leftType}); + } + else + { + expectedArgs = module->internalTypes.addTypePack({leftType, rightType}); + } + + reportErrors(tryUnify(scope, expr->location, ftv->argTypes, expectedArgs)); + + if (expr->op == AstExprBinary::CompareEq || expr->op == AstExprBinary::CompareNe || expr->op == AstExprBinary::CompareGe || + expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::Op::CompareLe || expr->op == AstExprBinary::Op::CompareLt) + { + TypePackId expectedRets = module->internalTypes.addTypePack({singletonTypes->booleanType}); + if (!isSubtype(ftv->retTypes, expectedRets, scope, singletonTypes, ice)) + { + reportError(GenericError{format("Metamethod '%s' must return type 'boolean'", it->second)}, expr->location); + } + } + else if (!first(ftv->retTypes)) + { + reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location); + } + } + else + { + reportError(CannotCallNonFunction{*mm}, expr->location); + } + + return; + } + // If this is a string comparison, or a concatenation of strings, we + // want to fall through to primitive behavior. + else if (!isEquality && !(isStringOperation && (expr->op == AstExprBinary::Op::Concat || isComparison))) + { + if (leftMt || rightMt) + { + if (isComparison) + { + reportError(GenericError{format( + "Types '%s' and '%s' cannot be compared with %s because neither type's metatable has a '%s' metamethod", + toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str(), it->second)}, + expr->location); + } + else + { + reportError(GenericError{format( + "Operator %s is not applicable for '%s' and '%s' because neither type's metatable has a '%s' metamethod", + toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str(), it->second)}, + expr->location); + } + + return; + } + else if (!leftMt && !rightMt && (get(leftType) || get(rightType))) + { + if (isComparison) + { + reportError(GenericError{format("Types '%s' and '%s' cannot be compared with %s because neither type has a metatable", + toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())}, + expr->location); + } + else + { + reportError(GenericError{format("Operator %s is not applicable for '%s' and '%s' because neither type has a metatable", + toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str())}, + expr->location); + } + + return; + } + } + } + + switch (expr->op) + { + case AstExprBinary::Op::Add: + case AstExprBinary::Op::Sub: + case AstExprBinary::Op::Mul: + case AstExprBinary::Op::Div: + case AstExprBinary::Op::Pow: + case AstExprBinary::Op::Mod: + reportErrors(tryUnify(scope, expr->left->location, leftType, singletonTypes->numberType)); + reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->numberType)); + + break; + case AstExprBinary::Op::Concat: + reportErrors(tryUnify(scope, expr->left->location, leftType, singletonTypes->stringType)); + reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->stringType)); + + break; + case AstExprBinary::Op::CompareGe: + case AstExprBinary::Op::CompareGt: + case AstExprBinary::Op::CompareLe: + case AstExprBinary::Op::CompareLt: + if (isNumber(leftType)) + reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->numberType)); + else if (isString(leftType)) + reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->stringType)); + else + reportError(GenericError{format("Types '%s' and '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), + toString(rightType).c_str(), toString(expr->op).c_str())}, + expr->location); + + break; + case AstExprBinary::Op::And: + case AstExprBinary::Op::Or: + case AstExprBinary::Op::CompareEq: + case AstExprBinary::Op::CompareNe: + break; + default: + // Unhandled AstExprBinary::Op possibility. + LUAU_ASSERT(false); + } } void visit(AstExprTypeAssertion* expr) @@ -907,10 +1207,10 @@ struct TypeChecker2 TypeId computedType = lookupType(expr->expr); // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice)) return; - if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice)) return; reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); @@ -998,9 +1298,8 @@ struct TypeChecker2 Scope* scope = findInnermostScope(ty->location); LUAU_ASSERT(scope); - // TODO: Imported types - - std::optional alias = scope->lookupType(ty->name.value); + std::optional alias = + (ty->prefix) ? scope->lookupImportedType(ty->prefix->value, ty->name.value) : scope->lookupType(ty->name.value); if (alias.has_value()) { @@ -1212,7 +1511,6 @@ struct TypeChecker2 UnifierSharedState sharedState{&ice}; Normalizer normalizer{&module->internalTypes, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant}; - u.anyIsTop = true; u.tryUnify(subTy, superTy); return std::move(u.errors); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index b806edb7..ccb1490a 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -31,12 +31,12 @@ LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) -LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAG(LuauTypeNormalization2) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAGVARIABLE(LuauAnyifyModuleReturnGenerics, false) +LUAU_FASTFLAGVARIABLE(LuauLvaluelessPath, false) LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) LUAU_FASTFLAGVARIABLE(LuauFixVarargExprHeadType, false) @@ -44,15 +44,15 @@ LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false) -LUAU_FASTFLAGVARIABLE(LuauUnionOfTypesFollow, false) LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false) LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false) +LUAU_FASTFLAGVARIABLE(LuauArgMismatchReportFunctionLocation, false) namespace Luau { - -const char* TimeLimitError::what() const throw() +const char* TimeLimitError_DEPRECATED::what() const throw() { + LUAU_ASSERT(!FFlag::LuauIceExceptionInheritanceChange); return "Typeinfer failed to complete in allotted time"; } @@ -265,6 +265,11 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona reportErrorCodeTooComplex(module.root->location); return std::move(currentModule); } + catch (const RecursionLimitException_DEPRECATED&) + { + reportErrorCodeTooComplex(module.root->location); + return std::move(currentModule); + } } ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional environmentScope) @@ -280,11 +285,8 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo iceHandler->moduleName = module.name; normalizer.arena = ¤tModule->internalTypes; - if (FFlag::LuauAutocompleteDynamicLimits) - { - unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; - unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit; - } + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit; ScopePtr parentScope = environmentScope.value_or(globalScope); ScopePtr moduleScope = std::make_shared(parentScope); @@ -312,6 +314,10 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo { currentModule->timeout = true; } + catch (const TimeLimitError_DEPRECATED&) + { + currentModule->timeout = true; + } if (FFlag::DebugLuauSharedSelf) { @@ -419,7 +425,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) ice("Unknown AstStat"); if (finishTime && TimeTrace::getClock() > *finishTime) - throw TimeLimitError(); + throwTimeLimitError(); } // This particular overload is for do...end. If you need to not increase the scope level, use checkBlock directly. @@ -446,6 +452,11 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) reportErrorCodeTooComplex(block.location); return; } + catch (const RecursionLimitException_DEPRECATED&) + { + reportErrorCodeTooComplex(block.location); + return; + } } struct InplaceDemoter : TypeVarOnceVisitor @@ -773,16 +784,6 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) checkExpr(repScope, *statement.condition); } -void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const ScopePtr& scope, const Location& location) -{ - Unifier state = mkUnifier(scope, location); - state.unifyLowerBound(subTy, superTy, demotedLevel); - - state.log.commit(); - - reportErrors(state.errors); -} - struct Demoter : Substitution { Demoter(TypeArena* arena) @@ -2091,39 +2092,6 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( return std::nullopt; } -std::vector TypeChecker::reduceUnion(const std::vector& types) -{ - std::vector result; - for (TypeId t : types) - { - t = follow(t); - if (get(t)) - continue; - - if (get(t) || get(t)) - return {t}; - - if (const UnionTypeVar* utv = get(t)) - { - for (TypeId ty : utv) - { - ty = follow(ty); - if (get(ty)) - continue; - if (get(ty) || get(ty)) - return {ty}; - - if (result.end() == std::find(result.begin(), result.end(), ty)) - result.push_back(ty); - } - } - else if (std::find(result.begin(), result.end(), t) == result.end()) - result.push_back(t); - } - - return result; -} - std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) { if (const UnionTypeVar* utv = get(ty)) @@ -2503,11 +2471,8 @@ std::string opToMetaTableEntry(const AstExprBinary::Op& op) TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes) { - if (FFlag::LuauUnionOfTypesFollow) - { - a = follow(a); - b = follow(b); - } + a = follow(a); + b = follow(b); if (unifyFreeTypes && (get(a) || get(b))) { @@ -3643,8 +3608,17 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam location = {state.location.begin, argLocations.back().end}; std::string namePath; - if (std::optional lValue = tryGetLValue(funName)) - namePath = toString(*lValue); + + if (FFlag::LuauLvaluelessPath) + { + if (std::optional path = getFunctionNameAsString(funName)) + namePath = *path; + } + else + { + if (std::optional lValue = tryGetLValue(funName)) + namePath = toString(*lValue); + } auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack); state.reportError(TypeError{location, @@ -3753,11 +3727,28 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam bool isVariadic = tail && Luau::isVariadic(*tail); std::string namePath; - if (std::optional lValue = tryGetLValue(funName)) - namePath = toString(*lValue); - state.reportError(TypeError{ - state.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); + if (FFlag::LuauLvaluelessPath) + { + if (std::optional path = getFunctionNameAsString(funName)) + namePath = *path; + } + else + { + if (std::optional lValue = tryGetLValue(funName)) + namePath = toString(*lValue); + } + + if (FFlag::LuauArgMismatchReportFunctionLocation) + { + state.reportError(TypeError{ + funName.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); + } + else + { + state.reportError(TypeError{ + state.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); + } return; } ++paramIter; @@ -4597,7 +4588,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat Instantiation instantiation{log, ¤tModule->internalTypes, scope->level, /*scope*/ nullptr}; - if (FFlag::LuauAutocompleteDynamicLimits && instantiationChildLimit) + if (instantiationChildLimit) instantiation.childLimit = *instantiationChildLimit; std::optional instantiated = instantiation.substitute(ty); @@ -4694,6 +4685,19 @@ void TypeChecker::ice(const std::string& message) iceHandler->ice(message); } +// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted. +void TypeChecker::throwTimeLimitError() +{ + if (FFlag::LuauIceExceptionInheritanceChange) + { + throw TimeLimitError(iceHandler->moduleName); + } + else + { + throw TimeLimitError_DEPRECATED(); + } +} + void TypeChecker::prepareErrorsForDisplay(ErrorVec& errVec) { // Remove errors with names that were generated by recovery from a parse error diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 0fa4df60..0852f053 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypePack.h" +#include "Luau/Error.h" #include "Luau/TxnLog.h" #include @@ -234,7 +235,7 @@ TypePackId follow(TypePackId tp, std::function mapper) cycleTester = nullptr; if (tp == cycleTester) - throw std::runtime_error("Luau::follow detected a TypeVar cycle!!"); + throwRuntimeError("Luau::follow detected a TypeVar cycle!!"); } } } diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 688c8767..72597c4a 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -6,6 +6,8 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" +#include + namespace Luau { @@ -146,18 +148,15 @@ std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro return std::nullopt; } + goodOptions = reduceUnion(goodOptions); + if (goodOptions.empty()) return singletonTypes->neverType; if (goodOptions.size() == 1) return goodOptions[0]; - // TODO: inefficient. - TypeId result = arena->addType(UnionTypeVar{std::move(goodOptions)}); - auto [ty, ok] = normalize(result, NotNull{scope.get()}, *arena, singletonTypes, handle); - if (!ok && addErrors) - errors.push_back(TypeError{location, NormalizationTooComplex{}}); - return ok ? ty : singletonTypes->anyType; + return arena->addType(UnionTypeVar{std::move(goodOptions)}); } else if (const IntersectionTypeVar* itv = get(type)) { @@ -264,4 +263,79 @@ std::vector flatten(TypeArena& arena, NotNull singletonT return result; } +std::vector reduceUnion(const std::vector& types) +{ + std::vector result; + for (TypeId t : types) + { + t = follow(t); + if (get(t)) + continue; + + if (get(t) || get(t)) + return {t}; + + if (const UnionTypeVar* utv = get(t)) + { + for (TypeId ty : utv) + { + ty = follow(ty); + if (get(ty)) + continue; + if (get(ty) || get(ty)) + return {ty}; + + if (result.end() == std::find(result.begin(), result.end(), ty)) + result.push_back(ty); + } + } + else if (std::find(result.begin(), result.end(), t) == result.end()) + result.push_back(t); + } + + return result; +} + +static std::optional tryStripUnionFromNil(TypeArena& arena, TypeId ty) +{ + if (const UnionTypeVar* utv = get(ty)) + { + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; + + std::vector result; + + for (TypeId option : utv) + { + if (!isNil(option)) + result.push_back(option); + } + + if (result.empty()) + return std::nullopt; + + return result.size() == 1 ? result[0] : arena.addType(UnionTypeVar{std::move(result)}); + } + + return std::nullopt; +} + +TypeId stripNil(NotNull singletonTypes, TypeArena& arena, TypeId ty) +{ + ty = follow(ty); + + if (get(ty)) + { + std::optional cleaned = tryStripUnionFromNil(arena, ty); + + // If there is no union option without 'nil' + if (!cleaned) + return singletonTypes->nilType; + + return follow(*cleaned); + } + + return follow(ty); +} + } // namespace Luau diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index bcdaff7d..de0890e1 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -66,7 +66,7 @@ TypeId follow(TypeId t, std::function mapper) { TypeId res = ltv->thunk(); if (get(res)) - throw std::runtime_error("Lazy TypeVar cannot resolve to another Lazy TypeVar"); + throwRuntimeError("Lazy TypeVar cannot resolve to another Lazy TypeVar"); *asMutable(ty) = BoundTypeVar(res); } @@ -104,7 +104,7 @@ TypeId follow(TypeId t, std::function mapper) cycleTester = nullptr; if (t == cycleTester) - throw std::runtime_error("Luau::follow detected a TypeVar cycle!!"); + throwRuntimeError("Luau::follow detected a TypeVar cycle!!"); } } } @@ -754,12 +754,15 @@ SingletonTypes::SingletonTypes() , stringType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true})) , booleanType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true})) , threadType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true})) + , functionType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Function}, /*persistent*/ true})) , trueType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true})) , falseType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true})) , anyType(arena->addType(TypeVar{AnyTypeVar{}, /*persistent*/ true})) , unknownType(arena->addType(TypeVar{UnknownTypeVar{}, /*persistent*/ true})) , neverType(arena->addType(TypeVar{NeverTypeVar{}, /*persistent*/ true})) , errorType(arena->addType(TypeVar{ErrorTypeVar{}, /*persistent*/ true})) + , falsyType(arena->addType(TypeVar{UnionTypeVar{{falseType, nilType}}, /*persistent*/ true})) + , truthyType(arena->addType(TypeVar{NegationTypeVar{falsyType}, /*persistent*/ true})) , anyTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, /*persistent*/ true})) , neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true})) , uninhabitableTypePack(arena->addTypePack({neverType}, neverTypePack)) @@ -896,7 +899,6 @@ void persist(TypeId ty) continue; asMutable(t)->persistent = true; - asMutable(t)->normal = true; // all persistent types are assumed to be normal if (auto btv = get(t)) queue.push_back(btv->boundTo); @@ -933,17 +935,13 @@ void persist(TypeId ty) for (TypeId opt : itv->parts) queue.push_back(opt); } - else if (auto ctv = get(t)) - { - for (TypeId opt : ctv->parts) - queue.push_back(opt); - } else if (auto mtv = get(t)) { queue.push_back(mtv->table); queue.push_back(mtv->metatable); } - else if (get(t) || get(t) || get(t) || get(t) || get(t)) + else if (get(t) || get(t) || get(t) || get(t) || get(t) || + get(t)) { } else @@ -990,8 +988,6 @@ const TypeLevel* getLevel(TypeId ty) return &ttv->level; else if (auto ftv = get(ty)) return &ftv->level; - else if (auto ctv = get(ty)) - return &ctv->level; else return nullptr; } @@ -1056,11 +1052,6 @@ const std::vector& getTypes(const IntersectionTypeVar* itv) return itv->parts; } -const std::vector& getTypes(const ConstrainedTypeVar* ctv) -{ - return ctv->parts; -} - UnionTypeVarIterator begin(const UnionTypeVar* utv) { return UnionTypeVarIterator{utv}; @@ -1081,17 +1072,6 @@ IntersectionTypeVarIterator end(const IntersectionTypeVar* itv) return IntersectionTypeVarIterator{}; } -ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv) -{ - return ConstrainedTypeVarIterator{ctv}; -} - -ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv) -{ - return ConstrainedTypeVarIterator{}; -} - - static std::vector parseFormatString(TypeChecker& typechecker, const char* data, size_t size) { const char* options = "cdiouxXeEfgGqs*"; diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 42fcd2fd..df5d86f1 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -8,23 +8,23 @@ #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/TimeTrace.h" +#include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" #include "Luau/ToString.h" #include -LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); -LUAU_FASTINT(LuauTypeInferIterationLimit); -LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) -LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAGVARIABLE(LuauReportTypeMismatchForTypePackUnificationFailure, false) LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false); LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) +LUAU_FASTFLAGVARIABLE(LuauOverloadedFunctionSubtypingPerf, false); LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauNegatedFunctionTypes) namespace Luau { @@ -95,15 +95,6 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor return true; } - bool visit(TypeId ty, const ConstrainedTypeVar&) override - { - if (!FFlag::LuauUnknownAndNeverType) - return visit(ty); - - promote(ty, log.getMutable(ty)); - return true; - } - bool visit(TypeId ty, const FunctionTypeVar&) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside @@ -285,7 +276,7 @@ TypeId Widen::clean(TypeId ty) TypePackId Widen::clean(TypePackId) { - throw std::runtime_error("Widen attempted to clean a dirty type pack?"); + throwRuntimeError("Widen attempted to clean a dirty type pack?"); } bool Widen::ignoreChildren(TypeId ty) @@ -368,26 +359,14 @@ void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool i void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) { - RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); ++sharedState.counters.iterationCount; - if (FFlag::LuauAutocompleteDynamicLimits) + if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) { - if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) - { - reportError(TypeError{location, UnificationTooComplex{}}); - return; - } - } - else - { - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) - { - reportError(TypeError{location, UnificationTooComplex{}}); - return; - } + reportError(location, UnificationTooComplex{}); + return; } superTy = log.follow(superTy); @@ -396,9 +375,6 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superTy == subTy) return; - if (log.get(superTy)) - return tryUnifyWithConstrainedSuperTypeVar(subTy, superTy); - auto superFree = log.getMutable(superTy); auto subFree = log.getMutable(subTy); @@ -430,7 +406,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (subGeneric && !subsumes(useScopes, subGeneric, superFree)) { // TODO: a more informative error message? CLI-39912 - reportError(TypeError{location, GenericError{"Generic subtype escaping scope"}}); + reportError(location, GenericError{"Generic subtype escaping scope"}); return; } @@ -459,7 +435,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superGeneric && !subsumes(useScopes, superGeneric, subFree)) { // TODO: a more informative error message? CLI-39912 - reportError(TypeError{location, GenericError{"Generic supertype escaping scope"}}); + reportError(location, GenericError{"Generic supertype escaping scope"}); return; } @@ -476,15 +452,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool return tryUnifyWithAny(subTy, superTy); if (get(subTy)) - { - if (anyIsTop) - { - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - return; - } - else - return tryUnifyWithAny(superTy, subTy); - } + return tryUnifyWithAny(superTy, subTy); if (log.get(subTy)) return tryUnifyWithAny(superTy, subTy); @@ -504,7 +472,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (auto error = sharedState.cachedUnifyError.find({subTy, superTy})) { - reportError(TypeError{location, *error}); + reportError(location, *error); return; } } @@ -520,9 +488,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool size_t errorCount = errors.size(); - if (log.get(subTy)) - tryUnifyWithConstrainedSubTypeVar(subTy, superTy); - else if (const UnionTypeVar* subUnion = log.getMutable(subTy)) + if (const UnionTypeVar* subUnion = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, subUnion, superTy); } @@ -548,6 +514,12 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if ((log.getMutable(superTy) || log.getMutable(superTy)) && log.getMutable(subTy)) tryUnifySingletons(subTy, superTy); + else if (auto ptv = get(superTy); + FFlag::LuauNegatedFunctionTypes && ptv && ptv->type == PrimitiveTypeVar::Function && get(subTy)) + { + // Ok. Do nothing. forall functions F, F <: function + } + else if (log.getMutable(superTy) && log.getMutable(subTy)) tryUnifyFunctions(subTy, superTy, isFunctionCall); @@ -580,8 +552,14 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.getMutable(subTy)) tryUnifyWithClass(subTy, superTy, /*reversed*/ true); + else if (log.get(superTy)) + tryUnifyTypeWithNegation(subTy, superTy); + + else if (log.get(subTy)) + tryUnifyNegationWithType(subTy, superTy); + else - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); if (cacheEnabled) cacheResult(subTy, superTy, errorCount); @@ -655,9 +633,9 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* subUnion, else if (failed) { if (firstFailedOption) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); + reportError(location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}); else - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); } } @@ -756,7 +734,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp const NormalizedType* subNorm = normalizer->normalize(subTy); const NormalizedType* superNorm = normalizer->normalize(superTy); if (!subNorm || !superNorm) - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); else @@ -765,9 +743,9 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp else if (!found) { if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); + reportError(location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}); else - reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); + reportError(location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}); } } @@ -796,7 +774,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I if (unificationTooComplex) reportError(*unificationTooComplex); else if (firstFailedOption) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); + reportError(location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}); } void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall) @@ -854,11 +832,11 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV if (subNorm && superNorm) tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); else - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); } else if (!found) { - reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); + reportError(location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}); } } @@ -870,43 +848,37 @@ void Unifier::tryUnifyNormalizedTypes( if (get(superNorm.tops) || get(superNorm.tops) || get(subNorm.tops)) return; else if (get(subNorm.tops)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (get(subNorm.errors)) if (!get(superNorm.errors)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (get(subNorm.booleans)) { if (!get(superNorm.booleans)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } else if (const SingletonTypeVar* stv = get(subNorm.booleans)) { if (!get(superNorm.booleans) && stv != get(superNorm.booleans)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } if (get(subNorm.nils)) if (!get(superNorm.nils)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (get(subNorm.numbers)) if (!get(superNorm.numbers)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); - if (subNorm.strings && superNorm.strings) - { - for (auto [name, ty] : *subNorm.strings) - if (!superNorm.strings->count(name)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); - } - else if (!subNorm.strings && superNorm.strings) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + if (!isSubtype(subNorm.strings, superNorm.strings)) + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (get(subNorm.threads)) if (!get(superNorm.errors)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); for (TypeId subClass : subNorm.classes) { @@ -922,7 +894,7 @@ void Unifier::tryUnifyNormalizedTypes( } } if (!found) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } for (TypeId subTable : subNorm.tables) @@ -947,21 +919,19 @@ void Unifier::tryUnifyNormalizedTypes( return reportError(*e); } if (!found) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } - if (subNorm.functions) + if (!subNorm.functions.isNever()) { - if (!superNorm.functions) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); - if (superNorm.functions->empty()) - return; - for (TypeId superFun : *superNorm.functions) + if (superNorm.functions.isNever()) + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); + for (TypeId superFun : *superNorm.functions.parts) { Unifier innerState = makeChildUnifier(); const FunctionTypeVar* superFtv = get(superFun); if (!superFtv) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes); innerState.tryUnify_(tgt, superFtv->retTypes); if (innerState.errors.empty()) @@ -969,7 +939,7 @@ void Unifier::tryUnifyNormalizedTypes( else if (auto e = hasUnificationTooComplex(innerState.errors)) return reportError(*e); else - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } } @@ -987,15 +957,15 @@ void Unifier::tryUnifyNormalizedTypes( TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args) { - if (!overloads || overloads->empty()) + if (overloads.isNever()) { - reportError(TypeError{location, CannotCallNonFunction{function}}); + reportError(location, CannotCallNonFunction{function}); return singletonTypes->errorRecoveryTypePack(); } std::optional result; const FunctionTypeVar* firstFun = nullptr; - for (TypeId overload : *overloads) + for (TypeId overload : *overloads.parts) { if (const FunctionTypeVar* ftv = get(overload)) { @@ -1011,10 +981,17 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized log.concat(std::move(innerState.log)); if (result) { + if (FFlag::LuauOverloadedFunctionSubtypingPerf) + { + innerState.log.clear(); + innerState.tryUnify_(*result, ftv->retTypes); + } + if (FFlag::LuauOverloadedFunctionSubtypingPerf && innerState.errors.empty()) + log.concat(std::move(innerState.log)); // Annoyingly, since we don't support intersection of generic type packs, // the intersection may fail. We rather arbitrarily use the first matching overload // in that case. - if (std::optional intersect = normalizer->intersectionOfTypePacks(*result, ftv->retTypes)) + else if (std::optional intersect = normalizer->intersectionOfTypePacks(*result, ftv->retTypes)) result = intersect; } else @@ -1036,12 +1013,12 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized // TODO: better error reporting? // The logic for error reporting overload resolution // is currently over in TypeInfer.cpp, should we move it? - reportError(TypeError{location, GenericError{"No matching overload."}}); + reportError(location, GenericError{"No matching overload."}); return singletonTypes->errorRecoveryTypePack(firstFun->retTypes); } else { - reportError(TypeError{location, CannotCallNonFunction{function}}); + reportError(location, CannotCallNonFunction{function}); return singletonTypes->errorRecoveryTypePack(); } } @@ -1214,26 +1191,14 @@ void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall */ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall) { - RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); ++sharedState.counters.iterationCount; - if (FFlag::LuauAutocompleteDynamicLimits) + if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) { - if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) - { - reportError(TypeError{location, UnificationTooComplex{}}); - return; - } - } - else - { - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) - { - reportError(TypeError{location, UnificationTooComplex{}}); - return; - } + reportError(location, UnificationTooComplex{}); + return; } superTp = log.follow(superTp); @@ -1405,7 +1370,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal size_t actualSize = size(subTp); if (ctx == CountMismatch::FunctionResult || ctx == CountMismatch::ExprListResult) std::swap(expectedSize, actualSize); - reportError(TypeError{location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx}}); + reportError(location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx}); while (superIter.good()) { @@ -1426,7 +1391,10 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal } else { - reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); + if (FFlag::LuauReportTypeMismatchForTypePackUnificationFailure) + reportError(location, TypePackMismatch{subTp, superTp}); + else + reportError(location, GenericError{"Failed to unify type packs"}); } } @@ -1438,7 +1406,7 @@ void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy) ice("passed non primitive types to unifyPrimitives"); if (superPrim->type != subPrim->type) - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); } void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) @@ -1459,7 +1427,7 @@ void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) if (superPrim && superPrim->type == PrimitiveTypeVar::String && get(subSingleton) && variance == Covariant) return; - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); } void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall) @@ -1475,7 +1443,10 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal bool shouldInstantiate = (numGenerics == 0 && subFunction->generics.size() > 0) || (numGenericPacks == 0 && subFunction->genericPacks.size() > 0); - if (FFlag::LuauInstantiateInSubtyping && variance == Covariant && shouldInstantiate) + // TODO: This is unsound when the context is invariant, but the annotation burden without allowing it and without + // read-only properties is too high for lua-apps. Read-only properties _should_ resolve their issue by allowing + // generic methods in tables to be marked read-only. + if (FFlag::LuauInstantiateInSubtyping && shouldInstantiate) { Instantiation instantiation{&log, types, scope->level, scope}; @@ -1492,21 +1463,21 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal } else { - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); } } else if (numGenerics != subFunction->generics.size()) { numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); - reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); + reportError(location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}); } if (numGenericPacks != subFunction->genericPacks.size()) { numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); - reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); + reportError(location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}); } for (size_t i = 0; i < numGenerics; i++) @@ -1533,11 +1504,10 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) - reportError( - TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}); else if (!innerState.errors.empty()) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}); innerState.ctx = CountMismatch::FunctionResult; innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes); @@ -1547,13 +1517,12 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) - reportError( - TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}); else if (!innerState.errors.empty()) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}); } log.concat(std::move(innerState.log)); @@ -1610,6 +1579,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { TableTypeVar* superTable = log.getMutable(superTy); TableTypeVar* subTable = log.getMutable(subTy); + TableTypeVar* instantiatedSubTable = subTable; if (!superTable || !subTable) ice("passed non-table types to unifyTables"); @@ -1627,13 +1597,14 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (instantiated.has_value()) { subTable = log.getMutable(*instantiated); + instantiatedSubTable = subTable; if (!subTable) ice("instantiation made a table type into a non-table type in tryUnifyTables"); } else { - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); } } } @@ -1651,7 +1622,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!missingProperties.empty()) { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); + reportError(location, MissingProperties{superTy, subTy, std::move(missingProperties)}); return; } } @@ -1669,7 +1640,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!extraProperties.empty()) { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); + reportError(location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}); return; } } @@ -1730,7 +1701,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // txn log. TableTypeVar* newSuperTable = log.getMutable(superTy); TableTypeVar* newSubTable = log.getMutable(subTy); - if (superTable != newSuperTable || subTable != newSubTable) + if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable)) { if (errors.empty()) return tryUnifyTables(subTy, superTy, isIntersection); @@ -1792,7 +1763,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // txn log. TableTypeVar* newSuperTable = log.getMutable(superTy); TableTypeVar* newSubTable = log.getMutable(subTy); - if (superTable != newSuperTable || subTable != newSubTable) + if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable)) { if (errors.empty()) return tryUnifyTables(subTy, superTy, isIntersection); @@ -1850,13 +1821,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!missingProperties.empty()) { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); + reportError(location, MissingProperties{superTy, subTy, std::move(missingProperties)}); return; } if (!extraProperties.empty()) { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); + reportError(location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}); return; } @@ -1892,14 +1863,14 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) std::swap(subTy, superTy); if (auto ttv = log.get(superTy); !ttv || ttv->state != TableState::Free) - return reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}}); + return reportError(location, TypeMismatch{osuperTy, osubTy}); auto fail = [&](std::optional e) { std::string reason = "The former's metatable does not satisfy the requirements."; if (e) - reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason, *e}}); + reportError(location, TypeMismatch{osuperTy, osubTy, reason, *e}); else - reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason}}); + reportError(location, TypeMismatch{osuperTy, osubTy, reason}); }; // Given t1 where t1 = { lower: (t1) -> (a, b...) } @@ -1931,7 +1902,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) } } - reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}}); + reportError(location, TypeMismatch{osuperTy, osubTy}); return; } @@ -1972,7 +1943,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty()) - reportError(TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); + reportError(location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}); log.concat(std::move(innerState.log)); } @@ -2049,9 +2020,9 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) auto fail = [&]() { if (!reversed) - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); else - reportError(TypeError{location, TypeMismatch{subTy, superTy}}); + reportError(location, TypeMismatch{subTy, superTy}); }; const ClassTypeVar* superClass = get(superTy); @@ -2096,7 +2067,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) if (!classProp) { ok = false; - reportError(TypeError{location, UnknownProperty{superTy, propName}}); + reportError(location, UnknownProperty{superTy, propName}); } else { @@ -2120,7 +2091,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) { ok = false; std::string msg = "Class " + superClass->name + " does not have an indexer"; - reportError(TypeError{location, GenericError{msg}}); + reportError(location, GenericError{msg}); } if (!ok) @@ -2132,6 +2103,34 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) return fail(); } +void Unifier::tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy) +{ + const NegationTypeVar* ntv = get(superTy); + if (!ntv) + ice("tryUnifyTypeWithNegation superTy must be a negation type"); + + const NormalizedType* subNorm = normalizer->normalize(subTy); + const NormalizedType* superNorm = normalizer->normalize(superTy); + if (!subNorm || !superNorm) + return reportError(location, UnificationTooComplex{}); + + // T (subTy); + if (!ntv) + ice("tryUnifyNegationWithType subTy must be a negation type"); + + // TODO: ~T & queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) { while (true) @@ -2192,7 +2191,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } else if (get(tail)) { - reportError(TypeError{location, GenericError{"Cannot unify variadic and generic packs"}}); + reportError(location, GenericError{"Cannot unify variadic and generic packs"}); } else if (get(tail)) { @@ -2206,7 +2205,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } else { - reportError(TypeError{location, GenericError{"Failed to unify variadic packs"}}); + reportError(location, GenericError{"Failed to unify variadic packs"}); } } @@ -2314,186 +2313,6 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N return Luau::findTablePropertyRespectingMeta(singletonTypes, errors, lhsType, name, location); } -void Unifier::tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy) -{ - const ConstrainedTypeVar* subConstrained = get(subTy); - if (!subConstrained) - ice("tryUnifyWithConstrainedSubTypeVar received non-ConstrainedTypeVar subTy!"); - - const std::vector& subTyParts = subConstrained->parts; - - // A | B <: T if A <: T and B <: T - bool failed = false; - std::optional unificationTooComplex; - - const size_t count = subTyParts.size(); - - for (size_t i = 0; i < count; ++i) - { - TypeId type = subTyParts[i]; - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, superTy); - - if (i == count - 1) - log.concat(std::move(innerState.log)); - - ++i; - - if (auto e = hasUnificationTooComplex(innerState.errors)) - unificationTooComplex = e; - - if (!innerState.errors.empty()) - { - failed = true; - break; - } - } - - if (unificationTooComplex) - reportError(*unificationTooComplex); - else if (failed) - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - else - log.replace(subTy, BoundTypeVar{superTy}); -} - -void Unifier::tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy) -{ - ConstrainedTypeVar* superC = log.getMutable(superTy); - if (!superC) - ice("tryUnifyWithConstrainedSuperTypeVar received non-ConstrainedTypeVar superTy!"); - - // subTy could be a - // table - // metatable - // class - // function - // primitive - // free - // generic - // intersection - // union - // Do we really just tack it on? I think we might! - // We can certainly do some deduplication. - // Is there any point to deducing Player|Instance when we could just reduce to Instance? - // Is it actually ok to have multiple free types in a single intersection? What if they are later unified into the same type? - // Maybe we do a simplification step during quantification. - - auto it = std::find(superC->parts.begin(), superC->parts.end(), subTy); - if (it != superC->parts.end()) - return; - - superC->parts.push_back(subTy); -} - -void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel) -{ - // The duplication between this and regular typepack unification is tragic. - - auto superIter = begin(superTy, &log); - auto superEndIter = end(superTy); - - auto subIter = begin(subTy, &log); - auto subEndIter = end(subTy); - - int count = FInt::LuauTypeInferLowerBoundsIterationLimit; - - for (; subIter != subEndIter; ++subIter) - { - if (0 >= --count) - ice("Internal recursion counter limit exceeded in Unifier::unifyLowerBound"); - - if (superIter != superEndIter) - { - tryUnify_(*subIter, *superIter); - ++superIter; - continue; - } - - if (auto t = superIter.tail()) - { - TypePackId tailPack = follow(*t); - - if (log.get(tailPack) && occursCheck(tailPack, subTy)) - return; - - FreeTypePack* freeTailPack = log.getMutable(tailPack); - if (!freeTailPack) - return; - - TypePack* tp = getMutable(log.replace(tailPack, TypePack{})); - - for (; subIter != subEndIter; ++subIter) - { - tp->head.push_back(types->addType(ConstrainedTypeVar{demotedLevel, {follow(*subIter)}})); - } - - tp->tail = subIter.tail(); - } - - return; - } - - if (superIter != superEndIter) - { - if (auto subTail = subIter.tail()) - { - TypePackId subTailPack = follow(*subTail); - if (get(subTailPack)) - { - TypePack* tp = getMutable(log.replace(subTailPack, TypePack{})); - - for (; superIter != superEndIter; ++superIter) - tp->head.push_back(*superIter); - } - else if (const VariadicTypePack* subVariadic = log.getMutable(subTailPack)) - { - while (superIter != superEndIter) - { - tryUnify_(subVariadic->ty, *superIter); - ++superIter; - } - } - } - else - { - while (superIter != superEndIter) - { - if (!isOptional(*superIter)) - { - errors.push_back(TypeError{location, CountMismatch{size(superTy), std::nullopt, size(subTy), CountMismatch::Return}}); - return; - } - ++superIter; - } - } - - return; - } - - // Both iters are at their respective tails - auto subTail = subIter.tail(); - auto superTail = superIter.tail(); - if (subTail && superTail) - tryUnify(*subTail, *superTail); - else if (subTail) - { - const FreeTypePack* freeSubTail = log.getMutable(*subTail); - if (freeSubTail) - { - log.replace(*subTail, TypePack{}); - } - } - else if (superTail) - { - const FreeTypePack* freeSuperTail = log.getMutable(*superTail); - if (freeSuperTail) - { - log.replace(*superTail, TypePack{}); - } - } -} - bool Unifier::occursCheck(TypeId needle, TypeId haystack) { sharedState.tempSeenTy.clear(); @@ -2503,8 +2322,7 @@ bool Unifier::occursCheck(TypeId needle, TypeId haystack) bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack) { - RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); bool occurrence = false; @@ -2529,7 +2347,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (needle == haystack) { - reportError(TypeError{location, OccursCheckFailed{}}); + reportError(location, OccursCheckFailed{}); log.replace(needle, *singletonTypes->errorRecoveryType()); return true; @@ -2547,11 +2365,6 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays for (TypeId ty : a->parts) check(ty); } - else if (auto a = log.getMutable(haystack)) - { - for (TypeId ty : a->parts) - check(ty); - } return occurrence; } @@ -2579,14 +2392,13 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ if (!log.getMutable(needle)) ice("Expected needle pack to be free"); - RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); while (!log.getMutable(haystack)) { if (needle == haystack) { - reportError(TypeError{location, OccursCheckFailed{}}); + reportError(location, OccursCheckFailed{}); log.replace(needle, *singletonTypes->errorRecoveryTypePack()); return true; @@ -2607,18 +2419,31 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ Unifier Unifier::makeChildUnifier() { Unifier u = Unifier{normalizer, mode, scope, location, variance, &log}; - u.anyIsTop = anyIsTop; u.normalize = normalize; + u.useScopes = useScopes; return u; } // A utility function that appends the given error to the unifier's error log. // This allows setting a breakpoint wherever the unifier reports an error. +// +// Note: report error accepts its arguments by value intentionally to reduce the stack usage of functions which call `reportError`. +void Unifier::reportError(Location location, TypeErrorData data) +{ + errors.emplace_back(std::move(location), std::move(data)); +} + +// A utility function that appends the given error to the unifier's error log. +// This allows setting a breakpoint wherever the unifier reports an error. +// +// Note: to conserve stack space in calling functions it is generally preferred to call `Unifier::reportError(Location location, TypeErrorData data)` +// instead of this method. void Unifier::reportError(TypeError err) { errors.push_back(std::move(err)); } + bool Unifier::isNonstrictMode() const { return (mode == Mode::Nonstrict) || (mode == Mode::NoCheck); @@ -2629,7 +2454,7 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId if (auto e = hasUnificationTooComplex(innerErrors)) reportError(*e); else if (!innerErrors.empty()) - reportError(TypeError{location, TypeMismatch{wantedType, givenType}}); + reportError(location, TypeMismatch{wantedType, givenType}); } void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType) diff --git a/Ast/include/Luau/ParseResult.h b/Ast/include/Luau/ParseResult.h index 17ce2e3b..9c0a9527 100644 --- a/Ast/include/Luau/ParseResult.h +++ b/Ast/include/Luau/ParseResult.h @@ -58,6 +58,8 @@ struct Comment struct ParseResult { AstStatBlock* root; + size_t lines = 0; + std::vector hotcomments; std::vector errors; diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 848d7117..8b7eb73c 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -302,8 +302,8 @@ private: AstStatError* reportStatError(const Location& location, const AstArray& expressions, const AstArray& statements, const char* format, ...) LUAU_PRINTF_ATTR(5, 6); AstExprError* reportExprError(const Location& location, const AstArray& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); - AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray& types, bool isMissing, const char* format, ...) - LUAU_PRINTF_ATTR(5, 6); + AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray& types, const char* format, ...) + LUAU_PRINTF_ATTR(4, 5); // `parseErrorLocation` is associated with the parser error // `astErrorLocation` is associated with the AstTypeError created // It can be useful to have different error locations so that the parse error can include the next lexeme, while the AstTypeError can precisely diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index d93f2ccb..66436acd 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -641,8 +641,8 @@ Lexeme Lexer::readInterpolatedStringSection(Position start, Lexeme::Type formatT return brokenDoubleBrace; } - Lexeme lexemeOutput(Location(start, position()), Lexeme::InterpStringBegin, &buffer[startOffset], offset - startOffset); consume(); + Lexeme lexemeOutput(Location(start, position()), Lexeme::InterpStringBegin, &buffer[startOffset], offset - startOffset - 1); return lexemeOutput; } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 7150b18f..4c0cc125 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -23,9 +23,9 @@ LUAU_FASTFLAGVARIABLE(LuauErrorDoubleHexPrefix, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) -LUAU_FASTFLAGVARIABLE(LuauTypeAnnotationLocationChange, false) LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false) +LUAU_FASTFLAGVARIABLE(LuauTableConstructorRecovery, false) bool lua_telemetry_parsed_out_of_range_bin_integer = false; bool lua_telemetry_parsed_out_of_range_hex_integer = false; @@ -164,15 +164,16 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n try { AstStatBlock* root = p.parseChunk(); + size_t lines = p.lexer.current().location.end.line + (bufferSize > 0 && buffer[bufferSize - 1] != '\n'); - return ParseResult{root, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)}; + return ParseResult{root, lines, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)}; } catch (ParseError& err) { // when catching a fatal error, append it to the list of non-fatal errors and return p.parseErrors.push_back(err); - return ParseResult{nullptr, {}, p.parseErrors}; + return ParseResult{nullptr, 0, {}, p.parseErrors}; } } @@ -811,9 +812,8 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr) { - return AstDeclaredClassProp{fnName.name, - reportTypeAnnotationError(Location(start, end), {}, /*isMissing*/ false, "'self' must be present as the unannotated first parameter"), - true}; + return AstDeclaredClassProp{ + fnName.name, reportTypeAnnotationError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true}; } // Skip the first index. @@ -824,8 +824,7 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() if (args[i].annotation) vars.push_back(args[i].annotation); else - vars.push_back(reportTypeAnnotationError( - Location(start, end), {}, /*isMissing*/ false, "All declaration parameters aside from 'self' must be annotated")); + vars.push_back(reportTypeAnnotationError(Location(start, end), {}, "All declaration parameters aside from 'self' must be annotated")); } if (vararg && !varargAnnotation) @@ -1537,7 +1536,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location if (isUnion && isIntersection) { - return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts), /*isMissing*/ false, + return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts), "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); } @@ -1623,18 +1622,18 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) return {allocator.alloc(start, svalue)}; } else - return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")}; + return {reportTypeAnnotationError(start, {}, "String literal contains malformed escape sequence")}; } else if (lexer.current().type == Lexeme::InterpStringBegin || lexer.current().type == Lexeme::InterpStringSimple) { parseInterpString(); - return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "Interpolated string literals cannot be used as types")}; + return {reportTypeAnnotationError(start, {}, "Interpolated string literals cannot be used as types")}; } else if (lexer.current().type == Lexeme::BrokenString) { nextLexeme(); - return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "Malformed string")}; + return {reportTypeAnnotationError(start, {}, "Malformed string")}; } else if (lexer.current().type == Lexeme::Name) { @@ -1693,33 +1692,20 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { nextLexeme(); - return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, + return {reportTypeAnnotationError(start, {}, "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " "...any'"), {}}; } else { - if (FFlag::LuauTypeAnnotationLocationChange) - { - // For a missing type annotation, capture 'space' between last token and the next one - Location astErrorlocation(lexer.previousLocation().end, start.begin); - // The parse error includes the next lexeme to make it easier to display where the error is (e.g. in an IDE or a CLI error message). - // Including the current lexeme also makes the parse error consistent with other parse errors returned by Luau. - Location parseErrorLocation(lexer.previousLocation().end, start.end); - return { - reportMissingTypeAnnotationError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), - {}}; - } - else - { - Location location = lexer.current().location; - - // For a missing type annotation, capture 'space' between last token and the next one - location = Location(lexer.previousLocation().end, lexer.current().location.begin); - - return {reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()), {}}; - } + // For a missing type annotation, capture 'space' between last token and the next one + Location astErrorlocation(lexer.previousLocation().end, start.begin); + // The parse error includes the next lexeme to make it easier to display where the error is (e.g. in an IDE or a CLI error message). + // Including the current lexeme also makes the parse error consistent with other parse errors returned by Luau. + Location parseErrorLocation(lexer.previousLocation().end, start.end); + return { + reportMissingTypeAnnotationError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), {}}; } } @@ -2325,9 +2311,13 @@ AstExpr* Parser::parseTableConstructor() MatchLexeme matchBrace = lexer.current(); expectAndConsume('{', "table literal"); + unsigned lastElementIndent = 0; while (lexer.current().type != '}') { + if (FFlag::LuauTableConstructorRecovery) + lastElementIndent = lexer.current().location.begin.column; + if (lexer.current().type == '[') { MatchLexeme matchLocationBracket = lexer.current(); @@ -2372,10 +2362,14 @@ AstExpr* Parser::parseTableConstructor() { nextLexeme(); } - else + else if (FFlag::LuauTableConstructorRecovery && (lexer.current().type == '[' || lexer.current().type == Lexeme::Name) && + lexer.current().location.begin.column == lastElementIndent) { - if (lexer.current().type != '}') - break; + report(lexer.current().location, "Expected ',' after table constructor element"); + } + else if (lexer.current().type != '}') + { + break; } } @@ -3033,27 +3027,18 @@ AstExprError* Parser::reportExprError(const Location& location, const AstArray(location, expressions, unsigned(parseErrors.size() - 1)); } -AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray& types, bool isMissing, const char* format, ...) +AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray& types, const char* format, ...) { - if (FFlag::LuauTypeAnnotationLocationChange) - { - // Missing type annotations should be using `reportMissingTypeAnnotationError` when LuauTypeAnnotationLocationChange is enabled - // Note: `isMissing` can be removed once FFlag::LuauTypeAnnotationLocationChange is removed since it will always be true. - LUAU_ASSERT(!isMissing); - } - va_list args; va_start(args, format); report(location, format, args); va_end(args); - return allocator.alloc(location, types, isMissing, unsigned(parseErrors.size() - 1)); + return allocator.alloc(location, types, false, unsigned(parseErrors.size() - 1)); } AstTypeError* Parser::reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) { - LUAU_ASSERT(FFlag::LuauTypeAnnotationLocationChange); - va_list args; va_start(args, format); report(parseErrorLocation, format, args); diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 7e4c5691..6257e2f3 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -14,7 +14,6 @@ #endif LUAU_FASTFLAG(DebugLuauTimeTracing) -LUAU_FASTFLAG(LuauTypeMismatchModuleNameResolution) enum class ReportFormat { @@ -55,11 +54,9 @@ static void reportError(const Luau::Frontend& frontend, ReportFormat format, con if (const Luau::SyntaxError* syntaxError = Luau::get_if(&error.data)) report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str()); - else if (FFlag::LuauTypeMismatchModuleNameResolution) + else report(format, humanReadableName.c_str(), error.location, "TypeError", Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str()); - else - report(format, humanReadableName.c_str(), error.location, "TypeError", Luau::toString(error).c_str()); } static void reportWarning(ReportFormat format, const char* name, const Luau::LintWarning& warning) diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 06735d19..f0c62c56 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -49,6 +49,8 @@ enum class CompileFormat Binary, Remarks, Codegen, + CodegenVerbose, + CodegenNull, Null }; @@ -673,21 +675,33 @@ static void reportError(const char* name, const Luau::CompileError& error) report(name, error.getLocation(), "CompileError", error.what()); } -static std::string getCodegenAssembly(const char* name, const std::string& bytecode) +static std::string getCodegenAssembly(const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options) { std::unique_ptr globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); - setupState(L); - if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0) - return Luau::CodeGen::getAssemblyText(L, -1); + return Luau::CodeGen::getAssembly(L, -1, options); fprintf(stderr, "Error loading bytecode %s\n", name); return ""; } -static bool compileFile(const char* name, CompileFormat format) +static void annotateInstruction(void* context, std::string& text, int fid, int instpos) +{ + Luau::BytecodeBuilder& bcb = *(Luau::BytecodeBuilder*)context; + + bcb.annotateInstruction(text, fid, instpos); +} + +struct CompileStats +{ + size_t lines; + size_t bytecode; + size_t codegen; +}; + +static bool compileFile(const char* name, CompileFormat format, CompileStats& stats) { std::optional source = readFile(name); if (!source) @@ -696,9 +710,13 @@ static bool compileFile(const char* name, CompileFormat format) return false; } + // NOTE: Normally, you should use Luau::compile or luau_compile (see lua_require as an example) + // This function is much more complicated because it supports many output human-readable formats through internal interfaces + try { Luau::BytecodeBuilder bcb; + Luau::CodeGen::AssemblyOptions options = {format == CompileFormat::CodegenNull, format == CompileFormat::Codegen, annotateInstruction, &bcb}; if (format == CompileFormat::Text) { @@ -711,8 +729,24 @@ static bool compileFile(const char* name, CompileFormat format) bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks); bcb.setDumpSource(*source); } + else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenVerbose) + { + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | + Luau::BytecodeBuilder::Dump_Remarks); + bcb.setDumpSource(*source); + } - Luau::compileOrThrow(bcb, *source, copts()); + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + Luau::ParseResult result = Luau::Parser::parse(source->c_str(), source->size(), names, allocator); + + if (!result.errors.empty()) + throw Luau::ParseErrors(result.errors); + + stats.lines += result.lines; + + Luau::compileOrThrow(bcb, result, names, copts()); + stats.bytecode += bcb.getBytecode().size(); switch (format) { @@ -726,7 +760,11 @@ static bool compileFile(const char* name, CompileFormat format) fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); break; case CompileFormat::Codegen: - printf("%s", getCodegenAssembly(name, bcb.getBytecode()).c_str()); + case CompileFormat::CodegenVerbose: + printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options).c_str()); + break; + case CompileFormat::CodegenNull: + stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options).size(); break; case CompileFormat::Null: break; @@ -755,7 +793,7 @@ static void displayHelp(const char* argv0) printf("\n"); printf("Available modes:\n"); printf(" omitted: compile and run input files one by one\n"); - printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary, text, remarks, codegen or null)\n"); + printf(" --compile[=format]: compile input files and output resulting bytecode/assembly (binary, text, remarks, codegen)\n"); printf("\n"); printf("Available options:\n"); printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); @@ -813,6 +851,14 @@ int replMain(int argc, char** argv) { compileFormat = CompileFormat::Codegen; } + else if (strcmp(argv[1], "--compile=codegenverbose") == 0) + { + compileFormat = CompileFormat::CodegenVerbose; + } + else if (strcmp(argv[1], "--compile=codegennull") == 0) + { + compileFormat = CompileFormat::CodegenNull; + } else if (strcmp(argv[1], "--compile=null") == 0) { compileFormat = CompileFormat::Null; @@ -924,10 +970,17 @@ int replMain(int argc, char** argv) _setmode(_fileno(stdout), _O_BINARY); #endif + CompileStats stats = {}; int failed = 0; for (const std::string& path : files) - failed += !compileFile(path.c_str(), compileFormat); + failed += !compileFile(path.c_str(), compileFormat, stats); + + if (compileFormat == CompileFormat::Null) + printf("Compiled %d KLOC into %d KB bytecode\n", int(stats.lines / 1000), int(stats.bytecode / 1024)); + else if (compileFormat == CompileFormat::CodegenNull) + printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024), + int(stats.codegen / 1024)); return failed ? 1 : 0; } diff --git a/CMakeLists.txt b/CMakeLists.txt index 0016160a..05d701ee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -143,6 +143,11 @@ if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924) set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-) endif() +if (NOT MSVC) + # disable support for math_errno which allows compilers to lower sqrt() into a single CPU instruction + target_compile_options(Luau.VM PRIVATE -fno-math-errno) +endif() + if(MSVC AND LUAU_BUILD_CLI) # the default stack size that MSVC linker uses is 1 MB; we need more stack space in Debug because stack frames are larger set_target_properties(Luau.Analyze.CLI PROPERTIES LINK_FLAGS_DEBUG /STACK:2097152) diff --git a/CodeGen/include/Luau/AddressA64.h b/CodeGen/include/Luau/AddressA64.h new file mode 100644 index 00000000..351e6715 --- /dev/null +++ b/CodeGen/include/Luau/AddressA64.h @@ -0,0 +1,52 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/RegisterA64.h" + +namespace Luau +{ +namespace CodeGen +{ + +enum class AddressKindA64 : uint8_t +{ + imm, // reg + imm + reg, // reg + reg + + // TODO: + // reg + reg << shift + // reg + sext(reg) << shift + // reg + uext(reg) << shift + // pc + offset +}; + +struct AddressA64 +{ + AddressA64(RegisterA64 base, int off = 0) + : kind(AddressKindA64::imm) + , base(base) + , offset(xzr) + , data(off) + { + LUAU_ASSERT(base.kind == KindA64::x); + LUAU_ASSERT(off >= 0 && off < 4096); + } + + AddressA64(RegisterA64 base, RegisterA64 offset) + : kind(AddressKindA64::reg) + , base(base) + , offset(offset) + , data(0) + { + LUAU_ASSERT(base.kind == KindA64::x); + LUAU_ASSERT(offset.kind == KindA64::x); + } + + AddressKindA64 kind; + RegisterA64 base; + RegisterA64 offset; + int data; +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h new file mode 100644 index 00000000..9a1402be --- /dev/null +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -0,0 +1,144 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/RegisterA64.h" +#include "Luau/AddressA64.h" +#include "Luau/ConditionA64.h" +#include "Luau/Label.h" + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +class AssemblyBuilderA64 +{ +public: + explicit AssemblyBuilderA64(bool logText); + ~AssemblyBuilderA64(); + + // Moves + void mov(RegisterA64 dst, RegisterA64 src); + void mov(RegisterA64 dst, uint16_t src, int shift = 0); + void movk(RegisterA64 dst, uint16_t src, int shift = 0); + + // Arithmetics + void add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + void add(RegisterA64 dst, RegisterA64 src1, int src2); + void sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + void sub(RegisterA64 dst, RegisterA64 src1, int src2); + void neg(RegisterA64 dst, RegisterA64 src); + + // Comparisons + // Note: some arithmetic instructions also have versions that update flags (ADDS etc) but we aren't using them atm + // TODO: add cmp + + // Binary + // Note: shifted-register support and bitfield operations are omitted for simplicity + // TODO: support immediate arguments (they have odd encoding and forbid many values) + // TODO: support not variants for and/or/eor (required to support not...) + void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void lsr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void asr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void clz(RegisterA64 dst, RegisterA64 src); + void rbit(RegisterA64 dst, RegisterA64 src); + + // Load + // Note: paired loads are currently omitted for simplicity + void ldr(RegisterA64 dst, AddressA64 src); + void ldrb(RegisterA64 dst, AddressA64 src); + void ldrh(RegisterA64 dst, AddressA64 src); + void ldrsb(RegisterA64 dst, AddressA64 src); + void ldrsh(RegisterA64 dst, AddressA64 src); + void ldrsw(RegisterA64 dst, AddressA64 src); + + // Store + void str(RegisterA64 src, AddressA64 dst); + void strb(RegisterA64 src, AddressA64 dst); + void strh(RegisterA64 src, AddressA64 dst); + + // Control flow + // Note: tbz/tbnz are currently not supported because they have 15-bit offsets and we don't support branch thunks + void b(ConditionA64 cond, Label& label); + void cbz(RegisterA64 src, Label& label); + void cbnz(RegisterA64 src, Label& label); + void ret(); + + // Run final checks + bool finalize(); + + // Places a label at current location and returns it + Label setLabel(); + + // Assigns label position to the current location + void setLabel(Label& label); + + void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3); + + uint32_t getCodeSize() const; + + // Resulting data and code that need to be copied over one after the other + // The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code' + std::vector data; + std::vector code; + + std::string text; + + const bool logText = false; + +private: + // Instruction archetypes + void place0(const char* name, uint32_t word); + void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0); + void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op); + void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2); + void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op); + void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op); + void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0); + void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size); + void placeBC(const char* name, Label& label, uint8_t op, uint8_t cond); + void placeBR(const char* name, Label& label, uint8_t op, RegisterA64 cond); + + void place(uint32_t word); + void placeLabel(Label& label); + + void commit(); + LUAU_NOINLINE void extend(); + + // Data + size_t allocateData(size_t size, size_t align); + + // Logging of assembly in text form + LUAU_NOINLINE void log(const char* opcode); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, int src2); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, int src, int shift = 0); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, AddressA64 src); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 src, Label label); + LUAU_NOINLINE void log(const char* opcode, Label label); + LUAU_NOINLINE void log(Label label); + LUAU_NOINLINE void log(RegisterA64 reg); + LUAU_NOINLINE void log(AddressA64 addr); + + uint32_t nextLabel = 1; + std::vector(a) -> a" == toString(idType)); @@ -66,7 +42,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function") TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") { - AstStatBlock* block = parse(R"( + solve(R"( local function a(c) local function d(e) return c @@ -78,21 +54,9 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") local b = a(5) )"); - cgb.visit(block); - NotNull rootScope{cgb.rootScope}; - - ToStringOptions opts; - - NullModuleResolver resolver; - InternalErrorReporter iceHandler; - UnifierSharedState sharedState{&iceHandler}; - Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; - ConstraintSolver cs{NotNull{&normalizer}, rootScope, "MainModule", NotNull(&resolver), {}, &logger}; - - cs.run(); - TypeId idType = requireBinding(rootScope, "b"); + ToStringOptions opts; CHECK("(a) -> number" == toString(idType, opts)); } diff --git a/tests/DataFlowGraphBuilder.test.cpp b/tests/DataFlowGraphBuilder.test.cpp new file mode 100644 index 00000000..9aa7cde6 --- /dev/null +++ b/tests/DataFlowGraphBuilder.test.cpp @@ -0,0 +1,104 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/DataFlowGraphBuilder.h" +#include "Luau/Error.h" +#include "Luau/Parser.h" + +#include "AstQueryDsl.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +class DataFlowGraphFixture +{ + // Only needed to fix the operator== reflexivity of an empty Symbol. + ScopedFastFlag dcr{"DebugLuauDeferredConstraintResolution", true}; + + InternalErrorReporter handle; + + Allocator allocator; + AstNameTable names{allocator}; + AstStatBlock* module; + + std::optional graph; + +public: + void dfg(const std::string& code) + { + ParseResult parseResult = Parser::parse(code.c_str(), code.size(), names, allocator); + if (!parseResult.errors.empty()) + throw ParseErrors(std::move(parseResult.errors)); + module = parseResult.root; + graph = DataFlowGraphBuilder::build(module, NotNull{&handle}); + } + + template + std::optional getDef(const std::vector& nths = {nth(N)}) + { + T* node = query(module, nths); + REQUIRE(node); + return graph->getDef(node); + } + + template + DefId requireDef(const std::vector& nths = {nth(N)}) + { + auto loc = getDef(nths); + REQUIRE(loc); + return NotNull{*loc}; + } +}; + +TEST_SUITE_BEGIN("DataFlowGraphBuilder"); + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "define_locals_in_local_stat") +{ + dfg(R"( + local x = 5 + local y = x + )"); + + REQUIRE(getDef()); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "define_parameters_in_functions") +{ + dfg(R"( + local function f(x) + local y = x + end + )"); + + REQUIRE(getDef()); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "find_aliases") +{ + dfg(R"( + local x = 5 + local y = x + local z = y + )"); + + DefId x = requireDef(); + DefId y = requireDef(); + REQUIRE(x != y); // TODO: they should be equal but it's not just locals that can alias, so we'll support this later. +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "independent_locals") +{ + dfg(R"( + local x = 5 + local y = 5 + + local a = x + local b = y + )"); + + DefId x = requireDef(); + DefId y = requireDef(); + REQUIRE(x != y); +} + +TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 579b8942..b28155e3 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -506,6 +506,16 @@ std::optional linearSearchForBinding(Scope* scope, const char* name) return std::nullopt; } +void registerHiddenTypes(Fixture& fixture, TypeArena& arena) +{ + TypeId t = arena.addType(GenericTypeVar{"T"}); + GenericTypeDefinition genericT{t}; + + ScopePtr moduleScope = fixture.frontend.getGlobalScope(); + moduleScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, arena.addType(NegationTypeVar{t})}; + moduleScope->exportedTypeBindings["fun"] = TypeFun{{}, fixture.singletonTypes->functionType}; +} + void dump(const std::vector& constraints) { ToStringOptions opts; diff --git a/tests/Fixture.h b/tests/Fixture.h index 2fb48468..24c9566f 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -186,6 +186,8 @@ std::optional lookupName(ScopePtr scope, const std::string& name); // Wa std::optional linearSearchForBinding(Scope* scope, const char* name); +void registerHiddenTypes(Fixture& fixture, TypeArena& arena); + } // namespace Luau #define LUAU_REQUIRE_ERRORS(result) \ diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 957f3c7c..df0abdc9 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -517,6 +517,33 @@ TEST_CASE_FIXTURE(FrontendFixture, "recheck_if_dependent_script_is_dirty") CHECK_EQ("{| b_value: string |}", toString(*bExports)); } +TEST_CASE_FIXTURE(FrontendFixture, "mark_non_immediate_reverse_deps_as_dirty") +{ + ScopedFastFlag sff[] = { + {"LuauFixMarkDirtyReverseDeps", true}, + }; + + fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; + fileResolver.source["game/Gui/Modules/B"] = R"( + return require(game:GetService('Gui').Modules.A) + )"; + fileResolver.source["game/Gui/Modules/C"] = R"( + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {c_value = B.hello} + )"; + + frontend.check("game/Gui/Modules/C"); + + std::vector markedDirty; + frontend.markDirty("game/Gui/Modules/A", &markedDirty); + + REQUIRE(markedDirty.size() == 3); + CHECK(std::find(markedDirty.begin(), markedDirty.end(), "game/Gui/Modules/A") != markedDirty.end()); + CHECK(std::find(markedDirty.begin(), markedDirty.end(), "game/Gui/Modules/B") != markedDirty.end()); + CHECK(std::find(markedDirty.begin(), markedDirty.end(), "game/Gui/Modules/C") != markedDirty.end()); +} + #if 0 // Does not work yet. :( TEST_CASE_FIXTURE(FrontendFixture, "recheck_if_dependent_script_has_a_parse_error") diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index d5f635e6..b289b59e 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -11,6 +11,7 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauIceExceptionInheritanceChange); TEST_SUITE_BEGIN("ModuleTests"); @@ -226,24 +227,6 @@ TEST_CASE_FIXTURE(Fixture, "clone_free_tables") CHECK_EQ(clonedTtv->state, TableState::Free); } -TEST_CASE_FIXTURE(Fixture, "clone_constrained_intersection") -{ - TypeArena src; - - TypeId constrained = src.addType(ConstrainedTypeVar{TypeLevel{}, {singletonTypes->numberType, singletonTypes->stringType}}); - - TypeArena dest; - CloneState cloneState; - - TypeId cloned = clone(constrained, dest, cloneState); - CHECK_NE(constrained, cloned); - - const ConstrainedTypeVar* ctv = get(cloned); - REQUIRE_EQ(2, ctv->parts.size()); - CHECK_EQ(singletonTypes->numberType, ctv->parts[0]); - CHECK_EQ(singletonTypes->stringType, ctv->parts[1]); -} - TEST_CASE_FIXTURE(BuiltinsFixture, "clone_self_property") { fileResolver.source["Module/A"] = R"( @@ -296,7 +279,14 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") TypeArena dest; CloneState cloneState; - CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException); + if (FFlag::LuauIceExceptionInheritanceChange) + { + CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException); + } + else + { + CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException_DEPRECATED); + } } TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak") diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index b3522f6e..a8f3c7ba 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -3,6 +3,7 @@ #include "Fixture.h" #include "Luau/Common.h" +#include "Luau/TypeVar.h" #include "doctest.h" #include "Luau/Normalize.h" @@ -10,13 +11,16 @@ using namespace Luau; -struct NormalizeFixture : Fixture +namespace +{ +struct IsSubtypeFixture : Fixture { bool isSubtype(TypeId a, TypeId b) { return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice); } }; +} // namespace void createSomeClasses(Frontend& frontend) { @@ -55,7 +59,7 @@ void createSomeClasses(Frontend& frontend) TEST_SUITE_BEGIN("isSubtype"); -TEST_CASE_FIXTURE(NormalizeFixture, "primitives") +TEST_CASE_FIXTURE(IsSubtypeFixture, "primitives") { check(R"( local a = 41 @@ -75,7 +79,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "primitives") CHECK(!isSubtype(d, a)); } -TEST_CASE_FIXTURE(NormalizeFixture, "functions") +TEST_CASE_FIXTURE(IsSubtypeFixture, "functions") { check(R"( function a(x: number): number return x end @@ -96,7 +100,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "functions") CHECK(isSubtype(a, d)); } -TEST_CASE_FIXTURE(NormalizeFixture, "functions_and_any") +TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_and_any") { check(R"( function a(n: number) return "string" end @@ -106,15 +110,13 @@ TEST_CASE_FIXTURE(NormalizeFixture, "functions_and_any") TypeId a = requireType("a"); TypeId b = requireType("b"); - // Intuition: - // We cannot use b where a is required because we cannot rely on b to return a string. - // We cannot use a where b is required because we cannot rely on a to accept non-number arguments. + // any makes things work even when it makes no sense. - CHECK(!isSubtype(b, a)); - CHECK(!isSubtype(a, b)); + CHECK(isSubtype(b, a)); + CHECK(isSubtype(a, b)); } -TEST_CASE_FIXTURE(NormalizeFixture, "variadic_functions_with_no_head") +TEST_CASE_FIXTURE(IsSubtypeFixture, "variadic_functions_with_no_head") { check(R"( local a: (...number) -> () @@ -129,7 +131,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "variadic_functions_with_no_head") } #if 0 -TEST_CASE_FIXTURE(NormalizeFixture, "variadic_function_with_head") +TEST_CASE_FIXTURE(IsSubtypeFixture, "variadic_function_with_head") { check(R"( local a: (...number) -> () @@ -144,7 +146,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "variadic_function_with_head") } #endif -TEST_CASE_FIXTURE(NormalizeFixture, "union") +TEST_CASE_FIXTURE(IsSubtypeFixture, "union") { check(R"( local a: number | string @@ -171,7 +173,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "union") CHECK(!isSubtype(d, b)); } -TEST_CASE_FIXTURE(NormalizeFixture, "table_with_union_prop") +TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_union_prop") { check(R"( local a: {x: number} @@ -185,7 +187,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "table_with_union_prop") CHECK(!isSubtype(b, a)); } -TEST_CASE_FIXTURE(NormalizeFixture, "table_with_any_prop") +TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_any_prop") { check(R"( local a: {x: number} @@ -196,10 +198,10 @@ TEST_CASE_FIXTURE(NormalizeFixture, "table_with_any_prop") TypeId b = requireType("b"); CHECK(isSubtype(a, b)); - CHECK(!isSubtype(b, a)); + CHECK(isSubtype(b, a)); } -TEST_CASE_FIXTURE(NormalizeFixture, "intersection") +TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection") { ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, @@ -229,7 +231,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "intersection") CHECK(isSubtype(a, d)); } -TEST_CASE_FIXTURE(NormalizeFixture, "union_and_intersection") +TEST_CASE_FIXTURE(IsSubtypeFixture, "union_and_intersection") { check(R"( local a: number & string @@ -243,7 +245,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "union_and_intersection") CHECK(isSubtype(a, b)); } -TEST_CASE_FIXTURE(NormalizeFixture, "tables") +TEST_CASE_FIXTURE(IsSubtypeFixture, "tables") { check(R"( local a: {x: number} @@ -258,7 +260,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "tables") TypeId d = requireType("d"); CHECK(isSubtype(a, b)); - CHECK(!isSubtype(b, a)); + CHECK(isSubtype(b, a)); CHECK(!isSubtype(c, a)); CHECK(!isSubtype(a, c)); @@ -271,7 +273,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "tables") } #if 0 -TEST_CASE_FIXTURE(NormalizeFixture, "table_indexers_are_invariant") +TEST_CASE_FIXTURE(IsSubtypeFixture, "table_indexers_are_invariant") { check(R"( local a: {[string]: number} @@ -290,7 +292,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "table_indexers_are_invariant") CHECK(isSubtype(a, c)); } -TEST_CASE_FIXTURE(NormalizeFixture, "mismatched_indexers") +TEST_CASE_FIXTURE(IsSubtypeFixture, "mismatched_indexers") { check(R"( local a: {x: number} @@ -309,7 +311,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "mismatched_indexers") CHECK(isSubtype(b, c)); } -TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_table") +TEST_CASE_FIXTURE(IsSubtypeFixture, "cyclic_table") { check(R"( type A = {method: (A) -> ()} @@ -348,7 +350,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_table") } #endif -TEST_CASE_FIXTURE(NormalizeFixture, "classes") +TEST_CASE_FIXTURE(IsSubtypeFixture, "classes") { createSomeClasses(frontend); @@ -365,7 +367,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "classes") } #if 0 -TEST_CASE_FIXTURE(NormalizeFixture, "metatable" * doctest::expected_failures{1}) +TEST_CASE_FIXTURE(IsSubtypeFixture, "metatable" * doctest::expected_failures{1}) { check(R"( local T = {} @@ -389,26 +391,156 @@ TEST_CASE_FIXTURE(NormalizeFixture, "metatable" * doctest::expected_failures{1}) TEST_SUITE_END(); +struct NormalizeFixture : Fixture +{ + ScopedFastFlag sff0{"LuauNegatedStringSingletons", true}; + ScopedFastFlag sff1{"LuauNegatedFunctionTypes", true}; + + TypeArena arena; + InternalErrorReporter iceHandler; + UnifierSharedState unifierState{&iceHandler}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&unifierState}}; + + NormalizeFixture() + { + registerHiddenTypes(*this, arena); + } + + const NormalizedType* toNormalizedType(const std::string& annotation) + { + CheckResult result = check("type _Res = " + annotation); + LUAU_REQUIRE_NO_ERRORS(result); + std::optional ty = lookupType("_Res"); + REQUIRE(ty); + return normalizer.normalize(*ty); + } + + TypeId normal(const std::string& annotation) + { + const NormalizedType* norm = toNormalizedType(annotation); + REQUIRE(norm); + return normalizer.typeFromNormal(*norm); + } +}; + TEST_SUITE_BEGIN("Normalize"); -TEST_CASE_FIXTURE(NormalizeFixture, "union_with_overlapping_field_that_has_a_subtype_relationship") +TEST_CASE_FIXTURE(NormalizeFixture, "negate_string") { - check(R"( - local t: {x: number} | {x: number?} - )"); + CHECK("number" == toString(normal(R"( + (number | string) & Not + )"))); +} - ModulePtr tempModule{new Module}; - tempModule->scopes.emplace_back(Location(), std::make_shared(singletonTypes->anyTypePack)); +TEST_CASE_FIXTURE(NormalizeFixture, "negate_string_from_cofinite_string_intersection") +{ + CHECK("number" == toString(normal(R"( + (number | (string & Not<"hello"> & Not<"world">)) & Not + )"))); +} - // HACK: Normalization is an in-place operation. We need to cheat a little here and unfreeze - // the arena that the type lives in. - ModulePtr mainModule = getMainModule(); - unfreeze(mainModule->internalTypes); +TEST_CASE_FIXTURE(NormalizeFixture, "no_op_negation_is_dropped") +{ + CHECK("number" == toString(normal(R"( + number & Not + )"))); +} - TypeId tType = requireType("t"); - normalize(tType, tempModule, singletonTypes, *typeChecker.iceHandler); +TEST_CASE_FIXTURE(NormalizeFixture, "union_of_negation") +{ + CHECK("string" == toString(normal(R"( + (string & Not<"hello">) | "hello" + )"))); +} - CHECK_EQ("{| x: number? |}", toString(tType, {true})); +TEST_CASE_FIXTURE(NormalizeFixture, "intersect_truthy") +{ + CHECK("number | string | true" == toString(normal(R"( + (string | number | boolean | nil) & Not + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersect_truthy_expressed_as_intersection") +{ + CHECK("number | string | true" == toString(normal(R"( + (string | number | boolean | nil) & Not & Not + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "union_of_union") +{ + CHECK(R"("alpha" | "beta" | "gamma")" == toString(normal(R"( + ("alpha" | "beta") | "gamma" + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "union_of_negations") +{ + CHECK(R"(string & ~"world")" == toString(normal(R"( + (string & Not<"hello"> & Not<"world">) | (string & Not<"goodbye"> & Not<"world">) + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "disjoint_negations_normalize_to_string") +{ + CHECK(R"(string)" == toString(normal(R"( + (string & Not<"hello"> & Not<"world">) | (string & Not<"goodbye">) + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "negate_boolean") +{ + CHECK("true" == toString(normal(R"( + boolean & Not + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "negate_boolean_2") +{ + CHECK("never" == toString(normal(R"( + true & Not + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersect_function_and_top_function") +{ + CHECK("() -> ()" == toString(normal(R"( + fun & (() -> ()) + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersect_function_and_top_function_reverse") +{ + CHECK("() -> ()" == toString(normal(R"( + (() -> ()) & fun + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "union_function_and_top_function") +{ + CHECK("function" == toString(normal(R"( + fun | (() -> ()) + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "negated_function_is_anything_except_a_function") +{ + CHECK("(boolean | number | string | thread)?" == toString(normal(R"( + Not + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "specific_functions_cannot_be_negated") +{ + CHECK(nullptr == toNormalizedType("Not<(boolean) -> boolean>")); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "bare_negated_boolean") +{ + // TODO: We don't yet have a way to say number | string | thread | nil | Class | Table | Function + CHECK("(function | number | string | thread)?" == toString(normal(R"( + Not + )"))); } TEST_CASE_FIXTURE(Fixture, "higher_order_function") diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 662c2900..77cf6130 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "AstQueryDsl.h" #include "Fixture.h" #include "ScopedFlags.h" @@ -1736,8 +1737,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_type_annotation") TEST_CASE_FIXTURE(Fixture, "parse_error_missing_type_annotation") { - ScopedFastFlag LuauTypeAnnotationLocationChange{"LuauTypeAnnotationLocationChange", true}; - { ParseResult result = tryParse("local x:"); CHECK(result.errors.size() == 1); @@ -2777,4 +2776,51 @@ TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the CHECK(2 == f->generics.size); } +TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_no_comma_between_table_members") +{ + ScopedFastFlag luauTableConstructorRecovery{"LuauTableConstructorRecovery", true}; + + ParseResult result = tryParse(R"( + local t = { + first = 1 + second = 2, + third = 3, + fouth = 4, + } + )"); + + REQUIRE(1 == result.errors.size()); + + CHECK(Location({3, 12}, {3, 18}) == result.errors[0].getLocation()); + CHECK("Expected ',' after table constructor element" == result.errors[0].getMessage()); + + REQUIRE(1 == result.root->body.size); + + AstExprTable* table = Luau::query(result.root); + REQUIRE(table); + CHECK(table->items.size == 4); +} + +TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_no_comma_after_last_table_member") +{ + ParseResult result = tryParse(R"( + local t = { + first = 1 + + local ok = true + local good = ok == true + )"); + + REQUIRE(1 == result.errors.size()); + + CHECK(Location({4, 8}, {4, 13}) == result.errors[0].getLocation()); + CHECK("Expected '}' (to close '{' at line 2), got 'local'" == result.errors[0].getMessage()); + + REQUIRE(3 == result.root->body.size); + + AstExprTable* table = Luau::query(result.root); + REQUIRE(table); + CHECK(table->items.size == 1); +} + TEST_SUITE_END(); diff --git a/tests/Symbol.test.cpp b/tests/Symbol.test.cpp index e7d2973b..278c6ce2 100644 --- a/tests/Symbol.test.cpp +++ b/tests/Symbol.test.cpp @@ -10,7 +10,7 @@ using namespace Luau; TEST_SUITE_BEGIN("SymbolTests"); -TEST_CASE("hashing_globals") +TEST_CASE("equality_and_hashing_of_globals") { std::string s1 = "name"; std::string s2 = "name"; @@ -37,7 +37,7 @@ TEST_CASE("hashing_globals") REQUIRE_EQ(1, theMap.size()); } -TEST_CASE("hashing_locals") +TEST_CASE("equality_and_hashing_of_locals") { std::string s1 = "name"; std::string s2 = "name"; @@ -64,4 +64,24 @@ TEST_CASE("hashing_locals") REQUIRE_EQ(2, theMap.size()); } +TEST_CASE("equality_of_empty_symbols") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + std::string s1 = "name"; + std::string s2 = "name"; + + AstName one{s1.data()}; + AstLocal two{AstName{s2.data()}, Location(), nullptr, 0, 0, nullptr}; + + Symbol global{one}; + Symbol local{&two}; + Symbol empty1{}; + Symbol empty2{}; + + CHECK(empty1 != global); + CHECK(empty1 != local); + CHECK(empty1 == empty2); +} + TEST_SUITE_END(); diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 98eb9863..26c9a1ee 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -79,8 +79,8 @@ n1 [label="AnyTypeVar 1"]; TEST_CASE_FIXTURE(Fixture, "bound") { CheckResult result = check(R"( -local a = 444 -local b = a +function a(): number return 444 end +local b = a() )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -367,27 +367,6 @@ n3 [label="number"]; toDot(*ty, opts)); } -TEST_CASE_FIXTURE(Fixture, "constrained") -{ - // ConstrainedTypeVars never appear in the final type graph, so we have to create one directly - // to dotify it. - TypeVar t{ConstrainedTypeVar{TypeLevel{}, {typeChecker.numberType, typeChecker.stringType, typeChecker.nilType}}}; - - ToDotOptions opts; - opts.showPointers = false; - - CHECK_EQ(R"(digraph graphname { -n1 [label="ConstrainedTypeVar 1"]; -n1 -> n2; -n2 [label="number"]; -n1 -> n3; -n3 [label="string"]; -n1 -> n4; -n4 [label="nil"]; -})", - toDot(&t, opts)); -} - TEST_CASE_FIXTURE(Fixture, "singletontypes") { CheckResult result = check(R"( diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 53e5f71b..a510f914 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -10,7 +10,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); -LUAU_FASTFLAG(LuauSpecialTypesAsterisked); LUAU_FASTFLAG(LuauFixNameMaps); LUAU_FASTFLAG(LuauFunctionReturnStringificationFixup); @@ -270,16 +269,8 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_type_when_length_is_exceeded") o.maxTypeLength = 40; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); - if (FFlag::LuauSpecialTypesAsterisked) - { - CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); - } - else - { - CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); - CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); - } + CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); } TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") @@ -297,16 +288,8 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") o.maxTypeLength = 40; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); - if (FFlag::LuauSpecialTypesAsterisked) - { - CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); - } - else - { - CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); - CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); - } + CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); } TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table_state_braces") @@ -512,9 +495,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") TableTypeVar* tMeta5 = getMutable(tMeta4->props["__index"].type); REQUIRE(tMeta5); + REQUIRE(tMeta5->props.count("one") > 0); TableTypeVar* tMeta6 = getMutable(tMeta3->table); REQUIRE(tMeta6); + REQUIRE(tMeta6->props.count("two") > 0); ToStringResult oneResult = toStringDetailed(tMeta5->props["one"].type, opts); if (!FFlag::LuauFixNameMaps) @@ -533,10 +518,7 @@ local function target(callback: nil) return callback(4, "hello") end )"); LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("(nil) -> (*error-type*)", toString(requireType("target"))); - else - CHECK_EQ("(nil) -> ()", toString(requireType("target"))); + CHECK_EQ("(nil) -> (*error-type*)", toString(requireType("target"))); } TEST_CASE_FIXTURE(Fixture, "toStringGenericPack") diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 5ecc2a8c..8c738b7d 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -846,8 +846,16 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni type FutureIntersection = A & B )"); - // TODO: shared self causes this test to break in bizarre ways. - LUAU_REQUIRE_ERRORS(result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // To be quite honest, I don't know exactly why DCR fixes this. + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + // TODO: shared self causes this test to break in bizarre ways. + LUAU_REQUIRE_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") @@ -905,4 +913,14 @@ TEST_CASE_FIXTURE(Fixture, "it_is_ok_to_shadow_user_defined_alias") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "cannot_create_cyclic_type_with_unknown_module") +{ + CheckResult result = check(R"( + type AAA = B.AAA + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Unknown type 'B.AAA'"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 28767889..bb97bbeb 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -7,6 +7,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauIceExceptionInheritanceChange) + using namespace Luau; TEST_SUITE_BEGIN("AnnotationTests"); @@ -664,8 +666,8 @@ TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_exception_with_flag") AssertionCatcher ac; CHECK_THROWS_AS(check(R"( - local a: _luau_ice = 55 - )"), + local a: _luau_ice = 55 + )"), InternalCompilerError); LUAU_ASSERT(1 == AssertionCatcher::tripped); @@ -682,8 +684,8 @@ TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_exception_with_flag_handler }; CHECK_THROWS_AS(check(R"( - local a: _luau_ice = 55 - )"), + local a: _luau_ice = 55 + )"), InternalCompilerError); CHECK_EQ(true, caught); diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index 8c6f2e4f..91201812 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -13,8 +13,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauSpecialTypesAsterisked) - TEST_SUITE_BEGIN("TypeInferAnyError"); TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any") @@ -96,10 +94,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(requireType("a"))); - else - CHECK_EQ("", toString(requireType("a"))); + CHECK_EQ("*error-type*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") @@ -115,10 +110,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(requireType("a"))); - else - CHECK_EQ("", toString(requireType("a"))); + CHECK_EQ("*error-type*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") @@ -233,10 +225,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") CHECK_EQ("unknown", err->name); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(requireType("a"))); - else - CHECK_EQ("", toString(requireType("a"))); + CHECK_EQ("*error-type*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") @@ -245,10 +234,7 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") local a = Utility.Create "Foo" {} )"); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(requireType("a"))); - else - CHECK_EQ("", toString(requireType("a"))); + CHECK_EQ("*error-type*", toString(requireType("a"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") @@ -348,8 +334,6 @@ TEST_CASE_FIXTURE(Fixture, "prop_access_on_any_with_other_options") TEST_CASE_FIXTURE(BuiltinsFixture, "union_of_types_regression_test") { - ScopedFastFlag LuauUnionOfTypesFollow{"LuauUnionOfTypesFollow", true}; - CheckResult result = check(R"( --!strict local stat diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index f9c104fd..7c465c53 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -8,7 +8,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauSpecialTypesAsterisked); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) TEST_SUITE_BEGIN("BuiltinTests"); @@ -685,7 +684,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution && FFlag::LuauSpecialTypesAsterisked) + if (FFlag::DebugLuauDeferredConstraintResolution) { CHECK_EQ("string", toString(requireType("foo"))); CHECK_EQ("*error-type*", toString(requireType("bar"))); @@ -714,7 +713,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail_and_strin LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution && FFlag::LuauSpecialTypesAsterisked) + if (FFlag::DebugLuauDeferredConstraintResolution) { CHECK_EQ("string", toString(requireType("foo"))); CHECK_EQ("string", toString(requireType("bar"))); @@ -1016,10 +1015,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") CHECK_EQ("number", toString(requireType("a"))); CHECK_EQ("string", toString(requireType("b"))); CHECK_EQ("boolean", toString(requireType("c"))); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(requireType("d"))); - else - CHECK_EQ("", toString(requireType("d"))); + CHECK_EQ("*error-type*", toString(requireType("d"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 15c63ec7..684b47e9 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -309,6 +309,23 @@ TEST_CASE_FIXTURE(Fixture, "definitions_documentation_symbols") CHECK_EQ(yTtv->props["x"].documentationSymbol, "@test/global/y.x"); } +TEST_CASE_FIXTURE(Fixture, "definitions_symbols_are_generated_for_recursively_referenced_types") +{ + ScopedFastFlag LuauPersistTypesAfterGeneratingDocSyms("LuauPersistTypesAfterGeneratingDocSyms", true); + + loadDefinition(R"( + declare class MyClass + function myMethod(self) + end + + declare function myFunc(): MyClass + )"); + + std::optional myClassTy = typeChecker.globalScope->lookupType("MyClass"); + REQUIRE(bool(myClassTy)); + CHECK_EQ(myClassTy->type->documentationSymbol, "@test/globaltype/MyClass"); +} + TEST_CASE_FIXTURE(Fixture, "documentation_symbols_dont_attach_to_persistent_types") { loadDefinition(R"( diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index edc25c7e..ddf73349 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -15,7 +15,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauInstantiateInSubtyping); -LUAU_FASTFLAG(LuauSpecialTypesAsterisked); TEST_SUITE_BEGIN("TypeInferFunctions"); @@ -229,6 +228,48 @@ TEST_CASE_FIXTURE(Fixture, "too_many_arguments") CHECK_EQ(0, acm->actual); } +TEST_CASE_FIXTURE(Fixture, "too_many_arguments_error_location") +{ + ScopedFastFlag sff{"LuauArgMismatchReportFunctionLocation", true}; + + CheckResult result = check(R"( + --!strict + + function myfunction(a: number, b:number) end + myfunction(1) + + function getmyfunction() + return myfunction + end + getmyfunction()() + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + { + TypeError err = result.errors[0]; + + // Ensure the location matches the location of the function identifier + CHECK_EQ(err.location, Location(Position(4, 8), Position(4, 18))); + + auto acm = get(err); + REQUIRE(acm); + CHECK_EQ(2, acm->expected); + CHECK_EQ(1, acm->actual); + } + { + TypeError err = result.errors[1]; + + // Ensure the location matches the location of the expression returning the function + CHECK_EQ(err.location, Location(Position(9, 8), Position(9, 23))); + + auto acm = get(err); + REQUIRE(acm); + CHECK_EQ(2, acm->expected); + CHECK_EQ(0, acm->actual); + } +} + TEST_CASE_FIXTURE(Fixture, "recursive_function") { CheckResult result = check(R"( @@ -938,19 +979,13 @@ TEST_CASE_FIXTURE(Fixture, "function_cast_error_uses_correct_language") REQUIRE(tm1); CHECK_EQ("(string) -> number", toString(tm1->wantedType)); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("(string, *error-type*) -> number", toString(tm1->givenType)); - else - CHECK_EQ("(string, ) -> number", toString(tm1->givenType)); + CHECK_EQ("(string, *error-type*) -> number", toString(tm1->givenType)); auto tm2 = get(result.errors[1]); REQUIRE(tm2); CHECK_EQ("(number, number) -> (number, number)", toString(tm2->wantedType)); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("(string, *error-type*) -> number", toString(tm2->givenType)); - else - CHECK_EQ("(string, ) -> number", toString(tm2->givenType)); + CHECK_EQ("(string, *error-type*) -> number", toString(tm2->givenType)); } TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") @@ -1496,20 +1531,10 @@ function t:b() return 2 end -- not OK )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauSpecialTypesAsterisked) - { - CHECK_EQ(R"(Type '(*error-type*) -> number' could not be converted into '() -> number' + CHECK_EQ(R"(Type '(*error-type*) -> number' could not be converted into '() -> number' caused by: Argument count mismatch. Function expects 1 argument, but none are specified)", - toString(result.errors[0])); - } - else - { - CHECK_EQ(R"(Type '() -> number' could not be converted into '() -> number' -caused by: - Argument count mismatch. Function expects 1 argument, but none are specified)", - toString(result.errors[0])); - } + toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic") @@ -1659,17 +1684,20 @@ foo3() string.find() local t = {} -function t.foo(x: number, y: string?, ...: any) end +function t.foo(x: number, y: string?, ...: any) return 1 end function t:bar(x: number, y: string?) end t.foo() t:bar() -local u = { a = t } +local u = { a = t, b = function() return t end } u.a.foo() +local x = (u.a).foo() + +u.b().foo() )"); - LUAU_REQUIRE_ERROR_COUNT(7, result); + LUAU_REQUIRE_ERROR_COUNT(9, result); CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'foo1' expects 1 argument, but none are specified"); CHECK_EQ(toString(result.errors[1]), "Argument count mismatch. Function 'foo2' expects 1 to 2 arguments, but none are specified"); CHECK_EQ(toString(result.errors[2]), "Argument count mismatch. Function 'foo3' expects 1 to 3 arguments, but none are specified"); @@ -1677,6 +1705,8 @@ u.a.foo() CHECK_EQ(toString(result.errors[4]), "Argument count mismatch. Function 't.foo' expects at least 1 argument, but none are specified"); CHECK_EQ(toString(result.errors[5]), "Argument count mismatch. Function 't.bar' expects 2 to 3 arguments, but only 1 is specified"); CHECK_EQ(toString(result.errors[6]), "Argument count mismatch. Function 'u.a.foo' expects at least 1 argument, but none are specified"); + CHECK_EQ(toString(result.errors[7]), "Argument count mismatch. Function 'u.a.foo' expects at least 1 argument, but none are specified"); + CHECK_EQ(toString(result.errors[8]), "Argument count mismatch. Function expects at least 1 argument, but none are specified"); } // This might be surprising, but since 'any' became optional, unannotated functions in non-strict 'expect' 0 arguments @@ -1692,4 +1722,112 @@ foo(string.find("hello", "e")) CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'foo' expects 0 to 2 arguments, but 3 are specified"); } +TEST_CASE_FIXTURE(Fixture, "luau_subtyping_is_np_hard") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + {"LuauOverloadedFunctionSubtypingPerf", true}, + }; + + CheckResult result = check(R"( +--!strict + +-- An example of coding up graph coloring in the Luau type system. +-- This codes a three-node, two color problem. +-- A three-node triangle is uncolorable, +-- but a three-node line is colorable. + +type Red = "red" +type Blue = "blue" +type Color = Red | Blue +type Coloring = (Color) -> (Color) -> (Color) -> boolean +type Uncolorable = (Color) -> (Color) -> (Color) -> false + +type Line = Coloring + & ((Red) -> (Red) -> (Color) -> false) + & ((Blue) -> (Blue) -> (Color) -> false) + & ((Color) -> (Red) -> (Red) -> false) + & ((Color) -> (Blue) -> (Blue) -> false) + +type Triangle = Line + & ((Red) -> (Color) -> (Red) -> false) + & ((Blue) -> (Color) -> (Blue) -> false) + +local x : Triangle +local y : Line +local z : Uncolorable +z = x -- OK, so the triangle is uncolorable +z = y -- Not OK, so the line is colorable + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), + "Type '((\"blue\" | \"red\") -> (\"blue\" | \"red\") -> (\"blue\" | \"red\") -> boolean) & ((\"blue\" | \"red\") -> (\"blue\") -> (\"blue\") " + "-> false) & ((\"blue\" | \"red\") -> (\"red\") -> (\"red\") -> false) & ((\"blue\") -> (\"blue\") -> (\"blue\" | \"red\") -> false) & " + "((\"red\") -> (\"red\") -> (\"blue\" | \"red\") -> false)' could not be converted into '(\"blue\" | \"red\") -> (\"blue\" | \"red\") -> " + "(\"blue\" | \"red\") -> false'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "function_is_supertype_of_concrete_functions") +{ + ScopedFastFlag sff{"LuauNegatedFunctionTypes", true}; + registerHiddenTypes(*this, frontend.globalTypes); + + CheckResult result = check(R"( + function foo(f: fun) end + + function a() end + function id(x) return x end + + foo(a) + foo(id) + foo(foo) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "concrete_functions_are_not_supertypes_of_function") +{ + ScopedFastFlag sff{"LuauNegatedFunctionTypes", true}; + registerHiddenTypes(*this, frontend.globalTypes); + + CheckResult result = check(R"( + local a: fun = function() end + + function one(arg: () -> ()) end + function two(arg: (T) -> T) end + + one(a) + two(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK(6 == result.errors[0].location.begin.line); + CHECK(7 == result.errors[1].location.begin.line); +} + +TEST_CASE_FIXTURE(Fixture, "other_things_are_not_related_to_function") +{ + ScopedFastFlag sff{"LuauNegatedFunctionTypes", true}; + registerHiddenTypes(*this, frontend.globalTypes); + + CheckResult result = check(R"( + local a: fun = function() end + local b: {} = a + local c: boolean = a + local d: fun = true + local e: fun = {} + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + + CHECK(2 == result.errors[0].location.begin.line); + CHECK(3 == result.errors[1].location.begin.line); + CHECK(4 == result.errors[2].location.begin.line); + CHECK(5 == result.errors[3].location.begin.line); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index e1729ef5..de41c3a6 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -10,7 +10,6 @@ #include "doctest.h" LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAG(LuauSpecialTypesAsterisked) using namespace Luau; @@ -1011,10 +1010,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(t0->type)); - else - CHECK_EQ("", toString(t0->type)); + CHECK_EQ("*error-type*", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index ca22c351..0c10eb87 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -781,7 +781,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables") CheckResult result = check(R"( local a : string? = nil local b : number? = nil - + local x = setmetatable({}, { p = 5, q = a }); local y = setmetatable({}, { q = b, r = "hi" }); local z = setmetatable({}, { p = 5, q = nil, r = "hi" }); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index d6f787be..40912a95 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -13,7 +13,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauSpecialTypesAsterisked) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) TEST_SUITE_BEGIN("TypeInferLoops"); @@ -157,10 +156,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") LUAU_REQUIRE_ERROR_COUNT(2, result); TypeId p = requireType("p"); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(p)); - else - CHECK_EQ("", toString(p)); + CHECK_EQ("*error-type*", toString(p)); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index 8b7b3514..4cc628fb 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -14,8 +14,6 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping) using namespace Luau; -LUAU_FASTFLAG(LuauSpecialTypesAsterisked) - TEST_SUITE_BEGIN("TypeInferModules"); TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_require_basic") @@ -176,10 +174,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "require_module_that_does_not_export") auto hootyType = requireType(bModule, "Hooty"); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(hootyType)); - else - CHECK_EQ("", toString(hootyType)); + CHECK_EQ("*error-type*", toString(hootyType)); } TEST_CASE_FIXTURE(BuiltinsFixture, "warn_if_you_try_to_require_a_non_modulescript") @@ -251,23 +246,7 @@ end return m )"); - if (FFlag::LuauInstantiateInSubtyping) - { - // though this didn't error before the flag, it seems as though it should error since fields of a table are invariant. - // the user's intent would likely be that these "method" fields would be read-only, but without an annotation, accepting this should be - // unsound. - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ(R"(Type 'n' could not be converted into 't1 where t1 = {- Clone: (t1) -> (a...) -}' -caused by: - Property 'Clone' is not compatible. Type '(a) -> ()' could not be converted into 't1 where t1 = ({- Clone: t1 -}) -> (a...)'; different number of generic type parameters)", - toString(result.errors[0])); - } - else - { - LUAU_REQUIRE_NO_ERRORS(result); - } + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "custom_require_global") @@ -298,10 +277,7 @@ local ModuleA = require(game.A) std::optional oty = requireType("ModuleA"); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(*oty)); - else - CHECK_EQ("", toString(*oty)); + CHECK_EQ("*error-type*", toString(*oty)); } TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types") diff --git a/tests/TypeInfer.negations.test.cpp b/tests/TypeInfer.negations.test.cpp new file mode 100644 index 00000000..e8256f97 --- /dev/null +++ b/tests/TypeInfer.negations.test.cpp @@ -0,0 +1,52 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" +#include "Luau/Common.h" +#include "ScopedFlags.h" + +using namespace Luau; + +namespace +{ +struct NegationFixture : Fixture +{ + TypeArena arena; + ScopedFastFlag sff[2]{ + {"LuauNegatedStringSingletons", true}, + {"LuauSubtypeNormalizer", true}, + }; + + NegationFixture() + { + registerHiddenTypes(*this, arena); + } +}; +} // namespace + +TEST_SUITE_BEGIN("Negations"); + +TEST_CASE_FIXTURE(NegationFixture, "negated_string_is_a_subtype_of_string") +{ + CheckResult result = check(R"( + function foo(arg: string) end + local a: string & Not<"Hello"> + foo(a) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(NegationFixture, "string_is_not_a_subtype_of_negated_string") +{ + CheckResult result = check(R"( + function foo(arg: string & Not<"hello">) end + local a: string + foo(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 3d6c0193..b2516f6d 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -13,6 +13,8 @@ using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + TEST_SUITE_BEGIN("TypeInferOperators"); TEST_CASE_FIXTURE(Fixture, "or_joins_types") @@ -33,7 +35,7 @@ TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_extras") local x:number|string = s local y = x or "s" )"); - CHECK_EQ(0, result.errors.size()); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(toString(*requireType("s")), "number | string"); CHECK_EQ(toString(*requireType("y")), "number | string"); } @@ -44,7 +46,7 @@ TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union") local s = "a" or "b" local x:string = s )"); - CHECK_EQ(0, result.errors.size()); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(*requireType("s"), *typeChecker.stringType); } @@ -54,7 +56,7 @@ TEST_CASE_FIXTURE(Fixture, "and_adds_boolean") local s = "a" and 10 local x:boolean|number = s )"); - CHECK_EQ(0, result.errors.size()); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(toString(*requireType("s")), "boolean | number"); } @@ -64,7 +66,7 @@ TEST_CASE_FIXTURE(Fixture, "and_adds_boolean_no_superfluous_union") local s = "a" and true local x:boolean = s )"); - CHECK_EQ(0, result.errors.size()); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(*requireType("x"), *typeChecker.booleanType); } @@ -73,7 +75,7 @@ TEST_CASE_FIXTURE(Fixture, "and_or_ternary") CheckResult result = check(R"( local s = (1/2) > 0.5 and "a" or 10 )"); - CHECK_EQ(0, result.errors.size()); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(toString(*requireType("s")), "number | string"); } @@ -81,7 +83,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "primitive_arith_no_metatable") { CheckResult result = check(R"( function add(a: number, b: string) - return a + (tonumber(b) :: number), a .. b + return a + (tonumber(b) :: number), tostring(a) .. b end local n, s = add(2,"3") )"); @@ -432,16 +434,17 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus") { CheckResult result = check(R"( --!strict - local foo = { - value = 10 - } + local foo local mt = {} - setmetatable(foo, mt) mt.__unm = function(val: typeof(foo)): string - return val.value .. "test" + return tostring(val.value) .. "test" end + foo = setmetatable({ + value = 10 + }, mt) + local a = -foo local b = 1+-1 @@ -457,25 +460,32 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus") CHECK_EQ("string", toString(requireType("a"))); CHECK_EQ("number", toString(requireType("b"))); - GenericError* gen = get(result.errors[0]); - REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK(toString(result.errors[0]) == "Type '{ value: number }' could not be converted into 'number'"); + } + else + { + GenericError* gen = get(result.errors[0]); + REQUIRE(gen); + REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus_error") { CheckResult result = check(R"( --!strict - local foo = { - value = 10 - } - local mt = {} - setmetatable(foo, mt) mt.__unm = function(val: boolean): string return "test" end + local foo = setmetatable({ + value = 10 + }, mt) + local a = -foo )"); @@ -492,16 +502,16 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_len_error") { CheckResult result = check(R"( --!strict - local foo = { - value = 10 - } local mt = {} - setmetatable(foo, mt) - mt.__len = function(val: any): string + mt.__len = function(val): string return "test" end + local foo = setmetatable({ + value = 10, + }, mt) + local a = #foo )"); @@ -558,15 +568,21 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "disallow_string_and_types_without_metatables LUAU_REQUIRE_ERROR_COUNT(3, result); TypeMismatch* tm = get(result.errors[0]); - REQUIRE_EQ(*tm->wantedType, *typeChecker.numberType); - REQUIRE_EQ(*tm->givenType, *typeChecker.stringType); + REQUIRE(tm); + CHECK_EQ(*tm->wantedType, *typeChecker.numberType); + CHECK_EQ(*tm->givenType, *typeChecker.stringType); + + GenericError* gen1 = get(result.errors[1]); + REQUIRE(gen1); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ(gen1->message, "Operator + is not applicable for '{ value: number }' and 'number' because neither type has a metatable"); + else + CHECK_EQ(gen1->message, "Binary operator '+' not supported by types 'foo' and 'number'"); TypeMismatch* tm2 = get(result.errors[2]); + REQUIRE(tm2); CHECK_EQ(*tm2->wantedType, *typeChecker.numberType); CHECK_EQ(*tm2->givenType, *requireType("foo")); - - GenericError* gen2 = get(result.errors[1]); - REQUIRE_EQ(gen2->message, "Binary operator '+' not supported by types 'foo' and 'number'"); } // CLI-29033 @@ -611,12 +627,10 @@ TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") { std::vector ops = {"+", "-", "*", "/", "%", "^", ".."}; - std::string src = R"( - function foo(a, b) - )"; + std::string src = "function foo(a, b)\n"; for (const auto& op : ops) - src += "local _ = a " + op + "b\n"; + src += "local _ = a " + op + " b\n"; src += "end"; @@ -651,7 +665,11 @@ TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operato GenericError* ge = get(result.errors[0]); REQUIRE(ge); - CHECK_EQ("Type 'boolean' cannot be compared with relational operator <", ge->message); + + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("Types 'boolean' and 'boolean' cannot be compared with relational operator <", ge->message); + else + CHECK_EQ("Type 'boolean' cannot be compared with relational operator <", ge->message); } TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators2") @@ -666,7 +684,10 @@ TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operato GenericError* ge = get(result.errors[0]); REQUIRE(ge); - CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("Types 'number | string' and 'number | string' cannot be compared with relational operator <", ge->message); + else + CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message); } TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union") @@ -891,4 +912,63 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "expected_types_through_binary_or") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "mm_ops_must_return_a_value") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local mm = { + __add = function(self, other) + return + end, + } + + local x = setmetatable({}, mm) + local y = x + 123 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(requireType("y") == singletonTypes->errorRecoveryType()); + + const GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK(ge->message == "Metamethod '__add' must return a value"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "mm_comparisons_must_return_a_boolean") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local mm1 = { + __lt = function(self, other) + return 123 + end, + } + + local mm2 = { + __lt = function(self, other) + return + end, + } + + local o1 = setmetatable({}, mm1) + local v1 = o1 < o1 + + local o2 = setmetatable({}, mm2) + local v2 = o2 < o2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK(requireType("v1") == singletonTypes->booleanType); + CHECK(requireType("v2") == singletonTypes->booleanType); + + CHECK(toString(result.errors[0]) == "Metamethod '__lt' must return type 'boolean'"); + CHECK(toString(result.errors[1]) == "Metamethod '__lt' must return type 'boolean'"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index a31c9c50..3c2c8781 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -11,8 +11,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauSpecialTypesAsterisked) - using namespace Luau; TEST_SUITE_BEGIN("TypeInferPrimitives"); @@ -49,10 +47,7 @@ TEST_CASE_FIXTURE(Fixture, "string_index") REQUIRE(nat); CHECK_EQ("string", toString(nat->ty)); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(requireType("t"))); - else - CHECK_EQ("", toString(requireType("t"))); + CHECK_EQ("*error-type*", toString(requireType("t"))); } TEST_CASE_FIXTURE(Fixture, "string_method") diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index ccc4d775..f6e60cdc 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -624,15 +624,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") CHECK_EQ("{string | string}", toString(requireType("t"))); } -struct NormalizeFixture : Fixture +namespace +{ +struct IsSubtypeFixture : Fixture { bool isSubtype(TypeId a, TypeId b) { return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice); } }; +} // namespace -TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions_of_different_arities") +TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection_of_functions_of_different_arities") { check(R"( type A = (any) -> () @@ -653,7 +656,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions_of_different_arit CHECK("((any) -> ()) & ((any, any) -> ())" == toString(requireType("t"))); } -TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity") +TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_with_mismatching_arity") { check(R"( local a: (number) -> () @@ -676,7 +679,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity") CHECK(!isSubtype(b, c)); } -TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_optional_parameters") +TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_with_mismatching_arity_but_optional_parameters") { /* * (T0..TN) <: (T0..TN, A?) @@ -736,7 +739,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_option // CHECK(!isSubtype(b, c)); } -TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_any_is_an_optional_param") +TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_with_mismatching_arity_but_any_is_an_optional_param") { check(R"( local a: (number?) -> () diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index f707f952..26f23438 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -7,7 +7,7 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauSpecialTypesAsterisked) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) using namespace Luau; @@ -49,7 +49,6 @@ struct RefinementClassFixture : Fixture {"Y", Property{typeChecker.numberType}}, {"Z", Property{typeChecker.numberType}}, }; - normalize(vec3, scope, arena, singletonTypes, *typeChecker.iceHandler); TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); @@ -57,21 +56,17 @@ struct RefinementClassFixture : Fixture TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); getMutable(isA)->magicFunction = magicFunctionInstanceIsA; - normalize(isA, scope, arena, singletonTypes, *typeChecker.iceHandler); getMutable(inst)->props = { {"Name", Property{typeChecker.stringType}}, {"IsA", Property{isA}}, }; - normalize(inst, scope, arena, singletonTypes, *typeChecker.iceHandler); TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"}); - normalize(folder, scope, arena, singletonTypes, *typeChecker.iceHandler); TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"}); getMutable(part)->props = { {"Position", Property{vec3}}, }; - normalize(part, scope, arena, singletonTypes, *typeChecker.iceHandler); typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; @@ -102,8 +97,16 @@ TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("string", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({5, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({5, 26}))); + } + else + { + CHECK_EQ("string", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({5, 26}))); + } } TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint") @@ -120,8 +123,16 @@ TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({5, 26}))); + } + else + { + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); + } } TEST_CASE_FIXTURE(Fixture, "parenthesized_expressions_are_followed_through") @@ -138,8 +149,16 @@ TEST_CASE_FIXTURE(Fixture, "parenthesized_expressions_are_followed_through") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({5, 26}))); + } + else + { + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); + } } TEST_CASE_FIXTURE(Fixture, "and_constraint") @@ -158,8 +177,16 @@ TEST_CASE_FIXTURE(Fixture, "and_constraint") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("string", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("number", toString(requireTypeAtPosition({4, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("(number?) & ~(false?)", toString(requireTypeAtPosition({4, 26}))); + } + else + { + CHECK_EQ("string", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("number", toString(requireTypeAtPosition({4, 26}))); + } CHECK_EQ("string?", toString(requireTypeAtPosition({6, 26}))); CHECK_EQ("number?", toString(requireTypeAtPosition({7, 26}))); @@ -184,8 +211,16 @@ TEST_CASE_FIXTURE(Fixture, "not_and_constraint") CHECK_EQ("string?", toString(requireTypeAtPosition({3, 26}))); CHECK_EQ("number?", toString(requireTypeAtPosition({4, 26}))); - CHECK_EQ("string", toString(requireTypeAtPosition({6, 26}))); - CHECK_EQ("number", toString(requireTypeAtPosition({7, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({6, 26}))); + CHECK_EQ("(number?) & ~(false?)", toString(requireTypeAtPosition({7, 26}))); + } + else + { + CHECK_EQ("string", toString(requireTypeAtPosition({6, 26}))); + CHECK_EQ("number", toString(requireTypeAtPosition({7, 26}))); + } } TEST_CASE_FIXTURE(Fixture, "or_predicate_with_truthy_predicates") @@ -207,8 +242,56 @@ TEST_CASE_FIXTURE(Fixture, "or_predicate_with_truthy_predicates") CHECK_EQ("string?", toString(requireTypeAtPosition({3, 26}))); CHECK_EQ("number?", toString(requireTypeAtPosition({4, 26}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({6, 26}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({7, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({6, 26}))); + CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({7, 26}))); + } + else + { + CHECK_EQ("nil", toString(requireTypeAtPosition({6, 26}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({7, 26}))); + } +} + +TEST_CASE_FIXTURE(Fixture, "a_and_b_or_a_and_c") +{ + CheckResult result = check(R"( + function f(a: string?, b: number?, c: boolean) + if (a and b) or (a and c) then + local foo = a + local bar = b + local baz = c + else + local foo = a + local bar = b + local baz = c + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & (~(false?) | ~(false?))", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number?", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("boolean", toString(requireTypeAtPosition({5, 28}))); + + CHECK_EQ("string?", toString(requireTypeAtPosition({7, 28}))); + CHECK_EQ("number?", toString(requireTypeAtPosition({8, 28}))); + CHECK_EQ("boolean", toString(requireTypeAtPosition({9, 28}))); + } + else + { + CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number?", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("true", toString(requireTypeAtPosition({5, 28}))); // oh no! :( + + CHECK_EQ("string?", toString(requireTypeAtPosition({7, 28}))); + CHECK_EQ("number?", toString(requireTypeAtPosition({8, 28}))); + CHECK_EQ("boolean", toString(requireTypeAtPosition({9, 28}))); + } } TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints") @@ -224,8 +307,17 @@ TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("string", toString(requireTypeAtPosition({4, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("number?", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({4, 26}))); + } + else + { + // We're going to drop support for type refinements through type assertions. + CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("string", toString(requireTypeAtPosition({4, 26}))); + } } TEST_CASE_FIXTURE(Fixture, "typeguard_in_if_condition_position") @@ -361,11 +453,22 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "(number | string)?"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "((number | string)?) & (boolean?)"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "((number | string)?) & (boolean?)"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "((number | string)?) & unknown"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "(boolean?) & unknown"); // a ~= b + } + else + { + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "(number | string)?"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b + + CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b + } } TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") @@ -382,8 +485,16 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == 1 - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1 + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "((number | string)?) & number"); // a == 1 + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "((number | string)?) & unknown"); // a ~= 1 + } + else + { + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == 1; + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1 + } } TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") @@ -400,8 +511,16 @@ TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), R"("hello")"); // a == "hello" - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello" + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), R"("hello" & ((number | string)?))"); // a == "hello" + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), R"(((number | string)?) & ~"hello")"); // a ~= "hello" + } + else + { + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), R"("hello")"); // a == "hello" + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), R"((number | string)?)"); // a ~= "hello" + } } TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") @@ -418,8 +537,16 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a == nil + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "((number | string)?) & ~nil"); // a ~= nil + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "((number | string)?) & nil"); // a == nil + } + else + { + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a == nil + } } TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") @@ -434,8 +561,17 @@ TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b + if (FFlag::DebugLuauDeferredConstraintResolution) + { + ToStringOptions opts; + CHECK_EQ(toString(requireTypeAtPosition({3, 33}), opts), "(string?) & a"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36}), opts), "(string?) & a"); // a == b + } + else + { + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b + } } TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_equal") @@ -450,8 +586,16 @@ TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_e LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}?"); // a ~= b + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any & unknown"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "({| x: number |}?) & unknown"); // a ~= b + } + else + { + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}?"); // a ~= b + } } TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") @@ -470,11 +614,22 @@ TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "string?"); // a ~= b + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string & unknown"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "(string?) & unknown"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "(string?) & string"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "(string?) & string"); // a == b + } + else + { + CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "string?"); // a ~= b + + CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b + } } TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable") @@ -505,10 +660,7 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(requireTypeAtPosition({3, 28}))); - else - CHECK_EQ("", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("*error-type*", toString(requireTypeAtPosition({3, 28}))); } TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true") @@ -695,8 +847,16 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({4, 28}))); + } + else + { + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); + } } TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") @@ -712,8 +872,16 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({4, 28}))); + } + else + { + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); + } } TEST_CASE_FIXTURE(Fixture, "either_number_or_string") @@ -963,19 +1131,27 @@ TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") { CheckResult result = check(R"( - local function is_true(b: true) end - local function is_false(b: false) end - local function f(x: boolean) if x then - is_true(x) + local foo = x else - is_false(x) + local foo = x end end )"); LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("boolean & ~(false?)", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("boolean & ~~(false?)", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("true", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("false", toString(requireTypeAtPosition({5, 28}))); + } } TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false") diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 73ccac70..e37880d0 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -79,6 +79,16 @@ TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype_multi_assignment") +{ + CheckResult result = check(R"( + local a: "foo" = "foo" + local b: string, c: number = a, 10 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 53f9a1ab..68757fef 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1,5 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" +#include "Luau/Frontend.h" +#include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -11,6 +14,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping) TEST_SUITE_BEGIN("TableTests"); @@ -44,7 +49,7 @@ TEST_CASE_FIXTURE(Fixture, "augment_table") const TableTypeVar* tType = get(requireType("t")); REQUIRE(tType != nullptr); - CHECK(tType->props.find("foo") != tType->props.end()); + CHECK(1 == tType->props.count("foo")); } TEST_CASE_FIXTURE(Fixture, "augment_nested_table") @@ -101,7 +106,11 @@ TEST_CASE_FIXTURE(Fixture, "updating_sealed_table_prop_is_ok") TEST_CASE_FIXTURE(Fixture, "cannot_change_type_of_unsealed_table_prop") { - CheckResult result = check("local t = {} t.prop = 999 t.prop = 'hello'"); + CheckResult result = check(R"( + local t = {} + t.prop = 999 + t.prop = 'hello' + )"); LUAU_REQUIRE_ERROR_COUNT(1, result); } @@ -858,11 +867,12 @@ TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_s LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("a")); + CHECK("string" == toString(*typeChecker.stringType)); TableTypeVar* tableType = getMutable(requireType("t")); REQUIRE(tableType != nullptr); REQUIRE(tableType->indexer == std::nullopt); + REQUIRE(0 != tableType->props.count("a")); TypeId propertyA = tableType->props["a"].type; REQUIRE(propertyA != nullptr); @@ -1950,7 +1960,11 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table local c : string = t.m("hi") )"); - LUAU_REQUIRE_ERRORS(result); + // TODO: test behavior is wrong with LuauInstantiateInSubtyping until we can re-enable the covariant requirement for instantiation in subtyping + if (FFlag::LuauInstantiateInSubtyping) + LUAU_REQUIRE_NO_ERRORS(result); + else + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_properties_in_nonstrict") @@ -2390,9 +2404,12 @@ TEST_CASE_FIXTURE(Fixture, "wrong_assign_does_hit_indexer") TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer") { - CheckResult result = check("local a = {a=1, b=2} a['a'] = nil"); + CheckResult result = check(R"( + local a = {a=1, b=2} + a['a'] = nil + )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{ + CHECK_EQ(result.errors[0], (TypeError{Location{Position{2, 17}, Position{2, 20}}, TypeMismatch{ typeChecker.numberType, typeChecker.nilType, }})); @@ -2701,6 +2718,62 @@ local baz = foo[bar] CHECK_EQ(result.errors[0].location, Location{Position{3, 16}, Position{3, 19}}); } +TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_basic") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local a = setmetatable({ + a = 1, + }, { + __call = function(self, b: number) + return self.a * b + end, + }) + + local foo = a(12) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(requireType("foo") == singletonTypes->numberType); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_must_be_callable") +{ + CheckResult result = check(R"( + local a = setmetatable({}, { + __call = 123, + }) + + local foo = a() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(result.errors[0] == TypeError{ + Location{{5, 20}, {5, 21}}, + CannotCallNonFunction{singletonTypes->numberType}, + }); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_generic") +{ + CheckResult result = check(R"( + local a = setmetatable({}, { + __call = function(self, b: T) + return b + end, + }) + + local foo = a(12) + local bar = a("bar") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(requireType("foo") == singletonTypes->numberType); + CHECK(requireType("bar") == singletonTypes->stringType); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "table_simple_call") { CheckResult result = check(R"( @@ -3196,11 +3269,14 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table local c : string = t.m("hi") )"); - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), R"(Type 't' could not be converted into '{| m: (number) -> number |}' -caused by: - Property 'm' is not compatible. Type '(a) -> a' could not be converted into '(number) -> number'; different number of generic type parameters)"); - // this error message is not great since the underlying issue is that the context is invariant, + LUAU_REQUIRE_NO_ERRORS(result); + // TODO: test behavior is wrong until we can re-enable the covariant requirement for instantiation in subtyping + // LUAU_REQUIRE_ERRORS(result); + // CHECK_EQ(toString(result.errors[0]), R"(Type 't' could not be converted into '{| m: (number) -> number |}' + // caused by: + // Property 'm' is not compatible. Type '(a) -> a' could not be converted into '(number) -> number'; different number of generic type + // parameters)"); + // // this error message is not great since the underlying issue is that the context is invariant, // and `(number) -> number` cannot be a subtype of `(a) -> a`. } @@ -3226,4 +3302,40 @@ local g : ({ p : number, q : string }) -> ({ p : number, r : boolean }) = f CHECK_EQ("r", error->properties[0]); } +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_has_a_side_effect") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local mt = { + __add = function(x, y) + return 123 + end, + } + + local foo = {} + setmetatable(foo, mt) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("foo")) == "{ @metatable { __add: (a, b) -> number }, { } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tables_should_be_fully_populated") +{ + CheckResult result = check(R"( + local t = { + x = 5 :: NonexistingTypeWhichEndsUpReturningAnErrorType, + y = 5 + } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ("{ x: *error-type*, y: number }", toString(requireType("t"), opts)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 239b8c28..6c7201a6 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -17,7 +17,6 @@ LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping); -LUAU_FASTFLAG(LuauSpecialTypesAsterisked); using namespace Luau; @@ -238,20 +237,10 @@ TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") // TODO: Should we assert anything about these tests when DCR is being used? if (!FFlag::DebugLuauDeferredConstraintResolution) { - if (FFlag::LuauSpecialTypesAsterisked) - { - CHECK_EQ("*error-type*", toString(requireType("c"))); - CHECK_EQ("*error-type*", toString(requireType("d"))); - CHECK_EQ("*error-type*", toString(requireType("e"))); - CHECK_EQ("*error-type*", toString(requireType("f"))); - } - else - { - CHECK_EQ("", toString(requireType("c"))); - CHECK_EQ("", toString(requireType("d"))); - CHECK_EQ("", toString(requireType("e"))); - CHECK_EQ("", toString(requireType("f"))); - } + CHECK_EQ("*error-type*", toString(requireType("c"))); + CHECK_EQ("*error-type*", toString(requireType("d"))); + CHECK_EQ("*error-type*", toString(requireType("e"))); + CHECK_EQ("*error-type*", toString(requireType("f"))); } } @@ -662,10 +651,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(t0->type)); - else - CHECK_EQ("", toString(t0->type)); + CHECK_EQ("*error-type*", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); @@ -1046,7 +1032,6 @@ TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_normalizer") ScopedFastFlag sffs[]{ {"LuauSubtypeNormalizer", true}, {"LuauTypeNormalization2", true}, - {"LuauAutocompleteDynamicLimits", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index c178d2a4..f04a3d95 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -9,8 +9,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauSpecialTypesAsterisked) - struct TryUnifyFixture : Fixture { TypeArena arena; @@ -124,10 +122,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("a", toString(requireType("a"))); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(requireType("b"))); - else - CHECK_EQ("", toString(requireType("b"))); + CHECK_EQ("*error-type*", toString(requireType("b"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained") @@ -142,10 +137,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_con LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("a", toString(requireType("a"))); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(requireType("b"))); - else - CHECK_EQ("", toString(requireType("b"))); + CHECK_EQ("*error-type*", toString(requireType("b"))); CHECK_EQ("number", toString(requireType("c"))); } diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index aaa7ded4..4c8eeac6 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -467,6 +467,8 @@ type I = W TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit") { + ScopedFastFlag sff("LuauFunctionReturnStringificationFixup", true); + CheckResult result = check(R"( type X = (T...) -> (T...) @@ -490,6 +492,8 @@ type F = X<(string, ...number)> TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi") { + ScopedFastFlag sff("LuauFunctionReturnStringificationFixup", true); + CheckResult result = check(R"( type Y = (T...) -> (U...) diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index dc551634..0c25386f 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -6,8 +6,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauSpecialTypesAsterisked) - using namespace Luau; TEST_SUITE_BEGIN("UnionTypes"); @@ -199,10 +197,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") CHECK_EQ(mup->missing[0], *bTy); CHECK_EQ(mup->key, "x"); - if (FFlag::LuauSpecialTypesAsterisked) - CHECK_EQ("*error-type*", toString(requireType("r"))); - else - CHECK_EQ("", toString(requireType("r"))); + CHECK_EQ("*error-type*", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any") diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index b81c80ce..5dd1b1bc 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -436,7 +436,6 @@ TEST_CASE("proof_that_isBoolean_uses_all_of") TEST_CASE("content_reassignment") { TypeVar myAny{AnyTypeVar{}, /*presistent*/ true}; - myAny.normal = true; myAny.documentationSymbol = "@global/any"; TypeArena arena; @@ -446,7 +445,6 @@ TEST_CASE("content_reassignment") CHECK(get(futureAny) != nullptr); CHECK(!futureAny->persistent); - CHECK(futureAny->normal); CHECK(futureAny->documentationSymbol == "@global/any"); CHECK(futureAny->owningArena == &arena); } diff --git a/tests/Variant.test.cpp b/tests/Variant.test.cpp index aa0731ca..83eec519 100644 --- a/tests/Variant.test.cpp +++ b/tests/Variant.test.cpp @@ -217,4 +217,35 @@ TEST_CASE("Visit") CHECK(r3 == "1231147"); } +struct MoveOnly +{ + MoveOnly() = default; + + MoveOnly(const MoveOnly&) = delete; + MoveOnly& operator=(const MoveOnly&) = delete; + + MoveOnly(MoveOnly&&) = default; + MoveOnly& operator=(MoveOnly&&) = default; +}; + +TEST_CASE("Move") +{ + Variant v1 = MoveOnly{}; + Variant v2 = std::move(v1); +} + +TEST_CASE("MoveWithCopyableAlternative") +{ + Variant v1 = std::string{"Hello, world! I am longer than a normal hello world string to avoid SSO."}; + Variant v2 = std::move(v1); + + std::string* s1 = get_if(&v1); + REQUIRE(s1); + CHECK(*s1 == ""); + + std::string* s2 = get_if(&v2); + REQUIRE(s2); + CHECK(*s2 == "Hello, world! I am longer than a normal hello world string to avoid SSO."); +} + TEST_SUITE_END(); diff --git a/tests/VisitTypeVar.test.cpp b/tests/VisitTypeVar.test.cpp index 4fba694a..589c3bad 100644 --- a/tests/VisitTypeVar.test.cpp +++ b/tests/VisitTypeVar.test.cpp @@ -22,7 +22,14 @@ TEST_CASE_FIXTURE(Fixture, "throw_when_limit_is_exceeded") TypeId tType = requireType("t"); - CHECK_THROWS_AS(toString(tType), RecursionLimitException); + if (FFlag::LuauIceExceptionInheritanceChange) + { + CHECK_THROWS_AS(toString(tType), RecursionLimitException); + } + else + { + CHECK_THROWS_AS(toString(tType), RecursionLimitException_DEPRECATED); + } } TEST_CASE_FIXTURE(Fixture, "dont_throw_when_limit_is_high_enough") diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index b7e85aa7..7a05f8e9 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -93,7 +93,10 @@ assert((function() local a = 1 a = a * 2 return a end)() == 2) assert((function() local a = 1 a = a / 2 return a end)() == 0.5) assert((function() local a = 5 a = a % 2 return a end)() == 1) assert((function() local a = 3 a = a ^ 2 return a end)() == 9) +assert((function() local a = 3 a = a ^ 3 return a end)() == 27) assert((function() local a = 9 a = a ^ 0.5 return a end)() == 3) +assert((function() local a = -2 a = a ^ 2 return a end)() == 4) +assert((function() local a = -2 a = a ^ 0.5 return tostring(a) end)() == "nan") assert((function() local a = '1' a = a .. '2' return a end)() == "12") assert((function() local a = '1' a = a .. '2' .. '3' return a end)() == "123") @@ -706,7 +709,11 @@ end assert(chainTest(100) == "v0,v100") -- this validates import fallbacks +assert(idontexist == nil) +assert(math.idontexist == nil) assert(pcall(function() return idontexist.a end) == false) +assert(pcall(function() return math.pow.a end) == false) +assert(pcall(function() return math.a.b end) == false) -- make sure that NaN is preserved by the bytecode compiler local realnan = tostring(math.abs(0)/math.abs(0)) diff --git a/tests/conformance/calls.lua b/tests/conformance/calls.lua index 7f9610a3..621a921a 100644 --- a/tests/conformance/calls.lua +++ b/tests/conformance/calls.lua @@ -226,4 +226,14 @@ assert((function () return nil end)(4) == nil) assert((function () local a; return a end)(4) == nil) assert((function (a) return a end)() == nil) +-- C-stack overflow while handling C-stack overflow +if not limitedstack then + local function loop () + assert(pcall(loop)) + end + + local err, msg = xpcall(loop, loop) + assert(not err and string.find(msg, "error")) +end + return('OK') diff --git a/tests/conformance/datetime.lua b/tests/conformance/datetime.lua index ca35cf2f..dc73948b 100644 --- a/tests/conformance/datetime.lua +++ b/tests/conformance/datetime.lua @@ -16,6 +16,7 @@ D = os.date("*t", t) assert(os.date(string.rep("%d", 1000), t) == string.rep(os.date("%d", t), 1000)) assert(os.date(string.rep("%", 200)) == string.rep("%", 100)) +assert(os.date("", -1) == nil) local function checkDateTable (t) local D = os.date("!*t", t) diff --git a/tests/conformance/errors.lua b/tests/conformance/errors.lua index 529e9b0c..57d2b693 100644 --- a/tests/conformance/errors.lua +++ b/tests/conformance/errors.lua @@ -405,5 +405,7 @@ assert(ecall(function() (""):foo() end) == "attempt to call missing method 'foo' assert(ecall(function() (42):foo() end) == "attempt to index number with 'foo'") assert(ecall(function() ({foo=42}):foo() end) == "attempt to call a number value") assert(ecall(function() local ud = newproxy(true) getmetatable(ud).__index = {} ud:foo() end) == "attempt to call missing method 'foo' of userdata") +assert(ecall(function() local ud = newproxy(true) getmetatable(ud).__index = function() end ud:foo() end) == "attempt to call missing method 'foo' of userdata") +assert(ecall(function() local ud = newproxy(true) getmetatable(ud).__index = function() error("nope") end ud:foo() end) == "nope") return('OK') diff --git a/tests/conformance/events.lua b/tests/conformance/events.lua index 447b67bc..94314c3f 100644 --- a/tests/conformance/events.lua +++ b/tests/conformance/events.lua @@ -13,6 +13,11 @@ assert(getmetatable(a) == "xuxu") ud=newproxy(true); getmetatable(ud).__metatable = "xuxu" assert(getmetatable(ud) == "xuxu") +assert(pcall(getmetatable) == false) +assert(pcall(function() return getmetatable() end) == false) +assert(select(2, pcall(getmetatable, {})) == nil) +assert(select(2, pcall(getmetatable, ud)) == "xuxu") + local res,err = pcall(tostring, a) assert(not res and err == "'__tostring' must return a string") -- cannot change a protected metatable @@ -475,6 +480,9 @@ function testfenv() assert(_G.X == 20) assert(_G == getfenv(0)) + + assert(pcall(getfenv, 10) == false) + assert(pcall(setfenv, setfenv, {}) == false) end testfenv() -- DONT MOVE THIS LINE diff --git a/tests/conformance/iter.lua b/tests/conformance/iter.lua index 468ffafb..5f8f1a89 100644 --- a/tests/conformance/iter.lua +++ b/tests/conformance/iter.lua @@ -193,4 +193,24 @@ do assert(x == 15) end +-- pairs/ipairs/next may be substituted through getfenv +-- however, they *must* be substituted with functions - we don't support them falling back to generalized iteration +function testgetfenv() + local env = getfenv(1) + env.pairs = function() return "nope" end + env.ipairs = function() return "nope" end + env.next = {1, 2, 3} + + local ok, err = pcall(function() for k, v in pairs({}) do end end) + assert(not ok and err:match("attempt to iterate over a string value")) + + local ok, err = pcall(function() for k, v in ipairs({}) do end end) + assert(not ok and err:match("attempt to iterate over a string value")) + + local ok, err = pcall(function() for k, v in next, {} do end end) + assert(not ok and err:match("attempt to iterate over a table value")) +end + +testgetfenv() -- DONT MOVE THIS LINE + return"OK" diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index 0cd0cdce..972c399b 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -283,6 +283,13 @@ assert(math.fmod(-3, 2) == -1) assert(math.fmod(3, -2) == 1) assert(math.fmod(-3, -2) == -1) +-- pow +assert(math.pow(2, 0) == 1) +assert(math.pow(2, 2) == 4) +assert(math.pow(4, 0.5) == 2) +assert(math.pow(-2, 2) == 4) +assert(tostring(math.pow(-2, 0.5)) == "nan") + -- most of the tests above go through fastcall path -- to make sure the basic implementations are also correct we test these functions with string->number coercions assert(math.abs("-4") == 4) diff --git a/tests/conformance/move.lua b/tests/conformance/move.lua index 3f28b4b3..27a96ffc 100644 --- a/tests/conformance/move.lua +++ b/tests/conformance/move.lua @@ -74,4 +74,6 @@ checkerror("wrap around", table.move, {}, 1, maxI, 2) checkerror("wrap around", table.move, {}, 1, 2, maxI) checkerror("wrap around", table.move, {}, minI, -2, 2) +checkerror("readonly", table.move, table.freeze({}), 1, 1, 1) + return"OK" diff --git a/tests/conformance/strings.lua b/tests/conformance/strings.lua index 3d8fdd1f..61bac726 100644 --- a/tests/conformance/strings.lua +++ b/tests/conformance/strings.lua @@ -48,6 +48,7 @@ assert(string.find("", "") == 1) assert(string.find('', 'aaa', 1) == nil) assert(('alo(.)alo'):find('(.)', 1, 1) == 4) assert(string.find('', '1', 2) == nil) +assert(string.find('123', '2', 0) == 2) print('+') assert(string.len("") == 0) @@ -88,6 +89,8 @@ assert(string.lower("\0ABCc%$") == "\0abcc%$") assert(string.rep('teste', 0) == '') assert(string.rep('tés\00tê', 2) == 'tés\0têtés\000tê') assert(string.rep('', 10) == '') +assert(string.rep('', 1e9) == '') +assert(pcall(string.rep, 'x', 2e9) == false) assert(string.reverse"" == "") assert(string.reverse"\0\1\2\3" == "\3\2\1\0") @@ -126,6 +129,13 @@ assert(string.format("-%.20s.20s", string.rep("%", 2000)) == "-"..string.rep("%" assert(string.format('"-%20s.20s"', string.rep("%", 2000)) == string.format("%q", "-"..string.rep("%", 2000)..".20s")) +assert(string.format("%o %u %x %X", -1, -1, -1, -1) == "1777777777777777777777 18446744073709551615 ffffffffffffffff FFFFFFFFFFFFFFFF") + +assert(string.format("%e %E", 1.5, -1.5) == "1.500000e+00 -1.500000E+00") + +assert(pcall(string.format, "%##################d", 1) == false) +assert(pcall(string.format, "%.123d", 1) == false) +assert(pcall(string.format, "%?", 1) == false) -- longest number that can be formated assert(string.len(string.format('%99.99f', -1e308)) >= 100) @@ -179,6 +189,26 @@ assert(table.concat(a, ",", 2) == "b,c") assert(table.concat(a, ",", 3) == "c") assert(table.concat(a, ",", 4) == "") +-- string.split +do + local function eq(a, b) + if #a ~= #b then + return false + end + for i=1,#a do + if a[i] ~= b[i] then + return false + end + end + return true + end + + assert(eq(string.split("abc", ""), {'a', 'b', 'c'})) + assert(eq(string.split("abc", "b"), {'a', 'c'})) + assert(eq(string.split("abc", "d"), {'abc'})) + assert(eq(string.split("abc", "c"), {'ab', ''})) +end + --[[ local locales = { "ptb", "ISO-8859-1", "pt_BR" } local function trylocale (w) diff --git a/tests/conformance/tables.lua b/tests/conformance/tables.lua index 0eff8540..7ae80cc4 100644 --- a/tests/conformance/tables.lua +++ b/tests/conformance/tables.lua @@ -87,35 +87,59 @@ print'+' -- testing tables dynamically built local lim = 130 -local a = {}; a[2] = 1; check(a, 0, 1) -a = {}; a[0] = 1; check(a, 0, 1); a[2] = 1; check(a, 0, 2) -a = {}; a[0] = 1; a[1] = 1; check(a, 1, 1) -a = {} -for i = 1,lim do - a[i] = 1 - assert(#a == i) - check(a, mp2(i), 0) + +do + local a = {}; a[2] = 1; check(a, 0, 1) + a = {}; a[0] = 1; check(a, 0, 1); a[2] = 1; check(a, 0, 2) + a = {}; a[0] = 1; a[1] = 1; check(a, 1, 1) + a = {} + for i = 1,lim do + a[i] = 1 + assert(#a == i) + check(a, mp2(i), 0) + end end -a = {} -for i = 1,lim do - a['a'..i] = 1 - assert(#a == 0) - check(a, 0, mp2(i)) +do + local a = {} + for i = 1,lim do + a['a'..i] = 1 + assert(#a == 0) + check(a, 0, mp2(i)) + end end -a = {} -for i=1,16 do a[i] = i end -check(a, 16, 0) -for i=1,11 do a[i] = nil end -for i=30,40 do a[i] = nil end -- force a rehash (?) -check(a, 0, 8) -a[10] = 1 -for i=30,40 do a[i] = nil end -- force a rehash (?) -check(a, 0, 8) -for i=1,14 do a[i] = nil end -for i=30,50 do a[i] = nil end -- force a rehash (?) -check(a, 0, 4) +do + local a = {} + for i=1,16 do a[i] = i end + check(a, 16, 0) + for i=1,11 do a[i] = nil end + for i=30,40 do a[i] = nil end -- force a rehash (?) + check(a, 0, 8) + a[10] = 1 + for i=30,40 do a[i] = nil end -- force a rehash (?) + check(a, 0, 8) + for i=1,14 do a[i] = nil end + for i=30,50 do a[i] = nil end -- force a rehash (?) + check(a, 0, 4) +end + +do -- rehash moving elements from array to hash + local a = {} + for i = 1, 100 do a[i] = i end + check(a, 128, 0) + + for i = 5, 95 do a[i] = nil end + check(a, 128, 0) + + a.x = 1 -- force a re-hash + check(a, 4, 8) + + for i = 1, 4 do assert(a[i] == i) end + for i = 5, 95 do assert(a[i] == nil) end + for i = 96, 100 do assert(a[i] == i) end + assert(a.x == 1) +end -- reverse filling for i=1,lim do @@ -612,4 +636,54 @@ do assert(hit and child.foo == nil and parent.foo == nil) end +-- testing next x GC of deleted keys +do + local co = coroutine.wrap(function (t) + for k, v in pairs(t) do + local k1 = next(t) -- all previous keys were deleted + assert(k == k1) -- current key is the first in the table + t[k] = nil + local expected = (type(k) == "table" and k[1] or + type(k) == "function" and k() or + string.sub(k, 1, 1)) + assert(expected == v) + coroutine.yield(v) + end + end) + local t = {} + t[{1}] = 1 -- add several unanchored, collectable keys + t[{2}] = 2 + t[string.rep("a", 50)] = "a" -- long string + t[string.rep("b", 50)] = "b" + t[{3}] = 3 + t[string.rep("c", 10)] = "c" -- short string + t[function () return 10 end] = 10 + local count = 7 + while co(t) do + collectgarbage("collect") -- collect dead keys + count = count - 1 + end + assert(count == 0 and next(t) == nil) -- traversed the whole table +end + +-- test error cases for table functions +do + assert(pcall(table.insert, {}) == false) + assert(pcall(table.insert, {}, 1, 2, 3) == false) + assert(pcall(table.insert, table.freeze({1, 2, 3}), 4) == false) + assert(pcall(table.insert, table.freeze({1, 2, 3}), 1, 4) == false) + + assert(pcall(table.remove, table.freeze({1})) == false) + + assert(pcall(table.concat, {true}) == false) + + assert(pcall(table.create) == false) + assert(pcall(table.create, -1) == false) + assert(pcall(table.create, 1e9) == false) + + assert(pcall(table.find, {}, 42, 0) == false) + + assert(pcall(table.clear, table.freeze({})) == false) +end + return"OK" diff --git a/tests/conformance/tpack.lua b/tests/conformance/tpack.lua index 835bf564..b240f482 100644 --- a/tests/conformance/tpack.lua +++ b/tests/conformance/tpack.lua @@ -306,6 +306,8 @@ do -- testing initial position assert(i == 4 and p == 17) local i, p = unpack("!4 i4", x, -#x) assert(i == 1 and p == 5) + local i, p = unpack("!4 i4", x, 0) + assert(i == 1 and p == 5) -- limits for i = 1, #x + 1 do diff --git a/tools/faillist.txt b/tools/faillist.txt index 0eb02209..4ac2b357 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -1,16 +1,18 @@ -AnnotationTests.builtin_types_are_not_exported AnnotationTests.corecursive_types_error_on_tight_loop AnnotationTests.duplicate_type_param_name AnnotationTests.for_loop_counter_annotation_is_checked AnnotationTests.generic_aliases_are_cloned_properly AnnotationTests.instantiation_clone_has_to_follow +AnnotationTests.luau_print_is_not_special_without_the_flag AnnotationTests.occurs_check_on_cyclic_intersection_typevar AnnotationTests.occurs_check_on_cyclic_union_typevar AnnotationTests.too_many_type_params AnnotationTests.two_type_params -AnnotationTests.use_type_required_from_another_file +AnnotationTests.unknown_type_reference_generates_error AstQuery.last_argument_function_call_type +AstQuery::getDocumentationSymbolAtPosition.overloaded_class_method AstQuery::getDocumentationSymbolAtPosition.overloaded_fn +AstQuery::getDocumentationSymbolAtPosition.table_overloaded_function_prop AutocompleteTest.autocomplete_first_function_arg_expected_type AutocompleteTest.autocomplete_interpolated_string AutocompleteTest.autocomplete_oop_implicit_self @@ -18,12 +20,10 @@ AutocompleteTest.autocomplete_string_singleton_equality AutocompleteTest.autocomplete_string_singleton_escape AutocompleteTest.autocomplete_string_singletons AutocompleteTest.autocompleteProp_index_function_metamethod_is_variadic -AutocompleteTest.cyclic_table AutocompleteTest.do_compatible_self_calls AutocompleteTest.do_wrong_compatible_self_calls AutocompleteTest.keyword_methods AutocompleteTest.no_incompatible_self_calls -AutocompleteTest.no_incompatible_self_calls_2 AutocompleteTest.no_wrong_compatible_self_calls_with_generics AutocompleteTest.suggest_table_keys AutocompleteTest.type_correct_argument_type_suggestion @@ -40,8 +40,6 @@ AutocompleteTest.type_correct_keywords AutocompleteTest.type_correct_suggestion_for_overloads AutocompleteTest.type_correct_suggestion_in_argument AutocompleteTest.type_correct_suggestion_in_table -AutocompleteTest.unsealed_table -AutocompleteTest.unsealed_table_2 BuiltinTests.aliased_string_format BuiltinTests.assert_removes_falsy_types BuiltinTests.assert_removes_falsy_types2 @@ -75,7 +73,6 @@ BuiltinTests.select_with_decimal_argument_is_rounded_down BuiltinTests.set_metatable_needs_arguments BuiltinTests.setmetatable_should_not_mutate_persisted_types BuiltinTests.sort_with_bad_predicate -BuiltinTests.sort_with_predicate BuiltinTests.string_format_arg_count_mismatch BuiltinTests.string_format_arg_types_inference BuiltinTests.string_format_as_method @@ -91,8 +88,8 @@ BuiltinTests.table_pack BuiltinTests.table_pack_reduce BuiltinTests.table_pack_variadic BuiltinTests.tonumber_returns_optional_number_type -BuiltinTests.tonumber_returns_optional_number_type2 DefinitionTests.class_definition_overload_metamethods +DefinitionTests.class_definition_string_props DefinitionTests.declaring_generic_functions DefinitionTests.definition_file_classes FrontendTest.environments @@ -100,7 +97,6 @@ FrontendTest.imported_table_modification_2 FrontendTest.it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded FrontendTest.nocheck_cycle_used_by_checked FrontendTest.reexport_cyclic_type -FrontendTest.reexport_type_alias FrontendTest.trace_requires_in_nonstrict_mode GenericsTests.apply_type_function_nested_generics1 GenericsTests.apply_type_function_nested_generics2 @@ -109,26 +105,22 @@ GenericsTests.calling_self_generic_methods GenericsTests.check_generic_typepack_function GenericsTests.check_mutual_generic_functions GenericsTests.correctly_instantiate_polymorphic_member_functions -GenericsTests.do_not_always_instantiate_generic_intersection_types GenericsTests.do_not_infer_generic_functions GenericsTests.duplicate_generic_type_packs GenericsTests.duplicate_generic_types -GenericsTests.factories_of_generics GenericsTests.generic_argument_count_too_few GenericsTests.generic_argument_count_too_many GenericsTests.generic_factories -GenericsTests.generic_functions_in_types GenericsTests.generic_functions_should_be_memory_safe GenericsTests.generic_table_method GenericsTests.generic_type_pack_parentheses GenericsTests.generic_type_pack_unification1 GenericsTests.generic_type_pack_unification2 -GenericsTests.generic_type_pack_unification3 GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments GenericsTests.infer_generic_function_function_argument GenericsTests.infer_generic_function_function_argument_overloaded GenericsTests.infer_generic_methods -GenericsTests.inferred_local_vars_can_be_polytypes +GenericsTests.infer_generic_property GenericsTests.instantiate_cyclic_generic_function GenericsTests.instantiated_function_argument_names GenericsTests.instantiation_sharing_types @@ -147,7 +139,6 @@ IntersectionTypes.table_write_sealed_indirect ModuleTests.any_persistance_does_not_leak ModuleTests.clone_self_property ModuleTests.deepClone_cyclic_table -ModuleTests.do_not_clone_reexports NonstrictModeTests.for_in_iterator_variables_are_any NonstrictModeTests.function_parameters_are_any NonstrictModeTests.inconsistent_module_return_types_are_ok @@ -162,7 +153,6 @@ NonstrictModeTests.parameters_having_type_any_are_optional NonstrictModeTests.table_dot_insert_and_recursive_calls NonstrictModeTests.table_props_are_any Normalize.cyclic_table_normalizes_sensibly -Normalize.intersection_combine_on_bound_self ParseErrorRecovery.generic_type_list_recovery ParseErrorRecovery.recovery_of_parenthesized_expressions ParserTests.parse_nesting_based_end_detection_failsafe_earlier @@ -173,6 +163,7 @@ ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean ProvisionalTests.generic_type_leak_to_module_interface_variadic ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns +ProvisionalTests.lvalue_equals_another_lvalue_with_no_overlap ProvisionalTests.pcall_returns_at_least_two_value_but_function_returns_nothing ProvisionalTests.setmetatable_constrains_free_type_into_free_table ProvisionalTests.specialization_binds_with_prototypes_too_early @@ -180,7 +171,6 @@ ProvisionalTests.table_insert_with_a_singleton_argument ProvisionalTests.typeguard_inference_incomplete ProvisionalTests.weirditer_should_not_loop_forever ProvisionalTests.while_body_are_also_refined -RefinementTest.and_constraint RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string RefinementTest.assert_a_to_be_truthy_then_assert_a_to_be_number RefinementTest.assert_non_binary_expressions_actually_resolve_constraints @@ -195,28 +185,17 @@ RefinementTest.either_number_or_string RefinementTest.eliminate_subclasses_of_instance RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil RefinementTest.index_on_a_refined_property -RefinementTest.invert_is_truthy_constraint RefinementTest.invert_is_truthy_constraint_ifelse_expression -RefinementTest.is_truthy_constraint RefinementTest.is_truthy_constraint_ifelse_expression -RefinementTest.lvalue_is_not_nil RefinementTest.merge_should_be_fully_agnostic_of_hashmap_ordering -RefinementTest.narrow_boolean_to_true_or_false RefinementTest.narrow_property_of_a_bounded_variable RefinementTest.narrow_this_large_union RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true -RefinementTest.not_a_and_not_b -RefinementTest.not_a_and_not_b2 -RefinementTest.not_and_constraint RefinementTest.not_t_or_some_prop_of_t -RefinementTest.or_predicate_with_truthy_predicates -RefinementTest.parenthesized_expressions_are_followed_through RefinementTest.refine_a_property_not_to_be_nil_through_an_intersection_table RefinementTest.refine_the_correct_types_opposite_of_when_a_is_not_number_or_string RefinementTest.refine_unknowns -RefinementTest.term_is_equal_to_an_lvalue RefinementTest.truthy_constraint_on_properties -RefinementTest.type_assertion_expr_carry_its_constraints RefinementTest.type_comparison_ifelse_expression RefinementTest.type_guard_can_filter_for_intersection_of_tables RefinementTest.type_guard_can_filter_for_overloaded_function @@ -234,17 +213,14 @@ RefinementTest.typeguard_not_to_be_string RefinementTest.what_nonsensical_condition RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table RefinementTest.x_is_not_instance_or_else_not_part +RuntimeLimits.typescript_port_of_Result_type TableTests.a_free_shape_can_turn_into_a_scalar_if_it_is_compatible TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible TableTests.access_index_metamethod_that_returns_variadic TableTests.accidentally_checked_prop_in_opposite_branch -TableTests.assigning_to_an_unsealed_table_with_string_literal_should_infer_new_properties_over_indexer -TableTests.augment_nested_table -TableTests.augment_table TableTests.builtin_table_names TableTests.call_method TableTests.cannot_augment_sealed_table -TableTests.cannot_change_type_of_unsealed_table_prop TableTests.casting_sealed_tables_with_props_into_table_with_indexer TableTests.casting_tables_with_props_into_table_with_indexer3 TableTests.casting_tables_with_props_into_table_with_indexer4 @@ -255,7 +231,6 @@ TableTests.defining_a_self_method_for_a_builtin_sealed_table_must_fail TableTests.defining_a_self_method_for_a_local_sealed_table_must_fail TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index -TableTests.dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back TableTests.dont_leak_free_table_props TableTests.dont_quantify_table_that_belongs_to_outer_scope TableTests.dont_suggest_exact_match_keys @@ -283,9 +258,9 @@ TableTests.infer_indexer_from_value_property_in_literal TableTests.inferred_return_type_of_free_table TableTests.inferring_crazy_table_should_also_be_quick TableTests.instantiate_table_cloning_3 +TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound TableTests.leaking_bad_metatable_errors -TableTests.length_operator_union_errors TableTests.less_exponential_blowup_please TableTests.meta_add TableTests.meta_add_both_ways @@ -294,6 +269,7 @@ TableTests.metatable_mismatch_should_fail TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred TableTests.mixed_tables_with_implicit_numbered_keys TableTests.nil_assign_doesnt_hit_indexer +TableTests.nil_assign_doesnt_hit_no_indexer TableTests.okay_to_add_property_to_unsealed_tables_by_function_call TableTests.only_ascribe_synthetic_names_at_module_scope TableTests.oop_indexer_works @@ -327,7 +303,6 @@ TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors TableTests.tables_get_names_from_their_locals TableTests.tc_member_function TableTests.tc_member_function_2 -TableTests.top_table_type TableTests.type_mismatch_on_massive_table_is_cut_short TableTests.unification_of_unions_in_a_self_referential_type TableTests.unifying_tables_shouldnt_uaf2 @@ -347,15 +322,14 @@ ToString.toStringNamedFunction_id ToString.toStringNamedFunction_include_self_param ToString.toStringNamedFunction_map ToString.toStringNamedFunction_variadics -TranspilerTests.types_should_not_be_considered_cyclic_if_they_are_not_recursive TryUnifyTests.cli_41095_concat_log_in_sealed_table_unification TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType TryUnifyTests.result_of_failed_typepack_unification_is_constrained TryUnifyTests.typepack_unification_should_trim_free_tails TryUnifyTests.variadics_should_use_reversed_properly +TypeAliases.cannot_create_cyclic_type_with_unknown_module TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any TypeAliases.generic_param_remap -TypeAliases.mismatched_generic_pack_type_param TypeAliases.mismatched_generic_type_param TypeAliases.mutually_recursive_types_restriction_not_ok_1 TypeAliases.mutually_recursive_types_restriction_not_ok_2 @@ -369,7 +343,7 @@ TypeAliases.type_alias_fwd_declaration_is_precise TypeAliases.type_alias_local_mutation TypeAliases.type_alias_local_rename TypeAliases.type_alias_of_an_imported_recursive_generic_type -TypeAliases.type_alias_of_an_imported_recursive_type +TypeInfer.check_type_infer_recursion_count TypeInfer.checking_should_not_ice TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error TypeInfer.dont_report_type_errors_within_an_AstExprError @@ -377,22 +351,20 @@ TypeInfer.dont_report_type_errors_within_an_AstStatError TypeInfer.globals TypeInfer.globals2 TypeInfer.infer_assignment_value_types_mutable_lval +TypeInfer.it_is_ok_to_have_inconsistent_number_of_return_values_in_nonstrict TypeInfer.no_stack_overflow_from_isoptional TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.tc_if_else_expressions_expected_type_3 TypeInfer.tc_interpolated_string_basic -TypeInfer.tc_interpolated_string_constant_type TypeInfer.tc_interpolated_string_with_invalid_expression TypeInfer.type_infer_recursion_limit_no_ice -TypeInferAnyError.assign_prop_to_table_by_calling_any_yields_any +TypeInfer.type_infer_recursion_limit_normalizer TypeInferAnyError.for_in_loop_iterator_is_any2 TypeInferAnyError.for_in_loop_iterator_is_error2 TypeInferClasses.call_base_method TypeInferClasses.call_instance_method -TypeInferClasses.can_assign_to_prop_of_base_class_using_string TypeInferClasses.can_read_prop_of_base_class_using_string TypeInferClasses.class_type_mismatch_with_name_conflict -TypeInferClasses.classes_can_have_overloaded_operators TypeInferClasses.classes_without_overloaded_operators_cannot_be_added TypeInferClasses.detailed_class_unification_error TypeInferClasses.higher_order_function_arguments_are_contravariant @@ -402,6 +374,7 @@ TypeInferClasses.warn_when_prop_almost_matches TypeInferClasses.we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types TypeInferFunctions.calling_function_with_incorrect_argument_type_yields_errors_spanning_argument +TypeInferFunctions.cannot_hoist_interior_defns_into_signature TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict @@ -420,6 +393,7 @@ TypeInferFunctions.infer_return_value_type TypeInferFunctions.infer_that_function_does_not_return_a_table TypeInferFunctions.list_all_overloads_if_no_overload_takes_given_argument_count TypeInferFunctions.list_only_alternative_overloads_that_match_argument_count +TypeInferFunctions.luau_subtyping_is_np_hard TypeInferFunctions.no_lossy_function_type TypeInferFunctions.occurs_check_failure_in_function_return_type TypeInferFunctions.record_matching_overload @@ -430,12 +404,12 @@ TypeInferFunctions.too_few_arguments_variadic TypeInferFunctions.too_few_arguments_variadic_generic TypeInferFunctions.too_few_arguments_variadic_generic2 TypeInferFunctions.too_many_arguments +TypeInferFunctions.too_many_arguments_error_location TypeInferFunctions.too_many_return_values TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function TypeInferFunctions.vararg_function_is_quantified TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values -TypeInferLoops.for_in_loop_with_custom_iterator TypeInferLoops.for_in_loop_with_next TypeInferLoops.for_in_with_generic_next TypeInferLoops.for_in_with_just_one_iterator_is_ok @@ -443,15 +417,11 @@ TypeInferLoops.loop_iter_no_indexer_nonstrict TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.unreachable_code_after_infinite_loop TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free -TypeInferModules.bound_free_table_export_is_ok TypeInferModules.custom_require_global TypeInferModules.do_not_modify_imported_types -TypeInferModules.do_not_modify_imported_types_2 -TypeInferModules.do_not_modify_imported_types_3 TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict_instantiated TypeInferModules.require_a_variadic_function -TypeInferModules.require_types TypeInferModules.type_error_of_unknown_qualified_type TypeInferOOP.CheckMethodsOfSealed TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_another_overload_works @@ -459,9 +429,7 @@ TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.methods_are_topologically_sorted -TypeInferOperators.and_adds_boolean -TypeInferOperators.and_adds_boolean_no_superfluous_union -TypeInferOperators.and_binexps_dont_unify +TypeInferOOP.object_constructor_can_refer_to_method_of_self TypeInferOperators.and_or_ternary TypeInferOperators.CallAndOrOfFunctions TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable @@ -471,28 +439,13 @@ TypeInferOperators.cli_38355_recursive_union TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.compound_assign_mismatch_op TypeInferOperators.compound_assign_mismatch_result -TypeInferOperators.concat_op_on_free_lhs_and_string_rhs TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops -TypeInferOperators.dont_strip_nil_from_rhs_or_operator -TypeInferOperators.equality_operations_succeed_if_any_union_branch_succeeds -TypeInferOperators.error_on_invalid_operand_types_to_relational_operators -TypeInferOperators.error_on_invalid_operand_types_to_relational_operators2 -TypeInferOperators.expected_types_through_binary_and -TypeInferOperators.expected_types_through_binary_or +TypeInferOperators.in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators TypeInferOperators.infer_any_in_all_modes_when_lhs_is_unknown -TypeInferOperators.or_joins_types -TypeInferOperators.or_joins_types_with_no_extras -TypeInferOperators.primitive_arith_possible_metatable TypeInferOperators.produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not TypeInferOperators.refine_and_or -TypeInferOperators.strict_binary_op_where_lhs_unknown -TypeInferOperators.strip_nil_from_lhs_or_operator -TypeInferOperators.strip_nil_from_lhs_or_operator2 TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs -TypeInferOperators.typecheck_unary_len_error -TypeInferOperators.typecheck_unary_minus -TypeInferOperators.typecheck_unary_minus_error TypeInferOperators.UnknownGlobalCompoundAssign TypeInferPrimitives.CheckMethodsOfNumber TypeInferPrimitives.singleton_types @@ -511,6 +464,7 @@ TypeInferUnknownNever.math_operators_and_never TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable2 TypeInferUnknownNever.unary_minus_of_never +TypePackTests.detect_cyclic_typepacks2 TypePackTests.higher_order_function TypePackTests.pack_tail_unification_check TypePackTests.parenthesized_varargs_returns_any @@ -535,26 +489,17 @@ TypePackTests.unify_variadic_tails_in_arguments TypePackTests.unify_variadic_tails_in_arguments_free TypePackTests.varargs_inference_through_multiple_scopes TypePackTests.variadic_packs -TypeSingletons.enums_using_singletons -TypeSingletons.enums_using_singletons_mismatch -TypeSingletons.enums_using_singletons_subtyping TypeSingletons.error_detailed_tagged_union_mismatch_bool TypeSingletons.error_detailed_tagged_union_mismatch_string TypeSingletons.function_call_with_singletons TypeSingletons.function_call_with_singletons_mismatch -TypeSingletons.if_then_else_expression_singleton_options TypeSingletons.indexing_on_string_singletons TypeSingletons.indexing_on_union_of_string_singletons -TypeSingletons.no_widening_from_callsites TypeSingletons.overloaded_function_call_with_singletons TypeSingletons.overloaded_function_call_with_singletons_mismatch TypeSingletons.return_type_of_f_is_not_widened -TypeSingletons.string_singleton_subtype -TypeSingletons.string_singletons -TypeSingletons.string_singletons_escape_chars -TypeSingletons.string_singletons_mismatch +TypeSingletons.table_properties_singleton_strings_mismatch TypeSingletons.table_properties_type_error_escapes -TypeSingletons.tagged_unions_using_singletons TypeSingletons.taking_the_length_of_string_singleton TypeSingletons.taking_the_length_of_union_of_string_singleton TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton diff --git a/tools/lvmexecute_split.py b/tools/lvmexecute_split.py index 48d66cb0..f4a78960 100644 --- a/tools/lvmexecute_split.py +++ b/tools/lvmexecute_split.py @@ -32,6 +32,9 @@ source = """// This file is part of the Luau programming language and is license """ function = "" +signature = "" + +includeInsts = ["LOP_NEWCLOSURE", "LOP_NAMECALL", "LOP_FORGPREP", "LOP_GETVARARGS", "LOP_DUPCLOSURE", "LOP_PREPVARARGS", "LOP_COVERAGE", "LOP_BREAK", "LOP_GETGLOBAL", "LOP_SETGLOBAL", "LOP_GETTABLEKS", "LOP_SETTABLEKS"] state = 0 @@ -44,7 +47,6 @@ for line in input: if match: inst = match[1] signature = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, StkId base, TValue* k)" - header += signature + ";\n" function = signature + "\n" function += "{\n" function += " [[maybe_unused]] Closure* cl = clvalue(L->ci->func);\n" @@ -84,7 +86,10 @@ for line in input: function = function[:-len(finalline)] function += " return pc;\n}\n" - source += function + "\n" + if inst in includeInsts: + header += signature + ";\n" + source += function + "\n" + state = 0 # skip LUA_CUSTOM_EXECUTION code blocks diff --git a/tools/stack-usage-reporter.py b/tools/stack-usage-reporter.py new file mode 100644 index 00000000..91e74887 --- /dev/null +++ b/tools/stack-usage-reporter.py @@ -0,0 +1,173 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# The purpose of this script is to analyze disassembly generated by objdump or +# dumpbin to print (or to compare) the stack usage of functions/methods. +# This is a quickly written script, so it is quite possible it may not handle +# all code properly. +# +# The script expects the user to create a text assembly dump to be passed to +# the script. +# +# objdump Example +# objdump --demangle --disassemble objfile.o > objfile.s +# +# dumpbin Example +# dumpbin /disasm objfile.obj > objfile.s +# +# If the script is passed a single file, then all stack size information that +# is found it printed. If two files are passed, then the script compares the +# stack usage of the two files (useful for A/B comparisons). +# Currently more than two input files are not supported. (But adding support shouldn't +# be very difficult.) +# +# Note: The script only handles x64 disassembly. Supporting x86 is likely +# trivial, but ARM support could be difficult. +# Thus far the script has been tested with MSVC on Win64 and clang on OSX. + +import argparse +import re + +blank_re = re.compile('\s*') + +class LineReader: + def __init__(self, lines): + self.lines = list(reversed(lines)) + def get_line(self): + return self.lines.pop(-1) + def peek_line(self): + return self.lines[-1] + def consume_blank_lines(self): + while blank_re.fullmatch(self.peek_line()): + self.get_line() + def is_empty(self): + return len(self.lines) == 0 + +def parse_objdump_assembly(in_file): + results = {} + text_section_re = re.compile('Disassembly of section __TEXT,__text:\s*') + symbol_re = re.compile('[^<]*<(.*)>:\s*') + stack_alloc = re.compile('.*subq\s*\$(\d*), %rsp\s*') + + lr = LineReader(in_file.readlines()) + + def find_stack_alloc_size(): + while True: + if lr.is_empty(): + return None + if blank_re.fullmatch(lr.peek_line()): + return None + + line = lr.get_line() + mo = stack_alloc.fullmatch(line) + if mo: + lr.consume_blank_lines() + return int(mo.group(1)) + + # Find beginning of disassembly + while not text_section_re.fullmatch(lr.get_line()): + pass + + # Scan for symbols + while not lr.is_empty(): + lr.consume_blank_lines() + if lr.is_empty(): + break + line = lr.get_line() + mo = symbol_re.fullmatch(line) + # Found a symbol + if mo: + symbol = mo.group(1) + stack_size = find_stack_alloc_size() + if stack_size != None: + results[symbol] = stack_size + + return results + +def parse_dumpbin_assembly(in_file): + results = {} + + file_type_re = re.compile('File Type: COFF OBJECT\s*') + symbol_re = re.compile('[^(]*\((.*)\):\s*') + summary_re = re.compile('\s*Summary\s*') + stack_alloc = re.compile('.*sub\s*rsp,([A-Z0-9]*)h\s*') + + lr = LineReader(in_file.readlines()) + + def find_stack_alloc_size(): + while True: + if lr.is_empty(): + return None + if blank_re.fullmatch(lr.peek_line()): + return None + + line = lr.get_line() + mo = stack_alloc.fullmatch(line) + if mo: + lr.consume_blank_lines() + return int(mo.group(1), 16) # return value in decimal + + # Find beginning of disassembly + while not file_type_re.fullmatch(lr.get_line()): + pass + + # Scan for symbols + while not lr.is_empty(): + lr.consume_blank_lines() + if lr.is_empty(): + break + line = lr.get_line() + if summary_re.fullmatch(line): + break + mo = symbol_re.fullmatch(line) + # Found a symbol + if mo: + symbol = mo.group(1) + stack_size = find_stack_alloc_size() + if stack_size != None: + results[symbol] = stack_size + return results + +def main(): + parser = argparse.ArgumentParser(description='Tool used for reporting or comparing the stack usage of functions/methods') + parser.add_argument('--format', choices=['dumpbin', 'objdump'], required=True, help='Specifies the program used to generate the input files') + parser.add_argument('--input', action='append', required=True, help='Input assembly file. This option may be specified multiple times.') + parser.add_argument('--md-output', action='store_true', help='Show table output in markdown format') + parser.add_argument('--only-diffs', action='store_true', help='Only show stack info when it differs between the input files') + args = parser.parse_args() + + parsers = {'dumpbin': parse_dumpbin_assembly, 'objdump' : parse_objdump_assembly} + parse_func = parsers[args.format] + + input_results = [] + for input_name in args.input: + with open(input_name) as in_file: + results = parse_func(in_file) + input_results.append(results) + + if len(input_results) == 1: + # Print out the results sorted by size + size_sorted = sorted([(size, symbol) for symbol, size in results.items()], reverse=True) + print(input_name) + for size, symbol in size_sorted: + print(f'{size:10}\t{symbol}') + print() + elif len(input_results) == 2: + common_symbols = set(input_results[0].keys()).intersection(set(input_results[1].keys())) + print(f'Found {len(common_symbols)} common symbols') + stack_sizes = sorted([(input_results[0][sym], input_results[1][sym], sym) for sym in common_symbols], reverse=True) + if args.md_output: + print('Before | After | Symbol') + print('-- | -- | --') + for size0, size1, symbol in stack_sizes: + if args.only_diffs and size0 == size1: + continue + if args.md_output: + print(f'{size0} | {size1} | {symbol}') + else: + print(f'{size0:10}\t{size1:10}\t{symbol}') + else: + print("TODO support more than 2 inputs") + +if __name__ == '__main__': + main() diff --git a/tools/test_dcr.py b/tools/test_dcr.py index 76bf11ac..6d553b64 100644 --- a/tools/test_dcr.py +++ b/tools/test_dcr.py @@ -42,6 +42,8 @@ class Handler(x.ContentHandler): self.fail_count = 0 self.test_count = 0 + self.crashed_tests = [] + def startElement(self, name, attrs): if name == "TestSuite": self.currentTest.append(attrs["name"]) @@ -69,6 +71,10 @@ class Handler(x.ContentHandler): elif name == "OverallResultsTestCases": self.numSkippedTests = safeParseInt(attrs.get("skipped", 0)) + elif name == "Exception": + if attrs.get("crash") == "true": + self.crashed_tests.append(makeDottedName(self.currentTest)) + def endElement(self, name): if name == "TestCase": self.currentTest.pop() @@ -192,15 +198,23 @@ def main(): print(name, file=f) print_stderr("Updated faillist.txt") - if handler.numSkippedTests > 0: - print_stderr( - f"{handler.numSkippedTests} test(s) were skipped! That probably means that a test segfaulted!" - ) - sys.exit(1) + if handler.crashed_tests: + print_stderr() + for test in handler.crashed_tests: + print_stderr( + f"{c.Fore.RED}{test}{c.Fore.RESET} threw an exception and crashed the test process!" + ) - ok = all( - not passed == (dottedName in failList) - for dottedName, passed in handler.results.items() + if handler.numSkippedTests > 0: + print_stderr(f"{handler.numSkippedTests} test(s) were skipped!") + + ok = ( + not handler.crashed_tests + and handler.numSkippedTests == 0 + and all( + not passed == (dottedName in failList) + for dottedName, passed in handler.results.items() + ) ) if ok: