diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 9ee75004..aff3c4d9 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -277,11 +277,20 @@ struct MissingUnionProperty bool operator==(const MissingUnionProperty& rhs) const; }; +struct TypesAreUnrelated +{ + TypeId left; + TypeId right; + + bool operator==(const TypesAreUnrelated& rhs) const; +}; + using TypeErrorData = Variant; + DuplicateGenericParameter, CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty, + TypesAreUnrelated>; struct TypeError { diff --git a/Analysis/include/Luau/IostreamHelpers.h b/Analysis/include/Luau/IostreamHelpers.h index f9e9cd48..ee994296 100644 --- a/Analysis/include/Luau/IostreamHelpers.h +++ b/Analysis/include/Luau/IostreamHelpers.h @@ -36,6 +36,10 @@ std::ostream& operator<<(std::ostream& lhs, const IllegalRequire& error); std::ostream& operator<<(std::ostream& lhs, const ModuleHasCyclicDependency& error); std::ostream& operator<<(std::ostream& lhs, const DuplicateGenericParameter& error); std::ostream& operator<<(std::ostream& lhs, const CannotInferBinaryOperation& error); +std::ostream& operator<<(std::ostream& lhs, const SwappedGenericTypeParameter& error); +std::ostream& operator<<(std::ostream& lhs, const OptionalValueAccess& error); +std::ostream& operator<<(std::ostream& lhs, const MissingUnionProperty& error); +std::ostream& operator<<(std::ostream& lhs, const TypesAreUnrelated& error); std::ostream& operator<<(std::ostream& lhs, const TableState& tv); std::ostream& operator<<(std::ostream& lhs, const TypeVar& tv); diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 50379c1c..a97bf6d6 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -69,8 +69,8 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression -void dump(TypeId ty); -void dump(TypePackId ty); +std::string dump(TypeId ty); +std::string dump(TypePackId ty); std::string generateName(size_t n); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 9f553bc1..451976e4 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -156,13 +156,14 @@ struct TypeChecker // Returns both the type of the lvalue and its binding (if the caller wants to mutate the binding). // Note: the binding may be null. + // TODO: remove second return value with FFlagLuauUpdateFunctionNameBinding std::pair checkLValueBinding(const ScopePtr& scope, const AstExpr& expr); std::pair checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr); std::pair checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr); std::pair checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr); std::pair checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr); - TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName); + TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level); std::pair checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, std::optional originalNameLoc, std::optional expectedType); void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function); @@ -174,7 +175,7 @@ struct TypeChecker ExprResult checkExprPack(const ScopePtr& scope, const AstExprCall& expr); std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, + TypePackId argPack, TypePack* args, const std::vector* argLocations, const ExprResult& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, const std::vector& errors); @@ -277,7 +278,7 @@ public: [[noreturn]] void ice(const std::string& message); ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0); - ScopePtr childScope(const ScopePtr& parent, const Location& location, int subLevel = 0); + ScopePtr childScope(const ScopePtr& parent, const Location& location); // Wrapper for merge(l, r, toUnion) but without the lambda junk. void merge(RefinementMap& l, const RefinementMap& r); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 8c4c2f34..f6829ec3 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -499,6 +499,7 @@ struct SingletonTypes const TypePackId anyTypePack; SingletonTypes(); + ~SingletonTypes(); SingletonTypes(const SingletonTypes&) = delete; void operator=(const SingletonTypes&) = delete; @@ -509,10 +510,12 @@ struct SingletonTypes private: std::unique_ptr arena; + bool debugFreezeArena = false; + TypeId makeStringMetatable(); }; -extern SingletonTypes singletonTypes; +SingletonTypes& getSingletonTypes(); void persist(TypeId ty); void persist(TypePackId tp); @@ -523,9 +526,6 @@ TypeLevel* getMutableLevel(TypeId ty); const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name); bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent); -bool hasGeneric(TypeId ty); -bool hasGeneric(TypePackId tp); - TypeVar* asMutable(TypeId ty); template diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index b47610fc..e8eafe68 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -24,7 +24,7 @@ struct TypeLevel int level = 0; int subLevel = 0; - // Returns true if the typelevel "this" is "bigger" than rhs + // Returns true if the level of "this" belongs to an equal or larger scope than that of rhs bool subsumes(const TypeLevel& rhs) const { if (level < rhs.level) @@ -38,6 +38,15 @@ struct TypeLevel return false; } + // Returns true if the level of "this" belongs to a larger (not equal) scope than that of rhs + bool subsumesStrict(const TypeLevel& rhs) const + { + if (level == rhs.level && subLevel == rhs.subLevel) + return false; + else + return subsumes(rhs); + } + TypeLevel incr() const { TypeLevel result; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 4588cdd8..7681b966 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -91,6 +91,9 @@ private: [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); + + // Available after regular type pack unification errors + std::optional firstPackErrorPos; }; } // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index db2d1d0e..4b583792 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); LUAU_FASTFLAGVARIABLE(LuauAutocompletePreferToCallFunctions, false); +LUAU_FASTFLAGVARIABLE(LuauAutocompleteFirstArg, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -190,7 +191,48 @@ static ParenthesesRecommendation getParenRecommendation(TypeId id, const std::ve return ParenthesesRecommendation::None; } -static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typeArena, AstNode* node, TypeId ty) +static std::optional findExpectedTypeAt(const Module& module, AstNode* node, Position position) +{ + LUAU_ASSERT(FFlag::LuauAutocompleteFirstArg); + + auto expr = node->asExpr(); + if (!expr) + return std::nullopt; + + // Extra care for first function call argument location + // When we don't have anything inside () yet, we also don't have an AST node to base our lookup + if (AstExprCall* exprCall = expr->as()) + { + if (exprCall->args.size == 0 && exprCall->argLocation.contains(position)) + { + auto it = module.astTypes.find(exprCall->func); + + if (!it) + return std::nullopt; + + const FunctionTypeVar* ftv = get(follow(*it)); + + if (!ftv) + return std::nullopt; + + auto [head, tail] = flatten(ftv->argTypes); + unsigned index = exprCall->self ? 1 : 0; + + if (index < head.size()) + return head[index]; + + return std::nullopt; + } + } + + auto it = module.astExpectedTypes.find(expr); + if (!it) + return std::nullopt; + + return *it; +} + +static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typeArena, AstNode* node, Position position, TypeId ty) { ty = follow(ty); @@ -220,15 +262,29 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ } }; - auto expr = node->asExpr(); - if (!expr) - return TypeCorrectKind::None; + TypeId expectedType; - auto it = module.astExpectedTypes.find(expr); - if (!it) - return TypeCorrectKind::None; + if (FFlag::LuauAutocompleteFirstArg) + { + auto typeAtPosition = findExpectedTypeAt(module, node, position); - TypeId expectedType = follow(*it); + if (!typeAtPosition) + return TypeCorrectKind::None; + + expectedType = follow(*typeAtPosition); + } + else + { + auto expr = node->asExpr(); + if (!expr) + return TypeCorrectKind::None; + + auto it = module.astExpectedTypes.find(expr); + if (!it) + return TypeCorrectKind::None; + + expectedType = follow(*it); + } if (FFlag::LuauAutocompletePreferToCallFunctions) { @@ -333,8 +389,8 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId if (result.count(name) == 0 && name != Parser::errorName) { Luau::TypeId type = Luau::follow(prop.type); - TypeCorrectKind typeCorrect = - indexType == PropIndexType::Key ? TypeCorrectKind::Correct : checkTypeCorrectKind(module, typeArena, nodes.back(), type); + TypeCorrectKind typeCorrect = indexType == PropIndexType::Key ? TypeCorrectKind::Correct + : checkTypeCorrectKind(module, typeArena, nodes.back(), {{}, {}}, type); ParenthesesRecommendation parens = indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); @@ -692,17 +748,31 @@ std::optional returnFirstNonnullOptionOfType(const UnionTypeVar* utv) return ret; } -static std::optional functionIsExpectedAt(const Module& module, AstNode* node) +static std::optional functionIsExpectedAt(const Module& module, AstNode* node, Position position) { - auto expr = node->asExpr(); - if (!expr) - return std::nullopt; + TypeId expectedType; - auto it = module.astExpectedTypes.find(expr); - if (!it) - return std::nullopt; + if (FFlag::LuauAutocompleteFirstArg) + { + auto typeAtPosition = findExpectedTypeAt(module, node, position); - TypeId expectedType = follow(*it); + if (!typeAtPosition) + return std::nullopt; + + expectedType = follow(*typeAtPosition); + } + else + { + auto expr = node->asExpr(); + if (!expr) + return std::nullopt; + + auto it = module.astExpectedTypes.find(expr); + if (!it) + return std::nullopt; + + expectedType = follow(*it); + } if (get(expectedType)) return true; @@ -1171,7 +1241,7 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul std::string n = toString(name); if (!result.count(n)) { - TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, node, binding.typeId); + TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, node, position, binding.typeId); result[n] = {AutocompleteEntryKind::Binding, binding.typeId, binding.deprecated, false, typeCorrect, std::nullopt, std::nullopt, binding.documentationSymbol, {}, getParenRecommendation(binding.typeId, ancestry, typeCorrect)}; @@ -1181,9 +1251,10 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul scope = scope->parent; } - TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, typeChecker.nilType); - TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, typeChecker.booleanType); - TypeCorrectKind correctForFunction = functionIsExpectedAt(module, node).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); + TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.booleanType); + TypeCorrectKind correctForFunction = + functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; if (FFlag::LuauIfElseExpressionAnalysisSupport) result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index bac94a2b..d527414a 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -217,9 +217,9 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId genericK = arena.addType(GenericTypeVar{"K"}); TypeId genericV = arena.addType(GenericTypeVar{"V"}); - TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level}); + TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level, TableState::Generic}); - std::optional stringMetatableTy = getMetatable(singletonTypes.stringType); + std::optional stringMetatableTy = getMetatable(getSingletonTypes().stringType); LUAU_ASSERT(stringMetatableTy); const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); LUAU_ASSERT(stringMetatableTable); @@ -271,7 +271,10 @@ void registerBuiltinTypes(TypeChecker& typeChecker) persist(pair.second.typeId); if (TableTypeVar* ttv = getMutable(pair.second.typeId)) - ttv->name = toString(pair.first); + { + if (!ttv->name) + ttv->name = toString(pair.first); + } } attachMagicFunction(getGlobalBinding(typeChecker, "assert"), magicFunctionAssert); diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 9f5c8250..d0afa742 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +LUAU_FASTFLAGVARIABLE(LuauFixTonumberReturnType, false) + namespace Luau { @@ -113,7 +115,6 @@ declare function gcinfo(): number declare function error(message: T, level: number?) declare function tostring(value: T): string - declare function tonumber(value: T, radix: number?): number declare function rawequal(a: T1, b: T2): boolean declare function rawget(tab: {[K]: V}, k: K): V @@ -204,7 +205,14 @@ declare function gcinfo(): number std::string getBuiltinDefinitionSource() { - return kBuiltinDefinitionLuaSrc; + std::string result = kBuiltinDefinitionLuaSrc; + + if (FFlag::LuauFixTonumberReturnType) + result += "declare function tonumber(value: T, radix: number?): number?\n"; + else + result += "declare function tonumber(value: T, radix: number?): number\n"; + + return result; } } // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 8334bd62..ce832c6b 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -58,7 +58,7 @@ struct ErrorConverter result += "\ncaused by:\n "; if (!tm.reason.empty()) - result += tm.reason + ". "; + result += tm.reason + " "; result += Luau::toString(*tm.error); } @@ -410,6 +410,11 @@ struct ErrorConverter return ss + " in the type '" + toString(e.type) + "'"; } + + std::string operator()(const TypesAreUnrelated& e) const + { + return "Cannot cast '" + toString(e.left) + "' into '" + toString(e.right) + "' because the types are unrelated"; + } }; struct InvalidNameChecker @@ -658,6 +663,11 @@ bool MissingUnionProperty::operator==(const MissingUnionProperty& rhs) const return *type == *rhs.type && key == rhs.key; } +bool TypesAreUnrelated::operator==(const TypesAreUnrelated& rhs) const +{ + return left == rhs.left && right == rhs.right; +} + std::string toString(const TypeError& error) { ErrorConverter converter; @@ -793,6 +803,11 @@ void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& for (auto& ty : e.missing) ty = clone(ty); } + else if constexpr (std::is_same_v) + { + e.left = clone(e.left); + e.right = clone(e.right); + } else static_assert(always_false_v, "Non-exhaustive type switch"); } diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index ac46b5a4..5bc76ade 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -262,6 +262,12 @@ std::ostream& operator<<(std::ostream& stream, const MissingUnionProperty& error return stream << " }, key = '" + error.key + "' }"; } +std::ostream& operator<<(std::ostream& stream, const TypesAreUnrelated& error) +{ + stream << "TypesAreUnrelated { left = '" + toString(error.left) + "', right = '" + toString(error.right) + "' }"; + return stream; +} + std::ostream& operator<<(std::ostream& stream, const TableState& tv) { return stream << static_cast::type>(tv); diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index c7f623ee..23491a5a 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -262,7 +262,7 @@ struct AstJsonEncoder : public AstVisitor if (comma) writeRaw(","); else - comma = false; + comma = true; write(a); } @@ -379,7 +379,7 @@ struct AstJsonEncoder : public AstVisitor if (comma) writeRaw(","); else - comma = false; + comma = true; write(prop); } }); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index b4b6eb42..e1e53c97 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -13,7 +13,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) -LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 0) namespace Luau @@ -23,7 +22,7 @@ static bool contains(Position pos, Comment comment) { if (comment.location.contains(pos)) return true; - else if (FFlag::LuauCaptureBrokenCommentSpans && comment.type == Lexeme::BrokenComment && + else if (comment.type == Lexeme::BrokenComment && comment.location.begin <= pos) // Broken comments are broken specifically because they don't have an end return true; else if (comment.type == Lexeme::Comment && comment.location.end == pos) @@ -194,7 +193,7 @@ struct TypePackCloner { cloneState.encounteredFreeType = true; - TypePackId err = singletonTypes.errorRecoveryTypePack(singletonTypes.anyTypePack); + TypePackId err = getSingletonTypes().errorRecoveryTypePack(getSingletonTypes().anyTypePack); TypePackId cloned = dest.addTypePack(*err); seenTypePacks[typePackId] = cloned; } @@ -247,7 +246,7 @@ void TypeCloner::defaultClone(const T& t) void TypeCloner::operator()(const Unifiable::Free& t) { cloneState.encounteredFreeType = true; - TypeId err = singletonTypes.errorRecoveryType(singletonTypes.anyType); + TypeId err = getSingletonTypes().errorRecoveryType(getSingletonTypes().anyType); TypeId cloned = dest.addType(*err); seenTypes[typeId] = cloned; } @@ -421,9 +420,6 @@ TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypeP Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. } - if (FFlag::DebugLuauTrackOwningArena) - asMutable(res)->owningArena = &dest; - return res; } @@ -440,12 +436,11 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks { TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. + + // TODO: Make this work when the arena of 'res' might be frozen asMutable(res)->documentationSymbol = typeId->documentationSymbol; } - if (FFlag::DebugLuauTrackOwningArena) - asMutable(res)->owningArena = &dest; - return res; } @@ -508,8 +503,8 @@ bool Module::clonePublicInterface() if (moduleScope->varargPack) moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, cloneState); - for (auto& pair : moduleScope->exportedTypeBindings) - pair.second = clone(pair.second, interfaceTypes, seenTypes, seenTypePacks, cloneState); + for (auto& [name, tf] : moduleScope->exportedTypeBindings) + tf = clone(tf, interfaceTypes, seenTypes, seenTypePacks, cloneState); for (TypeId ty : moduleScope->returnType) if (get(follow(ty))) diff --git a/Analysis/src/Predicate.cpp b/Analysis/src/Predicate.cpp index 848627cf..7bd8001e 100644 --- a/Analysis/src/Predicate.cpp +++ b/Analysis/src/Predicate.cpp @@ -24,7 +24,7 @@ std::optional tryGetLValue(const AstExpr& node) else if (auto indexexpr = expr->as()) { if (auto lvalue = tryGetLValue(*indexexpr->expr)) - if (auto string = indexexpr->expr->as()) + if (auto string = indexexpr->index->as()) return Field{std::make_shared(*lvalue), std::string(string->value.data, string->value.size)}; } diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 6322096c..a6be5348 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -13,6 +13,13 @@ LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) LUAU_FASTFLAGVARIABLE(LuauFunctionArgumentNameSize, false) +/* + * Prefix generic typenames with gen- + * Additionally, free types will be prefixed with free- and suffixed with their level. eg free-a-4 + * Fair warning: Setting this will break a lot of Luau unit tests. + */ +LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false) + namespace Luau { @@ -290,7 +297,15 @@ struct TypeVarStringifier void operator()(TypeId ty, const Unifiable::Free& ftv) { state.result.invalid = true; + if (FFlag::DebugLuauVerboseTypeNames) + state.emit("free-"); state.emit(state.getName(ty)); + + if (FFlag::DebugLuauVerboseTypeNames) + { + state.emit("-"); + state.emit(std::to_string(ftv.level.level)); + } } void operator()(TypeId, const BoundTypeVar& btv) @@ -802,6 +817,8 @@ struct TypePackStringifier void operator()(TypePackId tp, const GenericTypePack& pack) { + if (FFlag::DebugLuauVerboseTypeNames) + state.emit("gen-"); if (pack.explicitName) { state.result.nameMap.typePacks[tp] = pack.name; @@ -817,7 +834,16 @@ struct TypePackStringifier void operator()(TypePackId tp, const FreeTypePack& pack) { state.result.invalid = true; + if (FFlag::DebugLuauVerboseTypeNames) + state.emit("free-"); state.emit(state.getName(tp)); + + if (FFlag::DebugLuauVerboseTypeNames) + { + state.emit("-"); + state.emit(std::to_string(pack.level.level)); + } + state.emit("..."); } @@ -1181,20 +1207,24 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV return s; } -void dump(TypeId ty) +std::string dump(TypeId ty) { ToStringOptions opts; opts.exhaustive = true; opts.functionTypeArguments = true; - printf("%s\n", toString(ty, opts).c_str()); + std::string s = toString(ty, opts); + printf("%s\n", s.c_str()); + return s; } -void dump(TypePackId ty) +std::string dump(TypePackId ty) { ToStringOptions opts; opts.exhaustive = true; opts.functionTypeArguments = true; - printf("%s\n", toString(ty, opts).c_str()); + std::string s = toString(ty, opts); + printf("%s\n", s.c_str()); + return s; } std::string generateName(size_t i) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 617bf482..abbc2901 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -9,9 +9,9 @@ #include "Luau/Scope.h" #include "Luau/Substitution.h" #include "Luau/TopoSortStatements.h" -#include "Luau/ToString.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" +#include "Luau/ToString.h" #include "Luau/TypeVar.h" #include "Luau/TimeTrace.h" @@ -29,7 +29,6 @@ LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) -LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) @@ -37,6 +36,12 @@ LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauTailArgumentTypeInfo, false) LUAU_FASTFLAGVARIABLE(LuauModuleRequireErrorPack, false) +LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) +LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) +LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) +LUAU_FASTFLAGVARIABLE(LuauFixRecursiveMetatableCall, false) +LUAU_FASTFLAGVARIABLE(LuauBidirectionalAsExpr, false) +LUAU_FASTFLAGVARIABLE(LuauUpdateFunctionNameBinding, false) namespace Luau { @@ -206,14 +211,14 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan : resolver(resolver) , iceHandler(iceHandler) , unifierState(iceHandler) - , nilType(singletonTypes.nilType) - , numberType(singletonTypes.numberType) - , stringType(singletonTypes.stringType) - , booleanType(singletonTypes.booleanType) - , threadType(singletonTypes.threadType) - , anyType(singletonTypes.anyType) - , optionalNumberType(singletonTypes.optionalNumberType) - , anyTypePack(singletonTypes.anyTypePack) + , nilType(getSingletonTypes().nilType) + , numberType(getSingletonTypes().numberType) + , stringType(getSingletonTypes().stringType) + , booleanType(getSingletonTypes().booleanType) + , threadType(getSingletonTypes().threadType) + , anyType(getSingletonTypes().anyType) + , optionalNumberType(getSingletonTypes().optionalNumberType) + , anyTypePack(getSingletonTypes().anyTypePack) { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -443,7 +448,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) functionDecls[*protoIter] = pair; ++subLevel; - TypeId leftType = checkFunctionName(scope, *fun->name); + TypeId leftType = checkFunctionName(scope, *fun->name, funScope->level); unify(leftType, funTy, fun->location); } else if (auto fun = (*protoIter)->as()) @@ -711,14 +716,15 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) } else if (auto tail = valueIter.tail()) { - if (get(*tail)) + TypePackId tailPack = follow(*tail); + if (get(tailPack)) right = errorRecoveryType(scope); - else if (auto vtp = get(*tail)) + else if (auto vtp = get(tailPack)) right = vtp->ty; - else if (get(*tail)) + else if (get(tailPack)) { - *asMutable(*tail) = TypePack{{left}}; - growingPack = getMutable(*tail); + *asMutable(tailPack) = TypePack{{left}}; + growingPack = getMutable(tailPack); } } @@ -1107,8 +1113,27 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco unify(leftType, ty, function.location); - if (leftTypeBinding) - *leftTypeBinding = follow(quantify(funScope, leftType, function.name->location)); + if (FFlag::LuauUpdateFunctionNameBinding) + { + LUAU_ASSERT(function.name->is() || function.name->is()); + + if (auto exprIndexName = function.name->as()) + { + if (auto typeIt = currentModule->astTypes.find(exprIndexName->expr)) + { + if (auto ttv = getMutableTableType(*typeIt)) + { + if (auto it = ttv->props.find(exprIndexName->index.value); it != ttv->props.end()) + it->second.type = follow(quantify(funScope, leftType, function.name->location)); + } + } + } + } + else + { + if (leftTypeBinding) + *leftTypeBinding = follow(quantify(funScope, leftType, function.name->location)); + } } } @@ -1148,8 +1173,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } else { - ScopePtr aliasScope = - FFlag::LuauQuantifyInPlace2 ? childScope(scope, typealias.location, subLevel) : childScope(scope, typealias.location); + ScopePtr aliasScope = childScope(scope, typealias.location); + aliasScope->level = scope->level.incr(); + if (FFlag::LuauProperTypeLevels) + aliasScope->level.subLevel = subLevel; auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); @@ -1166,6 +1193,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias ice("Not predeclared"); ScopePtr aliasScope = childScope(scope, typealias.location); + aliasScope->level = scope->level.incr(); for (TypeId ty : binding->typeParams) { @@ -1505,9 +1533,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa else if (auto vtp = get(retPack)) return {vtp->ty, std::move(result.predicates)}; else if (get(retPack)) - ice("Unexpected abstract type pack!"); + ice("Unexpected abstract type pack!", expr.location); else - ice("Unknown TypePack type!"); + ice("Unknown TypePack type!", expr.location); } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexName& expr) @@ -1574,7 +1602,7 @@ std::optional TypeChecker::getIndexTypeFromType( } else if (tableType->state == TableState::Free) { - TypeId result = freshType(scope); + TypeId result = FFlag::LuauAscribeCorrectLevelToInferredProperitesOfFreeTables ? freshType(tableType->level) : freshType(scope); tableType->props[name] = {result}; return result; } @@ -1738,7 +1766,16 @@ TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr) { - return {checkLValue(scope, expr)}; + TypeId ty = checkLValue(scope, expr); + + if (FFlag::LuauRefiLookupFromIndexExpr) + { + if (std::optional lvalue = tryGetLValue(expr)) + if (std::optional refiTy = resolveLValue(scope, *lvalue)) + return {*refiTy, {TruthyPredicate{std::move(*lvalue), expr.location}}}; + } + + return {ty}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) @@ -2421,12 +2458,27 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy TypeId annotationType = resolveType(scope, *expr.annotation); ExprResult result = checkExpr(scope, *expr.expr, annotationType); - ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); - reportErrors(errorVec); - if (!errorVec.empty()) - annotationType = errorRecoveryType(annotationType); + if (FFlag::LuauBidirectionalAsExpr) + { + // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. + if (canUnify(result.type, annotationType, expr.location).empty()) + return {annotationType, std::move(result.predicates)}; - return {annotationType, std::move(result.predicates)}; + if (canUnify(annotationType, result.type, expr.location).empty()) + return {annotationType, std::move(result.predicates)}; + + reportError(expr.location, TypesAreUnrelated{result.type, annotationType}); + return {errorRecoveryType(annotationType), std::move(result.predicates)}; + } + else + { + ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); + reportErrors(errorVec); + if (!errorVec.empty()) + annotationType = errorRecoveryType(annotationType); + + return {annotationType, std::move(result.predicates)}; + } } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprError& expr) @@ -2674,8 +2726,15 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope // Answers the question: "Can I define another function with this name?" // Primarily about detecting duplicates. -TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) +TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level) { + auto freshTy = [&]() { + if (FFlag::LuauProperTypeLevels) + return freshType(level); + else + return freshType(scope); + }; + if (auto globalName = funName.as()) { const ScopePtr& globalScope = currentModule->getModuleScope(); @@ -2689,7 +2748,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) } else { - TypeId ty = freshType(scope); + TypeId ty = freshTy(); globalScope->bindings[name] = {ty, funName.location}; return ty; } @@ -2699,7 +2758,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) Symbol name = localName->local; Binding& binding = scope->bindings[name]; if (binding.typeId == nullptr) - binding = {freshType(scope), funName.location}; + binding = {freshTy(), funName.location}; return binding.typeId; } @@ -2730,7 +2789,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) Property& property = ttv->props[name]; - property.type = freshType(scope); + property.type = freshTy(); property.location = indexName->indexLocation; ttv->methodDefinitionLocations[name] = funName.location; return property.type; @@ -3327,7 +3386,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A fn = follow(fn); if (auto ret = checkCallOverload( - scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) + scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) return *ret; } @@ -3402,9 +3461,11 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st } std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, + TypePackId argPack, TypePack* args, const std::vector* argLocations, const ExprResult& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) { + LUAU_ASSERT(argLocations); + fn = stripFromNilAndReport(fn, expr.func->location); if (get(fn)) @@ -3428,31 +3489,44 @@ std::optional> TypeChecker::checkCallOverload(const Scope return {{retPack}}; } - const FunctionTypeVar* ftv = get(fn); - if (!ftv) + std::vector metaArgLocations; + + // Might be a callable table + if (const MetatableTypeVar* mttv = get(fn)) { - // Might be a callable table - if (const MetatableTypeVar* mttv = get(fn)) + if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, false)) { - if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, false)) + // Construct arguments with 'self' added in front + TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); + + TypePack* metaCallArgs = getMutable(metaCallArgPack); + metaCallArgs->head.insert(metaCallArgs->head.begin(), fn); + + metaArgLocations = *argLocations; + metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); + + if (FFlag::LuauFixRecursiveMetatableCall) { - // Construct arguments with 'self' added in front - TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); - - TypePack* metaCallArgs = getMutable(metaCallArgPack); - metaCallArgs->head.insert(metaCallArgs->head.begin(), fn); - - std::vector metaArgLocations = argLocations; - metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); + fn = instantiate(scope, *ty, expr.func->location); + argPack = metaCallArgPack; + args = metaCallArgs; + argLocations = &metaArgLocations; + } + else + { TypeId fn = *ty; fn = instantiate(scope, fn, expr.func->location); - return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, + return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, &metaArgLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors); } } + } + const FunctionTypeVar* ftv = get(fn); + if (!ftv) + { reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); unify(retPack, errorRecoveryTypePack(scope), expr.func->location); return {{errorRecoveryTypePack(retPack)}}; @@ -3477,7 +3551,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope return {}; } - checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); + checkArgumentList(scope, state, argPack, ftv->argTypes, *argLocations); if (!state.errors.empty()) { @@ -3772,7 +3846,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module if (moduleInfo.name.empty()) { - if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) + if (currentModule->mode == Mode::Strict) { reportError(TypeError{location, UnknownRequire{}}); return errorRecoveryType(anyType); @@ -4268,9 +4342,11 @@ ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& } // Creates a new Scope and carries forward the varargs from the parent. -ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& location, int subLevel) +ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& location) { - ScopePtr scope = std::make_shared(parent, subLevel); + ScopePtr scope = std::make_shared(parent); + if (FFlag::LuauProperTypeLevels) + scope->level = parent->level; scope->varargPack = parent->varargPack; currentModule->scopes.push_back(std::make_pair(location, scope)); @@ -4329,22 +4405,22 @@ TypeId TypeChecker::singletonType(std::string value) TypeId TypeChecker::errorRecoveryType(const ScopePtr& scope) { - return singletonTypes.errorRecoveryType(); + return getSingletonTypes().errorRecoveryType(); } TypeId TypeChecker::errorRecoveryType(TypeId guess) { - return singletonTypes.errorRecoveryType(guess); + return getSingletonTypes().errorRecoveryType(guess); } TypePackId TypeChecker::errorRecoveryTypePack(const ScopePtr& scope) { - return singletonTypes.errorRecoveryTypePack(); + return getSingletonTypes().errorRecoveryTypePack(); } TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) { - return singletonTypes.errorRecoveryTypePack(guess); + return getSingletonTypes().errorRecoveryTypePack(guess); } std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) @@ -4547,6 +4623,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else if (const auto& func = annotation.as()) { ScopePtr funcScope = childScope(scope, func->location); + funcScope->level = scope->level.incr(); auto [generics, genericPacks] = createGenericTypes(funcScope, std::nullopt, annotation, func->generics, func->genericPacks); diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 0d9d91e0..8c6d5e49 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -19,7 +19,7 @@ std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globa TypeId unwrapped = follow(*metatable); if (get(unwrapped)) - return singletonTypes.anyType; + return getSingletonTypes().anyType; const TableTypeVar* mtt = getTableType(unwrapped); if (!mtt) @@ -61,12 +61,12 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const Sc { std::optional r = first(follow(itf->retType)); if (!r) - return singletonTypes.nilType; + return getSingletonTypes().nilType; else return *r; } else if (get(index)) - return singletonTypes.anyType; + return getSingletonTypes().anyType; else errors.push_back(TypeError{location, GenericError{"__index should either be a function or table. Got " + toString(index)}}); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 62715af5..571b13ca 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -21,6 +21,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) LUAU_FASTFLAG(LuauErrorRecoveryType) +LUAU_FASTFLAG(DebugLuauFreezeArena) namespace Luau { @@ -579,11 +580,25 @@ SingletonTypes::SingletonTypes() , arena(new TypeArena) { TypeId stringMetatable = makeStringMetatable(); - stringType_.ty = PrimitiveTypeVar{PrimitiveTypeVar::String, makeStringMetatable()}; + stringType_.ty = PrimitiveTypeVar{PrimitiveTypeVar::String, stringMetatable}; persist(stringMetatable); + + debugFreezeArena = FFlag::DebugLuauFreezeArena; freeze(*arena); } +SingletonTypes::~SingletonTypes() +{ + // Destroy the arena with the same memory management flags it was created with + bool prevFlag = FFlag::DebugLuauFreezeArena; + FFlag::DebugLuauFreezeArena.value = debugFreezeArena; + + unfreeze(*arena); + arena.reset(nullptr); + + FFlag::DebugLuauFreezeArena.value = prevFlag; +} + TypeId SingletonTypes::makeStringMetatable() { const TypeId optionalNumber = arena->addType(UnionTypeVar{{nilType, numberType}}); @@ -641,6 +656,9 @@ TypeId SingletonTypes::makeStringMetatable() TypeId tableType = arena->addType(TableTypeVar{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); + if (TableTypeVar* ttv = getMutable(tableType)) + ttv->name = "string"; + return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } @@ -670,7 +688,11 @@ TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess) return &errorTypePack_; } -SingletonTypes singletonTypes; +SingletonTypes& getSingletonTypes() +{ + static SingletonTypes singletonTypes; + return singletonTypes; +} void persist(TypeId ty) { @@ -719,6 +741,18 @@ void persist(TypeId ty) for (TypeId opt : itv->parts) queue.push_back(opt); } + else if (auto mtv = get(t)) + { + queue.push_back(mtv->table); + queue.push_back(mtv->metatable); + } + else if (get(t) || get(t) || get(t) || get(t) || get(t)) + { + } + else + { + LUAU_ASSERT(!"TypeId is not supported in a persist call"); + } } } @@ -736,6 +770,17 @@ void persist(TypePackId tp) if (p->tail) persist(*p->tail); } + else if (auto vtp = get(tp)) + { + persist(vtp->ty); + } + else if (get(tp)) + { + } + else + { + LUAU_ASSERT(!"TypePackId is not supported in a persist call"); + } } const TypeLevel* getLevel(TypeId ty) @@ -757,167 +802,6 @@ TypeLevel* getMutableLevel(TypeId ty) return const_cast(getLevel(ty)); } -struct QVarFinder -{ - mutable DenseHashSet seen; - - QVarFinder() - : seen(nullptr) - { - } - - bool hasSeen(const void* tv) const - { - if (seen.contains(tv)) - return true; - - seen.insert(tv); - return false; - } - - bool hasGeneric(TypeId tid) const - { - if (hasSeen(&tid->ty)) - return false; - - return Luau::visit(*this, tid->ty); - } - - bool hasGeneric(TypePackId tp) const - { - if (hasSeen(&tp->ty)) - return false; - - return Luau::visit(*this, tp->ty); - } - - bool operator()(const Unifiable::Free&) const - { - return false; - } - - bool operator()(const Unifiable::Bound& bound) const - { - return hasGeneric(bound.boundTo); - } - - bool operator()(const Unifiable::Generic&) const - { - return true; - } - bool operator()(const Unifiable::Error&) const - { - return false; - } - bool operator()(const PrimitiveTypeVar&) const - { - return false; - } - - bool operator()(const SingletonTypeVar&) const - { - return false; - } - - bool operator()(const FunctionTypeVar& ftv) const - { - if (hasGeneric(ftv.argTypes)) - return true; - return hasGeneric(ftv.retType); - } - - bool operator()(const TableTypeVar& ttv) const - { - if (ttv.state == TableState::Generic) - return true; - - if (ttv.indexer) - { - if (hasGeneric(ttv.indexer->indexType)) - return true; - if (hasGeneric(ttv.indexer->indexResultType)) - return true; - } - - for (const auto& [_name, prop] : ttv.props) - { - if (hasGeneric(prop.type)) - return true; - } - - return false; - } - - bool operator()(const MetatableTypeVar& mtv) const - { - return hasGeneric(mtv.table) || hasGeneric(mtv.metatable); - } - - bool operator()(const ClassTypeVar& ctv) const - { - for (const auto& [name, prop] : ctv.props) - { - if (hasGeneric(prop.type)) - return true; - } - - if (ctv.parent) - return hasGeneric(*ctv.parent); - - return false; - } - - bool operator()(const AnyTypeVar&) const - { - return false; - } - - bool operator()(const UnionTypeVar& utv) const - { - for (TypeId tid : utv.options) - if (hasGeneric(tid)) - return true; - - return false; - } - - bool operator()(const IntersectionTypeVar& utv) const - { - for (TypeId tid : utv.parts) - if (hasGeneric(tid)) - return true; - - return false; - } - - bool operator()(const LazyTypeVar&) const - { - return false; - } - - bool operator()(const Unifiable::Bound& bound) const - { - return hasGeneric(bound.boundTo); - } - - bool operator()(const TypePack& pack) const - { - for (TypeId ty : pack.head) - if (hasGeneric(ty)) - return true; - - if (pack.tail) - return hasGeneric(*pack.tail); - - return false; - } - - bool operator()(const VariadicTypePack& pack) const - { - return hasGeneric(pack.ty); - } -}; - const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name) { while (cls) @@ -953,16 +837,6 @@ bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent) return false; } -bool hasGeneric(TypeId ty) -{ - return Luau::visit(QVarFinder{}, ty->ty); -} - -bool hasGeneric(TypePackId tp) -{ - return Luau::visit(QVarFinder{}, tp->ty); -} - UnionTypeVarIterator::UnionTypeVarIterator(const UnionTypeVar* utv) { LUAU_ASSERT(utv); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index d0b18837..c5aab856 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -14,7 +14,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); -LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance, false); +LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) @@ -22,9 +22,82 @@ LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) LUAU_FASTFLAG(LuauErrorRecoveryType); +LUAU_FASTFLAG(LuauProperTypeLevels); +LUAU_FASTFLAGVARIABLE(LuauExtendedUnionMismatchError, false) +LUAU_FASTFLAGVARIABLE(LuauExtendedFunctionMismatchError, false) namespace Luau { + +struct PromoteTypeLevels +{ + TxnLog& log; + TypeLevel minLevel; + + explicit PromoteTypeLevels(TxnLog& log, TypeLevel minLevel) + : log(log) + , minLevel(minLevel) + {} + + template + void promote(TID ty, T* t) + { + LUAU_ASSERT(t); + if (minLevel.subsumesStrict(t->level)) + { + log(ty); + t->level = minLevel; + } + } + + template + void cycle(TID) {} + + template + bool operator()(TID, const T&) + { + return true; + } + + bool operator()(TypeId ty, const FreeTypeVar&) + { + promote(ty, getMutable(ty)); + return true; + } + + bool operator()(TypeId ty, const FunctionTypeVar&) + { + promote(ty, getMutable(ty)); + return true; + } + + bool operator()(TypeId ty, const TableTypeVar&) + { + promote(ty, getMutable(ty)); + return true; + } + + bool operator()(TypePackId tp, const FreeTypePack&) + { + promote(tp, getMutable(tp)); + return true; + } +}; + +void promoteTypeLevels(TxnLog& log, TypeLevel minLevel, TypeId ty) +{ + PromoteTypeLevels ptl{log, minLevel}; + DenseHashSet seen{nullptr}; + visitTypeVarOnce(ty, ptl, seen); +} + +void promoteTypeLevels(TxnLog& log, TypeLevel minLevel, TypePackId tp) +{ + PromoteTypeLevels ptl{log, minLevel}; + DenseHashSet seen{nullptr}; + visitTypeVarOnce(tp, ptl, seen); +} + struct SkipCacheForType { SkipCacheForType(const DenseHashMap& skipCacheForType) @@ -127,6 +200,29 @@ static std::optional hasUnificationTooComplex(const ErrorVec& errors) return *it; } +// Used for tagged union matching heuristic, returns first singleton type field +static std::optional> getTableMatchTag(TypeId type) +{ + LUAU_ASSERT(FFlag::LuauExtendedUnionMismatchError); + + type = follow(type); + + if (auto ttv = get(type)) + { + for (auto&& [name, prop] : ttv->props) + { + if (auto sing = get(follow(prop.type))) + return {{name, sing}}; + } + } + else if (auto mttv = get(type)) + { + return getTableMatchTag(mttv->table); + } + + return std::nullopt; +} + Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState) : types(types) , mode(mode) @@ -214,9 +310,11 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { occursCheck(superTy, subTy); + TypeLevel superLevel = l->level; + // Unification can't change the level of a generic. auto rightGeneric = get(subTy); - if (rightGeneric && !rightGeneric->level.subsumes(l->level)) + if (rightGeneric && !rightGeneric->level.subsumes(superLevel)) { // TODO: a more informative error message? CLI-39912 errors.push_back(TypeError{location, GenericError{"Generic subtype escaping scope"}}); @@ -226,7 +324,9 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool // The occurrence check might have caused superTy no longer to be a free type if (!get(superTy)) { - if (auto rightLevel = getMutableLevel(subTy)) + if (FFlag::LuauProperTypeLevels) + promoteTypeLevels(log, superLevel, subTy); + else if (auto rightLevel = getMutableLevel(subTy)) { if (!rightLevel->subsumes(l->level)) *rightLevel = l->level; @@ -240,6 +340,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool } else if (r) { + TypeLevel subLevel = r->level; + occursCheck(subTy, superTy); // Unification can't change the level of a generic. @@ -253,10 +355,16 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (!get(subTy)) { - if (auto leftLevel = getMutableLevel(superTy)) + if (FFlag::LuauProperTypeLevels) + promoteTypeLevels(log, subLevel, superTy); + + if (auto superLevel = getMutableLevel(superTy)) { - if (!leftLevel->subsumes(r->level)) - *leftLevel = r->level; + if (!superLevel->subsumes(r->level)) + { + log(superTy); + *superLevel = r->level; + } } log(subTy); @@ -327,7 +435,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool else if (failed) { if (FFlag::LuauExtendedTypeMismatchError && firstFailedOption) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible", *firstFailedOption}}); + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); else errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } @@ -338,28 +446,46 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool bool found = false; std::optional unificationTooComplex; + size_t failedOptionCount = 0; + std::optional failedOption; + + bool foundHeuristic = false; size_t startIndex = 0; if (FFlag::LuauUnionHeuristic) { - bool found = false; - - const std::string* subName = getName(subTy); - if (subName) + if (const std::string* subName = getName(subTy)) { for (size_t i = 0; i < uv->options.size(); ++i) { const std::string* optionName = getName(uv->options[i]); if (optionName && *optionName == *subName) { - found = true; + foundHeuristic = true; startIndex = i; break; } } } - if (!found && cacheEnabled) + if (FFlag::LuauExtendedUnionMismatchError) + { + if (auto subMatchTag = getTableMatchTag(subTy)) + { + for (size_t i = 0; i < uv->options.size(); ++i) + { + auto optionMatchTag = getTableMatchTag(uv->options[i]); + if (optionMatchTag && optionMatchTag->first == subMatchTag->first && *optionMatchTag->second == *subMatchTag->second) + { + foundHeuristic = true; + startIndex = i; + break; + } + } + } + } + + if (!foundHeuristic && cacheEnabled) { for (size_t i = 0; i < uv->options.size(); ++i) { @@ -390,15 +516,27 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { unificationTooComplex = e; } + else if (FFlag::LuauExtendedUnionMismatchError && !isNil(type)) + { + failedOptionCount++; + + if (!failedOption) + failedOption = {innerState.errors.front()}; + } innerState.log.rollback(); } if (unificationTooComplex) + { errors.push_back(*unificationTooComplex); + } else if (!found) { - if (FFlag::LuauExtendedTypeMismatchError) + if (FFlag::LuauExtendedUnionMismatchError && (failedOptionCount == 1 || foundHeuristic) && failedOption) + errors.push_back( + TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); + else if (FFlag::LuauExtendedTypeMismatchError) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); else errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); @@ -431,7 +569,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (unificationTooComplex) errors.push_back(*unificationTooComplex); else if (firstFailedOption) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible", *firstFailedOption}}); + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); } else { @@ -771,6 +909,10 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal if (superIter.good() && subIter.good()) { tryUnify_(*superIter, *subIter); + + if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) + firstPackErrorPos = loopCount; + superIter.advance(); subIter.advance(); continue; @@ -853,13 +995,13 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal while (superIter.good()) { - tryUnify_(singletonTypes.errorRecoveryType(), *superIter); + tryUnify_(getSingletonTypes().errorRecoveryType(), *superIter); superIter.advance(); } while (subIter.good()) { - tryUnify_(singletonTypes.errorRecoveryType(), *subIter); + tryUnify_(getSingletonTypes().errorRecoveryType(), *subIter); subIter.advance(); } @@ -917,14 +1059,22 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal if (numGenerics != rf->generics.size()) { numGenerics = std::min(lf->generics.size(), rf->generics.size()); - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + + if (FFlag::LuauExtendedFunctionMismatchError) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); + else + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } size_t numGenericPacks = lf->genericPacks.size(); if (numGenericPacks != rf->genericPacks.size()) { numGenericPacks = std::min(lf->genericPacks.size(), rf->genericPacks.size()); - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + + if (FFlag::LuauExtendedFunctionMismatchError) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); + else + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } for (size_t i = 0; i < numGenerics; i++) @@ -936,13 +1086,49 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal { Unifier innerState = makeChildUnifier(); - ctx = CountMismatch::Arg; - innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + if (FFlag::LuauExtendedFunctionMismatchError) + { + innerState.ctx = CountMismatch::Arg; + innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); - ctx = CountMismatch::Result; - innerState.tryUnify_(lf->retType, rf->retType); + bool reported = !innerState.errors.empty(); - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + if (auto e = hasUnificationTooComplex(innerState.errors)) + errors.push_back(*e); + else if (!innerState.errors.empty() && innerState.firstPackErrorPos) + errors.push_back( + TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}}); + else if (!innerState.errors.empty()) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + + innerState.ctx = CountMismatch::Result; + innerState.tryUnify_(lf->retType, rf->retType); + + if (!reported) + { + if (auto e = hasUnificationTooComplex(innerState.errors)) + errors.push_back(*e); + else if (!innerState.errors.empty() && size(lf->retType) == 1 && finite(lf->retType)) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); + else if (!innerState.errors.empty() && innerState.firstPackErrorPos) + errors.push_back( + TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}}); + else if (!innerState.errors.empty()) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + } + } + else + { + ctx = CountMismatch::Arg; + innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + + ctx = CountMismatch::Result; + innerState.tryUnify_(lf->retType, rf->retType); + + checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + } log.concat(std::move(innerState.log)); } @@ -994,7 +1180,7 @@ struct Resetter void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) { - if (!FFlag::LuauTableSubtypingVariance) + if (!FFlag::LuauTableSubtypingVariance2) return DEPRECATED_tryUnifyTables(left, right, isIntersection); TableTypeVar* lt = getMutable(left); @@ -1133,7 +1319,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) // TODO: hopefully readonly/writeonly properties will fix this. Property clone = prop; clone.type = deeplyOptional(clone.type); - log(lt); + log(left); lt->props[name] = clone; } else if (variance == Covariant) @@ -1146,7 +1332,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) } else if (lt->state == TableState::Free) { - log(lt); + log(left); lt->props[name] = prop; } else @@ -1176,7 +1362,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. // TODO: we only need to do this if the supertype's indexer is read/write // since that can add indexed elements. - log(rt); + log(right); rt->indexer = lt->indexer; } } @@ -1185,7 +1371,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) // Symmetric if we are invariant if (lt->state == TableState::Unsealed || lt->state == TableState::Free) { - log(lt); + log(left); lt->indexer = rt->indexer; } } @@ -1241,15 +1427,15 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see TableTypeVar* resultTtv = getMutable(result); for (auto& [name, prop] : resultTtv->props) prop.type = deeplyOptional(prop.type, seen); - return types->addType(UnionTypeVar{{singletonTypes.nilType, result}}); + return types->addType(UnionTypeVar{{getSingletonTypes().nilType, result}}); } else - return types->addType(UnionTypeVar{{singletonTypes.nilType, ty}}); + return types->addType(UnionTypeVar{{getSingletonTypes().nilType, ty}}); } void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection) { - LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance); + LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); Resetter resetter{&variance}; variance = Invariant; @@ -1467,7 +1653,7 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio } else if (lt->indexer) { - innerState.tryUnify_(lt->indexer->indexType, singletonTypes.stringType); + innerState.tryUnify_(lt->indexer->indexType, getSingletonTypes().stringType); // We already try to unify properties in both tables. // Skip those and just look for the ones remaining and see if they fit into the indexer. for (const auto& [name, type] : rt->props) @@ -1636,7 +1822,7 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) ok = false; errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); if (!FFlag::LuauExtendedClassMismatchError) - tryUnify_(prop.type, singletonTypes.errorRecoveryType()); + tryUnify_(prop.type, getSingletonTypes().errorRecoveryType()); } else { @@ -1825,7 +2011,7 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) if (get(ty) || get(ty) || get(ty)) return; - const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}}); + const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{getSingletonTypes().anyType}}); const TypePackId anyTP = get(any) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); @@ -1834,14 +2020,14 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) sharedState.tempSeenTy.clear(); sharedState.tempSeenTp.clear(); - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, singletonTypes.anyType, anyTP); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, getSingletonTypes().anyType, anyTP); } void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) { LUAU_ASSERT(get(any)); - const TypeId anyTy = singletonTypes.errorRecoveryType(); + const TypeId anyTy = getSingletonTypes().errorRecoveryType(); std::vector queue; @@ -1887,7 +2073,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays { errors.push_back(TypeError{location, OccursCheckFailed{}}); log(needle); - *asMutable(needle) = *singletonTypes.errorRecoveryType(); + *asMutable(needle) = *getSingletonTypes().errorRecoveryType(); return; } @@ -1951,7 +2137,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ { errors.push_back(TypeError{location, OccursCheckFailed{}}); log(needle); - *asMutable(needle) = *singletonTypes.errorRecoveryTypePack(); + *asMutable(needle) = *getSingletonTypes().errorRecoveryTypePack(); return; } @@ -2005,7 +2191,7 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const s errors.push_back(*e); else if (!innerErrors.empty()) errors.push_back( - TypeError{location, TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible", prop.c_str()), innerErrors.front()}}); + TypeError{location, TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front()}}); } void Unifier::ice(const std::string& message, const Location& location) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 3d0d5b7e..dd24f27c 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,7 +10,6 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) @@ -159,7 +158,7 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n { std::vector hotcomments; - while (isComment(p.lexer.current()) || (FFlag::LuauCaptureBrokenCommentSpans && p.lexer.current().type == Lexeme::BrokenComment)) + while (isComment(p.lexer.current()) || p.lexer.current().type == Lexeme::BrokenComment) { const char* text = p.lexer.current().data; unsigned int length = p.lexer.current().length; @@ -2780,7 +2779,7 @@ const Lexeme& Parser::nextLexeme() const Lexeme& lexeme = lexer.next(/*skipComments*/ false); // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. // The parser will turn this into a proper syntax error. - if (FFlag::LuauCaptureBrokenCommentSpans && lexeme.type == Lexeme::BrokenComment) + if (lexeme.type == Lexeme::BrokenComment) commentLocations.push_back(Comment{lexeme.type, lexeme.location}); if (isComment(lexeme)) commentLocations.push_back(Comment{lexeme.type, lexeme.location}); diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 9230d80d..aecb619a 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -11,26 +11,33 @@ enum class ReportFormat { Default, - Luacheck + Luacheck, + Gnu, }; -static void report(ReportFormat format, const char* name, const Luau::Location& location, const char* type, const char* message) +static void report(ReportFormat format, const char* name, const Luau::Location& loc, const char* type, const char* message) { switch (format) { case ReportFormat::Default: - fprintf(stderr, "%s(%d,%d): %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, type, message); + fprintf(stderr, "%s(%d,%d): %s: %s\n", name, loc.begin.line + 1, loc.begin.column + 1, type, message); break; case ReportFormat::Luacheck: { // Note: luacheck's end column is inclusive but our end column is exclusive // In addition, luacheck doesn't support multi-line messages, so if the error is multiline we'll fake end column as 100 and hope for the best - int columnEnd = (location.begin.line == location.end.line) ? location.end.column : 100; + int columnEnd = (loc.begin.line == loc.end.line) ? loc.end.column : 100; - fprintf(stdout, "%s:%d:%d-%d: (W0) %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, columnEnd, type, message); + // Use stdout to match luacheck behavior + fprintf(stdout, "%s:%d:%d-%d: (W0) %s: %s\n", name, loc.begin.line + 1, loc.begin.column + 1, columnEnd, type, message); break; } + + case ReportFormat::Gnu: + // Note: GNU end column is inclusive but our end column is exclusive + fprintf(stderr, "%s:%d.%d-%d.%d: %s: %s\n", name, loc.begin.line + 1, loc.begin.column + 1, loc.end.line + 1, loc.end.column, type, message); + break; } } @@ -97,6 +104,7 @@ static void displayHelp(const char* argv0) printf("\n"); printf("Available options:\n"); printf(" --formatter=plain: report analysis errors in Luacheck-compatible format\n"); + printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n"); } static int assertionHandler(const char* expr, const char* file, int line) @@ -201,6 +209,8 @@ int main(int argc, char** argv) if (strcmp(argv[i], "--formatter=plain") == 0) format = ReportFormat::Luacheck; + else if (strcmp(argv[i], "--formatter=gnu") == 0) + format = ReportFormat::Gnu; else if (strcmp(argv[i], "--annotate") == 0) annotate = true; } diff --git a/CLI/Coverage.cpp b/CLI/Coverage.cpp new file mode 100644 index 00000000..254df3f0 --- /dev/null +++ b/CLI/Coverage.cpp @@ -0,0 +1,88 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Coverage.h" + +#include "lua.h" + +#include +#include + +struct Coverage +{ + lua_State* L = nullptr; + std::vector functions; +} gCoverage; + +void coverageInit(lua_State* L) +{ + gCoverage.L = lua_mainthread(L); +} + +bool coverageActive() +{ + return gCoverage.L != nullptr; +} + +void coverageTrack(lua_State* L, int funcindex) +{ + int ref = lua_ref(L, funcindex); + gCoverage.functions.push_back(ref); +} + +static void coverageCallback(void* context, const char* function, int linedefined, int depth, const int* hits, size_t size) +{ + FILE* f = static_cast(context); + + std::string name; + + if (depth == 0) + name = "
"; + else if (function) + name = std::string(function) + ":" + std::to_string(linedefined); + else + name = ":" + std::to_string(linedefined); + + fprintf(f, "FN:%d,%s\n", linedefined, name.c_str()); + + for (size_t i = 0; i < size; ++i) + if (hits[i] != -1) + { + fprintf(f, "FNDA:%d,%s\n", hits[i], name.c_str()); + break; + } + + for (size_t i = 0; i < size; ++i) + if (hits[i] != -1) + fprintf(f, "DA:%d,%d\n", int(i), hits[i]); +} + +void coverageDump(const char* path) +{ + lua_State* L = gCoverage.L; + + FILE* f = fopen(path, "w"); + if (!f) + { + fprintf(stderr, "Error opening coverage %s\n", path); + return; + } + + fprintf(f, "TN:\n"); + + for (int fref: gCoverage.functions) + { + lua_getref(L, fref); + + lua_Debug ar = {}; + lua_getinfo(L, -1, "s", &ar); + + fprintf(f, "SF:%s\n", ar.short_src); + lua_getcoverage(L, -1, f, coverageCallback); + fprintf(f, "end_of_record\n"); + + lua_pop(L, 1); + } + + fclose(f); + + printf("Coverage dump written to %s (%d functions)\n", path, int(gCoverage.functions.size())); +} diff --git a/CLI/Coverage.h b/CLI/Coverage.h new file mode 100644 index 00000000..74be4e5c --- /dev/null +++ b/CLI/Coverage.h @@ -0,0 +1,10 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +struct lua_State; + +void coverageInit(lua_State* L); +bool coverageActive(); + +void coverageTrack(lua_State* L, int funcindex); +void coverageDump(const char* path); diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index b3c9557b..cb993dfe 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -67,6 +67,10 @@ std::optional readFile(const std::string& name) if (read != size_t(length)) return std::nullopt; + // Skip first line if it's a shebang + if (length > 2 && result[0] == '#' && result[1] == '!') + result.erase(0, result.find('\n')); + return result; } diff --git a/CLI/Profiler.cpp b/CLI/Profiler.cpp index c6d15a7f..30a171f0 100644 --- a/CLI/Profiler.cpp +++ b/CLI/Profiler.cpp @@ -110,12 +110,12 @@ void profilerStop() gProfiler.thread.join(); } -void profilerDump(const char* name) +void profilerDump(const char* path) { - FILE* f = fopen(name, "wb"); + FILE* f = fopen(path, "wb"); if (!f) { - fprintf(stderr, "Error opening profile %s\n", name); + fprintf(stderr, "Error opening profile %s\n", path); return; } @@ -129,7 +129,7 @@ void profilerDump(const char* name) fclose(f); - printf("Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n", name, double(total) / 1e6, + printf("Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n", path, double(total) / 1e6, static_cast(gProfiler.samples.load()), static_cast(gProfiler.data.size())); uint64_t totalgc = 0; diff --git a/CLI/Profiler.h b/CLI/Profiler.h index 0a407e47..67b1acfd 100644 --- a/CLI/Profiler.h +++ b/CLI/Profiler.h @@ -5,4 +5,4 @@ struct lua_State; void profilerStart(lua_State* L, int frequency); void profilerStop(); -void profilerDump(const char* name); \ No newline at end of file +void profilerDump(const char* path); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 2cdd0062..35c02f2c 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -8,6 +8,7 @@ #include "FileUtils.h" #include "Profiler.h" +#include "Coverage.h" #include "linenoise.hpp" @@ -24,6 +25,16 @@ enum class CompileFormat Binary }; +static Luau::CompileOptions copts() +{ + Luau::CompileOptions result = {}; + result.optimizationLevel = 1; + result.debugLevel = 1; + result.coverageLevel = coverageActive() ? 2 : 0; + + return result; +} + static int lua_loadstring(lua_State* L) { size_t l = 0; @@ -32,7 +43,7 @@ static int lua_loadstring(lua_State* L) lua_setsafeenv(L, LUA_ENVIRONINDEX, false); - std::string bytecode = Luau::compile(std::string(s, l)); + std::string bytecode = Luau::compile(std::string(s, l), copts()); if (luau_load(L, chunkname, bytecode.data(), bytecode.size(), 0) == 0) return 1; @@ -79,9 +90,12 @@ static int lua_require(lua_State* L) luaL_sandboxthread(ML); // now we can compile & run module on the new thread - std::string bytecode = Luau::compile(*source); + std::string bytecode = Luau::compile(*source, copts()); if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { + if (coverageActive()) + coverageTrack(ML, -1); + int status = lua_resume(ML, L, 0); if (status == 0) @@ -149,7 +163,7 @@ static void setupState(lua_State* L) static std::string runCode(lua_State* L, const std::string& source) { - std::string bytecode = Luau::compile(source); + std::string bytecode = Luau::compile(source, copts()); if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0) { @@ -329,11 +343,14 @@ static bool runFile(const char* name, lua_State* GL) std::string chunkname = "=" + std::string(name); - std::string bytecode = Luau::compile(*source); + std::string bytecode = Luau::compile(*source, copts()); int status = 0; if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { + if (coverageActive()) + coverageTrack(L, -1); + status = lua_resume(L, NULL, 0); } else @@ -437,6 +454,7 @@ static void displayHelp(const char* argv0) printf("\n"); printf("Available options:\n"); printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); + printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); } static int assertionHandler(const char* expr, const char* file, int line) @@ -495,6 +513,7 @@ int main(int argc, char** argv) setupState(L); int profile = 0; + bool coverage = false; for (int i = 1; i < argc; ++i) { @@ -505,11 +524,16 @@ int main(int argc, char** argv) profile = 10000; // default to 10 KHz else if (strncmp(argv[i], "--profile=", 10) == 0) profile = atoi(argv[i] + 10); + else if (strcmp(argv[i], "--coverage") == 0) + coverage = true; } if (profile) profilerStart(L, profile); + if (coverage) + coverageInit(L); + std::vector files = getSourceFiles(argc, argv); int failed = 0; @@ -523,6 +547,9 @@ int main(int argc, char** argv) profilerDump("profile.out"); } + if (coverage) + coverageDump("coverage.out"); + return failed; } } diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 2c1e85ff..8f74ffed 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -10,7 +10,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauPreloadClosures, false) LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false) @@ -462,20 +461,17 @@ struct Compiler bool shared = false; - if (FFlag::LuauPreloadClosures) + // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure + // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it + // is used) + if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) { - // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure - // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it - // is used) - if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) - { - int32_t cid = bytecode.addConstantClosure(f->id); + int32_t cid = bytecode.addConstantClosure(f->id); - if (cid >= 0 && cid < 32768) - { - bytecode.emitAD(LOP_DUPCLOSURE, target, cid); - shared = true; - } + if (cid >= 0 && cid < 32768) + { + bytecode.emitAD(LOP_DUPCLOSURE, target, cid); + shared = true; } } diff --git a/Makefile b/Makefile index 15c7ff7a..b144cac6 100644 --- a/Makefile +++ b/Makefile @@ -27,7 +27,7 @@ TESTS_SOURCES=$(wildcard tests/*.cpp) TESTS_OBJECTS=$(TESTS_SOURCES:%=$(BUILD)/%.o) TESTS_TARGET=$(BUILD)/luau-tests -REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Repl.cpp +REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp REPL_CLI_OBJECTS=$(REPL_CLI_SOURCES:%=$(BUILD)/%.o) REPL_CLI_TARGET=$(BUILD)/luau @@ -128,10 +128,10 @@ luau-size: luau # executable target aliases luau: $(REPL_CLI_TARGET) - cp $^ $@ + ln -fs $^ $@ luau-analyze: $(ANALYZE_CLI_TARGET) - cp $^ $@ + ln -fs $^ $@ # executable targets $(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) diff --git a/Sources.cmake b/Sources.cmake index 57df9b91..14834b3a 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -133,6 +133,7 @@ target_sources(Luau.VM PRIVATE VM/src/ltable.cpp VM/src/ltablib.cpp VM/src/ltm.cpp + VM/src/ludata.cpp VM/src/lutf8lib.cpp VM/src/lvmexecute.cpp VM/src/lvmload.cpp @@ -152,12 +153,15 @@ target_sources(Luau.VM PRIVATE VM/src/lstring.h VM/src/ltable.h VM/src/ltm.h + VM/src/ludata.h VM/src/lvm.h ) if(TARGET Luau.Repl.CLI) # Luau.Repl.CLI Sources target_sources(Luau.Repl.CLI PRIVATE + CLI/Coverage.h + CLI/Coverage.cpp CLI/FileUtils.h CLI/FileUtils.cpp CLI/Profiler.h diff --git a/VM/include/lua.h b/VM/include/lua.h index 7078acd0..55902160 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -334,6 +334,10 @@ LUA_API const char* lua_setupvalue(lua_State* L, int funcindex, int n); LUA_API void lua_singlestep(lua_State* L, int enabled); LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled); +typedef void (*lua_Coverage)(void* context, const char* function, int linedefined, int depth, const int* hits, size_t size); + +LUA_API void lua_getcoverage(lua_State* L, int funcindex, void* context, lua_Coverage callback); + /* Warning: this function is not thread-safe since it stores the result in a shared global array! Only use for debugging. */ LUA_API const char* lua_debugtrace(lua_State* L); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 76043b9c..a65b0325 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -8,6 +8,7 @@ #include "lfunc.h" #include "lgc.h" #include "ldo.h" +#include "ludata.h" #include "lvm.h" #include "lnumutils.h" @@ -43,36 +44,30 @@ static Table* getcurrenv(lua_State* L) } } -static LUAU_NOINLINE TValue* index2adrslow(lua_State* L, int idx) +static LUAU_NOINLINE TValue* pseudo2addr(lua_State* L, int idx) { - api_check(L, idx <= 0); - if (idx > LUA_REGISTRYINDEX) + api_check(L, lua_ispseudo(idx)); + switch (idx) + { /* pseudo-indices */ + case LUA_REGISTRYINDEX: + return registry(L); + case LUA_ENVIRONINDEX: { - api_check(L, idx != 0 && -idx <= L->top - L->base); - return L->top + idx; + sethvalue(L, &L->env, getcurrenv(L)); + return &L->env; + } + case LUA_GLOBALSINDEX: + return gt(L); + default: + { + Closure* func = curr_func(L); + idx = LUA_GLOBALSINDEX - idx; + return (idx <= func->nupvalues) ? &func->c.upvals[idx - 1] : cast_to(TValue*, luaO_nilobject); + } } - else - switch (idx) - { /* pseudo-indices */ - case LUA_REGISTRYINDEX: - return registry(L); - case LUA_ENVIRONINDEX: - { - sethvalue(L, &L->env, getcurrenv(L)); - return &L->env; - } - case LUA_GLOBALSINDEX: - return gt(L); - default: - { - Closure* func = curr_func(L); - idx = LUA_GLOBALSINDEX - idx; - return (idx <= func->nupvalues) ? &func->c.upvals[idx - 1] : cast_to(TValue*, luaO_nilobject); - } - } } -static LUAU_FORCEINLINE TValue* index2adr(lua_State* L, int idx) +static LUAU_FORCEINLINE TValue* index2addr(lua_State* L, int idx) { if (idx > 0) { @@ -83,15 +78,20 @@ static LUAU_FORCEINLINE TValue* index2adr(lua_State* L, int idx) else return o; } + else if (idx > LUA_REGISTRYINDEX) + { + api_check(L, idx != 0 && -idx <= L->top - L->base); + return L->top + idx; + } else { - return index2adrslow(L, idx); + return pseudo2addr(L, idx); } } const TValue* luaA_toobject(lua_State* L, int idx) { - StkId p = index2adr(L, idx); + StkId p = index2addr(L, idx); return (p == luaO_nilobject) ? NULL : p; } @@ -145,7 +145,7 @@ void lua_xpush(lua_State* from, lua_State* to, int idx) { api_check(from, from->global == to->global); luaC_checkthreadsleep(to); - setobj2s(to, to->top, index2adr(from, idx)); + setobj2s(to, to->top, index2addr(from, idx)); api_incr_top(to); return; } @@ -202,7 +202,7 @@ void lua_settop(lua_State* L, int idx) void lua_remove(lua_State* L, int idx) { - StkId p = index2adr(L, idx); + StkId p = index2addr(L, idx); api_checkvalidindex(L, p); while (++p < L->top) setobjs2s(L, p - 1, p); @@ -213,7 +213,7 @@ void lua_remove(lua_State* L, int idx) void lua_insert(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId p = index2adr(L, idx); + StkId p = index2addr(L, idx); api_checkvalidindex(L, p); for (StkId q = L->top; q > p; q--) setobjs2s(L, q, q - 1); @@ -228,7 +228,7 @@ void lua_replace(lua_State* L, int idx) luaG_runerror(L, "no calling environment"); api_checknelems(L, 1); luaC_checkthreadsleep(L); - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); api_checkvalidindex(L, o); if (idx == LUA_ENVIRONINDEX) { @@ -250,7 +250,7 @@ void lua_replace(lua_State* L, int idx) void lua_pushvalue(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); setobj2s(L, L->top, o); api_incr_top(L); return; @@ -262,7 +262,7 @@ void lua_pushvalue(lua_State* L, int idx) int lua_type(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return (o == luaO_nilobject) ? LUA_TNONE : ttype(o); } @@ -273,20 +273,20 @@ const char* lua_typename(lua_State* L, int t) int lua_iscfunction(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return iscfunction(o); } int lua_isLfunction(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return isLfunction(o); } int lua_isnumber(lua_State* L, int idx) { TValue n; - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); return tonumber(o, &n); } @@ -298,14 +298,14 @@ int lua_isstring(lua_State* L, int idx) int lua_isuserdata(lua_State* L, int idx) { - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); return (ttisuserdata(o) || ttislightuserdata(o)); } int lua_rawequal(lua_State* L, int index1, int index2) { - StkId o1 = index2adr(L, index1); - StkId o2 = index2adr(L, index2); + StkId o1 = index2addr(L, index1); + StkId o2 = index2addr(L, index2); return (o1 == luaO_nilobject || o2 == luaO_nilobject) ? 0 : luaO_rawequalObj(o1, o2); } @@ -313,8 +313,8 @@ int lua_equal(lua_State* L, int index1, int index2) { StkId o1, o2; int i; - o1 = index2adr(L, index1); - o2 = index2adr(L, index2); + o1 = index2addr(L, index1); + o2 = index2addr(L, index2); i = (o1 == luaO_nilobject || o2 == luaO_nilobject) ? 0 : equalobj(L, o1, o2); return i; } @@ -323,8 +323,8 @@ int lua_lessthan(lua_State* L, int index1, int index2) { StkId o1, o2; int i; - o1 = index2adr(L, index1); - o2 = index2adr(L, index2); + o1 = index2addr(L, index1); + o2 = index2addr(L, index2); i = (o1 == luaO_nilobject || o2 == luaO_nilobject) ? 0 : luaV_lessthan(L, o1, o2); return i; } @@ -332,7 +332,7 @@ int lua_lessthan(lua_State* L, int index1, int index2) double lua_tonumberx(lua_State* L, int idx, int* isnum) { TValue n; - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); if (tonumber(o, &n)) { if (isnum) @@ -350,7 +350,7 @@ double lua_tonumberx(lua_State* L, int idx, int* isnum) int lua_tointegerx(lua_State* L, int idx, int* isnum) { TValue n; - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); if (tonumber(o, &n)) { int res; @@ -371,7 +371,7 @@ int lua_tointegerx(lua_State* L, int idx, int* isnum) unsigned lua_tounsignedx(lua_State* L, int idx, int* isnum) { TValue n; - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); if (tonumber(o, &n)) { unsigned res; @@ -391,13 +391,13 @@ unsigned lua_tounsignedx(lua_State* L, int idx, int* isnum) int lua_toboolean(lua_State* L, int idx) { - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); return !l_isfalse(o); } const char* lua_tolstring(lua_State* L, int idx, size_t* len) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); if (!ttisstring(o)) { luaC_checkthreadsleep(L); @@ -408,7 +408,7 @@ const char* lua_tolstring(lua_State* L, int idx, size_t* len) return NULL; } luaC_checkGC(L); - o = index2adr(L, idx); /* previous call may reallocate the stack */ + o = index2addr(L, idx); /* previous call may reallocate the stack */ } if (len != NULL) *len = tsvalue(o)->len; @@ -417,7 +417,7 @@ const char* lua_tolstring(lua_State* L, int idx, size_t* len) const char* lua_tostringatom(lua_State* L, int idx, int* atom) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); if (!ttisstring(o)) return NULL; const TString* s = tsvalue(o); @@ -438,7 +438,7 @@ const char* lua_namecallatom(lua_State* L, int* atom) const float* lua_tovector(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); if (!ttisvector(o)) { return NULL; @@ -448,7 +448,7 @@ const float* lua_tovector(lua_State* L, int idx) int lua_objlen(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); switch (ttype(o)) { case LUA_TSTRING: @@ -469,13 +469,13 @@ int lua_objlen(lua_State* L, int idx) lua_CFunction lua_tocfunction(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return (!iscfunction(o)) ? NULL : cast_to(lua_CFunction, clvalue(o)->c.f); } void* lua_touserdata(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); switch (ttype(o)) { case LUA_TUSERDATA: @@ -489,13 +489,13 @@ void* lua_touserdata(lua_State* L, int idx) void* lua_touserdatatagged(lua_State* L, int idx, int tag) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return (ttisuserdata(o) && uvalue(o)->tag == tag) ? uvalue(o)->data : NULL; } int lua_userdatatag(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); if (ttisuserdata(o)) return uvalue(o)->tag; return -1; @@ -503,13 +503,13 @@ int lua_userdatatag(lua_State* L, int idx) lua_State* lua_tothread(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return (!ttisthread(o)) ? NULL : thvalue(o); } const void* lua_topointer(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); switch (ttype(o)) { case LUA_TTABLE: @@ -657,7 +657,7 @@ int lua_pushthread(lua_State* L) void lua_gettable(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_checkvalidindex(L, t); luaV_gettable(L, t, L->top - 1, L->top - 1); return; @@ -666,7 +666,7 @@ void lua_gettable(lua_State* L, int idx) void lua_getfield(lua_State* L, int idx, const char* k) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_checkvalidindex(L, t); TValue key; setsvalue(L, &key, luaS_new(L, k)); @@ -678,7 +678,7 @@ void lua_getfield(lua_State* L, int idx, const char* k) void lua_rawgetfield(lua_State* L, int idx, const char* k) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); TValue key; setsvalue(L, &key, luaS_new(L, k)); @@ -690,7 +690,7 @@ void lua_rawgetfield(lua_State* L, int idx, const char* k) void lua_rawget(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); setobj2s(L, L->top - 1, luaH_get(hvalue(t), L->top - 1)); return; @@ -699,7 +699,7 @@ void lua_rawget(lua_State* L, int idx) void lua_rawgeti(lua_State* L, int idx, int n) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); setobj2s(L, L->top, luaH_getnum(hvalue(t), n)); api_incr_top(L); @@ -717,7 +717,7 @@ void lua_createtable(lua_State* L, int narray, int nrec) void lua_setreadonly(lua_State* L, int objindex, int enabled) { - const TValue* o = index2adr(L, objindex); + const TValue* o = index2addr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); api_check(L, t != hvalue(registry(L))); @@ -727,7 +727,7 @@ void lua_setreadonly(lua_State* L, int objindex, int enabled) int lua_getreadonly(lua_State* L, int objindex) { - const TValue* o = index2adr(L, objindex); + const TValue* o = index2addr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); int res = t->readonly; @@ -736,7 +736,7 @@ int lua_getreadonly(lua_State* L, int objindex) void lua_setsafeenv(lua_State* L, int objindex, int enabled) { - const TValue* o = index2adr(L, objindex); + const TValue* o = index2addr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); t->safeenv = bool(enabled); @@ -748,7 +748,7 @@ int lua_getmetatable(lua_State* L, int objindex) const TValue* obj; Table* mt = NULL; int res; - obj = index2adr(L, objindex); + obj = index2addr(L, objindex); switch (ttype(obj)) { case LUA_TTABLE: @@ -775,7 +775,7 @@ int lua_getmetatable(lua_State* L, int objindex) void lua_getfenv(lua_State* L, int idx) { StkId o; - o = index2adr(L, idx); + o = index2addr(L, idx); api_checkvalidindex(L, o); switch (ttype(o)) { @@ -801,7 +801,7 @@ void lua_settable(lua_State* L, int idx) { StkId t; api_checknelems(L, 2); - t = index2adr(L, idx); + t = index2addr(L, idx); api_checkvalidindex(L, t); luaV_settable(L, t, L->top - 2, L->top - 1); L->top -= 2; /* pop index and value */ @@ -813,7 +813,7 @@ void lua_setfield(lua_State* L, int idx, const char* k) StkId t; TValue key; api_checknelems(L, 1); - t = index2adr(L, idx); + t = index2addr(L, idx); api_checkvalidindex(L, t); setsvalue(L, &key, luaS_new(L, k)); luaV_settable(L, t, &key, L->top - 1); @@ -825,7 +825,7 @@ void lua_rawset(lua_State* L, int idx) { StkId t; api_checknelems(L, 2); - t = index2adr(L, idx); + t = index2addr(L, idx); api_check(L, ttistable(t)); if (hvalue(t)->readonly) luaG_runerror(L, "Attempt to modify a readonly table"); @@ -839,7 +839,7 @@ void lua_rawseti(lua_State* L, int idx, int n) { StkId o; api_checknelems(L, 1); - o = index2adr(L, idx); + o = index2addr(L, idx); api_check(L, ttistable(o)); if (hvalue(o)->readonly) luaG_runerror(L, "Attempt to modify a readonly table"); @@ -854,7 +854,7 @@ int lua_setmetatable(lua_State* L, int objindex) TValue* obj; Table* mt; api_checknelems(L, 1); - obj = index2adr(L, objindex); + obj = index2addr(L, objindex); api_checkvalidindex(L, obj); if (ttisnil(L->top - 1)) mt = NULL; @@ -896,7 +896,7 @@ int lua_setfenv(lua_State* L, int idx) StkId o; int res = 1; api_checknelems(L, 1); - o = index2adr(L, idx); + o = index2addr(L, idx); api_checkvalidindex(L, o); api_check(L, ttistable(L->top - 1)); switch (ttype(o)) @@ -987,7 +987,7 @@ int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc) func = 0; else { - StkId o = index2adr(L, errfunc); + StkId o = index2addr(L, errfunc); api_checkvalidindex(L, o); func = savestack(L, o); } @@ -1150,7 +1150,7 @@ l_noret lua_error(lua_State* L) int lua_next(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); int more = luaH_next(L, hvalue(t), L->top - 1); if (more) @@ -1187,7 +1187,7 @@ void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag) api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); luaC_checkGC(L); luaC_checkthreadsleep(L); - Udata* u = luaS_newudata(L, sz, tag); + Udata* u = luaU_newudata(L, sz, tag); setuvalue(L, L->top, u); api_incr_top(L); return u->data; @@ -1197,7 +1197,7 @@ void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)) { luaC_checkGC(L); luaC_checkthreadsleep(L); - Udata* u = luaS_newudata(L, sz + sizeof(dtor), UTAG_IDTOR); + Udata* u = luaU_newudata(L, sz + sizeof(dtor), UTAG_IDTOR); memcpy(&u->data + sz, &dtor, sizeof(dtor)); setuvalue(L, L->top, u); api_incr_top(L); @@ -1232,7 +1232,7 @@ const char* lua_getupvalue(lua_State* L, int funcindex, int n) { luaC_checkthreadsleep(L); TValue* val; - const char* name = aux_upvalue(index2adr(L, funcindex), n, &val); + const char* name = aux_upvalue(index2addr(L, funcindex), n, &val); if (name) { setobj2s(L, L->top, val); @@ -1246,7 +1246,7 @@ const char* lua_setupvalue(lua_State* L, int funcindex, int n) const char* name; TValue* val; StkId fi; - fi = index2adr(L, funcindex); + fi = index2addr(L, funcindex); api_checknelems(L, 1); name = aux_upvalue(fi, n, &val); if (name) @@ -1270,7 +1270,7 @@ int lua_ref(lua_State* L, int idx) api_check(L, idx != LUA_REGISTRYINDEX); /* idx is a stack index for value */ int ref = LUA_REFNIL; global_State* g = L->global; - StkId p = index2adr(L, idx); + StkId p = index2addr(L, idx); if (!ttisnil(p)) { Table* reg = hvalue(registry(L)); diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index d77f84ef..9fe1885f 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -370,6 +370,69 @@ void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled) luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled)); } +static int getmaxline(Proto* p) +{ + int result = -1; + + for (int i = 0; i < p->sizecode; ++i) + { + int line = luaG_getline(p, i); + result = result < line ? line : result; + } + + for (int i = 0; i < p->sizep; ++i) + { + int psize = getmaxline(p->p[i]); + result = result < psize ? psize : result; + } + + return result; +} + +static void getcoverage(Proto* p, int depth, int* buffer, size_t size, void* context, lua_Coverage callback) +{ + memset(buffer, -1, size * sizeof(int)); + + for (int i = 0; i < p->sizecode; ++i) + { + Instruction insn = p->code[i]; + if (LUAU_INSN_OP(insn) != LOP_COVERAGE) + continue; + + int line = luaG_getline(p, i); + int hits = LUAU_INSN_E(insn); + + LUAU_ASSERT(size_t(line) < size); + buffer[line] = buffer[line] < hits ? hits : buffer[line]; + } + + const char* debugname = p->debugname ? getstr(p->debugname) : NULL; + int linedefined = luaG_getline(p, 0); + + callback(context, debugname, linedefined, depth, buffer, size); + + for (int i = 0; i < p->sizep; ++i) + getcoverage(p->p[i], depth + 1, buffer, size, context, callback); +} + +void lua_getcoverage(lua_State* L, int funcindex, void* context, lua_Coverage callback) +{ + const TValue* func = luaA_toobject(L, funcindex); + api_check(L, ttisfunction(func) && !clvalue(func)->isC); + + Proto* p = clvalue(func)->l.p; + + size_t size = getmaxline(p) + 1; + if (size == 0) + return; + + int* buffer = luaM_newarray(L, size, int, 0); + + getcoverage(p, 0, buffer, size, context, callback); + + luaM_freearray(L, buffer, size, int, 0); +} + static size_t append(char* buf, size_t bufsize, size_t offset, const char* data) { size_t size = strlen(data); diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index ab416041..7393fc74 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -8,11 +8,10 @@ #include "lfunc.h" #include "lstring.h" #include "ldo.h" +#include "ludata.h" #include -LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false) - LUAU_FASTFLAG(LuauArrayBoundary) #define GC_SWEEPMAX 40 @@ -59,10 +58,6 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds, case GCSpropagate: case GCSpropagateagain: g->gcstats.currcycle.marktime += seconds; - - // atomic step had to be performed during the switch and it's tracked separately - if (!FFlag::LuauSeparateAtomic && g->gcstate == GCSsweepstring) - g->gcstats.currcycle.marktime -= g->gcstats.currcycle.atomictime; break; case GCSatomic: g->gcstats.currcycle.atomictime += seconds; @@ -488,7 +483,7 @@ static void freeobj(lua_State* L, GCObject* o) luaS_free(L, gco2ts(o)); break; case LUA_TUSERDATA: - luaS_freeudata(L, gco2u(o)); + luaU_freeudata(L, gco2u(o)); break; default: LUAU_ASSERT(0); @@ -632,17 +627,9 @@ static size_t remarkupvals(global_State* g) static size_t atomic(lua_State* L) { global_State* g = L->global; + LUAU_ASSERT(g->gcstate == GCSatomic); + size_t work = 0; - - if (FFlag::LuauSeparateAtomic) - { - LUAU_ASSERT(g->gcstate == GCSatomic); - } - else - { - g->gcstate = GCSatomic; - } - /* remark occasional upvalues of (maybe) dead threads */ work += remarkupvals(g); /* traverse objects caught by write barrier and by 'remarkupvals' */ @@ -666,11 +653,6 @@ static size_t atomic(lua_State* L) g->sweepgc = &g->rootgc; g->gcstate = GCSsweepstring; - if (!FFlag::LuauSeparateAtomic) - { - GC_INTERRUPT(GCSatomic); - } - return work; } @@ -716,22 +698,7 @@ static size_t gcstep(lua_State* L, size_t limit) if (!g->gray) /* no more `gray' objects */ { - if (FFlag::LuauSeparateAtomic) - { - g->gcstate = GCSatomic; - } - else - { - double starttimestamp = lua_clock(); - - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - - atomic(L); /* finish mark phase */ - LUAU_ASSERT(g->gcstate == GCSsweepstring); - - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; - } + g->gcstate = GCSatomic; } break; } @@ -853,7 +820,7 @@ static size_t getheaptrigger(global_State* g, size_t heapgoal) void luaC_step(lua_State* L, bool assist) { global_State* g = L->global; - ptrdiff_t lim = (g->gcstepsize / 100) * g->gcstepmul; /* how much to work */ + int lim = (g->gcstepsize / 100) * g->gcstepmul; /* how much to work */ LUAU_ASSERT(g->totalbytes >= g->GCthreshold); size_t debt = g->totalbytes - g->GCthreshold; @@ -908,7 +875,7 @@ void luaC_fullgc(lua_State* L) if (g->gcstate == GCSpause) startGcCycleStats(g); - if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain)) + if (g->gcstate <= GCSatomic) { /* reset sweep marks to sweep all elements (returning them to white) */ g->sweepstrgc = 0; @@ -1049,7 +1016,7 @@ int64_t luaC_allocationrate(lua_State* L) global_State* g = L->global; const double durationthreshold = 1e-3; // avoid measuring intervals smaller than 1ms - if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain)) + if (g->gcstate <= GCSatomic) { double duration = lua_clock() - g->gcstats.lastcycle.endtimestamp; diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index a79e7b95..f6f7a878 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -7,6 +7,7 @@ #include "ltable.h" #include "lfunc.h" #include "lstring.h" +#include "ludata.h" #include #include diff --git a/VM/src/lobject.h b/VM/src/lobject.h index ba040af6..fd0a15b7 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -78,15 +78,7 @@ typedef struct lua_TValue #define thvalue(o) check_exp(ttisthread(o), &(o)->value.gc->th) #define upvalue(o) check_exp(ttisupval(o), &(o)->value.gc->uv) -// beware bit magic: a value is false if it's nil or boolean false -// baseline implementation: (ttisnil(o) || (ttisboolean(o) && bvalue(o) == 0)) -// we'd like a branchless version of this which helps with performance, and a very fast version -// so our strategy is to always read the boolean value (not using bvalue(o) because that asserts when type isn't boolean) -// we then combine it with type to produce 0/1 as follows: -// - when type is nil (0), & makes the result 0 -// - when type is boolean (1), we effectively only look at the bottom bit, so result is 0 iff boolean value is 0 -// - when type is different, it must have some of the top bits set - we keep all top bits of boolean value so the result is non-0 -#define l_isfalse(o) (!(((o)->value.b | ~1) & ttype(o))) +#define l_isfalse(o) (ttisnil(o) || (ttisboolean(o) && bvalue(o) == 0)) /* ** for internal debug only diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index 18ee1cda..a9e90d17 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -206,32 +206,3 @@ void luaS_free(lua_State* L, TString* ts) L->global->strt.nuse--; luaM_free(L, ts, sizestring(ts->len), ts->memcat); } - -Udata* luaS_newudata(lua_State* L, size_t s, int tag) -{ - if (s > INT_MAX - sizeof(Udata)) - luaM_toobig(L); - Udata* u = luaM_new(L, Udata, sizeudata(s), L->activememcat); - luaC_link(L, u, LUA_TUSERDATA); - u->len = int(s); - u->metatable = NULL; - LUAU_ASSERT(tag >= 0 && tag <= 255); - u->tag = uint8_t(tag); - return u; -} - -void luaS_freeudata(lua_State* L, Udata* u) -{ - LUAU_ASSERT(u->tag < LUA_UTAG_LIMIT || u->tag == UTAG_IDTOR); - - void (*dtor)(void*) = nullptr; - if (u->tag == UTAG_IDTOR) - memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); - else if (u->tag) - dtor = L->global->udatagc[u->tag]; - - if (dtor) - dtor(u->data); - - luaM_free(L, u, sizeudata(u->len), u->memcat); -} diff --git a/VM/src/lstring.h b/VM/src/lstring.h index 612da28d..3fd0bd39 100644 --- a/VM/src/lstring.h +++ b/VM/src/lstring.h @@ -8,11 +8,7 @@ /* string size limit */ #define MAXSSIZE (1 << 30) -/* special tag value is used for user data with inline dtors */ -#define UTAG_IDTOR LUA_UTAG_LIMIT - #define sizestring(len) (offsetof(TString, data) + len + 1) -#define sizeudata(len) (offsetof(Udata, data) + len) #define luaS_new(L, s) (luaS_newlstr(L, s, strlen(s))) #define luaS_newliteral(L, s) (luaS_newlstr(L, "" s, (sizeof(s) / sizeof(char)) - 1)) @@ -26,8 +22,5 @@ LUAI_FUNC void luaS_resize(lua_State* L, int newsize); LUAI_FUNC TString* luaS_newlstr(lua_State* L, const char* str, size_t l); LUAI_FUNC void luaS_free(lua_State* L, TString* ts); -LUAI_FUNC Udata* luaS_newudata(lua_State* L, size_t s, int tag); -LUAI_FUNC void luaS_freeudata(lua_State* L, Udata* u); - LUAI_FUNC TString* luaS_bufstart(lua_State* L, size_t size); LUAI_FUNC TString* luaS_buffinish(lua_State* L, TString* ts); diff --git a/VM/src/ludata.cpp b/VM/src/ludata.cpp new file mode 100644 index 00000000..d180c388 --- /dev/null +++ b/VM/src/ludata.cpp @@ -0,0 +1,37 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "ludata.h" + +#include "lgc.h" +#include "lmem.h" + +#include + +Udata* luaU_newudata(lua_State* L, size_t s, int tag) +{ + if (s > INT_MAX - sizeof(Udata)) + luaM_toobig(L); + Udata* u = luaM_new(L, Udata, sizeudata(s), L->activememcat); + luaC_link(L, u, LUA_TUSERDATA); + u->len = int(s); + u->metatable = NULL; + LUAU_ASSERT(tag >= 0 && tag <= 255); + u->tag = uint8_t(tag); + return u; +} + +void luaU_freeudata(lua_State* L, Udata* u) +{ + LUAU_ASSERT(u->tag < LUA_UTAG_LIMIT || u->tag == UTAG_IDTOR); + + void (*dtor)(void*) = nullptr; + if (u->tag == UTAG_IDTOR) + memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); + else if (u->tag) + dtor = L->global->udatagc[u->tag]; + + if (dtor) + dtor(u->data); + + luaM_free(L, u, sizeudata(u->len), u->memcat); +} diff --git a/VM/src/ludata.h b/VM/src/ludata.h new file mode 100644 index 00000000..59cb85bd --- /dev/null +++ b/VM/src/ludata.h @@ -0,0 +1,13 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include "lobject.h" + +/* special tag value is used for user data with inline dtors */ +#define UTAG_IDTOR LUA_UTAG_LIMIT + +#define sizeudata(len) (offsetof(Udata, data) + len) + +LUAI_FUNC Udata* luaU_newudata(lua_State* L, size_t s, int tag); +LUAI_FUNC void luaU_freeudata(lua_State* L, Udata* u); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index bf8d493e..cebeeb58 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -63,7 +63,8 @@ #define VM_KV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->l.p->sizek)), &k[i]) #define VM_UV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->nupvalues)), &cl->l.uprefs[i]) -#define VM_PATCH_C(pc, slot) ((uint8_t*)(pc))[3] = uint8_t(slot) +#define VM_PATCH_C(pc, slot) *const_cast(pc) = ((uint8_t(slot) << 24) | (0x00ffffffu & *(pc))) +#define VM_PATCH_E(pc, slot) *const_cast(pc) = ((uint32_t(slot) << 8) | (0x000000ffu & *(pc))) // NOTE: If debugging the Luau code, disable this macro to prevent timeouts from // occurring when tracing code in Visual Studio / XCode @@ -120,7 +121,7 @@ */ #if VM_USE_CGOTO #define VM_CASE(op) CASE_##op: -#define VM_NEXT() goto*(SingleStep ? &&dispatch : kDispatchTable[*(uint8_t*)pc]) +#define VM_NEXT() goto*(SingleStep ? &&dispatch : kDispatchTable[LUAU_INSN_OP(*pc)]) #define VM_CONTINUE(op) goto* kDispatchTable[uint8_t(op)] #else #define VM_CASE(op) case op: @@ -325,7 +326,7 @@ static void luau_execute(lua_State* L) // ... and singlestep logic :) if (SingleStep) { - if (L->global->cb.debugstep && !luau_skipstep(*(uint8_t*)pc)) + if (L->global->cb.debugstep && !luau_skipstep(LUAU_INSN_OP(*pc))) { VM_PROTECT(luau_callhook(L, L->global->cb.debugstep, NULL)); @@ -335,13 +336,12 @@ static void luau_execute(lua_State* L) } #if VM_USE_CGOTO - VM_CONTINUE(*(uint8_t*)pc); + VM_CONTINUE(LUAU_INSN_OP(*pc)); #endif } #if !VM_USE_CGOTO - // Note: this assumes that LUAU_INSN_OP() decodes the first byte (aka least significant byte in the little endian encoding) - size_t dispatchOp = *(uint8_t*)pc; + size_t dispatchOp = LUAU_INSN_OP(*pc); dispatchContinue: switch (dispatchOp) @@ -2577,7 +2577,7 @@ static void luau_execute(lua_State* L) // update hits with saturated add and patch the instruction in place hits = (hits < (1 << 23) - 1) ? hits + 1 : hits; - ((uint32_t*)pc)[-1] = LOP_COVERAGE | (uint32_t(hits) << 8); + VM_PATCH_E(pc - 1, hits); VM_NEXT(); } diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 740a4cfd..5d802277 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -53,7 +53,7 @@ const float* luaV_tovector(const TValue* obj) static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1, const TValue* p2) { ptrdiff_t result = savestack(L, res); - // RBOLOX: using stack room beyond top is technically safe here, but for very complicated reasons: + // using stack room beyond top is technically safe here, but for very complicated reasons: // * The stack guarantees 1 + EXTRA_STACK room beyond stack_last (see luaD_reallocstack) will be allocated // * we cannot move luaD_checkstack above because the arguments are *sometimes* pointers to the lua // stack and checkstack may invalidate those pointers @@ -74,7 +74,7 @@ static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1 static void callTM(lua_State* L, const TValue* f, const TValue* p1, const TValue* p2, const TValue* p3) { - // RBOLOX: using stack room beyond top is technically safe here, but for very complicated reasons: + // using stack room beyond top is technically safe here, but for very complicated reasons: // * The stack guarantees 1 + EXTRA_STACK room beyond stack_last (see luaD_reallocstack) will be allocated // * we cannot move luaD_checkstack above because the arguments are *sometimes* pointers to the lua // stack and checkstack may invalidate those pointers diff --git a/bench/tests/sunspider/3d-raytrace.lua b/bench/tests/sunspider/3d-raytrace.lua index 60e4f61e..c8f6b5dc 100644 --- a/bench/tests/sunspider/3d-raytrace.lua +++ b/bench/tests/sunspider/3d-raytrace.lua @@ -451,15 +451,16 @@ function raytraceScene() end function arrayToCanvasCommands(pixels) - local s = 'Test\nvar pixels = ['; + local s = {}; + table.insert(s, 'Test\nvar pixels = ['); for y = 0,size-1 do - s = s .. "["; + table.insert(s, "["); for x = 0,size-1 do - s = s .. "[" .. math.floor(pixels[y + 1][x + 1][1] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][2] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][3] * 255) .. "],"; + table.insert(s, "[" .. math.floor(pixels[y + 1][x + 1][1] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][2] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][3] * 255) .. "],"); end - s = s .. "],"; + table.insert(s, "],"); end - s = s .. '];\n var canvas = document.getElementById("renderCanvas").getContext("2d");\n\ + table.insert(s, '];\n var canvas = document.getElementById("renderCanvas").getContext("2d");\n\ \n\ \n\ var size = ' .. size .. ';\n\ @@ -479,9 +480,9 @@ for (var y = 0; y < size; y++) {\n\ canvas.setFillColor(l[0], l[1], l[2], 1);\n\ canvas.fillRect(x, y, 1, 1);\n\ }\n\ -}'; +}'); - return s; + return table.concat(s); end testOutput = arrayToCanvasCommands(raytraceScene()); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 3b74a99e..62a9999b 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -513,8 +513,6 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_the_end_of_a_comme TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_comment") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - check(R"( --[[ @1 )"); @@ -526,8 +524,6 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_co TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_comment_at_the_very_end_of_the_file") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - check("--[[@1"); auto ac = autocomplete('1'); @@ -2625,4 +2621,55 @@ local a: A<(number, s@1> CHECK(ac.entryMap.count("string")); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_first_function_arg_expected_type") +{ + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag luauAutocompleteFirstArg("LuauAutocompleteFirstArg", true); + + check(R"( +local function foo1() return 1 end +local function foo2() return "1" end + +local function bar0() return "got" .. a end +local function bar1(a: number) return "got " .. a end +local function bar2(a: number, b: string) return "got " .. a .. b end + +local t = {} +function t:bar1(a: number) return "got " .. a end + +local r1 = bar0(@1) +local r2 = bar1(@2) +local r3 = bar2(@3) +local r4 = t:bar1(@4) + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::None); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); + + ac = autocomplete('2'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); + + ac = autocomplete('3'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); + + ac = autocomplete('4'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 6ba39ada..95811b3f 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -10,9 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauPreloadClosures) -LUAU_FASTFLAG(LuauGenericSpecialGlobals) - using namespace Luau; static std::string compileFunction(const char* source, uint32_t id) @@ -74,20 +71,10 @@ TEST_CASE("BasicFunction") bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); Luau::compileOrThrow(bcb, "local function foo(a, b) return b end"); - if (FFlag::LuauPreloadClosures) - { - CHECK_EQ("\n" + bcb.dumpFunction(1), R"( + CHECK_EQ("\n" + bcb.dumpFunction(1), R"( DUPCLOSURE R0 K0 RETURN R0 0 )"); - } - else - { - CHECK_EQ("\n" + bcb.dumpFunction(1), R"( -NEWCLOSURE R0 P0 -RETURN R0 0 -)"); - } CHECK_EQ("\n" + bcb.dumpFunction(0), R"( RETURN R1 1 @@ -2859,47 +2846,35 @@ CAPTURE UPVAL U1 RETURN R0 1 )"); - if (FFlag::LuauPreloadClosures) - { - // recursive capture - CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( + // recursive capture + CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( DUPCLOSURE R0 K0 CAPTURE VAL R0 RETURN R0 0 )"); - // multi-level recursive capture - CHECK_EQ("\n" + compileFunction("local function foo() return function() return foo() end end", 1), R"( + // multi-level recursive capture + CHECK_EQ("\n" + compileFunction("local function foo() return function() return foo() end end", 1), R"( DUPCLOSURE R0 K0 CAPTURE UPVAL U0 RETURN R0 1 )"); - // multi-level recursive capture where function isn't top-level - // note: this should probably be optimized to DUPCLOSURE but doing that requires a different upval tracking flow in the compiler - CHECK_EQ("\n" + compileFunction(R"( + // multi-level recursive capture where function isn't top-level + // note: this should probably be optimized to DUPCLOSURE but doing that requires a different upval tracking flow in the compiler + CHECK_EQ("\n" + compileFunction(R"( local function foo() local function bar() return function() return bar() end end end )", - 1), - R"( + 1), + R"( NEWCLOSURE R0 P0 CAPTURE UPVAL U0 RETURN R0 1 )"); - } - else - { - // recursive capture - CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( -NEWCLOSURE R0 P0 -CAPTURE VAL R0 -RETURN R0 0 -)"); - } } TEST_CASE("OutOfLocals") @@ -3504,8 +3479,6 @@ local t = { TEST_CASE("ConstantClosure") { - ScopedFastFlag sff("LuauPreloadClosures", true); - // closures without upvalues are created when bytecode is loaded CHECK_EQ("\n" + compileFunction(R"( return function() end @@ -3570,8 +3543,6 @@ RETURN R0 1 TEST_CASE("SharedClosure") { - ScopedFastFlag sff1("LuauPreloadClosures", true); - // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level CHECK_EQ("\n" + compileFunction(R"( local val = ... diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index b2aad316..b055a38e 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -123,8 +123,8 @@ int lua_silence(lua_State* L) using StateRef = std::unique_ptr; -static StateRef runConformance( - const char* name, void (*setup)(lua_State* L) = nullptr, void (*yield)(lua_State* L) = nullptr, lua_State* initialLuaState = nullptr) +static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = nullptr, void (*yield)(lua_State* L) = nullptr, + lua_State* initialLuaState = nullptr, lua_CompileOptions* copts = nullptr) { std::string path = __FILE__; path.erase(path.find_last_of("\\/")); @@ -180,13 +180,8 @@ static StateRef runConformance( std::string chunkname = "=" + std::string(name); - lua_CompileOptions copts = {}; - copts.optimizationLevel = 1; // default - copts.debugLevel = 2; // for debugger tests - copts.vectorCtor = "vector"; // for vector tests - size_t bytecodeSize = 0; - char* bytecode = luau_compile(source.data(), source.size(), &copts, &bytecodeSize); + char* bytecode = luau_compile(source.data(), source.size(), copts, &bytecodeSize); int result = luau_load(L, chunkname.c_str(), bytecode, bytecodeSize, 0); free(bytecode); @@ -373,29 +368,37 @@ TEST_CASE("Vector") { ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; - runConformance("vector.lua", [](lua_State* L) { - lua_pushcfunction(L, lua_vector, "vector"); - lua_setglobal(L, "vector"); + lua_CompileOptions copts = {}; + copts.optimizationLevel = 1; + copts.debugLevel = 1; + copts.vectorCtor = "vector"; + + runConformance( + "vector.lua", + [](lua_State* L) { + lua_pushcfunction(L, lua_vector, "vector"); + lua_setglobal(L, "vector"); #if LUA_VECTOR_SIZE == 4 - lua_pushvector(L, 0.0f, 0.0f, 0.0f, 0.0f); + lua_pushvector(L, 0.0f, 0.0f, 0.0f, 0.0f); #else - lua_pushvector(L, 0.0f, 0.0f, 0.0f); + lua_pushvector(L, 0.0f, 0.0f, 0.0f); #endif - luaL_newmetatable(L, "vector"); + luaL_newmetatable(L, "vector"); - lua_pushstring(L, "__index"); - lua_pushcfunction(L, lua_vector_index, nullptr); - lua_settable(L, -3); + lua_pushstring(L, "__index"); + lua_pushcfunction(L, lua_vector_index, nullptr); + lua_settable(L, -3); - lua_pushstring(L, "__namecall"); - lua_pushcfunction(L, lua_vector_namecall, nullptr); - lua_settable(L, -3); + lua_pushstring(L, "__namecall"); + lua_pushcfunction(L, lua_vector_namecall, nullptr); + lua_settable(L, -3); - lua_setreadonly(L, -1, true); - lua_setmetatable(L, -2); - lua_pop(L, 1); - }); + lua_setreadonly(L, -1, true); + lua_setmetatable(L, -2); + lua_pop(L, 1); + }, + nullptr, nullptr, &copts); } static void populateRTTI(lua_State* L, Luau::TypeId type) @@ -499,6 +502,10 @@ TEST_CASE("Debugger") breakhits = 0; interruptedthread = nullptr; + lua_CompileOptions copts = {}; + copts.optimizationLevel = 1; + copts.debugLevel = 2; + runConformance( "debugger.lua", [](lua_State* L) { @@ -614,7 +621,8 @@ TEST_CASE("Debugger") lua_resume(interruptedthread, nullptr, 0); interruptedthread = nullptr; } - }); + }, + nullptr, &copts); CHECK(breakhits == 10); // 2 hits per breakpoint } @@ -863,4 +871,46 @@ TEST_CASE("TagMethodError") }); } +TEST_CASE("Coverage") +{ + lua_CompileOptions copts = {}; + copts.optimizationLevel = 1; + copts.debugLevel = 1; + copts.coverageLevel = 2; + + runConformance( + "coverage.lua", + [](lua_State* L) { + lua_pushcfunction( + L, + [](lua_State* L) -> int { + luaL_argexpected(L, lua_isLfunction(L, 1), 1, "function"); + + lua_newtable(L); + lua_getcoverage(L, 1, L, [](void* context, const char* function, int linedefined, int depth, const int* hits, size_t size) { + lua_State* L = static_cast(context); + + lua_newtable(L); + + lua_pushstring(L, function); + lua_setfield(L, -2, "name"); + + for (size_t i = 0; i < size; ++i) + if (hits[i] != -1) + { + lua_pushinteger(L, hits[i]); + lua_rawseti(L, -2, int(i)); + } + + lua_rawseti(L, -2, lua_objlen(L, -2) + 1); + }); + + return 1; + }, + "getcoverage"); + lua_setglobal(L, "getcoverage"); + }, + nullptr, nullptr, &copts); +} + TEST_SUITE_END(); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 2800d2fe..e3993cc5 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -278,7 +278,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") #if defined(_DEBUG) || defined(_NOOPT) int limit = 250; #else - int limit = 500; + int limit = 400; #endif ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit}; @@ -287,7 +287,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") TypeId table = src.addType(TableTypeVar{}); TypeId nested = table; - for (unsigned i = 0; i < limit + 100; i++) + for (int i = 0; i < limit + 100; i++) { TableTypeVar* ttv = getMutable(nested); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 72d3a9a6..5abcb09a 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2303,8 +2303,6 @@ TEST_CASE_FIXTURE(Fixture, "capture_comments") TEST_CASE_FIXTURE(Fixture, "capture_broken_comment_at_the_start_of_the_file") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - ParseOptions options; options.captureComments = true; @@ -2319,8 +2317,6 @@ TEST_CASE_FIXTURE(Fixture, "capture_broken_comment_at_the_start_of_the_file") TEST_CASE_FIXTURE(Fixture, "capture_broken_comment") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - ParseOptions options; options.captureComments = true; diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 091c2f01..71ff4e1b 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -207,6 +207,36 @@ TEST_CASE_FIXTURE(Fixture, "as_expr_does_not_propagate_type_info") CHECK_EQ("number", toString(requireType("b"))); } +TEST_CASE_FIXTURE(Fixture, "as_expr_is_bidirectional") +{ + ScopedFastFlag sff{"LuauBidirectionalAsExpr", true}; + + CheckResult result = check(R"( + local a = 55 :: number? + local b = a :: number + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number?", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "as_expr_warns_on_unrelated_cast") +{ + ScopedFastFlag sff{"LuauBidirectionalAsExpr", true}; + ScopedFastFlag sff2{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + local a = 55 :: string + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Cannot cast 'number' into 'string' because the types are unrelated", toString(result.errors[0])); + CHECK_EQ("string", toString(requireType("a"))); +} + TEST_CASE_FIXTURE(Fixture, "type_annotations_inside_function_bodies") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 1e2eae14..1d8135d4 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -7,6 +7,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauFixTonumberReturnType) + using namespace Luau; TEST_SUITE_BEGIN("BuiltinTests"); @@ -814,6 +816,30 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi CHECK_EQ(TypeErrorData(TypeMismatch{stringType, booleanType}), result.errors[2].data); } +TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type") +{ + CheckResult result = check(R"( + --!strict + local b: number = tonumber('asdf') + )"); + + if (FFlag::LuauFixTonumberReturnType) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number?' could not be converted into 'number'", toString(result.errors[0])); + } +} + +TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type2") +{ + CheckResult result = check(R"( + --!strict + local b: number = tonumber('asdf') or 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "dont_add_definitions_to_persistent_types") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index aba50891..b62044fa 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) + TEST_SUITE_BEGIN("GenericsTests"); TEST_CASE_FIXTURE(Fixture, "check_generic_function") @@ -644,4 +646,42 @@ f(1, 2, 3) CHECK_EQ(toString(*ty, opts), "(a: number, number, number) -> ()"); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_generic_types") +{ + CheckResult result = check(R"( +type C = () -> () +type D = () -> () + +local c: C +local d: D = c + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauExtendedFunctionMismatchError) + CHECK_EQ( + toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type parameters)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_generic_pack") +{ + CheckResult result = check(R"( +type C = () -> () +type D = () -> () + +local c: C +local d: D = c + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauExtendedFunctionMismatchError) + CHECK_EQ(toString(result.errors[0]), + R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type pack parameters)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()')"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index fe8e7ff9..688680c1 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -279,7 +279,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_non_binary_expressions_actually_resolve_const TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal") { - ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; + ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( @@ -1085,4 +1085,19 @@ TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); } +TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") +{ + ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; + + CheckResult result = check(R"( + type T = { [string]: { prop: number }? } + local t: T = {} + if t["hello"] then + local foo = t["hello"].prop + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 5f95efd5..1621ef32 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -374,4 +374,54 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_string") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauUnionHeuristic", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauExtendedUnionMismatchError", true}, + }; + + CheckResult result = check(R"( +type Cat = { tag: 'cat', catfood: string } +type Dog = { tag: 'dog', dogfood: string } +type Animal = Cat | Dog + +local a: Animal = { tag = 'cat', cafood = 'something' } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type 'a' could not be converted into 'Cat | Dog' +caused by: + None of the union options are compatible. For example: Table type 'a' not compatible with type 'Cat' because the former is missing field 'catfood')", + toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_bool") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauUnionHeuristic", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauExtendedUnionMismatchError", true}, + }; + + CheckResult result = check(R"( +type Good = { success: true, result: string } +type Bad = { success: false, error: string } +type Result = Good | Bad + +local a: Result = { success = false, result = 'something' } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type 'a' could not be converted into 'Bad | Good' +caused by: + None of the union options are compatible. For example: Table type 'a' not compatible with type 'Bad' because the former is missing field 'error')", + toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index cb72faaf..3ea9b80c 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -12,6 +12,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) + TEST_SUITE_BEGIN("TableTests"); TEST_CASE_FIXTURE(Fixture, "basic") @@ -275,7 +277,7 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification") TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local a = {} @@ -346,7 +348,7 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1") TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( --!strict @@ -369,7 +371,7 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local T = {} @@ -476,7 +478,7 @@ TEST_CASE_FIXTURE(Fixture, "ok_to_add_property_to_free_table") TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_assignment") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( --!strict @@ -511,7 +513,7 @@ TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_ TEST_CASE_FIXTURE(Fixture, "width_subtyping") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( --!strict @@ -771,7 +773,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_for_left_unsealed_table_from_right_han TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local t: { a: string, [number]: string } = { a = "foo" } @@ -782,7 +784,7 @@ TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") TEST_CASE_FIXTURE(Fixture, "array_factory_function") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( function empty() return {} end @@ -1465,7 +1467,7 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2") TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local function foo(a: {[string]: number, a: string}) end @@ -1550,7 +1552,7 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multipl TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_is_ok") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local vec3 = {x = 1, y = 2, z = 3} @@ -1937,7 +1939,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_strict") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( --!strict @@ -1952,7 +1954,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in TEST_CASE_FIXTURE(Fixture, "error_detailed_prop") { - ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( @@ -1971,7 +1973,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_prop_nested") { - ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( @@ -1995,7 +1997,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop") { - ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( @@ -2015,11 +2017,22 @@ caused by: caused by: Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); - CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' + if (FFlag::LuauExtendedFunctionMismatchError) + { + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' +caused by: + Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' +caused by: + Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); + } + else + { + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' caused by: Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()')"); + } } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") @@ -2027,7 +2040,7 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") ScopedFastFlag sffs[] { {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauTableSubtypingVariance", true}, + {"LuauTableSubtypingVariance2", true}, }; CheckResult result = check(R"( @@ -2048,7 +2061,7 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") ScopedFastFlag sffs[] { {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauTableSubtypingVariance", true}, + {"LuauTableSubtypingVariance2", true}, {"LuauExtendedTypeMismatchError", true}, }; @@ -2076,7 +2089,7 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") ScopedFastFlag sffs[] { {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauTableSubtypingVariance", true}, + {"LuauTableSubtypingVariance2", true}, }; CheckResult result = check(R"( @@ -2092,4 +2105,18 @@ a.p = { x = 9 } LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "recursive_metatable_type_call") +{ + ScopedFastFlag luauFixRecursiveMetatableCall{"LuauFixRecursiveMetatableCall", true}; + + CheckResult result = check(R"( +local b +b = setmetatable({}, {__call = b}) +b() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable {| __call: t1 |}, { } })"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index e3222a41..ad9ea827 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) LUAU_FASTFLAG(LuauEqConstraint) +LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) using namespace Luau; @@ -2084,7 +2085,7 @@ TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") { CheckResult result = check(R"( function add(a: number, b: string) - return a + tonumber(b), a .. b + return a + (tonumber(b) :: number), a .. b end local n, s = add(2,"3") )"); @@ -2485,7 +2486,7 @@ TEST_CASE_FIXTURE(Fixture, "inferring_crazy_table_should_also_be_quick") CheckResult result = check(R"( --!strict function f(U) - U(w:s(an):c()():c():U(s):c():c():U(s):c():U(s):cU()):c():U(s):c():U(s):c():c():U(s):c():U(s):cU() + U(w:s(an):c()():c():U(s):c():c():U(s):c():U(s):cU()):c():U(s):c():U(s):c():c():U(s):c():U(s):cU() end )"); @@ -3329,7 +3330,7 @@ TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable { CheckResult result = check(R"( local x - print((x == true and (x .. "y")) .. 1) + print((x == true and (x .. "y")) .. 1) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -4473,7 +4474,18 @@ f(function(a, b, c, ...) return a + b end) )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Type '(number, number, a) -> number' could not be converted into '(number, number) -> number'", toString(result.errors[0])); + + if (FFlag::LuauExtendedFunctionMismatchError) + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' +caused by: + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", + toString(result.errors[0])); + } + else + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number')", toString(result.errors[0])); + } // Infer from variadic packs into elements result = check(R"( @@ -4604,7 +4616,17 @@ local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not i )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'", toString(result.errors[0])); + if (FFlag::LuauExtendedFunctionMismatchError) + { + CHECK_EQ( + "Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " + "parameters", + toString(result.errors[0])); + } + else + { + CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'", toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "infer_return_value_type") @@ -4799,4 +4821,211 @@ local ModuleA = require(game.A) CHECK_EQ("*unknown*", toString(*oty)); } +/* + * If it wasn't instantly obvious, we have the fuzzer to thank for this gem of a test. + * + * We had an issue here where the scope for the `if` block here would + * have an elevated TypeLevel even though there is no function nesting going on. + * This would result in a free typevar for the type of _ that was much higher than + * it should be. This type would be erroneously quantified in the definition of `aaa`. + * This in turn caused an ice when evaluating `_()` in the while loop. + */ +TEST_CASE_FIXTURE(Fixture, "free_typevars_introduced_within_control_flow_constructs_do_not_get_an_elevated_TypeLevel") +{ + check(R"( + --!strict + if _ then + _[_], _ = nil + _() + end + + local aaa = function():typeof(_) return 1 end + + if aaa then + while _() do + end + end + )"); + + // No ice()? No problem. +} + +/* + * This is a bit elaborate. Bear with me. + * + * The type of _ becomes free with the first statement. With the second, we unify it with a function. + * + * At this point, it is important that the newly created fresh types of this new function type are promoted + * to the same level as the original free type. If we do not, they are incorrectly ascribed the level of the + * containing function. + * + * If this is allowed to happen, the final lambda erroneously quantifies the type of _ to something ridiculous + * just before we typecheck the invocation to _. + */ +TEST_CASE_FIXTURE(Fixture, "fuzzer_found_this") +{ + check(R"( + l0, _ = nil + + local function p() + _() + end + + a = _( + function():(typeof(p),typeof(_)) + end + )[nil] + )"); +} + +/* + * We had an issue where part of the type of pairs() was an unsealed table. + * This test depends on FFlagDebugLuauFreezeArena to trigger it. + */ +TEST_CASE_FIXTURE(Fixture, "pairs_parameters_are_not_unsealed_tables") +{ + check(R"( + function _(l0:{n0:any}) + _ = pairs + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "inferred_methods_of_free_tables_have_the_same_level_as_the_enclosing_table") +{ + check(R"( + function Base64FileReader(data) + local reader = {} + local index: number + + function reader:PeekByte() + return data:byte(index) + end + + function reader:Byte() + return data:byte(index - 1) + end + + return reader + end + + Base64FileReader() + + function ReadMidiEvents(data) + + local reader = Base64FileReader(data) + + while reader:HasMore() do + (reader:Byte() % 128) + end + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg_count") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number) -> string + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number) -> string' +caused by: + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number, string) -> string + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, string) -> string' +caused by: + Argument #2 type is not compatible. Type 'string' could not be converted into 'number')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_count") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> (number) +type B = (number, number) -> (number, boolean) + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> number' could not be converted into '(number, number) -> (number, boolean)' +caused by: + Function only returns 1 value. 2 are required here)"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number, number) -> number + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, number) -> number' +caused by: + Return type is not compatible. Type 'string' could not be converted into 'number')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_mult") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> (number, string) +type B = (number, number) -> (number, boolean) + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ( + toString(result.errors[0]), R"(Type '(number, number) -> (number, string)' could not be converted into '(number, number) -> (number, boolean)' +caused by: + Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); +} + +TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free") +{ + ScopedFastFlag luauUnifyFunctionCheckResult{"LuauUpdateFunctionNameBinding", true}; + + CheckResult result = check(R"( +local t = {} + +function t.x(value) + for k,v in pairs(t) do end +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 9f9a007f..f55b46a4 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -214,4 +214,32 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_41095_concat_log_in_sealed_table_unifica CHECK_EQ(toString(result.errors[1]), "Available overloads: ({a}, a) -> (); and ({a}, number, a) -> ()"); } +TEST_CASE_FIXTURE(TryUnifyFixture, "undo_new_prop_on_unsealed_table") +{ + ScopedFastFlag flags[] = { + {"LuauTableSubtypingVariance2", true}, + }; + // I am not sure how to make this happen in Luau code. + + TypeId unsealedTable = arena.addType(TableTypeVar{TableState::Unsealed, TypeLevel{}}); + TypeId sealedTable = arena.addType(TableTypeVar{ + {{"prop", Property{getSingletonTypes().numberType}}}, + std::nullopt, + TypeLevel{}, + TableState::Sealed + }); + + const TableTypeVar* ttv = get(unsealedTable); + REQUIRE(ttv); + + state.tryUnify(unsealedTable, sealedTable); + + // To be honest, it's really quite spooky here that we're amending an unsealed table in this case. + CHECK(!ttv->props.empty()); + + state.log.rollback(); + + CHECK(ttv->props.empty()); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 48496b89..b095a0db 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -462,4 +462,20 @@ local a: XYZ = { w = 4 } CHECK_EQ(toString(result.errors[0]), R"(Type 'a' could not be converted into 'X | Y | Z'; none of the union options are compatible)"); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_optional") +{ + ScopedFastFlag luauExtendedUnionMismatchError{"LuauExtendedUnionMismatchError", true}; + + CheckResult result = check(R"( +type X = { x: number } + +local a: X? = { w = 4 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'a' could not be converted into 'X?' +caused by: + None of the union options are compatible. For example: Table type 'a' not compatible with type 'X' because the former is missing field 'x')"); +} + TEST_SUITE_END(); diff --git a/tests/conformance/coverage.lua b/tests/conformance/coverage.lua new file mode 100644 index 00000000..f899603f --- /dev/null +++ b/tests/conformance/coverage.lua @@ -0,0 +1,64 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print("testing coverage") + +function foo() + local x = 1 + local y = 2 + assert(x + y) +end + +function bar() + local function one(x) + return x + end + + local two = function(x) + return x + end + + one(1) +end + +function validate(stats, hits, misses) + local checked = {} + + for _,l in ipairs(hits) do + if not (stats[l] and stats[l] > 0) then + return false, string.format("expected line %d to be hit", l) + end + checked[l] = true + end + + for _,l in ipairs(misses) do + if not (stats[l] and stats[l] == 0) then + return false, string.format("expected line %d to be missed", l) + end + checked[l] = true + end + + for k,v in pairs(stats) do + if type(k) == "number" and not checked[k] then + return false, string.format("expected line %d to be absent", k) + end + end + + return true +end + +foo() +c = getcoverage(foo) +assert(#c == 1) +assert(c[1].name == "foo") +assert(validate(c[1], {5, 6, 7}, {})) + +bar() +c = getcoverage(bar) +assert(#c == 3) +assert(c[1].name == "bar") +assert(validate(c[1], {11, 15, 19}, {})) +assert(c[2].name == "one") +assert(validate(c[2], {12}, {})) +assert(c[3].name == nil) +assert(validate(c[3], {}, {16})) + +return 'OK'