diff --git a/Analysis/include/Luau/Documentation.h b/Analysis/include/Luau/Documentation.h index 7b609b4f..68ff3a7c 100644 --- a/Analysis/include/Luau/Documentation.h +++ b/Analysis/include/Luau/Documentation.h @@ -12,10 +12,17 @@ namespace Luau struct FunctionDocumentation; struct TableDocumentation; struct OverloadedFunctionDocumentation; +struct BasicDocumentation; -using Documentation = Luau::Variant; +using Documentation = Luau::Variant; using DocumentationSymbol = std::string; +struct BasicDocumentation +{ + std::string documentation; + std::string learnMoreLink; +}; + struct FunctionParameterDocumentation { std::string name; @@ -29,6 +36,7 @@ struct FunctionDocumentation std::string documentation; std::vector parameters; std::vector returns; + std::string learnMoreLink; }; struct OverloadedFunctionDocumentation @@ -43,6 +51,7 @@ struct TableDocumentation { std::string documentation; Luau::DenseHashMap keys; + std::string learnMoreLink; }; using DocumentationDatabase = Luau::DenseHashMap; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index e5683fc4..50379c1c 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -23,10 +23,11 @@ struct ToStringNameMap struct ToStringOptions { - bool exhaustive = false; // If true, we produce complete output rather than comprehensible output - bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable. - bool functionTypeArguments = false; // If true, output function type argument names when they are available - bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' + bool exhaustive = false; // If true, we produce complete output rather than comprehensible output + bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable. + bool functionTypeArguments = false; // If true, output function type argument names when they are available + bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' + bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); std::optional nameMap; @@ -64,6 +65,8 @@ inline std::string toString(TypePackId ty) std::string toString(const TypeVar& tv, const ToStringOptions& opts = {}); std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {}); +std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts = {}); + // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression void dump(TypeId ty); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 306ac77d..78d642c5 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -175,10 +175,10 @@ struct TypeChecker std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, - std::vector& overloadsThatMatchArgCount, std::vector& errors); + std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, const std::vector& errors); - ExprResult reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, + void reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, const std::vector& errors); @@ -282,6 +282,14 @@ public: // Wrapper for merge(l, r, toUnion) but without the lambda junk. void merge(RefinementMap& l, const RefinementMap& r); + // Produce an "emergency backup type" for recovery from type errors. + // This comes in two flavours, depening on whether or not we can make a good guess + // for an error recovery type. + TypeId errorRecoveryType(TypeId guess); + TypePackId errorRecoveryTypePack(TypePackId guess); + TypeId errorRecoveryType(const ScopePtr& scope); + TypePackId errorRecoveryTypePack(const ScopePtr& scope); + private: void prepareErrorsForDisplay(ErrorVec& errVec); void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data); @@ -297,6 +305,10 @@ private: TypeId freshType(const ScopePtr& scope); TypeId freshType(TypeLevel level); + // Produce a new singleton type var. + TypeId singletonType(bool value); + TypeId singletonType(std::string value); + // Returns nullopt if the predicate filters down the TypeId to 0 options. std::optional filterMap(TypeId type, TypeIdPredicate predicate); @@ -330,8 +342,8 @@ private: const std::vector& typePackParams, const Location& location); // Note: `scope` must be a fresh scope. - std::pair, std::vector> createGenericTypes( - const ScopePtr& scope, std::optional levelOpt, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); + std::pair, std::vector> createGenericTypes(const ScopePtr& scope, std::optional levelOpt, + const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); @@ -347,7 +359,6 @@ private: void resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); bool isNonstrictMode() const; @@ -387,12 +398,9 @@ public: const TypeId booleanType; const TypeId threadType; const TypeId anyType; - - const TypeId errorType; const TypeId optionalNumberType; const TypePackId anyTypePack; - const TypePackId errorTypePack; private: int checkRecursionCount = 0; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 6bd7932d..093ea431 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -108,6 +108,79 @@ struct PrimitiveTypeVar } }; +// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md +// Types for true and false +struct BoolSingleton +{ + bool value; + + bool operator==(const BoolSingleton& rhs) const + { + return value == rhs.value; + } + + bool operator!=(const BoolSingleton& rhs) const + { + return !(*this == rhs); + } +}; + +// Types for "foo", "bar" etc. +struct StringSingleton +{ + std::string value; + + bool operator==(const StringSingleton& rhs) const + { + return value == rhs.value; + } + + bool operator!=(const StringSingleton& rhs) const + { + return !(*this == rhs); + } +}; + +// No type for float singletons, partly because === isn't any equalivalence on floats +// (NaN != NaN). + +using SingletonVariant = Luau::Variant; + +struct SingletonTypeVar +{ + explicit SingletonTypeVar(const SingletonVariant& variant) + : variant(variant) + { + } + + explicit SingletonTypeVar(SingletonVariant&& variant) + : variant(std::move(variant)) + { + } + + // Default operator== is C++20. + bool operator==(const SingletonTypeVar& rhs) const + { + return variant == rhs.variant; + } + + bool operator!=(const SingletonTypeVar& rhs) const + { + return !(*this == rhs); + } + + SingletonVariant variant; +}; + +template +const T* get(const SingletonTypeVar* stv) +{ + if (stv) + return get_if(&stv->variant); + else + return nullptr; +} + struct FunctionArgument { Name name; @@ -332,8 +405,8 @@ struct LazyTypeVar using ErrorTypeVar = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = Unifiable::Variant; struct TypeVar final { @@ -410,6 +483,9 @@ bool isGeneric(const TypeId ty); // Checks if a type may be instantiated to one containing generic type binders bool maybeGeneric(const TypeId ty); +// Checks if a type is of the form T1|...|Tn where one of the Ti is a singleton +bool maybeSingleton(TypeId ty); + struct SingletonTypes { const TypeId nilType; @@ -418,16 +494,19 @@ struct SingletonTypes const TypeId booleanType; const TypeId threadType; const TypeId anyType; - const TypeId errorType; const TypeId optionalNumberType; const TypePackId anyTypePack; - const TypePackId errorTypePack; SingletonTypes(); SingletonTypes(const SingletonTypes&) = delete; void operator=(const SingletonTypes&) = delete; + TypeId errorRecoveryType(TypeId guess); + TypePackId errorRecoveryTypePack(TypePackId guess); + TypeId errorRecoveryType(); + TypePackId errorRecoveryTypePack(); + private: std::unique_ptr arena; TypeId makeStringMetatable(); diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index c2e07e46..b47610fc 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -105,6 +105,8 @@ private: struct Error { + // This constructor has to be public, since it's used in TypeVar and TypePack, + // but shouldn't be called directly. Please use errorRecoveryType() instead. Error(); int index; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index be0aadd0..503034a1 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -65,6 +65,7 @@ struct Unifier private: void tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false); void tryUnifyPrimitives(TypeId superTy, TypeId subTy); + void tryUnifySingletons(TypeId superTy, TypeId subTy); void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false); void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); void DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 1c94bb68..6fc0b3f8 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(ElseElseIfCompletionImprovements, false); LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) +LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -198,11 +199,24 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ UnifierSharedState unifierState(&iceReporter); Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState); - unifier.tryUnify(expectedType, actualType); + if (FFlag::LuauAutocompleteAvoidMutation) + { + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + expectedType = clone(expectedType, *typeArena, seenTypes, seenTypePacks, nullptr); + actualType = clone(actualType, *typeArena, seenTypes, seenTypePacks, nullptr); - bool ok = unifier.errors.empty(); - unifier.log.rollback(); - return ok; + auto errors = unifier.canUnify(expectedType, actualType); + return errors.empty(); + } + else + { + unifier.tryUnify(expectedType, actualType); + + bool ok = unifier.errors.empty(); + unifier.log.rollback(); + return ok; + } }; auto expr = node->asExpr(); @@ -1496,11 +1510,9 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName if (!sourceModule) return {}; - TypeChecker& typeChecker = - (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); - ModulePtr module = - (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName) - : frontend.moduleResolver.getModule(moduleName)); + TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + ModulePtr module = (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName) + : frontend.moduleResolver.getModule(moduleName)); if (!module) return {}; @@ -1527,8 +1539,7 @@ OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view sourceModule->mode = Mode::Strict; sourceModule->commentLocations = std::move(result.commentLocations); - TypeChecker& typeChecker = - (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict); diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 96703ef1..9f5c8250 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -153,6 +153,7 @@ declare function gcinfo(): number wrap: ((A...) -> R...) -> any, yield: (A...) -> R..., isyieldable: () -> boolean, + close: (thread) -> (boolean, any?) } declare table: { diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 46ff2c72..f80d50a7 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -180,13 +180,13 @@ struct ErrorConverter switch (e.context) { case CountMismatch::Return: - return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + - std::to_string(e.actual) + " " + actualVerb + " returned here"; + return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " + + actualVerb + " returned here"; case CountMismatch::Result: // It is alright if right hand side produces more values than the // left hand side accepts. In this context consider only the opposite case. - return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + - std::to_string(e.actual) + " are required here"; + return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + std::to_string(e.actual) + + " are required here"; case CountMismatch::Arg: if (FFlag::LuauTypeAliasPacks) return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 2f411274..1e97705d 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -22,7 +22,6 @@ LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) LUAU_FASTFLAG(LuauNewRequireTrace2) -LUAU_FASTFLAGVARIABLE(LuauClearScopes, false) namespace Luau { @@ -458,8 +457,7 @@ CheckResult Frontend::check(const ModuleName& name) module->astTypes.clear(); module->astExpectedTypes.clear(); module->astOriginalCallTypes.clear(); - if (FFlag::LuauClearScopes) - module->scopes.resize(1); + module->scopes.resize(1); } if (mode != Mode::NoCheck) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 880ffd2e..32a0646a 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -161,6 +161,7 @@ struct TypeCloner void operator()(const Unifiable::Bound& t); void operator()(const Unifiable::Error& t); void operator()(const PrimitiveTypeVar& t); + void operator()(const SingletonTypeVar& t); void operator()(const FunctionTypeVar& t); void operator()(const TableTypeVar& t); void operator()(const MetatableTypeVar& t); @@ -199,7 +200,9 @@ struct TypePackCloner if (encounteredFreeType) *encounteredFreeType = true; - seenTypePacks[typePackId] = dest.addTypePack(TypePackVar{Unifiable::Error{}}); + TypePackId err = singletonTypes.errorRecoveryTypePack(singletonTypes.anyTypePack); + TypePackId cloned = dest.addTypePack(*err); + seenTypePacks[typePackId] = cloned; } void operator()(const Unifiable::Generic& t) @@ -251,8 +254,9 @@ void TypeCloner::operator()(const Unifiable::Free& t) { if (encounteredFreeType) *encounteredFreeType = true; - - seenTypes[typeId] = dest.addType(ErrorTypeVar{}); + TypeId err = singletonTypes.errorRecoveryType(singletonTypes.anyType); + TypeId cloned = dest.addType(*err); + seenTypes[typeId] = cloned; } void TypeCloner::operator()(const Unifiable::Generic& t) @@ -270,11 +274,17 @@ void TypeCloner::operator()(const Unifiable::Error& t) { defaultClone(t); } + void TypeCloner::operator()(const PrimitiveTypeVar& t) { defaultClone(t); } +void TypeCloner::operator()(const SingletonTypeVar& t) +{ + defaultClone(t); +} + void TypeCloner::operator()(const FunctionTypeVar& t) { TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 885fd489..735bfa50 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -350,6 +350,23 @@ struct TypeVarStringifier } } + void operator()(TypeId, const SingletonTypeVar& stv) + { + if (const BoolSingleton* bs = Luau::get(&stv)) + state.emit(bs->value ? "true" : "false"); + else if (const StringSingleton* ss = Luau::get(&stv)) + { + state.emit("\""); + state.emit(escape(ss->value)); + state.emit("\""); + } + else + { + LUAU_ASSERT(!"Unknown singleton type"); + throw std::runtime_error("Unknown singleton type"); + } + } + void operator()(TypeId, const FunctionTypeVar& ftv) { if (state.hasSeen(&ftv)) @@ -359,6 +376,7 @@ struct TypeVarStringifier return; } + // We should not be respecting opts.hideNamedFunctionTypeParameters here. if (ftv.generics.size() > 0 || ftv.genericPacks.size() > 0) { state.emit("<"); @@ -514,7 +532,14 @@ struct TypeVarStringifier break; } - state.emit(name); + if (isIdentifier(name)) + state.emit(name); + else + { + state.emit("[\""); + state.emit(escape(name)); + state.emit("\"]"); + } state.emit(": "); stringify(prop.type); comma = true; @@ -1084,6 +1109,94 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts) return toString(const_cast(&tp), std::move(opts)); } +std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) +{ + std::string s = prefix; + + auto toString_ = [&opts](TypeId ty) -> std::string { + ToStringResult res = toStringDetailed(ty, opts); + opts.nameMap = std::move(res.nameMap); + return res.name; + }; + + auto toStringPack_ = [&opts](TypePackId ty) -> std::string { + ToStringResult res = toStringDetailed(ty, opts); + opts.nameMap = std::move(res.nameMap); + return res.name; + }; + + if (!opts.hideNamedFunctionTypeParameters && (!ftv.generics.empty() || !ftv.genericPacks.empty())) + { + s += "<"; + + bool first = true; + for (TypeId g : ftv.generics) + { + if (!first) + s += ", "; + first = false; + s += toString_(g); + } + + for (TypePackId gp : ftv.genericPacks) + { + if (!first) + s += ", "; + first = false; + s += toStringPack_(gp); + } + + s += ">"; + } + + s += "("; + + auto argPackIter = begin(ftv.argTypes); + auto argNameIter = ftv.argNames.begin(); + + bool first = true; + while (argPackIter != end(ftv.argTypes)) + { + if (!first) + s += ", "; + first = false; + + // argNames is guaranteed to be equal to argTypes iff argNames is not empty. + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (!ftv.argNames.empty()) + s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; + s += toString_(*argPackIter); + + ++argPackIter; + if (!ftv.argNames.empty()) + { + LUAU_ASSERT(argNameIter != ftv.argNames.end()); + ++argNameIter; + } + } + + if (argPackIter.tail()) + { + if (auto vtp = get(*argPackIter.tail())) + s += ", ...: " + toString_(vtp->ty); + else + s += ", ...: " + toStringPack_(*argPackIter.tail()); + } + + s += "): "; + + size_t retSize = size(ftv.retType); + bool hasTail = !finite(ftv.retType); + if (retSize == 0 && !hasTail) + s += "()"; + else if ((retSize == 0 && hasTail) || (retSize == 1 && !hasTail)) + s += toStringPack_(ftv.retType); + else + s += "(" + toStringPack_(ftv.retType) + ")"; + + return s; +} + void dump(TypeId ty) { ToStringOptions opts; diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 7d880af4..6627fbe3 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -14,61 +14,6 @@ LUAU_FASTFLAG(LuauTypeAliasPacks) namespace { - -std::string escape(std::string_view s) -{ - std::string r; - r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting - - for (uint8_t c : s) - { - if (c >= ' ' && c != '\\' && c != '\'' && c != '\"') - r += c; - else - { - r += '\\'; - - switch (c) - { - case '\a': - r += 'a'; - break; - case '\b': - r += 'b'; - break; - case '\f': - r += 'f'; - break; - case '\n': - r += 'n'; - break; - case '\r': - r += 'r'; - break; - case '\t': - r += 't'; - break; - case '\v': - r += 'v'; - break; - case '\'': - r += '\''; - break; - case '\"': - r += '\"'; - break; - case '\\': - r += '\\'; - break; - default: - Luau::formatAppend(r, "%03u", c); - } - } - } - - return r; -} - bool isIdentifierStartChar(char c) { return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c == '_'; diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 11aa7b39..af6d2543 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -96,6 +96,22 @@ public: return nullptr; } } + + AstType* operator()(const SingletonTypeVar& stv) + { + if (const BoolSingleton* bs = get(&stv)) + return allocator->alloc(Location(), bs->value); + else if (const StringSingleton* ss = get(&stv)) + { + AstArray value; + value.data = const_cast(ss->value.c_str()); + value.size = strlen(value.data); + return allocator->alloc(Location(), value); + } + else + return nullptr; + } + AstType* operator()(const AnyTypeVar&) { return allocator->alloc(Location(), std::nullopt, AstName("any")); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 96b79923..b2ae94c7 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -36,6 +36,9 @@ LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAG(LuauNewRequireTrace2) LUAU_FASTFLAG(LuauTypeAliasPacks) +LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) +LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) namespace Luau { @@ -211,10 +214,8 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , booleanType(singletonTypes.booleanType) , threadType(singletonTypes.threadType) , anyType(singletonTypes.anyType) - , errorType(singletonTypes.errorType) , optionalNumberType(singletonTypes.optionalNumberType) , anyTypePack(singletonTypes.anyTypePack) - , errorTypePack(singletonTypes.errorTypePack) { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -484,7 +485,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std TypeId type = bindings[name].type; if (get(follow(type))) { - *asMutable(type) = ErrorTypeVar{}; + *asMutable(type) = *errorRecoveryType(anyType); reportError(TypeError{typealias->location, OccursCheckFailed{}}); } } @@ -719,7 +720,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) else if (auto tail = valueIter.tail()) { if (get(*tail)) - right = errorType; + right = errorRecoveryType(scope); else if (auto vtp = get(*tail)) right = vtp->ty; else if (get(*tail)) @@ -961,7 +962,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) else if (get(callRetPack) || !first(callRetPack)) { for (TypeId var : varTypes) - unify(var, errorType, forin.location); + unify(var, errorRecoveryType(scope), forin.location); return check(loopScope, *forin.body); } @@ -979,7 +980,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) const FunctionTypeVar* iterFunc = get(iterTy); if (!iterFunc) { - TypeId varTy = get(iterTy) ? anyType : errorType; + TypeId varTy = get(iterTy) ? anyType : errorRecoveryType(loopScope); for (TypeId var : varTypes) unify(var, varTy, forin.location); @@ -1152,9 +1153,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); if (FFlag::LuauTypeAliasPacks) - bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorType}; + bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; else - bindingsMap[name] = TypeFun{binding->typeParams, errorType}; + bindingsMap[name] = TypeFun{binding->typeParams, errorRecoveryType(anyType)}; } else { @@ -1398,7 +1399,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) { reportErrorCodeTooComplex(expr.location); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult result; @@ -1407,12 +1408,22 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& result = checkExpr(scope, *a->expr); else if (expr.is()) result = {nilType}; - else if (expr.is()) - result = {booleanType}; + else if (const AstExprConstantBool* bexpr = expr.as()) + { + if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType)) + result = {singletonType(bexpr->value)}; + else + result = {booleanType}; + } + else if (const AstExprConstantString* sexpr = expr.as()) + { + if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType)) + result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))}; + else + result = {stringType}; + } else if (expr.is()) result = {numberType}; - else if (expr.is()) - result = {stringType}; else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) @@ -1485,7 +1496,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLo // TODO: tempting to ice here, but this breaks very often because our toposort doesn't enforce this constraint // ice("AstExprLocal exists but no binding definition for it?", expr.location); reportError(TypeError{expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}}); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) @@ -1497,7 +1508,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGl return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}}; reportError(TypeError{expr.location, UnknownSymbol{expr.name.value, UnknownSymbol::Binding}}); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) @@ -1517,14 +1528,14 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa return {head}; } if (get(varargPack)) - return {errorType}; + return {errorRecoveryType(scope)}; else if (auto vtp = get(varargPack)) return {vtp->ty}; else if (get(varargPack)) { // TODO: Better error? reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"}); - return {errorType}; + return {errorRecoveryType(scope)}; } else ice("Unknown TypePack type in checkExpr(AstExprVarargs)!"); @@ -1547,7 +1558,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa return {head, std::move(result.predicates)}; } if (get(retPack)) - return {errorType, std::move(result.predicates)}; + return {errorRecoveryType(scope), std::move(result.predicates)}; else if (auto vtp = get(retPack)) return {vtp->ty, std::move(result.predicates)}; else if (get(retPack)) @@ -1572,7 +1583,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true)) return {*ty}; - return {errorType}; + return {errorRecoveryType(scope)}; } std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location) @@ -1876,6 +1887,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa std::vector> fieldTypes(expr.items.size); const TableTypeVar* expectedTable = nullptr; + const UnionTypeVar* expectedUnion = nullptr; std::optional expectedIndexType; std::optional expectedIndexResultType; @@ -1894,6 +1906,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa } } } + else if (FFlag::LuauExpectedTypesOfProperties) + if (const UnionTypeVar* utv = get(follow(*expectedType))) + expectedUnion = utv; } for (size_t i = 0; i < expr.items.size; ++i) @@ -1916,6 +1931,18 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) expectedResultType = prop->second.type; } + else if (FFlag::LuauExpectedTypesOfProperties && expectedUnion) + { + std::vector expectedResultTypes; + for (TypeId expectedOption : expectedUnion) + if (const TableTypeVar* ttv = get(follow(expectedOption))) + if (auto prop = ttv->props.find(key->value.data); prop != ttv->props.end()) + expectedResultTypes.push_back(prop->second.type); + if (expectedResultTypes.size() == 1) + expectedResultType = expectedResultTypes[0]; + else if (expectedResultTypes.size() > 1) + expectedResultType = addType(UnionTypeVar{expectedResultTypes}); + } } else { @@ -1958,21 +1985,22 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn { TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); TypePackId arguments = addTypePack({operandType}); - TypePackId retType = freshTypePack(scope); - TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retType)); + TypePackId retTypePack = freshTypePack(scope); + TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(expr.location); state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); + TypeId retType = first(retTypePack).value_or(nilType); if (!state.errors.empty()) - return {errorType}; + retType = errorRecoveryType(retType); - return {first(retType).value_or(nilType)}; + return {retType}; } reportError(expr.location, GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())}); - return {errorType}; + return {errorRecoveryType(scope)}; } reportErrors(tryUnify(numberType, operandType, expr.location)); @@ -1984,7 +2012,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn operandType = stripFromNilAndReport(operandType, expr.location); if (get(operandType)) - return {errorType}; + return {errorRecoveryType(scope)}; if (get(operandType)) return {numberType}; // Not strictly correct: metatables permit overriding this @@ -2044,7 +2072,7 @@ TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const Location& location, b if (unify(a, b, location)) return a; - return errorType; + return errorRecoveryType(anyType); } if (*a == *b) @@ -2166,11 +2194,13 @@ TypeId TypeChecker::checkRelationalOperation( std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType)); std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType)); + // TODO: this check seems odd, the second part is redundant + // is it meant to be if (leftMetatable && rightMetatable && leftMetatable != rightMetatable) if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) { reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); - return errorType; + return errorRecoveryType(booleanType); } if (leftMetatable) @@ -2188,7 +2218,7 @@ TypeId TypeChecker::checkRelationalOperation( if (!state.errors.empty()) { reportError(expr.location, GenericError{format("Metamethod '%s' must return type 'boolean'", metamethodName.c_str())}); - return errorType; + return errorRecoveryType(booleanType); } } } @@ -2206,7 +2236,7 @@ TypeId TypeChecker::checkRelationalOperation( { reportError( expr.location, GenericError{format("Table %s does not offer metamethod %s", toString(lhsType).c_str(), metamethodName.c_str())}); - return errorType; + return errorRecoveryType(booleanType); } } @@ -2214,14 +2244,14 @@ TypeId TypeChecker::checkRelationalOperation( { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison}); - return errorType; + return errorRecoveryType(booleanType); } if (needsMetamethod) { reportError(expr.location, GenericError{format("Type %s cannot be compared with %s because it has no metatable", toString(lhsType).c_str(), toString(expr.op).c_str())}); - return errorType; + return errorRecoveryType(booleanType); } return booleanType; @@ -2266,7 +2296,8 @@ TypeId TypeChecker::checkBinaryOperation( { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); - return errorType; + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); } // If we know nothing at all about the lhs type, we can usually say nothing about the result. @@ -2296,18 +2327,33 @@ TypeId TypeChecker::checkBinaryOperation( auto checkMetatableCall = [this, &scope, &expr](TypeId fnt, TypeId lhst, TypeId rhst) -> TypeId { TypeId actualFunctionType = instantiate(scope, fnt, expr.location); TypePackId arguments = addTypePack({lhst, rhst}); - TypePackId retType = freshTypePack(scope); - TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retType)); + TypePackId retTypePack = freshTypePack(scope); + TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(expr.location); state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); reportErrors(state.errors); + bool hasErrors = !state.errors.empty(); - if (!state.errors.empty()) - return errorType; + if (FFlag::LuauErrorRecoveryType && hasErrors) + { + // If there are unification errors, the return type may still be unknown + // so we loosen the argument types to see if that helps. + TypePackId fallbackArguments = freshTypePack(scope); + TypeId fallbackFunctionType = addType(FunctionTypeVar(scope->level, fallbackArguments, retTypePack)); + state.log.rollback(); + state.errors.clear(); + state.tryUnify(fallbackFunctionType, actualFunctionType, /*isFunctionCall*/ true); + if (!state.errors.empty()) + state.log.rollback(); + } - return first(retType).value_or(nilType); + TypeId retType = first(retTypePack).value_or(nilType); + if (hasErrors) + retType = errorRecoveryType(retType); + + return retType; }; std::string op = opToMetaTableEntry(expr.op); @@ -2321,7 +2367,8 @@ TypeId TypeChecker::checkBinaryOperation( reportError(expr.location, GenericError{format("Binary operator '%s' not supported by types '%s' and '%s'", toString(expr.op).c_str(), toString(lhsType).c_str(), toString(rhsType).c_str())}); - return errorType; + + return errorRecoveryType(scope); } switch (expr.op) @@ -2414,11 +2461,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy ExprResult result = checkExpr(scope, *expr.expr, annotationType); ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); + reportErrors(errorVec); if (!errorVec.empty()) - { - reportErrors(errorVec); - return {errorType, std::move(result.predicates)}; - } + annotationType = errorRecoveryType(annotationType); return {annotationType, std::move(result.predicates)}; } @@ -2434,7 +2479,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr // any type errors that may arise from it are going to be useless. currentModule->errors.resize(oldSize); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr) @@ -2476,7 +2521,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope { for (AstExpr* expr : a->expressions) checkExpr(scope, *expr); - return std::pair(errorType, nullptr); + return {errorRecoveryType(scope), nullptr}; } else ice("Unexpected AST node in checkLValue", expr.location); @@ -2488,7 +2533,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope return {*ty, nullptr}; reportError(expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}); - return {errorType, nullptr}; + return {errorRecoveryType(scope), nullptr}; } std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr) @@ -2545,24 +2590,25 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope { Unifier state = mkUnifier(expr.location); state.tryUnify(indexer->indexType, stringType); + TypeId retType = indexer->indexResultType; if (!state.errors.empty()) { state.log.rollback(); reportError(expr.location, UnknownProperty{lhs, name}); - return std::pair(errorType, nullptr); + retType = errorRecoveryType(retType); } - return std::pair(indexer->indexResultType, nullptr); + return std::pair(retType, nullptr); } else if (lhsTable->state == TableState::Sealed) { reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } else { reportError(TypeError{expr.location, GenericError{"Internal error: generic tables are not lvalues"}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } } else if (const ClassTypeVar* lhsClass = get(lhs)) @@ -2571,7 +2617,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!prop) { reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } return std::pair(prop->type, nullptr); @@ -2585,12 +2631,12 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (isTableIntersection(lhs)) { reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } } reportError(TypeError{expr.location, NotATable{lhs}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr) @@ -2615,7 +2661,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!prop) { reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } return std::pair(prop->type, nullptr); } @@ -2626,7 +2672,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!exprTable) { reportError(TypeError{expr.expr->location, NotATable{exprType}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } if (value) @@ -2678,7 +2724,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) if (isNonstrictMode()) return globalScope->bindings[name].typeId; - return errorType; + return errorRecoveryType(scope); } else { @@ -2705,20 +2751,21 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) TableTypeVar* ttv = getMutableTableType(lhsType); if (!ttv) { - if (!isTableIntersection(lhsType)) + if (!FFlag::LuauErrorRecoveryType && !isTableIntersection(lhsType)) + // This error now gets reported when we check the function body. reportError(TypeError{funName.location, OnlyTablesCanHaveMethods{lhsType}}); - return errorType; + return errorRecoveryType(scope); } // Cannot extend sealed table, but we dont report an error here because it will be reported during AstStatFunction check if (lhsType->persistent || ttv->state == TableState::Sealed) - return errorType; + return errorRecoveryType(scope); Name name = indexName->index.value; if (ttv->props.count(name)) - return errorType; + return errorRecoveryType(scope); Property& property = ttv->props[name]; @@ -2728,9 +2775,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) return property.type; } else if (funName.is()) - { - return errorType; - } + return errorRecoveryType(scope); else { ice("Unexpected AST node type", funName.location); @@ -2991,7 +3036,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A else if (expr.is()) { if (!scope->varargPack) - return {addTypePack({addType(ErrorTypeVar())})}; + return {errorRecoveryTypePack(scope)}; return {*scope->varargPack}; } @@ -3095,10 +3140,9 @@ void TypeChecker::checkArgumentList( if (get(tail)) { // Unify remaining parameters so we don't leave any free-types hanging around. - TypeId argTy = errorType; while (paramIter != endIter) { - state.tryUnify(*paramIter, argTy); + state.tryUnify(*paramIter, errorRecoveryType(anyType)); ++paramIter; } return; @@ -3157,7 +3201,7 @@ void TypeChecker::checkArgumentList( { while (argIter != endIter) { - unify(*argIter, errorType, state.location); + unify(*argIter, errorRecoveryType(scope), state.location); ++argIter; } // For this case, we want the error span to cover every errant extra parameter @@ -3246,7 +3290,8 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A // For each overload // Compare parameter and argument types // Report any errors (also speculate dot vs colon warnings!) - // If there are no errors, return the resulting return type + // Return the resulting return type (even if there are errors) + // If there are no matching overloads, unify with (a...) -> (b...) and return b... TypeId selfType = nullptr; TypeId functionType = nullptr; @@ -3268,8 +3313,8 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A } else { - functionType = errorType; - actualFunctionType = errorType; + functionType = errorRecoveryType(scope); + actualFunctionType = functionType; } } else @@ -3296,7 +3341,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A TypePackId argPack = argListResult.type; if (get(argPack)) - return ExprResult{errorTypePack}; + return {errorRecoveryTypePack(scope)}; TypePack* args = getMutable(argPack); LUAU_ASSERT(args != nullptr); @@ -3314,19 +3359,34 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A std::vector errors; // errors encountered for each overload std::vector overloadsThatMatchArgCount; + std::vector overloadsThatDont; for (TypeId fn : overloads) { fn = follow(fn); - if (auto ret = checkCallOverload(scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, errors)) + if (auto ret = checkCallOverload( + scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) return *ret; } if (handleSelfCallMismatch(scope, expr, args, argLocations, errors)) return {retPack}; - return reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); + reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); + + if (FFlag::LuauErrorRecoveryType) + { + const FunctionTypeVar* overload = nullptr; + if (!overloadsThatMatchArgCount.empty()) + overload = get(overloadsThatMatchArgCount[0]); + if (!overload && !overloadsThatDont.empty()) + overload = get(overloadsThatDont[0]); + if (overload) + return {errorRecoveryTypePack(overload->retType)}; + } + + return {errorRecoveryTypePack(retPack)}; } std::vector> TypeChecker::getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall) @@ -3382,7 +3442,7 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, - std::vector& overloadsThatMatchArgCount, std::vector& errors) + std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) { fn = stripFromNilAndReport(fn, expr.func->location); @@ -3394,7 +3454,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (get(fn)) { - return {{addTypePack(TypePackVar{Unifiable::Error{}})}}; + return {{errorRecoveryTypePack(scope)}}; } if (get(fn)) @@ -3427,14 +3487,14 @@ std::optional> TypeChecker::checkCallOverload(const Scope TypeId fn = *ty; fn = instantiate(scope, fn, expr.func->location); - return checkCallOverload( - scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, overloadsThatMatchArgCount, errors); + return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, + overloadsThatMatchArgCount, overloadsThatDont, errors); } } reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); - unify(retPack, errorTypePack, expr.func->location); - return {{errorTypePack}}; + unify(retPack, errorRecoveryTypePack(scope), expr.func->location); + return {{errorRecoveryTypePack(retPack)}}; } // When this function type has magic functions and did return something, we select that overload instead. @@ -3476,6 +3536,8 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (!argMismatch) overloadsThatMatchArgCount.push_back(fn); + else if (FFlag::LuauErrorRecoveryType) + overloadsThatDont.push_back(fn); errors.emplace_back(std::move(state.errors), args->head, ftv); state.log.rollback(); @@ -3586,14 +3648,14 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal return false; } -ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, - TypePackId argPack, const std::vector& argLocations, const std::vector& overloads, - const std::vector& overloadsThatMatchArgCount, const std::vector& errors) +void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, + const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, + const std::vector& errors) { if (overloads.size() == 1) { reportErrors(std::get<0>(errors.front())); - return {errorTypePack}; + return; } std::vector overloadTypes = overloadsThatMatchArgCount; @@ -3622,7 +3684,7 @@ ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr // If only one overload matched, we don't need this error because we provided the previous errors. if (overloadsThatMatchArgCount.size() == 1) - return {errorTypePack}; + return; } std::string s; @@ -3655,7 +3717,7 @@ ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr reportError(expr.func->location, ExtraInformation{"Other overloads are also not viable: " + s}); // No viable overload - return {errorTypePack}; + return; } ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, @@ -3740,7 +3802,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) { reportError(TypeError{location, UnknownRequire{}}); - return errorType; + return errorRecoveryType(anyType); } return anyType; @@ -3758,14 +3820,14 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module reportError(TypeError{location, UnknownRequire{reportedModulePath}}); } - return errorType; + return errorRecoveryType(scope); } if (module->type != SourceCode::Module) { std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); - return errorType; + return errorRecoveryType(scope); } std::optional moduleType = first(module->getModuleScope()->returnType); @@ -3773,7 +3835,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module { std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); - return errorType; + return errorRecoveryType(scope); } SeenTypes seenTypes; @@ -4078,7 +4140,7 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location if (!qty.has_value()) { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(scope); } if (ty == *qty) @@ -4101,7 +4163,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat else { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(scope); } } @@ -4116,7 +4178,7 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) else { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(anyType); } } @@ -4131,7 +4193,7 @@ TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location lo else { reportError(location, UnificationTooComplex{}); - return errorTypePack; + return errorRecoveryTypePack(anyTypePack); } } @@ -4279,6 +4341,38 @@ TypeId TypeChecker::freshType(TypeLevel level) return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level))); } +TypeId TypeChecker::singletonType(bool value) +{ + // TODO: cache singleton types + return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BoolSingleton{value}))); +} + +TypeId TypeChecker::singletonType(std::string value) +{ + // TODO: cache singleton types + return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(StringSingleton{std::move(value)}))); +} + +TypeId TypeChecker::errorRecoveryType(const ScopePtr& scope) +{ + return singletonTypes.errorRecoveryType(); +} + +TypeId TypeChecker::errorRecoveryType(TypeId guess) +{ + return singletonTypes.errorRecoveryType(guess); +} + +TypePackId TypeChecker::errorRecoveryTypePack(const ScopePtr& scope) +{ + return singletonTypes.errorRecoveryTypePack(); +} + +TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) +{ + return singletonTypes.errorRecoveryTypePack(guess); +} + std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) { std::vector types = Luau::filterMap(type, predicate); @@ -4350,7 +4444,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (lit->parameters.size != 1 || !lit->parameters.data[0].type) { reportError(TypeError{annotation.location, GenericError{"_luau_print requires one generic parameter"}}); - return addType(ErrorTypeVar{}); + return errorRecoveryType(anyType); } ToStringOptions opts; @@ -4368,7 +4462,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (!tf) { if (lit->name == Parser::errorName) - return addType(ErrorTypeVar{}); + return errorRecoveryType(scope); std::string typeName; if (lit->hasPrefix) @@ -4380,7 +4474,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else reportError(TypeError{annotation.location, UnknownSymbol{typeName, UnknownSymbol::Type}}); - return addType(ErrorTypeVar{}); + return errorRecoveryType(scope); } if (lit->parameters.size == 0 && tf->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf->typePackParams.empty())) @@ -4390,14 +4484,17 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else if (!FFlag::LuauTypeAliasPacks && lit->parameters.size != tf->typeParams.size()) { reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->parameters.size, 0}}); - return addType(ErrorTypeVar{}); + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); } - else if (FFlag::LuauTypeAliasPacks) + + if (FFlag::LuauTypeAliasPacks) { if (!lit->hasParameterList && !tf->typePackParams.empty()) { reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); - return addType(ErrorTypeVar{}); + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); } std::vector typeParams; @@ -4445,7 +4542,17 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation { reportError( TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); - return addType(ErrorTypeVar{}); + + if (FFlag::LuauErrorRecoveryType) + { + // Pad the types out with error recovery types + while (typeParams.size() < tf->typeParams.size()) + typeParams.push_back(errorRecoveryType(scope)); + while (typePackParams.size() < tf->typePackParams.size()) + typePackParams.push_back(errorRecoveryTypePack(scope)); + } + else + return errorRecoveryType(scope); } if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) @@ -4464,6 +4571,14 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation for (const auto& param : lit->parameters) typeParams.push_back(resolveType(scope, *param.type)); + if (FFlag::LuauErrorRecoveryType) + { + // If there aren't enough type parameters, pad them out with error recovery types + // (we've already reported the error) + while (typeParams.size() < lit->parameters.size) + typeParams.push_back(errorRecoveryType(scope)); + } + if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams) { // If the generic parameters and the type arguments are the same, we are about to @@ -4483,8 +4598,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation props[prop.name.value] = {resolveType(scope, *prop.type)}; if (const auto& indexer = table->indexer) - tableIndexer = TableIndexer( - resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); + tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); return addType(TableTypeVar{ props, tableIndexer, scope->level, @@ -4536,14 +4650,20 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return addType(IntersectionTypeVar{types}); } - else if (annotation.is()) + else if (const auto& tsb = annotation.as()) { - return addType(ErrorTypeVar{}); + return singletonType(tsb->value); } + else if (const auto& tss = annotation.as()) + { + return singletonType(std::string(tss->value.data, tss->value.size)); + } + else if (annotation.is()) + return errorRecoveryType(scope); else { reportError(TypeError{annotation.location, GenericError{"Unknown type annotation?"}}); - return addType(ErrorTypeVar{}); + return errorRecoveryType(scope); } } @@ -4584,7 +4704,7 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack else reportError(TypeError{generic->location, UnknownSymbol{genericName, UnknownSymbol::Type}}); - return addTypePack(TypePackVar{Unifiable::Error{}}); + return errorRecoveryTypePack(scope); } return *genericTy; @@ -4706,12 +4826,12 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, if (!maybeInstantiated.has_value()) { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(scope); } if (FFlag::LuauRecursiveTypeParameterRestriction && applyTypeFunction.encounteredForwardedType) { reportError(TypeError{location, GenericError{"Recursive type being used with different parameters"}}); - return errorType; + return errorRecoveryType(scope); } TypeId instantiated = *maybeInstantiated; @@ -4773,8 +4893,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, return instantiated; } -std::pair, std::vector> TypeChecker::createGenericTypes( - const ScopePtr& scope, std::optional levelOpt, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) +std::pair, std::vector> TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional levelOpt, + const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) { LUAU_ASSERT(scope->parent); @@ -5043,7 +5163,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement addRefinement(refis, isaP.lvalue, *result); else { - addRefinement(refis, isaP.lvalue, errorType); + addRefinement(refis, isaP.lvalue, errorRecoveryType(scope)); errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}}); } } @@ -5107,7 +5227,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec addRefinement(refis, typeguardP.lvalue, *result); else { - addRefinement(refis, typeguardP.lvalue, errorType); + addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); if (sense) errVec.push_back( TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}}); @@ -5118,7 +5238,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec auto fail = [&](const TypeErrorData& err) { errVec.push_back(TypeError{typeguardP.location, err}); - addRefinement(refis, typeguardP.lvalue, errorType); + addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); }; if (!typeguardP.isTypeof) @@ -5139,28 +5259,6 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec return resolve(IsAPredicate{std::move(typeguardP.lvalue), typeguardP.location, type}, errVec, refis, scope, sense); } -void TypeChecker::DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) -{ - if (!sense) - return; - - static std::vector primitives{ - "string", "number", "boolean", "nil", "thread", - "table", // no op. Requires special handling. - "function", // no op. Requires special handling. - "userdata", // no op. Requires special handling. - }; - - if (auto typeFun = globalScope->lookupType(typeguardP.kind); - typeFun && typeFun->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || typeFun->typePackParams.empty())) - { - if (auto it = std::find(primitives.begin(), primitives.end(), typeguardP.kind); it != primitives.end()) - addRefinement(refis, typeguardP.lvalue, typeFun->type); - else if (typeguardP.isTypeof) - addRefinement(refis, typeguardP.lvalue, typeFun->type); - } -} - void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) { // This refinement will require success typing to do everything correctly. For now, we can get most of the way there. diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 228b1926..d3221c73 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -286,5 +286,4 @@ TypePack* asMutable(const TypePack* tp) { return const_cast(tp); } - } // namespace Luau diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 4ee5cb82..924bf082 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -21,6 +21,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) +LUAU_FASTFLAG(LuauErrorRecoveryType) namespace Luau { @@ -305,6 +306,18 @@ bool maybeGeneric(TypeId ty) return isGeneric(ty); } +bool maybeSingleton(TypeId ty) +{ + ty = follow(ty); + if (get(ty)) + return true; + if (const UnionTypeVar* utv = get(ty)) + for (TypeId option : utv) + if (get(follow(option))) + return true; + return false; +} + FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) : argTypes(argTypes) , retType(retType) @@ -562,10 +575,8 @@ SingletonTypes::SingletonTypes() , booleanType(&booleanType_) , threadType(&threadType_) , anyType(&anyType_) - , errorType(&errorType_) , optionalNumberType(&optionalNumberType_) , anyTypePack(&anyTypePack_) - , errorTypePack(&errorTypePack_) , arena(new TypeArena) { TypeId stringMetatable = makeStringMetatable(); @@ -634,6 +645,32 @@ TypeId SingletonTypes::makeStringMetatable() return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } +TypeId SingletonTypes::errorRecoveryType() +{ + return &errorType_; +} + +TypePackId SingletonTypes::errorRecoveryTypePack() +{ + return &errorTypePack_; +} + +TypeId SingletonTypes::errorRecoveryType(TypeId guess) +{ + if (FFlag::LuauErrorRecoveryType) + return guess; + else + return &errorType_; +} + +TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess) +{ + if (FFlag::LuauErrorRecoveryType) + return guess; + else + return &errorTypePack_; +} + SingletonTypes singletonTypes; void persist(TypeId ty) @@ -1141,6 +1178,11 @@ struct QVarFinder return false; } + bool operator()(const SingletonTypeVar&) const + { + return false; + } + bool operator()(const FunctionTypeVar& ftv) const { if (hasGeneric(ftv.argTypes)) @@ -1412,7 +1454,7 @@ static std::vector parseFormatString(TypeChecker& typechecker, const cha else if (strchr(options, data[i])) result.push_back(typechecker.numberType); else - result.push_back(typechecker.errorType); + result.push_back(typechecker.errorRecoveryType(typechecker.anyType)); } } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 82f621b6..e1a52be4 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,7 +22,9 @@ LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) LUAU_FASTFLAG(LuauShareTxnSeen); LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false) LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) +LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) +LUAU_FASTFLAG(LuauErrorRecoveryType); namespace Luau { @@ -211,6 +213,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { occursCheck(subTy, superTy); + // The occurrence check might have caused superTy no longer to be a free type if (!get(subTy)) { log(subTy); @@ -221,10 +224,20 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool } else if (l && r) { - log(superTy); + if (!FFlag::LuauErrorRecoveryType) + log(superTy); occursCheck(superTy, subTy); r->level = min(r->level, l->level); - *asMutable(superTy) = BoundTypeVar(subTy); + + // The occurrence check might have caused superTy no longer to be a free type + if (!FFlag::LuauErrorRecoveryType) + *asMutable(superTy) = BoundTypeVar(subTy); + else if (!get(superTy)) + { + log(superTy); + *asMutable(superTy) = BoundTypeVar(subTy); + } + return; } else if (l) @@ -240,6 +253,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool return; } + // The occurrence check might have caused superTy no longer to be a free type if (!get(superTy)) { if (auto rightLevel = getMutableLevel(subTy)) @@ -251,6 +265,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool log(superTy); *asMutable(superTy) = BoundTypeVar(subTy); } + return; } else if (r) @@ -512,6 +527,9 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool else if (get(superTy) && get(subTy)) tryUnifyPrimitives(superTy, subTy); + else if (FFlag::LuauSingletonTypes && (get(superTy) || get(superTy)) && get(subTy)) + tryUnifySingletons(superTy, subTy); + else if (get(superTy) && get(subTy)) tryUnifyFunctions(superTy, subTy, isFunctionCall); @@ -723,17 +741,18 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal { occursCheck(superTp, subTp); + // The occurrence check might have caused superTp no longer to be a free type if (!get(superTp)) { log(superTp); *asMutable(superTp) = Unifiable::Bound(subTp); } } - else if (get(subTp)) { occursCheck(subTp, superTp); + // The occurrence check might have caused superTp no longer to be a free type if (!get(subTp)) { log(subTp); @@ -874,13 +893,13 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal while (superIter.good()) { - tryUnify_(singletonTypes.errorType, *superIter); + tryUnify_(singletonTypes.errorRecoveryType(), *superIter); superIter.advance(); } while (subIter.good()) { - tryUnify_(singletonTypes.errorType, *subIter); + tryUnify_(singletonTypes.errorRecoveryType(), *subIter); subIter.advance(); } @@ -906,6 +925,27 @@ void Unifier::tryUnifyPrimitives(TypeId superTy, TypeId subTy) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } +void Unifier::tryUnifySingletons(TypeId superTy, TypeId subTy) +{ + const PrimitiveTypeVar* lp = get(superTy); + const SingletonTypeVar* ls = get(superTy); + const SingletonTypeVar* rs = get(subTy); + + if ((!lp && !ls) || !rs) + ice("passed non singleton/primitive types to unifySingletons"); + + if (ls && *ls == *rs) + return; + + if (lp && lp->type == PrimitiveTypeVar::Boolean && get(rs) && variance == Covariant) + return; + + if (lp && lp->type == PrimitiveTypeVar::String && get(rs) && variance == Covariant) + return; + + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); +} + void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall) { FunctionTypeVar* lf = getMutable(superTy); @@ -1023,7 +1063,8 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) } // And vice versa if we're invariant - if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && lt->state != TableState::Free) + if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && + lt->state != TableState::Free) { for (const auto& [propName, subProp] : rt->props) { @@ -1038,7 +1079,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) return; } } - + // Reminder: left is the supertype, right is the subtype. // Width subtyping: any property in the supertype must be in the subtype, // and the types must agree. @@ -1634,9 +1675,8 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) { ok = false; errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); - if (!FFlag::LuauExtendedClassMismatchError) - tryUnify_(prop.type, singletonTypes.errorType); + tryUnify_(prop.type, singletonTypes.errorRecoveryType()); } else { @@ -1952,7 +1992,7 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) { LUAU_ASSERT(get(any)); - const TypeId anyTy = singletonTypes.errorType; + const TypeId anyTy = singletonTypes.errorRecoveryType(); if (FFlag::LuauTypecheckOpts) { @@ -2046,7 +2086,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHash { errors.push_back(TypeError{location, OccursCheckFailed{}}); log(needle); - *asMutable(needle) = ErrorTypeVar{}; + *asMutable(needle) = *singletonTypes.errorRecoveryType(); return; } @@ -2134,7 +2174,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense { errors.push_back(TypeError{location, OccursCheckFailed{}}); log(needle); - *asMutable(needle) = ErrorTypeVar{}; + *asMutable(needle) = *singletonTypes.errorRecoveryTypePack(); return; } diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index a2189f7b..5b4bfa03 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -255,6 +255,14 @@ public: { return visit((class AstType*)node); } + virtual bool visit(class AstTypeSingletonBool* node) + { + return visit((class AstType*)node); + } + virtual bool visit(class AstTypeSingletonString* node) + { + return visit((class AstType*)node); + } virtual bool visit(class AstTypeError* node) { return visit((class AstType*)node); @@ -1158,6 +1166,30 @@ public: unsigned messageIndex; }; +class AstTypeSingletonBool : public AstType +{ +public: + LUAU_RTTI(AstTypeSingletonBool) + + AstTypeSingletonBool(const Location& location, bool value); + + void visit(AstVisitor* visitor) override; + + bool value; +}; + +class AstTypeSingletonString : public AstType +{ +public: + LUAU_RTTI(AstTypeSingletonString) + + AstTypeSingletonString(const Location& location, const AstArray& value); + + void visit(AstVisitor* visitor) override; + + const AstArray value; +}; + class AstTypePack : public AstNode { public: diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 39c7d925..87ebc48b 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -286,6 +286,7 @@ private: // `<' typeAnnotation[, ...] `>' AstArray parseTypeParams(); + std::optional> parseCharArray(); AstExpr* parseString(); AstLocal* pushLocal(const Binding& binding); diff --git a/Ast/include/Luau/StringUtils.h b/Ast/include/Luau/StringUtils.h index 4f7673fa..6ecf0606 100644 --- a/Ast/include/Luau/StringUtils.h +++ b/Ast/include/Luau/StringUtils.h @@ -34,4 +34,6 @@ bool equalsLower(std::string_view lhs, std::string_view rhs); size_t hashRange(const char* data, size_t size); +std::string escape(std::string_view s); +bool isIdentifier(std::string_view s); } // namespace Luau diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index b1209faa..e709894d 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -841,6 +841,28 @@ void AstTypeIntersection::visit(AstVisitor* visitor) } } +AstTypeSingletonBool::AstTypeSingletonBool(const Location& location, bool value) + : AstType(ClassIndex(), location) + , value(value) +{ +} + +void AstTypeSingletonBool::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + +AstTypeSingletonString::AstTypeSingletonString(const Location& location, const AstArray& value) + : AstType(ClassIndex(), location) + , value(value) +{ +} + +void AstTypeSingletonString::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + AstTypeError::AstTypeError(const Location& location, const AstArray& types, bool isMissing, unsigned messageIndex) : AstType(ClassIndex(), location) , types(types) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index a1bad65e..bc63e37d 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false) LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) +LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false) namespace Luau @@ -1278,7 +1279,27 @@ AstType* Parser::parseTableTypeAnnotation() while (lexer.current().type != '}') { - if (lexer.current().type == '[') + if (FFlag::LuauParseSingletonTypes && lexer.current().type == '[' && + (lexer.lookahead().type == Lexeme::RawString || lexer.lookahead().type == Lexeme::QuotedString)) + { + const Lexeme begin = lexer.current(); + nextLexeme(); // [ + std::optional> chars = parseCharArray(); + + expectMatchAndConsume(']', begin); + expectAndConsume(':', "table field"); + + AstType* type = parseTypeAnnotation(); + + // TODO: since AstName conains a char*, it can't contain null + bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); + + if (chars && !containsNull) + props.push_back({AstName(chars->data), begin.location, type}); + else + report(begin.location, "String literal contains malformed escape sequence"); + } + else if (lexer.current().type == '[') { if (indexer) { @@ -1528,6 +1549,32 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) nextLexeme(); return {allocator.alloc(begin, std::nullopt, nameNil), {}}; } + else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedTrue) + { + nextLexeme(); + return {allocator.alloc(begin, true)}; + } + else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedFalse) + { + nextLexeme(); + return {allocator.alloc(begin, false)}; + } + else if (FFlag::LuauParseSingletonTypes && (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString)) + { + if (std::optional> value = parseCharArray()) + { + AstArray svalue = *value; + return {allocator.alloc(begin, svalue)}; + } + else + return {reportTypeAnnotationError(begin, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")}; + } + else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::BrokenString) + { + Location location = lexer.current().location; + nextLexeme(); + return {reportTypeAnnotationError(location, {}, /*isMissing*/ false, "Malformed string")}; + } else if (lexer.current().type == Lexeme::Name) { std::optional prefix; @@ -2416,7 +2463,7 @@ AstArray Parser::parseTypeParams() return copy(parameters); } -AstExpr* Parser::parseString() +std::optional> Parser::parseCharArray() { LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString); @@ -2426,11 +2473,8 @@ AstExpr* Parser::parseString() { if (!Lexer::fixupQuotedString(scratchData)) { - Location location = lexer.current().location; - nextLexeme(); - - return reportExprError(location, {}, "String literal contains malformed escape sequence"); + return std::nullopt; } } else @@ -2438,12 +2482,18 @@ AstExpr* Parser::parseString() Lexer::fixupMultilineString(scratchData); } - Location start = lexer.current().location; AstArray value = copy(scratchData); - nextLexeme(); + return value; +} - return allocator.alloc(start, value); +AstExpr* Parser::parseString() +{ + Location location = lexer.current().location; + if (std::optional> value = parseCharArray()) + return allocator.alloc(location, *value); + else + return reportExprError(location, {}, "String literal contains malformed escape sequence"); } AstLocal* Parser::pushLocal(const Binding& binding) diff --git a/Ast/src/StringUtils.cpp b/Ast/src/StringUtils.cpp index 24b2283a..9c7fed31 100644 --- a/Ast/src/StringUtils.cpp +++ b/Ast/src/StringUtils.cpp @@ -225,4 +225,62 @@ size_t hashRange(const char* data, size_t size) return hash; } +bool isIdentifier(std::string_view s) +{ + return (s.find_first_not_of("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ01234567890_") == std::string::npos); +} + +std::string escape(std::string_view s) +{ + std::string r; + r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting + + for (uint8_t c : s) + { + if (c >= ' ' && c != '\\' && c != '\'' && c != '\"') + r += c; + else + { + r += '\\'; + + switch (c) + { + case '\a': + r += 'a'; + break; + case '\b': + r += 'b'; + break; + case '\f': + r += 'f'; + break; + case '\n': + r += 'n'; + break; + case '\r': + r += 'r'; + break; + case '\t': + r += 't'; + break; + case '\v': + r += 'v'; + break; + case '\'': + r += '\''; + break; + case '\"': + r += '\"'; + break; + case '\\': + r += '\\'; + break; + default: + Luau::formatAppend(r, "%03u", c); + } + } + } + + return r; +} } // namespace Luau diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 798ead06..4194d5ae 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -236,32 +236,12 @@ int main(int argc, char** argv) Luau::registerBuiltinTypes(frontend.typeChecker); Luau::freeze(frontend.typeChecker.globalTypes); + std::vector files = getSourceFiles(argc, argv); + int failed = 0; - for (int i = 1; i < argc; ++i) - { - if (argv[i][0] == '-') - continue; - - if (isDirectory(argv[i])) - { - traverseDirectory(argv[i], [&](const std::string& name) { - // Look for .luau first and if absent, fall back to .lua - if (name.length() > 5 && name.rfind(".luau") == name.length() - 5) - { - failed += !analyzeFile(frontend, name.c_str(), format, annotate); - } - else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - { - failed += !analyzeFile(frontend, name.c_str(), format, annotate); - } - }); - } - else - { - failed += !analyzeFile(frontend, argv[i], format, annotate); - } - } + for (const std::string& path : files) + failed += !analyzeFile(frontend, path.c_str(), format, annotate); if (!configResolver.configErrors.empty()) { diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index fff5e429..b3c9557b 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -223,3 +223,40 @@ std::optional getParentPath(const std::string& path) return ""; } + +static std::string getExtension(const std::string& path) +{ + std::string::size_type dot = path.find_last_of(".\\/"); + + if (dot == std::string::npos || path[dot] != '.') + return ""; + + return path.substr(dot); +} + +std::vector getSourceFiles(int argc, char** argv) +{ + std::vector files; + + for (int i = 1; i < argc; ++i) + { + if (argv[i][0] == '-') + continue; + + if (isDirectory(argv[i])) + { + traverseDirectory(argv[i], [&](const std::string& name) { + std::string ext = getExtension(name); + + if (ext == ".lua" || ext == ".luau") + files.push_back(name); + }); + } + else + { + files.push_back(argv[i]); + } + } + + return files; +} diff --git a/CLI/FileUtils.h b/CLI/FileUtils.h index f7fbe8af..da11f512 100644 --- a/CLI/FileUtils.h +++ b/CLI/FileUtils.h @@ -4,6 +4,7 @@ #include #include #include +#include std::optional readFile(const std::string& name); @@ -12,3 +13,5 @@ bool traverseDirectory(const std::string& path, const std::function getParentPath(const std::string& path); + +std::vector getSourceFiles(int argc, char** argv); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 82fc4e42..1414ffde 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -20,7 +20,7 @@ enum class CompileFormat { - Default, + Text, Binary }; @@ -33,7 +33,7 @@ static int lua_loadstring(lua_State* L) lua_setsafeenv(L, LUA_ENVIRONINDEX, false); std::string bytecode = Luau::compile(std::string(s, l)); - if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0) + if (luau_load(L, chunkname, bytecode.data(), bytecode.size(), 0) == 0) return 1; lua_pushnil(L); @@ -80,7 +80,7 @@ static int lua_require(lua_State* L) // now we can compile & run module on the new thread std::string bytecode = Luau::compile(*source); - if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) + if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { int status = lua_resume(ML, L, 0); @@ -151,7 +151,7 @@ static std::string runCode(lua_State* L, const std::string& source) { std::string bytecode = Luau::compile(source); - if (luau_load(L, "=stdin", bytecode.data(), bytecode.size()) != 0) + if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0) { size_t len; const char* msg = lua_tolstring(L, -1, &len); @@ -370,7 +370,7 @@ static bool runFile(const char* name, lua_State* GL) std::string bytecode = Luau::compile(*source); int status = 0; - if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) + if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { status = lua_resume(L, NULL, 0); } @@ -379,11 +379,7 @@ static bool runFile(const char* name, lua_State* GL) status = LUA_ERRSYNTAX; } - if (status == 0) - { - return true; - } - else + if (status != 0) { std::string error; @@ -400,8 +396,10 @@ static bool runFile(const char* name, lua_State* GL) error += lua_debugtrace(L); fprintf(stderr, "%s", error.c_str()); - return false; } + + lua_pop(GL, 1); + return status == 0; } static void report(const char* name, const Luau::Location& location, const char* type, const char* message) @@ -431,14 +429,18 @@ static bool compileFile(const char* name, CompileFormat format) try { Luau::BytecodeBuilder bcb; - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source); - bcb.setDumpSource(*source); + + if (format == CompileFormat::Text) + { + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source); + bcb.setDumpSource(*source); + } Luau::compileOrThrow(bcb, *source); switch (format) { - case CompileFormat::Default: + case CompileFormat::Text: printf("%s", bcb.dumpEverything().c_str()); break; case CompileFormat::Binary: @@ -504,7 +506,7 @@ int main(int argc, char** argv) if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) { - CompileFormat format = CompileFormat::Default; + CompileFormat format = CompileFormat::Text; if (strcmp(argv[1], "--compile=binary") == 0) format = CompileFormat::Binary; @@ -514,27 +516,12 @@ int main(int argc, char** argv) _setmode(_fileno(stdout), _O_BINARY); #endif + std::vector files = getSourceFiles(argc, argv); + int failed = 0; - for (int i = 2; i < argc; ++i) - { - if (argv[i][0] == '-') - continue; - - if (isDirectory(argv[i])) - { - traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 5 && name.rfind(".luau") == name.length() - 5) - failed += !compileFile(name.c_str(), format); - else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - failed += !compileFile(name.c_str(), format); - }); - } - else - { - failed += !compileFile(argv[i], format); - } - } + for (const std::string& path : files) + failed += !compileFile(path.c_str(), format); return failed; } @@ -548,33 +535,25 @@ int main(int argc, char** argv) int profile = 0; for (int i = 1; i < argc; ++i) + { + if (argv[i][0] != '-') + continue; + if (strcmp(argv[i], "--profile") == 0) profile = 10000; // default to 10 KHz else if (strncmp(argv[i], "--profile=", 10) == 0) profile = atoi(argv[i] + 10); + } if (profile) profilerStart(L, profile); + std::vector files = getSourceFiles(argc, argv); + int failed = 0; - for (int i = 1; i < argc; ++i) - { - if (argv[i][0] == '-') - continue; - - if (isDirectory(argv[i])) - { - traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - failed += !runFile(name.c_str(), L); - }); - } - else - { - failed += !runFile(argv[i], L); - } - } + for (const std::string& path : files) + failed += !runFile(path.c_str(), L); if (profile) { diff --git a/Compiler/include/Luau/Compiler.h b/Compiler/include/Luau/Compiler.h index 4f88e602..65e962da 100644 --- a/Compiler/include/Luau/Compiler.h +++ b/Compiler/include/Luau/Compiler.h @@ -13,11 +13,9 @@ class AstNameTable; class BytecodeBuilder; class BytecodeEncoder; +// Note: this structure is duplicated in luacode.h, don't forget to change these in sync! struct CompileOptions { - // default bytecode version target; can be used to compile code for older clients - int bytecodeVersion = 1; - // 0 - no optimization // 1 - baseline optimization level that doesn't prevent debuggability // 2 - includes optimizations that harm debuggability such as inlining diff --git a/Compiler/include/luacode.h b/Compiler/include/luacode.h new file mode 100644 index 00000000..e235a2e7 --- /dev/null +++ b/Compiler/include/luacode.h @@ -0,0 +1,39 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +/* Can be used to reconfigure visibility/exports for public APIs */ +#ifndef LUACODE_API +#define LUACODE_API extern +#endif + +typedef struct lua_CompileOptions lua_CompileOptions; + +struct lua_CompileOptions +{ + // 0 - no optimization + // 1 - baseline optimization level that doesn't prevent debuggability + // 2 - includes optimizations that harm debuggability such as inlining + int optimizationLevel; // default=1 + + // 0 - no debugging support + // 1 - line info & function names only; sufficient for backtraces + // 2 - full debug info with local & upvalue names; necessary for debugger + int debugLevel; // default=1 + + // 0 - no code coverage support + // 1 - statement coverage + // 2 - statement and expression coverage (verbose) + int coverageLevel; // default=0 + + // global builtin to construct vectors; disabled by default + const char* vectorLib; + const char* vectorCtor; + + // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these + const char** mutableGlobals; +}; + +/* compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy */ +LUACODE_API char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize); diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 9712f02f..5b93c1dc 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -11,9 +11,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauPreloadClosures, false) -LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresFenv, false) -LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresUpval, false) -LUAU_FASTFLAGVARIABLE(LuauGenericSpecialGlobals, false) LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false) @@ -24,9 +21,6 @@ static const uint32_t kMaxRegisterCount = 255; static const uint32_t kMaxUpvalueCount = 200; static const uint32_t kMaxLocalCount = 200; -// TODO: Remove with LuauGenericSpecialGlobals -static const char* kSpecialGlobals[] = {"Game", "Workspace", "_G", "game", "plugin", "script", "shared", "workspace"}; - CompileError::CompileError(const Location& location, const std::string& message) : location(location) , message(message) @@ -466,7 +460,7 @@ struct Compiler bool shared = false; - if (FFlag::LuauPreloadClosuresUpval) + if (FFlag::LuauPreloadClosures) { // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it @@ -482,18 +476,6 @@ struct Compiler } } } - // Optimization: when closure has no upvalues, instead of allocating it every time we can share closure objects - // (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it is used) - else if (FFlag::LuauPreloadClosures && options.optimizationLevel >= 1 && f->upvals.empty() && !setfenvUsed) - { - int32_t cid = bytecode.addConstantClosure(f->id); - - if (cid >= 0 && cid < 32768) - { - bytecode.emitAD(LOP_DUPCLOSURE, target, cid); - return; - } - } if (!shared) bytecode.emitAD(LOP_NEWCLOSURE, target, pid); @@ -3298,8 +3280,7 @@ struct Compiler bool visit(AstStatLocalFunction* node) override { // record local->function association for some optimizations - if (FFlag::LuauPreloadClosuresUpval) - self->locals[node->name].func = node->func; + self->locals[node->name].func = node->func; return true; } @@ -3711,24 +3692,13 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName Compiler compiler(bytecode, options); // since access to some global objects may result in values that change over time, we block imports from non-readonly tables - if (FFlag::LuauGenericSpecialGlobals) - { - if (AstName name = names.get("_G"); name.value) - compiler.globals[name].writable = true; + if (AstName name = names.get("_G"); name.value) + compiler.globals[name].writable = true; - if (options.mutableGlobals) - for (const char** ptr = options.mutableGlobals; *ptr; ++ptr) - if (AstName name = names.get(*ptr); name.value) - compiler.globals[name].writable = true; - } - else - { - for (const char* global : kSpecialGlobals) - { - if (AstName name = names.get(global); name.value) + if (options.mutableGlobals) + for (const char** ptr = options.mutableGlobals; *ptr; ++ptr) + if (AstName name = names.get(*ptr); name.value) compiler.globals[name].writable = true; - } - } // this visitor traverses the AST to analyze mutability of locals/globals, filling Local::written and Global::written Compiler::AssignmentVisitor assignmentVisitor(&compiler); @@ -3742,7 +3712,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName } // this visitor tracks calls to getfenv/setfenv and disables some optimizations when they are found - if (FFlag::LuauPreloadClosuresFenv && options.optimizationLevel >= 1 && (names.get("getfenv").value || names.get("setfenv").value)) + if (options.optimizationLevel >= 1 && (names.get("getfenv").value || names.get("setfenv").value)) { Compiler::FenvVisitor fenvVisitor(compiler.getfenvUsed, compiler.setfenvUsed); root->visit(&fenvVisitor); diff --git a/Compiler/src/lcode.cpp b/Compiler/src/lcode.cpp new file mode 100644 index 00000000..ee150b17 --- /dev/null +++ b/Compiler/src/lcode.cpp @@ -0,0 +1,29 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "luacode.h" + +#include "Luau/Compiler.h" + +#include + +char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize) +{ + LUAU_ASSERT(outsize); + + Luau::CompileOptions opts; + + if (options) + { + static_assert(sizeof(lua_CompileOptions) == sizeof(Luau::CompileOptions), "C and C++ interface must match"); + memcpy(static_cast(&opts), options, sizeof(opts)); + } + + std::string result = compile(std::string(source, size), opts); + + char* copy = static_cast(malloc(result.size())); + if (!copy) + return nullptr; + + memcpy(copy, result.data(), result.size()); + *outsize = result.size(); + return copy; +} diff --git a/README.md b/README.md index ffb464d0..d79a0bd6 100644 --- a/README.md +++ b/README.md @@ -40,8 +40,13 @@ make config=release luau luau-analyze To integrate Luau into your CMake application projects, at the minimum you'll need to depend on `Luau.Compiler` and `Luau.VM` projects. From there you need to create a new Luau state (using Lua 5.x API such as `lua_newstate`), compile source to bytecode and load it into the VM like this: ```cpp -std::string bytecode = Luau::compile(source); // needs Luau/Compiler.h include -if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0) +// needs lua.h and luacode.h +size_t bytecodeSize = 0; +char* bytecode = luau_compile(source, strlen(source), NULL, &bytecodeSize); +int result = luau_load(L, chunkname, bytecode, bytecodeSize, 0); +free(bytecode); + +if (result == 0) return 1; /* return chunk main function */ ``` diff --git a/SECURITY.md b/SECURITY.md index 3fc9f66d..ad929775 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -7,7 +7,7 @@ Any source code can not result in memory safety errors or crashes during its com Note that Luau does not provide termination guarantees - some code may exhaust CPU or RAM resources on the system during compilation or execution. The runtime expects valid bytecode as an input. Feeding bytecode that was not produced by Luau compiler into the VM is not supported and -doesn't come with any security guarantees; make sure to sign the bytecode when it crosses a network or file system boundary to avoid tampering. +doesn't come with any security guarantees; make sure to sign and/or encrypt the bytecode when it crosses a network or file system boundary to avoid tampering. # Reporting a Vulnerability diff --git a/Sources.cmake b/Sources.cmake index c30cf77d..23b931c6 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -25,9 +25,11 @@ target_sources(Luau.Compiler PRIVATE Compiler/include/Luau/Bytecode.h Compiler/include/Luau/BytecodeBuilder.h Compiler/include/Luau/Compiler.h + Compiler/include/luacode.h Compiler/src/BytecodeBuilder.cpp Compiler/src/Compiler.cpp + Compiler/src/lcode.cpp ) # Luau.Analysis Sources @@ -204,6 +206,7 @@ if(TARGET Luau.UnitTest) tests/TypeInfer.intersectionTypes.test.cpp tests/TypeInfer.provisional.test.cpp tests/TypeInfer.refinements.test.cpp + tests/TypeInfer.singletons.test.cpp tests/TypeInfer.tables.test.cpp tests/TypeInfer.test.cpp tests/TypeInfer.tryUnify.test.cpp diff --git a/VM/include/lua.h b/VM/include/lua.h index a9d3e875..1568d191 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -102,6 +102,8 @@ LUA_API lua_State* lua_newstate(lua_Alloc f, void* ud); LUA_API void lua_close(lua_State* L); LUA_API lua_State* lua_newthread(lua_State* L); LUA_API lua_State* lua_mainthread(lua_State* L); +LUA_API void lua_resetthread(lua_State* L); +LUA_API int lua_isthreadreset(lua_State* L); /* ** basic stack manipulation @@ -162,8 +164,7 @@ LUA_API void lua_pushlstring(lua_State* L, const char* s, size_t l); LUA_API void lua_pushstring(lua_State* L, const char* s); LUA_API const char* lua_pushvfstring(lua_State* L, const char* fmt, va_list argp); LUA_API LUA_PRINTF_ATTR(2, 3) const char* lua_pushfstringL(lua_State* L, const char* fmt, ...); -LUA_API void lua_pushcfunction( - lua_State* L, lua_CFunction fn, const char* debugname = NULL, int nup = 0, lua_Continuation cont = NULL); +LUA_API void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont); LUA_API void lua_pushboolean(lua_State* L, int b); LUA_API void lua_pushlightuserdata(lua_State* L, void* p); LUA_API int lua_pushthread(lua_State* L); @@ -178,9 +179,9 @@ LUA_API void lua_rawget(lua_State* L, int idx); LUA_API void lua_rawgeti(lua_State* L, int idx, int n); LUA_API void lua_createtable(lua_State* L, int narr, int nrec); -LUA_API void lua_setreadonly(lua_State* L, int idx, bool value); +LUA_API void lua_setreadonly(lua_State* L, int idx, int enabled); LUA_API int lua_getreadonly(lua_State* L, int idx); -LUA_API void lua_setsafeenv(lua_State* L, int idx, bool value); +LUA_API void lua_setsafeenv(lua_State* L, int idx, int enabled); LUA_API void* lua_newuserdata(lua_State* L, size_t sz, int tag); LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)); @@ -200,7 +201,7 @@ LUA_API int lua_setfenv(lua_State* L, int idx); /* ** `load' and `call' functions (load and run Luau bytecode) */ -LUA_API int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env = 0); +LUA_API int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env); LUA_API void lua_call(lua_State* L, int nargs, int nresults); LUA_API int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc); @@ -293,6 +294,8 @@ LUA_API void lua_unref(lua_State* L, int ref); #define lua_isnoneornil(L, n) (lua_type(L, (n)) <= LUA_TNIL) #define lua_pushliteral(L, s) lua_pushlstring(L, "" s, (sizeof(s) / sizeof(char)) - 1) +#define lua_pushcfunction(L, fn, debugname) lua_pushcclosurek(L, fn, debugname, 0, NULL) +#define lua_pushcclosure(L, fn, debugname, nup) lua_pushcclosurek(L, fn, debugname, nup, NULL) #define lua_setglobal(L, s) lua_setfield(L, LUA_GLOBALSINDEX, (s)) #define lua_getglobal(L, s) lua_getfield(L, LUA_GLOBALSINDEX, (s)) @@ -319,8 +322,8 @@ LUA_API const char* lua_setlocal(lua_State* L, int level, int n); LUA_API const char* lua_getupvalue(lua_State* L, int funcindex, int n); LUA_API const char* lua_setupvalue(lua_State* L, int funcindex, int n); -LUA_API void lua_singlestep(lua_State* L, bool singlestep); -LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, bool enable); +LUA_API void lua_singlestep(lua_State* L, int enabled); +LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled); /* Warning: this function is not thread-safe since it stores the result in a shared global array! Only use for debugging. */ LUA_API const char* lua_debugtrace(lua_State* L); @@ -361,6 +364,7 @@ struct lua_Callbacks void (*debuginterrupt)(lua_State* L, lua_Debug* ar); /* gets called when thread execution is interrupted by break in another thread */ void (*debugprotectederror)(lua_State* L); /* gets called when protected call results in an error */ }; +typedef struct lua_Callbacks lua_Callbacks; LUA_API lua_Callbacks* lua_callbacks(lua_State* L); diff --git a/VM/include/lualib.h b/VM/include/lualib.h index 30cffaff..fa836955 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -8,11 +8,12 @@ #define luaL_typeerror(L, narg, tname) luaL_typeerrorL(L, narg, tname) #define luaL_argerror(L, narg, extramsg) luaL_argerrorL(L, narg, extramsg) -typedef struct luaL_Reg +struct luaL_Reg { const char* name; lua_CFunction func; -} luaL_Reg; +}; +typedef struct luaL_Reg luaL_Reg; LUALIB_API void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l); LUALIB_API int luaL_getmetafield(lua_State* L, int obj, const char* e); @@ -75,6 +76,7 @@ struct luaL_Buffer struct TString* storage; char buffer[LUA_BUFFERSIZE]; }; +typedef struct luaL_Buffer luaL_Buffer; // when internal buffer storage is exhausted, a mutable string value 'storage' will be placed on the stack // in general, functions expect the mutable string buffer to be placed on top of the stack (top-1) diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 7e742644..a79ba0d4 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -593,7 +593,7 @@ const char* lua_pushfstringL(lua_State* L, const char* fmt, ...) return ret; } -void lua_pushcfunction(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont) +void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont) { luaC_checkGC(L); luaC_checkthreadsleep(L); @@ -698,13 +698,13 @@ void lua_createtable(lua_State* L, int narray, int nrec) return; } -void lua_setreadonly(lua_State* L, int objindex, bool value) +void lua_setreadonly(lua_State* L, int objindex, int enabled) { const TValue* o = index2adr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); api_check(L, t != hvalue(registry(L))); - t->readonly = value; + t->readonly = bool(enabled); return; } @@ -717,12 +717,12 @@ int lua_getreadonly(lua_State* L, int objindex) return res; } -void lua_setsafeenv(lua_State* L, int objindex, bool value) +void lua_setsafeenv(lua_State* L, int objindex, int enabled) { const TValue* o = index2adr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); - t->safeenv = value; + t->safeenv = bool(enabled); return; } diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index 87fc1631..61798e2b 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -436,8 +436,8 @@ static const luaL_Reg base_funcs[] = { static void auxopen(lua_State* L, const char* name, lua_CFunction f, lua_CFunction u) { - lua_pushcfunction(L, u); - lua_pushcfunction(L, f, name, 1); + lua_pushcfunction(L, u, NULL); + lua_pushcclosure(L, f, name, 1); lua_setfield(L, -2, name); } @@ -456,10 +456,10 @@ LUALIB_API int luaopen_base(lua_State* L) auxopen(L, "ipairs", luaB_ipairs, luaB_inext); auxopen(L, "pairs", luaB_pairs, luaB_next); - lua_pushcfunction(L, luaB_pcally, "pcall", 0, luaB_pcallcont); + lua_pushcclosurek(L, luaB_pcally, "pcall", 0, luaB_pcallcont); lua_setfield(L, -2, "pcall"); - lua_pushcfunction(L, luaB_xpcally, "xpcall", 0, luaB_xpcallcont); + lua_pushcclosurek(L, luaB_xpcally, "xpcall", 0, luaB_xpcallcont); lua_setfield(L, -2, "xpcall"); return 1; diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index 9724c0e7..0178fae8 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -5,6 +5,8 @@ #include "lstate.h" #include "lvm.h" +LUAU_FASTFLAGVARIABLE(LuauCoroutineClose, false) + #define CO_RUN 0 /* running */ #define CO_SUS 1 /* suspended */ #define CO_NOR 2 /* 'normal' (it resumed another coroutine) */ @@ -208,8 +210,7 @@ static int cowrap(lua_State* L) { cocreate(L); - lua_pushcfunction(L, auxwrapy, NULL, 1, auxwrapcont); - + lua_pushcclosurek(L, auxwrapy, NULL, 1, auxwrapcont); return 1; } @@ -232,6 +233,34 @@ static int coyieldable(lua_State* L) return 1; } +static int coclose(lua_State* L) +{ + if (!FFlag::LuauCoroutineClose) + luaL_error(L, "coroutine.close is not enabled"); + + lua_State* co = lua_tothread(L, 1); + luaL_argexpected(L, co, 1, "thread"); + + int status = auxstatus(L, co); + if (status != CO_DEAD && status != CO_SUS) + luaL_error(L, "cannot close %s coroutine", statnames[status]); + + if (co->status == LUA_OK || co->status == LUA_YIELD) + { + lua_pushboolean(L, true); + lua_resetthread(co); + return 1; + } + else + { + lua_pushboolean(L, false); + if (lua_gettop(co)) + lua_xmove(co, L, 1); /* move error message */ + lua_resetthread(co); + return 2; + } +} + static const luaL_Reg co_funcs[] = { {"create", cocreate}, {"running", corunning}, @@ -239,6 +268,7 @@ static const luaL_Reg co_funcs[] = { {"wrap", cowrap}, {"yield", coyield}, {"isyieldable", coyieldable}, + {"close", coclose}, {NULL, NULL}, }; @@ -246,7 +276,7 @@ LUALIB_API int luaopen_coroutine(lua_State* L) { luaL_register(L, LUA_COLIBNAME, co_funcs); - lua_pushcfunction(L, coresumey, "resume", 0, coresumecont); + lua_pushcclosurek(L, coresumey, "resume", 0, coresumecont); lua_setfield(L, -2, "resume"); return 1; diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 1890e682..d77f84ef 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -316,7 +316,7 @@ void luaG_breakpoint(lua_State* L, Proto* p, int line, bool enable) p->debuginsn[j] = LUAU_INSN_OP(p->code[j]); } - uint8_t op = enable ? LOP_BREAK : LUAU_INSN_OP(p->code[i]); + uint8_t op = enable ? LOP_BREAK : LUAU_INSN_OP(p->debuginsn[i]); // patch just the opcode byte, leave arguments alone p->code[i] &= ~0xff; @@ -357,17 +357,17 @@ int luaG_getline(Proto* p, int pc) return p->abslineinfo[pc >> p->linegaplog2] + p->lineinfo[pc]; } -void lua_singlestep(lua_State* L, bool singlestep) +void lua_singlestep(lua_State* L, int enabled) { - L->singlestep = singlestep; + L->singlestep = bool(enabled); } -void lua_breakpoint(lua_State* L, int funcindex, int line, bool enable) +void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled) { const TValue* func = luaA_toobject(L, funcindex); api_check(L, ttisfunction(func) && !clvalue(func)->isC); - luaG_breakpoint(L, clvalue(func)->l.p, line, enable); + luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled)); } static size_t append(char* buf, size_t bufsize, size_t offset, const char* data) diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 328b47e6..1259d461 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -19,6 +19,7 @@ LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false) LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) +LUAU_FASTFLAG(LuauCoroutineClose) /* ** {====================================================== @@ -300,7 +301,10 @@ static void resume(lua_State* L, void* ud) if (L->status == 0) { // start coroutine - LUAU_ASSERT(L->ci == L->base_ci && firstArg > L->base); + LUAU_ASSERT(L->ci == L->base_ci && firstArg >= L->base); + if (FFlag::LuauCoroutineClose && firstArg == L->base) + luaG_runerror(L, "cannot resume dead coroutine"); + if (luau_precall(L, firstArg - 1, LUA_MULTRET) != PCRLUA) return; diff --git a/VM/src/linit.cpp b/VM/src/linit.cpp index bf5e738f..4e40165a 100644 --- a/VM/src/linit.cpp +++ b/VM/src/linit.cpp @@ -22,7 +22,7 @@ LUALIB_API void luaL_openlibs(lua_State* L) const luaL_Reg* lib = lualibs; for (; lib->func; lib++) { - lua_pushcfunction(L, lib->func); + lua_pushcfunction(L, lib->func, NULL); lua_pushstring(L, lib->name); lua_call(L, 1, 0); } diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index 0b2dfb69..24e97063 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -124,6 +124,34 @@ void luaE_freethread(lua_State* L, lua_State* L1) luaM_free(L, L1, sizeof(lua_State), L1->memcat); } +void lua_resetthread(lua_State* L) +{ + /* close upvalues before clearing anything */ + luaF_close(L, L->stack); + /* clear call frames */ + CallInfo* ci = L->base_ci; + ci->func = L->stack; + ci->base = ci->func + 1; + ci->top = ci->base + LUA_MINSTACK; + setnilvalue(ci->func); + L->ci = ci; + luaD_reallocCI(L, BASIC_CI_SIZE); + /* clear thread state */ + L->status = LUA_OK; + L->base = L->ci->base; + L->top = L->ci->base; + L->nCcalls = L->baseCcalls = 0; + /* clear thread stack */ + luaD_reallocstack(L, BASIC_STACK_SIZE); + for (int i = 0; i < L->stacksize; i++) + setnilvalue(L->stack + i); +} + +int lua_isthreadreset(lua_State* L) +{ + return L->ci == L->base_ci && L->base == L->top && L->status == LUA_OK; +} + lua_State* lua_newstate(lua_Alloc f, void* ud) { int i; diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index 80a34483..b576f809 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -748,7 +748,7 @@ static int gmatch(lua_State* L) luaL_checkstring(L, 2); lua_settop(L, 2); lua_pushinteger(L, 0); - lua_pushcfunction(L, gmatch_aux, NULL, 3); + lua_pushcclosure(L, gmatch_aux, NULL, 3); return 1; } diff --git a/VM/src/lutf8lib.cpp b/VM/src/lutf8lib.cpp index 6a026296..378de3d0 100644 --- a/VM/src/lutf8lib.cpp +++ b/VM/src/lutf8lib.cpp @@ -265,7 +265,7 @@ static int iter_aux(lua_State* L) static int iter_codes(lua_State* L) { luaL_checkstring(L, 1); - lua_pushcfunction(L, iter_aux); + lua_pushcfunction(L, iter_aux, NULL); lua_pushvalue(L, 1); lua_pushinteger(L, 0); return 3; diff --git a/bench/bench.py b/bench/bench.py index b23ca891..39f219f3 100644 --- a/bench/bench.py +++ b/bench/bench.py @@ -25,8 +25,8 @@ try: import scipy from scipy import stats except ModuleNotFoundError: - print("scipy package is required") - exit(1) + print("Warning: scipy package is not installed, confidence values will not be available") + stats = None scriptdir = os.path.dirname(os.path.realpath(__file__)) defaultVm = 'luau.exe' if os.name == "nt" else './luau' @@ -200,11 +200,14 @@ def finalizeResult(result): result.sampleStdDev = math.sqrt(sumOfSquares / (result.count - 1)) result.unbiasedEst = result.sampleStdDev * result.sampleStdDev - # Two-tailed distribution with 95% conf. - tValue = stats.t.ppf(1 - 0.05 / 2, result.count - 1) + if stats: + # Two-tailed distribution with 95% conf. + tValue = stats.t.ppf(1 - 0.05 / 2, result.count - 1) - # Compute confidence interval - result.sampleConfidenceInterval = tValue * result.sampleStdDev / math.sqrt(result.count) + # Compute confidence interval + result.sampleConfidenceInterval = tValue * result.sampleStdDev / math.sqrt(result.count) + else: + result.sampleConfidenceInterval = result.sampleStdDev else: result.sampleStdDev = 0 result.unbiasedEst = 0 @@ -377,14 +380,19 @@ def analyzeResult(subdir, main, comparisons): tStat = abs(main.avg - compare.avg) / (pooledStdDev * math.sqrt(2 / main.count)) degreesOfFreedom = 2 * main.count - 2 - # Two-tailed distribution with 95% conf. - tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom) + if stats: + # Two-tailed distribution with 95% conf. + tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom) - noSignificantDifference = tStat < tCritical + noSignificantDifference = tStat < tCritical + pValue = 2 * (1 - stats.t.cdf(tStat, df = degreesOfFreedom)) + else: + noSignificantDifference = None + pValue = -1 - pValue = 2 * (1 - stats.t.cdf(tStat, df = degreesOfFreedom)) - - if noSignificantDifference: + if noSignificantDifference is None: + verdict = "" + elif noSignificantDifference: verdict = "likely same" elif main.avg < compare.avg: verdict = "likely worse" diff --git a/bench/gc/test_LB_mandel.lua b/bench/gc/test_LB_mandel.lua index 4be78502..fe5b4eb2 100644 --- a/bench/gc/test_LB_mandel.lua +++ b/bench/gc/test_LB_mandel.lua @@ -88,7 +88,7 @@ for i=1,N do local y=ymin+(j-1)*dy S = S + level(x,y) end - -- if i % 10 == 0 then print(collectgarbage"count") end + -- if i % 10 == 0 then print(collectgarbage("count")) end end print(S) diff --git a/bench/tests/shootout/mandel.lua b/bench/tests/shootout/mandel.lua index 4be78502..fe5b4eb2 100644 --- a/bench/tests/shootout/mandel.lua +++ b/bench/tests/shootout/mandel.lua @@ -88,7 +88,7 @@ for i=1,N do local y=ymin+(j-1)*dy S = S + level(x,y) end - -- if i % 10 == 0 then print(collectgarbage"count") end + -- if i % 10 == 0 then print(collectgarbage("count")) end end print(S) diff --git a/bench/tests/shootout/qt.lua b/bench/tests/shootout/qt.lua index de962a74..79cbe38b 100644 --- a/bench/tests/shootout/qt.lua +++ b/bench/tests/shootout/qt.lua @@ -275,7 +275,7 @@ local function memory(s) local t=os.clock() --local dt=string.format("%f",t-t0) local dt=t-t0 - --io.stdout:write(s,"\t",dt," sec\t",t," sec\t",math.floor(collectgarbage"count"/1024),"M\n") + --io.stdout:write(s,"\t",dt," sec\t",t," sec\t",math.floor(collectgarbage("count")/1024),"M\n") t0=t end @@ -286,7 +286,7 @@ local function do_(f,s) end local function julia(l,a,b) -memory"begin" +memory("begin") cx=a cy=b root=newcell() exterior=newcell() exterior.color=white @@ -297,14 +297,14 @@ memory"begin" do_(update,"update") repeat N=0 color(root,Rxmin,Rxmax,Rymin,Rymax) --print("color",N) - until N==0 memory"color" + until N==0 memory("color") repeat N=0 prewhite(root,Rxmin,Rxmax,Rymin,Rymax) --print("prewhite",N) - until N==0 memory"prewhite" + until N==0 memory("prewhite") do_(recolor,"recolor") do_(colorup,"colorup") --print("colorup",N) local g,b=do_(area,"area") --print("area",g,b,g+b) - show(i) memory"output" + show(i) memory("output") --print("edges",nE) end end diff --git a/docs/_pages/library.md b/docs/_pages/library.md index 0a6e3ec9..28a9f389 100644 --- a/docs/_pages/library.md +++ b/docs/_pages/library.md @@ -759,7 +759,7 @@ Otherwise, `s` is interpreted as a [date format string](https://www.cplusplus.co function os.difftime(a: number, b: number): number ``` -Calculates the difference in seconds between `a` and `b`; provided for compatibility. +Calculates the difference in seconds between `a` and `b`; provided for compatibility only. Please use `a - b` instead. ``` function os.time(t: table?): number diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index c85fac7d..ae2399e4 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -257,7 +257,7 @@ DEFINE_PROTO_FUZZER(const luau::StatBlock& message) lua_State* L = lua_newthread(globalState); luaL_sandboxthread(L); - if (luau_load(L, "=fuzz", bytecode.data(), bytecode.size()) == 0) + if (luau_load(L, "=fuzz", bytecode.data(), bytecode.size(), 0) == 0) { interruptDeadline = std::chrono::system_clock::now() + kInterruptTimeout; diff --git a/rfcs/function-coroutine-close.md b/rfcs/function-coroutine-close.md new file mode 100644 index 00000000..635050c4 --- /dev/null +++ b/rfcs/function-coroutine-close.md @@ -0,0 +1,34 @@ +# coroutine.close + +## Summary + +Add `coroutine.close` function from Lua 5.4 that takes a suspended coroutine and makes it "dead" (non-runnable). + +## Motivation + +When implementing various higher level objects on top of coroutines, such as promises, it can be useful to cancel the coroutine execution externally - when the caller is not +interested in getting the results anymore, execution can be aborted. Since coroutines don't provide a way to do that externally, this requires the framework to implement +cancellation on top of coroutines by keeping extra status/token and checking that token in all places where the coroutine is resumed. + +Since coroutine execution can be aborted with an error at any point, coroutines already implement support for "dead" status. If it were possible to externally transition a coroutine +to that status, it would be easier to implement cancellable promises on top of coroutines. + +## Design + +We implement Lua 5.4 behavior exactly with the exception of to-be-closed variables that we don't support. Quoting Lua 5.4 manual: + +> coroutine.close (co) +> Closes coroutine co, that is, puts the coroutine in a dead state. The given coroutine must be dead or suspended. In case of error (either the original error that stopped the coroutine or errors in closing methods), returns false plus the error object; otherwise returns true. + +The `co` argument must be a coroutine object (of type `thread`). + +After closing the coroutine, it gets transitioned to dead state which means that `coroutine.status` will return `"dead"` and attempts to resume the coroutine will fail. In addition, the coroutine stack (which can be accessed via `debug.traceback` or `debug.info`) will become empty. Calling `coroutine.close` on a closed coroutine will return `true` - after closing, the coroutine transitions into a "dead" state with no error information. + +## Drawbacks + +None known, as this function doesn't introduce any existing states to coroutines, and is similar to running the coroutine to completion/error. + +## Alternatives + +Lua's name for this function is likely in part motivated by to-be-closed variables that we don't support. As such, a more appropriate name could be `coroutine.cancel` which also +aligns with use cases better. However, since the semantics is otherwise the same, using the same name as Lua 5.4 reduces library fragmentation. diff --git a/rfcs/syntax-singleton-types.md b/rfcs/syntax-singleton-types.md index 749006c7..26ea3028 100644 --- a/rfcs/syntax-singleton-types.md +++ b/rfcs/syntax-singleton-types.md @@ -48,6 +48,18 @@ type Animals = "Dog" | "Cat" | "Bird" type TrueOrNil = true? ``` +Adding constant strings as type means that it is now legal to write +`{["foo"]:T}` as a table type. This should be parsed as a property, +not an indexer. For example: +```lua + type T = { + ["foo"]: number, + ["$$bar"]: string, + baz: boolean, + } +``` +The table type `T` is a table with three properties and no indexer. + ### Semantics You are allowed to provide a constant value to the generic primitive type. diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 8a7798f3..5a7c8602 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -91,10 +91,6 @@ struct ACFixture : ACFixtureImpl { }; -struct UnfrozenACFixture : ACFixtureImpl -{ -}; - TEST_SUITE_BEGIN("AutocompleteTest"); TEST_CASE_FIXTURE(ACFixture, "empty_program") @@ -1919,9 +1915,10 @@ local bar: @1= foo CHECK(!ac.entryMap.count("foo")); } -// CLI-45692: Remove UnfrozenACFixture here -TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_function_no_parenthesis") +TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis") { + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + check(R"( local function target(a: (number) -> number) return a(4) end local function bar1(a: number) return -a end @@ -1950,9 +1947,10 @@ local fp: @1= f CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } -// CLI-45692: Remove UnfrozenACFixture here -TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_keywords") +TEST_CASE_FIXTURE(ACFixture, "type_correct_keywords") { + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + check(R"( local function a(x: boolean) end local function b(x: number?) end @@ -2484,7 +2482,7 @@ local t = { CHECK(ac.entryMap.count("second")); } -TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols") +TEST_CASE_FIXTURE(UnfrozenFixture, "autocomplete_documentation_symbols") { loadDefinition(R"( declare y: { diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 7f03019c..4ce8d08a 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -11,8 +11,6 @@ #include LUAU_FASTFLAG(LuauPreloadClosures) -LUAU_FASTFLAG(LuauPreloadClosuresFenv) -LUAU_FASTFLAG(LuauPreloadClosuresUpval) LUAU_FASTFLAG(LuauGenericSpecialGlobals) using namespace Luau; @@ -2797,7 +2795,7 @@ CAPTURE UPVAL U1 RETURN R0 1 )"); - if (FFlag::LuauPreloadClosuresUpval) + if (FFlag::LuauPreloadClosures) { // recursive capture CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( @@ -3479,15 +3477,13 @@ CAPTURE VAL R0 RETURN R1 1 )"); - if (FFlag::LuauPreloadClosuresFenv) - { - // if they don't need upvalues but we sense that environment may be modified, we disable this to avoid fenv-related identity confusion - CHECK_EQ("\n" + compileFunction(R"( + // if they don't need upvalues but we sense that environment may be modified, we disable this to avoid fenv-related identity confusion + CHECK_EQ("\n" + compileFunction(R"( setfenv(1, {}) return function() print("hi") end )", - 1), - R"( + 1), + R"( GETIMPORT R0 1 LOADN R1 1 NEWTABLE R2 0 0 @@ -3496,23 +3492,21 @@ NEWCLOSURE R0 P0 RETURN R0 1 )"); - // note that fenv analysis isn't flow-sensitive right now, which is sort of a feature - CHECK_EQ("\n" + compileFunction(R"( + // note that fenv analysis isn't flow-sensitive right now, which is sort of a feature + CHECK_EQ("\n" + compileFunction(R"( if false then setfenv(1, {}) end return function() print("hi") end )", - 1), - R"( + 1), + R"( NEWCLOSURE R0 P0 RETURN R0 1 )"); - } } TEST_CASE("SharedClosure") { ScopedFastFlag sff1("LuauPreloadClosures", true); - ScopedFastFlag sff2("LuauPreloadClosuresUpval", true); // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level CHECK_EQ("\n" + compileFunction(R"( @@ -3671,7 +3665,7 @@ RETURN R0 0 )"); } -TEST_CASE("LuauGenericSpecialGlobals") +TEST_CASE("MutableGlobals") { const char* source = R"( print() @@ -3685,43 +3679,6 @@ shared.print() workspace.print() )"; - { - ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", false}; - - // Check Roblox globals are here - CHECK_EQ("\n" + compileFunction0(source), R"( -GETIMPORT R0 1 -CALL R0 0 0 -GETIMPORT R1 3 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 5 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 7 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 9 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 11 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 13 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 15 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 17 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -RETURN R0 0 -)"); - } - - ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", true}; - // Check Roblox globals are no longer here CHECK_EQ("\n" + compileFunction0(source), R"( GETIMPORT R0 1 diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index c1b790b9..e495a213 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -1,5 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Compiler.h" +#include "lua.h" +#include "lualib.h" +#include "luacode.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/ModuleResolver.h" @@ -10,9 +12,6 @@ #include "doctest.h" #include "ScopedFlags.h" -#include "lua.h" -#include "lualib.h" - #include #include @@ -49,8 +48,12 @@ static int lua_loadstring(lua_State* L) lua_setsafeenv(L, LUA_ENVIRONINDEX, false); - std::string bytecode = Luau::compile(std::string(s, l)); - if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0) + size_t bytecodeSize = 0; + char* bytecode = luau_compile(s, l, nullptr, &bytecodeSize); + int result = luau_load(L, chunkname, bytecode, bytecodeSize, 0); + free(bytecode); + + if (result == 0) return 1; lua_pushnil(L); @@ -179,21 +182,17 @@ static StateRef runConformance( std::string chunkname = "=" + std::string(name); - Luau::CompileOptions copts; + lua_CompileOptions copts = {}; + copts.optimizationLevel = 1; // default copts.debugLevel = 2; // for debugger tests copts.vectorCtor = "vector"; // for vector tests - std::string bytecode = Luau::compile(source, copts); - int status = 0; + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source.data(), source.size(), &copts, &bytecodeSize); + int result = luau_load(L, chunkname.c_str(), bytecode, bytecodeSize, 0); + free(bytecode); - if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) - { - status = lua_resume(L, nullptr, 0); - } - else - { - status = LUA_ERRSYNTAX; - } + int status = (result == 0) ? lua_resume(L, nullptr, 0) : LUA_ERRSYNTAX; while (yield && (status == LUA_YIELD || status == LUA_BREAK)) { @@ -332,53 +331,61 @@ TEST_CASE("UTF8") TEST_CASE("Coroutine") { + ScopedFastFlag sff("LuauCoroutineClose", true); + runConformance("coroutine.lua"); } +static int cxxthrow(lua_State* L) +{ +#if LUA_USE_LONGJMP + luaL_error(L, "oops"); +#else + throw std::runtime_error("oops"); +#endif +} + TEST_CASE("PCall") { runConformance("pcall.lua", [](lua_State* L) { - lua_pushcfunction(L, [](lua_State* L) -> int { -#if LUA_USE_LONGJMP - luaL_error(L, "oops"); -#else - throw std::runtime_error("oops"); -#endif - }); + lua_pushcfunction(L, cxxthrow, "cxxthrow"); lua_setglobal(L, "cxxthrow"); - lua_pushcfunction(L, [](lua_State* L) -> int { - lua_State* co = lua_tothread(L, 1); - lua_xmove(L, co, 1); - lua_resumeerror(co, L); - return 0; - }); + lua_pushcfunction( + L, + [](lua_State* L) -> int { + lua_State* co = lua_tothread(L, 1); + lua_xmove(L, co, 1); + lua_resumeerror(co, L); + return 0; + }, + "resumeerror"); lua_setglobal(L, "resumeerror"); }); } TEST_CASE("Pack") { - ScopedFastFlag sff{ "LuauStrPackUBCastFix", true }; - + ScopedFastFlag sff{"LuauStrPackUBCastFix", true}; + runConformance("tpack.lua"); } TEST_CASE("Vector") { runConformance("vector.lua", [](lua_State* L) { - lua_pushcfunction(L, lua_vector); + lua_pushcfunction(L, lua_vector, "vector"); lua_setglobal(L, "vector"); lua_pushvector(L, 0.0f, 0.0f, 0.0f); luaL_newmetatable(L, "vector"); lua_pushstring(L, "__index"); - lua_pushcfunction(L, lua_vector_index); + lua_pushcfunction(L, lua_vector_index, nullptr); lua_settable(L, -3); lua_pushstring(L, "__namecall"); - lua_pushcfunction(L, lua_vector_namecall); + lua_pushcfunction(L, lua_vector_namecall, nullptr); lua_settable(L, -3); lua_setreadonly(L, -1, true); @@ -513,15 +520,19 @@ TEST_CASE("Debugger") }; // add breakpoint() function - lua_pushcfunction(L, [](lua_State* L) -> int { - int line = luaL_checkinteger(L, 1); + lua_pushcfunction( + L, + [](lua_State* L) -> int { + int line = luaL_checkinteger(L, 1); + bool enabled = lua_isboolean(L, 2) ? lua_toboolean(L, 2) : true; - lua_Debug ar = {}; - lua_getinfo(L, 1, "f", &ar); + lua_Debug ar = {}; + lua_getinfo(L, 1, "f", &ar); - lua_breakpoint(L, -1, line, true); - return 0; - }); + lua_breakpoint(L, -1, line, enabled); + return 0; + }, + "breakpoint"); lua_setglobal(L, "breakpoint"); }, [](lua_State* L) { @@ -744,7 +755,7 @@ TEST_CASE("ExceptionObject") if (nsize == 0) { free(ptr); - return NULL; + return nullptr; } else if (nsize > 512 * 1024) { diff --git a/tests/IostreamOptional.h b/tests/IostreamOptional.h index e55b5b0c..e0756bad 100644 --- a/tests/IostreamOptional.h +++ b/tests/IostreamOptional.h @@ -4,7 +4,8 @@ #include #include -namespace std { +namespace std +{ inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&) { diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 18f55d2c..7a3543c7 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -203,6 +203,8 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + TypeVar freeTy(FreeTypeVar{TypeLevel{}}); TypePackVar freeTp(FreeTypePack{TypeLevel{}}); @@ -212,12 +214,12 @@ TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") bool encounteredFreeType = false; TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, &encounteredFreeType); - CHECK(Luau::get(clonedTy)); + CHECK_EQ("any", toString(clonedTy)); CHECK(encounteredFreeType); encounteredFreeType = false; TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, &encounteredFreeType); - CHECK(Luau::get(clonedTp)); + CHECK_EQ("...any", toString(clonedTp)); CHECK(encounteredFreeType); } diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index b076e9ad..80a258f5 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -198,7 +198,8 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table TypeVar tv{ttv}; - ToStringOptions o{/* exhaustive= */ false, /* useLineBreaks= */ false, /* functionTypeArguments= */ false, /* hideTableKind= */ false, 40}; + ToStringOptions o; + o.maxTableLength = 40; CHECK_EQ(toString(&tv, o), "{| a: number, b: number, c: number, d: number, e: number, ... 5 more ... |}"); } @@ -395,7 +396,7 @@ local function target(callback: nil) return callback(4, "hello") end )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(requireType("target")), "(nil) -> (*unknown*)"); + CHECK_EQ("(nil) -> (*unknown*)", toString(requireType("target"))); } TEST_CASE_FIXTURE(Fixture, "toStringGenericPack") @@ -469,4 +470,110 @@ TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") CHECK_EQ(toString(tableTy), "Table"); } +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") +{ + CheckResult result = check(R"( + local function id(x) return x end + )"); + + TypeId ty = requireType("id"); + const FunctionTypeVar* ftv = get(follow(ty)); + + CHECK_EQ("id(x: a): a", toStringNamedFunction("id", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") +{ + CheckResult result = check(R"( + local function map(arr, fn) + local t = {} + for i = 0, #arr do + t[i] = fn(arr[i]) + end + return t + end + )"); + + TypeId ty = requireType("map"); + const FunctionTypeVar* ftv = get(follow(ty)); + + CHECK_EQ("map(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv)); +} + +TEST_CASE("toStringNamedFunction_unit_f") +{ + TypePackVar empty{TypePack{}}; + FunctionTypeVar ftv{&empty, &empty, {}, false}; + CHECK_EQ("f(): ()", toStringNamedFunction("f", ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") +{ + CheckResult result = check(R"( + local function f(x: a, ...): (a, a, b...) + return x, x, ... + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(x: a, ...: any): (a, a, b...)", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") +{ + CheckResult result = check(R"( + local function f(): ...number + return 1, 2, 3 + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(): ...number", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") +{ + CheckResult result = check(R"( + local function f(): (string, ...number) + return 'a', 1, 2, 3 + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(): (string, ...number)", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_argnames") +{ + CheckResult result = check(R"( + local f: (number, y: number) -> number + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(_: number, y: number): number", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") +{ + CheckResult result = check(R"( + local function f(x: T, g: (T) -> U)): () + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + ToStringOptions opts; + opts.hideNamedFunctionTypeParameters = true; + CHECK_EQ("f(x: T, g: (T) -> U): ()", toStringNamedFunction("f", *ftv, opts)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index c27f8083..74ce155c 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -109,7 +109,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_typ CheckResult result = check(R"( type A = number type A = string -- Redefinition of type 'A', previously defined at line 1 - local foo: string = 1 -- No "Type 'number' could not be converted into 'string'" + local foo: string = 1 -- "Type 'number' could not be converted into 'string'" )"); LUAU_REQUIRE_ERROR_COUNT(2, result); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 2e400164..091c2f01 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -381,6 +381,8 @@ TEST_CASE_FIXTURE(Fixture, "typeof_expr") TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( type A = B type B = A @@ -390,7 +392,7 @@ TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") )"); TypeId fType = requireType("aa"); - const ErrorTypeVar* ftv = get(follow(fType)); + const AnyTypeVar* ftv = get(follow(fType)); REQUIRE(ftv != nullptr); REQUIRE(!result.errors.empty()); } diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index de2f0154..88c2dc85 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -289,7 +289,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") end )"); // TODO: Should typecheck but currently errors CLI-39916 - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "infer_generic_property") @@ -352,7 +352,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_generic_types") -- so this assignment should fail local b: boolean = f(true) )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") @@ -368,7 +368,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") local y: number = id(37) end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "dont_substitute_bound_types") diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 733fc39b..fe8e7ff9 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -704,9 +704,10 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); else CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" - if (FFlag::LuauQuantifyInPlace2) + if (FFlag::LuauQuantifyInPlace2) CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" else CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp new file mode 100644 index 00000000..5f95efd5 --- /dev/null +++ b/tests/TypeInfer.singletons.test.cpp @@ -0,0 +1,377 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" +#include "Luau/BuiltinDefinitions.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeSingletons"); + +TEST_CASE_FIXTURE(Fixture, "bool_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: true = true + local b: false = false + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "string_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "foo" = "foo" + local b: "bar" = "bar" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "bool_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: true = false + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'false' could not be converted into 'true'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "string_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "foo" = "bar" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '\"bar\"' could not be converted into '\"foo\"'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "string_singletons_escape_chars") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "\n" = "\000\r" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type '"\000\r"' could not be converted into '"\n"')", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "bool_singleton_subtype") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: true = true + local b: boolean = a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "foo" = "foo" + local b: string = a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a: true, b: "foo") end + f(true, "foo") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a: true, b: "foo") end + f(true, "bar") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '\"bar\"' could not be converted into '\"foo\"'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a, b) end + local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any) + g(true, "foo") + g(false, 37) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a, b) end + local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any) + g(true, 37) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); + CHECK_EQ("Other overloads are also not viable: (false, number) -> ()", toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type MyEnum = "foo" | "bar" | "baz" + local a : MyEnum = "foo" + local b : MyEnum = "bar" + local c : MyEnum = "baz" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauExtendedTypeMismatchError", true}, + }; + + CheckResult result = check(R"( + type MyEnum = "foo" | "bar" | "baz" + local a : MyEnum = "bang" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '\"bang\"' could not be converted into '\"bar\" | \"baz\" | \"foo\"'; none of the union options are compatible", + toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_subtyping") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type MyEnum1 = "foo" | "bar" + type MyEnum2 = MyEnum1 | "baz" + local a : MyEnum1 = "foo" + local b : MyEnum2 = a + local c : string = b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauExpectedTypesOfProperties", true}, + }; + + CheckResult result = check(R"( + type Dog = { tag: "Dog", howls: boolean } + type Cat = { tag: "Cat", meows: boolean } + type Animal = Dog | Cat + local a : Dog = { tag = "Dog", howls = true } + local b : Animal = { tag = "Cat", meows = true } + local c : Animal = a + c = b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type Dog = { tag: "Dog", howls: boolean } + type Cat = { tag: "Cat", meows: boolean } + type Animal = Dog | Cat + local a : Animal = { tag = "Cat", howls = true } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tagged_unions_immutable_tag") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type Dog = { tag: "Dog", howls: boolean } + type Cat = { tag: "Cat", meows: boolean } + type Animal = Dog | Cat + local a : Animal = { tag = "Cat", meows = true } + a.tag = "Dog" + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type T = { + ["foo"] : number, + ["$$bar"] : string, + baz : boolean + } + local t: T = { + ["foo"] = 37, + ["$$bar"] = "hi", + baz = true + } + local a: number = t.foo + local b: string = t["$$bar"] + local c: boolean = t.baz + t.foo = 5 + t["$$bar"] = "lo" + t.baz = false + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} +TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type T = { + ["$$bar"] : string, + } + local t: T = { + ["$$bar"] = "hi", + } + t["$$bar"] = 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type S = "bar" + type T = { + [("foo")] : number, + [S] : string, + } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Syntax error: Cannot have more than one table indexer", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + local x: { ["<>"] : number } + x = { ["\n"] = 5 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Table type '{| ["\n"]: number |}' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')", + toString(result.errors[0])); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 30d9130a..99fd8339 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -362,7 +362,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") CHECK_EQ(2, result.errors.size()); TypeId p = requireType("p"); - CHECK_EQ(*p, *typeChecker.errorType); + CHECK_EQ("*unknown*", toString(p)); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") @@ -480,7 +480,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.anyType, requireType("a")); + CHECK_EQ("any", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") @@ -496,7 +496,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.anyType, requireType("a")); + CHECK_EQ("any", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") @@ -512,7 +512,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.anyType, requireType("a")); + CHECK_EQ("any", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") @@ -526,7 +526,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(typeChecker.errorType, requireType("a")); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") @@ -542,7 +542,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(typeChecker.errorType, requireType("a")); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator") @@ -673,7 +673,7 @@ TEST_CASE_FIXTURE(Fixture, "string_index") REQUIRE(nat); CHECK_EQ("string", toString(nat->ty)); - CHECK(get(requireType("t"))); + CHECK_EQ("*unknown*", toString(requireType("t"))); } TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") @@ -1456,7 +1456,7 @@ TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export") auto hootyType = requireType(bModule, "Hooty"); - CHECK_MESSAGE(get(follow(hootyType)) != nullptr, "Should be an error: " << toString(hootyType)); + CHECK_EQ("*unknown*", toString(hootyType)); } TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") @@ -2032,7 +2032,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_4") CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[1]); } -TEST_CASE_FIXTURE(Fixture, "error_types_propagate") +TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") { CheckResult result = check(R"( local err = (true).x @@ -2049,10 +2049,10 @@ TEST_CASE_FIXTURE(Fixture, "error_types_propagate") CHECK_EQ("boolean", toString(err->table)); CHECK_EQ("x", err->key); - CHECK(nullptr != get(requireType("c"))); - CHECK(nullptr != get(requireType("d"))); - CHECK(nullptr != get(requireType("e"))); - CHECK(nullptr != get(requireType("f"))); + CHECK_EQ("*unknown*", toString(requireType("c"))); + CHECK_EQ("*unknown*", toString(requireType("d"))); + CHECK_EQ("*unknown*", toString(requireType("e"))); + CHECK_EQ("*unknown*", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") @@ -2068,7 +2068,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") CHECK_EQ("unknown", err->name); - CHECK(nullptr != get(requireType("a"))); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") @@ -2077,9 +2077,7 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") local a = Utility.Create "Foo" {} )"); - TypeId aType = requireType("a"); - - REQUIRE_MESSAGE(nullptr != get(aType), "Not an error: " << toString(aType)); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") @@ -2146,6 +2144,8 @@ TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( --!strict local Vec3 = {} @@ -2175,11 +2175,13 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio CHECK_EQ("Vec3", toString(requireType("b"))); CHECK_EQ("Vec3", toString(requireType("c"))); CHECK_EQ("Vec3", toString(requireType("d"))); - CHECK(get(requireType("e"))); + CHECK_EQ("Vec3", toString(requireType("e"))); } TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( --!strict local Vec3 = {} @@ -2209,7 +2211,7 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio CHECK_EQ("Vec3", toString(requireType("b"))); CHECK_EQ("Vec3", toString(requireType("c"))); CHECK_EQ("Vec3", toString(requireType("d"))); - CHECK(get(requireType("e"))); + CHECK_EQ("Vec3", toString(requireType("e"))); } TEST_CASE_FIXTURE(Fixture, "compare_numbers") @@ -2901,6 +2903,8 @@ end TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( local x: number = 9999 function x:y(z: number) @@ -2908,7 +2912,7 @@ function x:y(z: number) end )"); - LUAU_REQUIRE_ERROR_COUNT(3, result); + LUAU_REQUIRE_ERROR_COUNT(2, result); } TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfError") @@ -2920,7 +2924,7 @@ function x:y(z: number) end )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "CallOrOfFunctions") @@ -3799,7 +3803,7 @@ TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign") print(a) )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERRORS(result); CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); } @@ -4215,7 +4219,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - CHECK(get(t0->type)); + CHECK_EQ("*unknown*", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); @@ -4238,7 +4242,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - CHECK(get(t0->type)); + CHECK_EQ("*unknown*", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); @@ -4394,6 +4398,25 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload") CHECK_EQ(toString(*it), "(number) -> number"); } +TEST_CASE_FIXTURE(Fixture, "return_type_by_overload") +{ + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + type Overload = ((string) -> string) & ((number, number) -> number) + local abc: Overload + local x = abc(true) + local y = abc(true,true) + local z = abc(true,true,true) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("string", toString(requireType("x"))); + CHECK_EQ("number", toString(requireType("y"))); + // Should this be string|number? + CHECK_EQ("string", toString(requireType("z"))); +} + TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") { // Simple direct arg to arg propagation @@ -4740,4 +4763,20 @@ TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions3") } } +TEST_CASE_FIXTURE(Fixture, "type_error_addition") +{ + CheckResult result = check(R"( +--!strict +local foo = makesandwich() +local bar = foo.nutrition + 100 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + // We should definitely get this error + CHECK_EQ("Unknown global 'makesandwich'", toString(result.errors[0])); + // We get this error if makesandwich() returns a free type + // CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'foo'", toString(result.errors[1])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 2d697fc9..9f9a007f 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -121,9 +121,26 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u LUAU_REQUIRE_ERROR_COUNT(1, result); - TypeId bType = requireType("b"); + CHECK_EQ("a", toString(requireType("a"))); + CHECK_EQ("*unknown*", toString(requireType("b"))); +} - CHECK_MESSAGE(get(bType), "Should be an error: " << toString(bType)); +TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained") +{ + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + function f(arg: number) return arg end + local a + local b + local c = f(a, b) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("a", toString(requireType("a"))); + CHECK_EQ("*unknown*", toString(requireType("b"))); + CHECK_EQ("number", toString(requireType("c"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails") @@ -167,15 +184,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_tails_respect_progress") CHECK(state.errors.empty()); } -TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_variadic_pack_with_error_should_work") -{ - TypePackId variadicPack = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.numberType}}); - TypePackId errorPack = arena.addTypePack(TypePack{{typeChecker.numberType}, arena.addTypePack(TypePackVar{Unifiable::Error{}})}); - - state.tryUnify(variadicPack, errorPack); - REQUIRE_EQ(0, state.errors.size()); -} - TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 9f29b642..48496b89 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -200,8 +200,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") CHECK_EQ(mup->missing[0], *bTy); CHECK_EQ(mup->key, "x"); - TypeId r = requireType("r"); - CHECK_MESSAGE(get(r), "Expected error, got " << toString(r)); + CHECK_EQ("*unknown*", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any") @@ -283,7 +282,7 @@ local c = b:foo(1, 2) CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "optional_union_follow") +TEST_CASE_FIXTURE(UnfrozenFixture, "optional_union_follow") { CheckResult result = check(R"( local y: number? = 2 diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 687fff1e..693f68df 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -876,4 +876,4 @@ assert(concat(typeof(5), typeof(nil), typeof({}), typeof(newproxy())) == "number testgetfenv() -- DONT MOVE THIS LINE -return'OK' +return 'OK' diff --git a/tests/conformance/closure.lua b/tests/conformance/closure.lua index aac42c56..f32d5bdc 100644 --- a/tests/conformance/closure.lua +++ b/tests/conformance/closure.lua @@ -419,11 +419,5 @@ co = coroutine.create(function () return loadstring("return a")() end) -a = {a = 15} --- debug.setfenv(co, a) --- assert(debug.getfenv(co) == a) --- assert(select(2, coroutine.resume(co)) == a) --- assert(select(2, coroutine.resume(co)) == a.a) - -return'OK' +return 'OK' diff --git a/tests/conformance/constructs.lua b/tests/conformance/constructs.lua index 16c63b00..f133501f 100644 --- a/tests/conformance/constructs.lua +++ b/tests/conformance/constructs.lua @@ -237,4 +237,4 @@ repeat i = i+1 until i==c -return'OK' +return 'OK' diff --git a/tests/conformance/coroutine.lua b/tests/conformance/coroutine.lua index 75329642..f2ecc96b 100644 --- a/tests/conformance/coroutine.lua +++ b/tests/conformance/coroutine.lua @@ -319,4 +319,58 @@ for i=0,30 do assert(#T2 == 1 or T2[#T2] == 42) end -return'OK' +-- test coroutine.close +do + -- ok to close a dead coroutine + local co = coroutine.create(type) + assert(coroutine.resume(co, "testing 'coroutine.close'")) + assert(coroutine.status(co) == "dead") + local st, msg = coroutine.close(co) + assert(st and msg == nil) + -- also ok to close it again + st, msg = coroutine.close(co) + assert(st and msg == nil) + + + -- cannot close the running coroutine + coroutine.wrap(function() + local st, msg = pcall(coroutine.close, coroutine.running()) + assert(not st and string.find(msg, "running")) + end)() + + -- cannot close a "normal" coroutine + coroutine.wrap(function() + local co = coroutine.running() + coroutine.wrap(function () + local st, msg = pcall(coroutine.close, co) + assert(not st and string.find(msg, "normal")) + end)() + end)() + + -- closing a coroutine after an error + local co = coroutine.create(error) + local obj = {42} + local st, msg = coroutine.resume(co, obj) + assert(not st and msg == obj) + st, msg = coroutine.close(co) + assert(not st and msg == obj) + -- after closing, no more errors + st, msg = coroutine.close(co) + assert(st and msg == nil) + + -- closing a coroutine that has outstanding upvalues + local f + local co = coroutine.create(function() + local a = 42 + f = function() return a end + coroutine.yield() + a = 20 + end) + coroutine.resume(co) + assert(f() == 42) + st, msg = coroutine.close(co) + assert(st and msg == nil) + assert(f() == 42) +end + +return 'OK' diff --git a/tests/conformance/datetime.lua b/tests/conformance/datetime.lua index 21ef60d7..ca35cf2f 100644 --- a/tests/conformance/datetime.lua +++ b/tests/conformance/datetime.lua @@ -74,4 +74,4 @@ assert(os.difftime(t1,t2) == 60*2-19) assert(os.time({ year = 1970, day = 1, month = 1, hour = 0}) == 0) -return'OK' +return 'OK' diff --git a/tests/conformance/debug.lua b/tests/conformance/debug.lua index ee79a14f..9cf3c742 100644 --- a/tests/conformance/debug.lua +++ b/tests/conformance/debug.lua @@ -98,4 +98,4 @@ assert(quuz(function(...) end) == "0 true") assert(quuz(function(a, b) end) == "2 false") assert(quuz(function(a, b, ...) end) == "2 true") -return'OK' +return 'OK' diff --git a/tests/conformance/debugger.lua b/tests/conformance/debugger.lua index 5e69fc6b..6ba99fb9 100644 --- a/tests/conformance/debugger.lua +++ b/tests/conformance/debugger.lua @@ -45,4 +45,13 @@ breakpoint(38) -- break inside corobad() local co = coroutine.create(corobad) assert(coroutine.resume(co) == false) -- this breaks, resumes and dies! +function bar() + print("in bar") +end + +breakpoint(49) +breakpoint(49, false) -- validate that disabling breakpoints works + +bar() + return 'OK' diff --git a/tests/conformance/errors.lua b/tests/conformance/errors.lua index eded14e9..d5ff215b 100644 --- a/tests/conformance/errors.lua +++ b/tests/conformance/errors.lua @@ -34,15 +34,15 @@ assert(doit("error('hi', 0)") == 'hi') assert(doit("unpack({}, 1, n=2^30)")) assert(doit("a=math.sin()")) assert(not doit("tostring(1)") and doit("tostring()")) -assert(doit"tonumber()") -assert(doit"repeat until 1; a") +assert(doit("tonumber()")) +assert(doit("repeat until 1; a")) checksyntax("break label", "", "label", 1) -assert(doit";") -assert(doit"a=1;;") -assert(doit"return;;") -assert(doit"assert(false)") -assert(doit"assert(nil)") -assert(doit"a=math.sin\n(3)") +assert(doit(";")) +assert(doit("a=1;;")) +assert(doit("return;;")) +assert(doit("assert(false)")) +assert(doit("assert(nil)")) +assert(doit("a=math.sin\n(3)")) assert(doit("function a (... , ...) end")) assert(doit("function a (, ...) end")) @@ -59,7 +59,7 @@ checkmessage("a=1; local a,bbbb=2,3; a = math.sin(1) and bbbb(3)", "local 'bbbb'") checkmessage("a={}; do local a=1 end a:bbbb(3)", "method 'bbbb'") checkmessage("local a={}; a.bbbb(3)", "field 'bbbb'") -assert(not string.find(doit"a={13}; local bbbb=1; a[bbbb](3)", "'bbbb'")) +assert(not string.find(doit("a={13}; local bbbb=1; a[bbbb](3)"), "'bbbb'")) checkmessage("a={13}; local bbbb=1; a[bbbb](3)", "number") aaa = nil @@ -67,14 +67,14 @@ checkmessage("aaa.bbb:ddd(9)", "global 'aaa'") checkmessage("local aaa={bbb=1}; aaa.bbb:ddd(9)", "field 'bbb'") checkmessage("local aaa={bbb={}}; aaa.bbb:ddd(9)", "method 'ddd'") checkmessage("local a,b,c; (function () a = b+1 end)()", "upvalue 'b'") -assert(not doit"local aaa={bbb={ddd=next}}; aaa.bbb:ddd(nil)") +assert(not doit("local aaa={bbb={ddd=next}}; aaa.bbb:ddd(nil)")) checkmessage("b=1; local aaa='a'; x=aaa+b", "local 'aaa'") checkmessage("aaa={}; x=3/aaa", "global 'aaa'") checkmessage("aaa='2'; b=nil;x=aaa*b", "global 'b'") checkmessage("aaa={}; x=-aaa", "global 'aaa'") -assert(not string.find(doit"aaa={}; x=(aaa or aaa)+(aaa and aaa)", "'aaa'")) -assert(not string.find(doit"aaa={}; (aaa or aaa)()", "'aaa'")) +assert(not string.find(doit("aaa={}; x=(aaa or aaa)+(aaa and aaa)"), "'aaa'")) +assert(not string.find(doit("aaa={}; (aaa or aaa)()"), "'aaa'")) checkmessage([[aaa=9 repeat until 3==3 @@ -122,10 +122,10 @@ function lineerror (s) return line and line+0 end -assert(lineerror"local a\n for i=1,'a' do \n print(i) \n end" == 2) --- assert(lineerror"\n local a \n for k,v in 3 \n do \n print(k) \n end" == 3) --- assert(lineerror"\n\n for k,v in \n 3 \n do \n print(k) \n end" == 4) -assert(lineerror"function a.x.y ()\na=a+1\nend" == 1) +assert(lineerror("local a\n for i=1,'a' do \n print(i) \n end") == 2) +-- assert(lineerror("\n local a \n for k,v in 3 \n do \n print(k) \n end") == 3) +-- assert(lineerror("\n\n for k,v in \n 3 \n do \n print(k) \n end") == 4) +assert(lineerror("function a.x.y ()\na=a+1\nend") == 1) local p = [[ function g() f() end diff --git a/tests/conformance/gc.lua b/tests/conformance/gc.lua index 4263dfda..6d9eb854 100644 --- a/tests/conformance/gc.lua +++ b/tests/conformance/gc.lua @@ -77,7 +77,7 @@ end local function dosteps (siz) collectgarbage() - collectgarbage"stop" + collectgarbage("stop") local a = {} for i=1,100 do a[i] = {{}}; local b = {} end local x = gcinfo() @@ -99,11 +99,11 @@ assert(dosteps(10000) == 1) do local x = gcinfo() collectgarbage() - collectgarbage"stop" + collectgarbage("stop") repeat local a = {} until gcinfo() > 1000 - collectgarbage"restart" + collectgarbage("restart") repeat local a = {} until gcinfo() < 1000 @@ -123,7 +123,7 @@ for n in pairs(b) do end b = nil collectgarbage() -for n in pairs(a) do error'cannot be here' end +for n in pairs(a) do error("cannot be here") end for i=1,lim do a[i] = i end for i=1,lim do assert(a[i] == i) end diff --git a/tests/conformance/nextvar.lua b/tests/conformance/nextvar.lua index 7f9b7596..94ba5ccf 100644 --- a/tests/conformance/nextvar.lua +++ b/tests/conformance/nextvar.lua @@ -368,9 +368,9 @@ assert(next(a,nil) == 1000 and next(a,1000) == nil) assert(next({}) == nil) assert(next({}, nil) == nil) -for a,b in pairs{} do error"not here" end -for i=1,0 do error'not here' end -for i=0,1,-1 do error'not here' end +for a,b in pairs{} do error("not here") end +for i=1,0 do error("not here") end +for i=0,1,-1 do error("not here") end a = nil; for i=1,1 do assert(not a); a=1 end; assert(a) a = nil; for i=1,1,-1 do assert(not a); a=1 end; assert(a) diff --git a/tests/conformance/pcall.lua b/tests/conformance/pcall.lua index a2072d2c..84ac2ba1 100644 --- a/tests/conformance/pcall.lua +++ b/tests/conformance/pcall.lua @@ -144,4 +144,4 @@ coroutine.resume(co) resumeerror(co, "fail") checkresults({ true, false, "fail" }, coroutine.resume(co)) -return'OK' +return 'OK' diff --git a/tests/conformance/utf8.lua b/tests/conformance/utf8.lua index 024cb16d..bfd7a1ac 100644 --- a/tests/conformance/utf8.lua +++ b/tests/conformance/utf8.lua @@ -205,4 +205,4 @@ for p, c in string.gmatch(x, "()(" .. utf8.charpattern .. ")") do end end -return'OK' +return 'OK'