This commit is contained in:
Pelanyo Kamara 2021-11-21 21:37:28 +00:00
commit 921598bf9c
No known key found for this signature in database
GPG key ID: 848AD95363B749B5
80 changed files with 1878 additions and 614 deletions

View file

@ -12,10 +12,17 @@ namespace Luau
struct FunctionDocumentation; struct FunctionDocumentation;
struct TableDocumentation; struct TableDocumentation;
struct OverloadedFunctionDocumentation; struct OverloadedFunctionDocumentation;
struct BasicDocumentation;
using Documentation = Luau::Variant<std::string, FunctionDocumentation, TableDocumentation, OverloadedFunctionDocumentation>; using Documentation = Luau::Variant<BasicDocumentation, FunctionDocumentation, TableDocumentation, OverloadedFunctionDocumentation>;
using DocumentationSymbol = std::string; using DocumentationSymbol = std::string;
struct BasicDocumentation
{
std::string documentation;
std::string learnMoreLink;
};
struct FunctionParameterDocumentation struct FunctionParameterDocumentation
{ {
std::string name; std::string name;
@ -29,6 +36,7 @@ struct FunctionDocumentation
std::string documentation; std::string documentation;
std::vector<FunctionParameterDocumentation> parameters; std::vector<FunctionParameterDocumentation> parameters;
std::vector<DocumentationSymbol> returns; std::vector<DocumentationSymbol> returns;
std::string learnMoreLink;
}; };
struct OverloadedFunctionDocumentation struct OverloadedFunctionDocumentation
@ -43,6 +51,7 @@ struct TableDocumentation
{ {
std::string documentation; std::string documentation;
Luau::DenseHashMap<std::string, DocumentationSymbol> keys; Luau::DenseHashMap<std::string, DocumentationSymbol> keys;
std::string learnMoreLink;
}; };
using DocumentationDatabase = Luau::DenseHashMap<DocumentationSymbol, Documentation>; using DocumentationDatabase = Luau::DenseHashMap<DocumentationSymbol, Documentation>;

View file

@ -27,6 +27,7 @@ struct ToStringOptions
bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable. bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable.
bool functionTypeArguments = false; // If true, output function type argument names when they are available bool functionTypeArguments = false; // If true, output function type argument names when they are available
bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}'
bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level.
size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars
size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength);
std::optional<ToStringNameMap> nameMap; std::optional<ToStringNameMap> nameMap;
@ -64,6 +65,8 @@ inline std::string toString(TypePackId ty)
std::string toString(const TypeVar& tv, const ToStringOptions& opts = {}); std::string toString(const TypeVar& tv, const ToStringOptions& opts = {});
std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {}); std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {});
std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts = {});
// It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class
// These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression
void dump(TypeId ty); void dump(TypeId ty);

View file

@ -175,10 +175,10 @@ struct TypeChecker
std::vector<std::optional<TypeId>> getExpectedTypesForCall(const std::vector<TypeId>& overloads, size_t argumentCount, bool selfCall); std::vector<std::optional<TypeId>> getExpectedTypesForCall(const std::vector<TypeId>& overloads, size_t argumentCount, bool selfCall);
std::optional<ExprResult<TypePackId>> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, std::optional<ExprResult<TypePackId>> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack,
TypePackId argPack, TypePack* args, const std::vector<Location>& argLocations, const ExprResult<TypePackId>& argListResult, TypePackId argPack, TypePack* args, const std::vector<Location>& argLocations, const ExprResult<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<OverloadErrorEntry>& errors); std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<TypeId>& overloadsThatDont, std::vector<OverloadErrorEntry>& errors);
bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector<Location>& argLocations, bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector<Location>& argLocations,
const std::vector<OverloadErrorEntry>& errors); const std::vector<OverloadErrorEntry>& errors);
ExprResult<TypePackId> reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, void reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack,
const std::vector<Location>& argLocations, const std::vector<TypeId>& overloads, const std::vector<TypeId>& overloadsThatMatchArgCount, const std::vector<Location>& argLocations, const std::vector<TypeId>& overloads, const std::vector<TypeId>& overloadsThatMatchArgCount,
const std::vector<OverloadErrorEntry>& errors); const std::vector<OverloadErrorEntry>& errors);
@ -282,6 +282,14 @@ public:
// Wrapper for merge(l, r, toUnion) but without the lambda junk. // Wrapper for merge(l, r, toUnion) but without the lambda junk.
void merge(RefinementMap& l, const RefinementMap& r); void merge(RefinementMap& l, const RefinementMap& r);
// Produce an "emergency backup type" for recovery from type errors.
// This comes in two flavours, depening on whether or not we can make a good guess
// for an error recovery type.
TypeId errorRecoveryType(TypeId guess);
TypePackId errorRecoveryTypePack(TypePackId guess);
TypeId errorRecoveryType(const ScopePtr& scope);
TypePackId errorRecoveryTypePack(const ScopePtr& scope);
private: private:
void prepareErrorsForDisplay(ErrorVec& errVec); void prepareErrorsForDisplay(ErrorVec& errVec);
void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data); void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data);
@ -297,6 +305,10 @@ private:
TypeId freshType(const ScopePtr& scope); TypeId freshType(const ScopePtr& scope);
TypeId freshType(TypeLevel level); TypeId freshType(TypeLevel level);
// Produce a new singleton type var.
TypeId singletonType(bool value);
TypeId singletonType(std::string value);
// Returns nullopt if the predicate filters down the TypeId to 0 options. // Returns nullopt if the predicate filters down the TypeId to 0 options.
std::optional<TypeId> filterMap(TypeId type, TypeIdPredicate predicate); std::optional<TypeId> filterMap(TypeId type, TypeIdPredicate predicate);
@ -330,8 +342,8 @@ private:
const std::vector<TypePackId>& typePackParams, const Location& location); const std::vector<TypePackId>& typePackParams, const Location& location);
// Note: `scope` must be a fresh scope. // Note: `scope` must be a fresh scope.
std::pair<std::vector<TypeId>, std::vector<TypePackId>> createGenericTypes( std::pair<std::vector<TypeId>, std::vector<TypePackId>> createGenericTypes(const ScopePtr& scope, std::optional<TypeLevel> levelOpt,
const ScopePtr& scope, std::optional<TypeLevel> levelOpt, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames); const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames);
public: public:
ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense);
@ -347,7 +359,6 @@ private:
void resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
void DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
bool isNonstrictMode() const; bool isNonstrictMode() const;
@ -387,12 +398,9 @@ public:
const TypeId booleanType; const TypeId booleanType;
const TypeId threadType; const TypeId threadType;
const TypeId anyType; const TypeId anyType;
const TypeId errorType;
const TypeId optionalNumberType; const TypeId optionalNumberType;
const TypePackId anyTypePack; const TypePackId anyTypePack;
const TypePackId errorTypePack;
private: private:
int checkRecursionCount = 0; int checkRecursionCount = 0;

View file

@ -108,6 +108,79 @@ struct PrimitiveTypeVar
} }
}; };
// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md
// Types for true and false
struct BoolSingleton
{
bool value;
bool operator==(const BoolSingleton& rhs) const
{
return value == rhs.value;
}
bool operator!=(const BoolSingleton& rhs) const
{
return !(*this == rhs);
}
};
// Types for "foo", "bar" etc.
struct StringSingleton
{
std::string value;
bool operator==(const StringSingleton& rhs) const
{
return value == rhs.value;
}
bool operator!=(const StringSingleton& rhs) const
{
return !(*this == rhs);
}
};
// No type for float singletons, partly because === isn't any equalivalence on floats
// (NaN != NaN).
using SingletonVariant = Luau::Variant<BoolSingleton, StringSingleton>;
struct SingletonTypeVar
{
explicit SingletonTypeVar(const SingletonVariant& variant)
: variant(variant)
{
}
explicit SingletonTypeVar(SingletonVariant&& variant)
: variant(std::move(variant))
{
}
// Default operator== is C++20.
bool operator==(const SingletonTypeVar& rhs) const
{
return variant == rhs.variant;
}
bool operator!=(const SingletonTypeVar& rhs) const
{
return !(*this == rhs);
}
SingletonVariant variant;
};
template<typename T>
const T* get(const SingletonTypeVar* stv)
{
if (stv)
return get_if<T>(&stv->variant);
else
return nullptr;
}
struct FunctionArgument struct FunctionArgument
{ {
Name name; Name name;
@ -332,8 +405,8 @@ struct LazyTypeVar
using ErrorTypeVar = Unifiable::Error; using ErrorTypeVar = Unifiable::Error;
using TypeVariant = Unifiable::Variant<TypeId, PrimitiveTypeVar, FunctionTypeVar, TableTypeVar, MetatableTypeVar, ClassTypeVar, AnyTypeVar, using TypeVariant = Unifiable::Variant<TypeId, PrimitiveTypeVar, SingletonTypeVar, FunctionTypeVar, TableTypeVar, MetatableTypeVar, ClassTypeVar,
UnionTypeVar, IntersectionTypeVar, LazyTypeVar>; AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar>;
struct TypeVar final struct TypeVar final
{ {
@ -410,6 +483,9 @@ bool isGeneric(const TypeId ty);
// Checks if a type may be instantiated to one containing generic type binders // Checks if a type may be instantiated to one containing generic type binders
bool maybeGeneric(const TypeId ty); bool maybeGeneric(const TypeId ty);
// Checks if a type is of the form T1|...|Tn where one of the Ti is a singleton
bool maybeSingleton(TypeId ty);
struct SingletonTypes struct SingletonTypes
{ {
const TypeId nilType; const TypeId nilType;
@ -418,16 +494,19 @@ struct SingletonTypes
const TypeId booleanType; const TypeId booleanType;
const TypeId threadType; const TypeId threadType;
const TypeId anyType; const TypeId anyType;
const TypeId errorType;
const TypeId optionalNumberType; const TypeId optionalNumberType;
const TypePackId anyTypePack; const TypePackId anyTypePack;
const TypePackId errorTypePack;
SingletonTypes(); SingletonTypes();
SingletonTypes(const SingletonTypes&) = delete; SingletonTypes(const SingletonTypes&) = delete;
void operator=(const SingletonTypes&) = delete; void operator=(const SingletonTypes&) = delete;
TypeId errorRecoveryType(TypeId guess);
TypePackId errorRecoveryTypePack(TypePackId guess);
TypeId errorRecoveryType();
TypePackId errorRecoveryTypePack();
private: private:
std::unique_ptr<struct TypeArena> arena; std::unique_ptr<struct TypeArena> arena;
TypeId makeStringMetatable(); TypeId makeStringMetatable();

View file

@ -105,6 +105,8 @@ private:
struct Error struct Error
{ {
// This constructor has to be public, since it's used in TypeVar and TypePack,
// but shouldn't be called directly. Please use errorRecoveryType() instead.
Error(); Error();
int index; int index;

View file

@ -65,6 +65,7 @@ struct Unifier
private: private:
void tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false); void tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false);
void tryUnifyPrimitives(TypeId superTy, TypeId subTy); void tryUnifyPrimitives(TypeId superTy, TypeId subTy);
void tryUnifySingletons(TypeId superTy, TypeId subTy);
void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false); void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false);
void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false);
void DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); void DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false);

View file

@ -14,6 +14,7 @@
LUAU_FASTFLAGVARIABLE(ElseElseIfCompletionImprovements, false); LUAU_FASTFLAGVARIABLE(ElseElseIfCompletionImprovements, false);
LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false);
static const std::unordered_set<std::string> kStatementStartingKeywords = { static const std::unordered_set<std::string> kStatementStartingKeywords = {
"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"};
@ -198,11 +199,24 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
UnifierSharedState unifierState(&iceReporter); UnifierSharedState unifierState(&iceReporter);
Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState); Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState);
if (FFlag::LuauAutocompleteAvoidMutation)
{
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
expectedType = clone(expectedType, *typeArena, seenTypes, seenTypePacks, nullptr);
actualType = clone(actualType, *typeArena, seenTypes, seenTypePacks, nullptr);
auto errors = unifier.canUnify(expectedType, actualType);
return errors.empty();
}
else
{
unifier.tryUnify(expectedType, actualType); unifier.tryUnify(expectedType, actualType);
bool ok = unifier.errors.empty(); bool ok = unifier.errors.empty();
unifier.log.rollback(); unifier.log.rollback();
return ok; return ok;
}
}; };
auto expr = node->asExpr(); auto expr = node->asExpr();
@ -1496,10 +1510,8 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName
if (!sourceModule) if (!sourceModule)
return {}; return {};
TypeChecker& typeChecker = TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker);
(frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); ModulePtr module = (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName)
ModulePtr module =
(frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName)
: frontend.moduleResolver.getModule(moduleName)); : frontend.moduleResolver.getModule(moduleName));
if (!module) if (!module)
@ -1527,8 +1539,7 @@ OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view
sourceModule->mode = Mode::Strict; sourceModule->mode = Mode::Strict;
sourceModule->commentLocations = std::move(result.commentLocations); sourceModule->commentLocations = std::move(result.commentLocations);
TypeChecker& typeChecker = TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker);
(frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker);
ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict); ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict);

View file

@ -153,6 +153,7 @@ declare function gcinfo(): number
wrap: <A..., R...>((A...) -> R...) -> any, wrap: <A..., R...>((A...) -> R...) -> any,
yield: <A..., R...>(A...) -> R..., yield: <A..., R...>(A...) -> R...,
isyieldable: () -> boolean, isyieldable: () -> boolean,
close: (thread) -> (boolean, any?)
} }
declare table: { declare table: {

View file

@ -180,13 +180,13 @@ struct ErrorConverter
switch (e.context) switch (e.context)
{ {
case CountMismatch::Return: case CountMismatch::Return:
return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " +
std::to_string(e.actual) + " " + actualVerb + " returned here"; actualVerb + " returned here";
case CountMismatch::Result: case CountMismatch::Result:
// It is alright if right hand side produces more values than the // It is alright if right hand side produces more values than the
// left hand side accepts. In this context consider only the opposite case. // left hand side accepts. In this context consider only the opposite case.
return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + std::to_string(e.actual) +
std::to_string(e.actual) + " are required here"; " are required here";
case CountMismatch::Arg: case CountMismatch::Arg:
if (FFlag::LuauTypeAliasPacks) if (FFlag::LuauTypeAliasPacks)
return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual);

View file

@ -22,7 +22,6 @@ LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false)
LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAG(LuauTraceRequireLookupChild)
LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false)
LUAU_FASTFLAG(LuauNewRequireTrace2) LUAU_FASTFLAG(LuauNewRequireTrace2)
LUAU_FASTFLAGVARIABLE(LuauClearScopes, false)
namespace Luau namespace Luau
{ {
@ -458,7 +457,6 @@ CheckResult Frontend::check(const ModuleName& name)
module->astTypes.clear(); module->astTypes.clear();
module->astExpectedTypes.clear(); module->astExpectedTypes.clear();
module->astOriginalCallTypes.clear(); module->astOriginalCallTypes.clear();
if (FFlag::LuauClearScopes)
module->scopes.resize(1); module->scopes.resize(1);
} }

View file

@ -161,6 +161,7 @@ struct TypeCloner
void operator()(const Unifiable::Bound<TypeId>& t); void operator()(const Unifiable::Bound<TypeId>& t);
void operator()(const Unifiable::Error& t); void operator()(const Unifiable::Error& t);
void operator()(const PrimitiveTypeVar& t); void operator()(const PrimitiveTypeVar& t);
void operator()(const SingletonTypeVar& t);
void operator()(const FunctionTypeVar& t); void operator()(const FunctionTypeVar& t);
void operator()(const TableTypeVar& t); void operator()(const TableTypeVar& t);
void operator()(const MetatableTypeVar& t); void operator()(const MetatableTypeVar& t);
@ -199,7 +200,9 @@ struct TypePackCloner
if (encounteredFreeType) if (encounteredFreeType)
*encounteredFreeType = true; *encounteredFreeType = true;
seenTypePacks[typePackId] = dest.addTypePack(TypePackVar{Unifiable::Error{}}); TypePackId err = singletonTypes.errorRecoveryTypePack(singletonTypes.anyTypePack);
TypePackId cloned = dest.addTypePack(*err);
seenTypePacks[typePackId] = cloned;
} }
void operator()(const Unifiable::Generic& t) void operator()(const Unifiable::Generic& t)
@ -251,8 +254,9 @@ void TypeCloner::operator()(const Unifiable::Free& t)
{ {
if (encounteredFreeType) if (encounteredFreeType)
*encounteredFreeType = true; *encounteredFreeType = true;
TypeId err = singletonTypes.errorRecoveryType(singletonTypes.anyType);
seenTypes[typeId] = dest.addType(ErrorTypeVar{}); TypeId cloned = dest.addType(*err);
seenTypes[typeId] = cloned;
} }
void TypeCloner::operator()(const Unifiable::Generic& t) void TypeCloner::operator()(const Unifiable::Generic& t)
@ -270,11 +274,17 @@ void TypeCloner::operator()(const Unifiable::Error& t)
{ {
defaultClone(t); defaultClone(t);
} }
void TypeCloner::operator()(const PrimitiveTypeVar& t) void TypeCloner::operator()(const PrimitiveTypeVar& t)
{ {
defaultClone(t); defaultClone(t);
} }
void TypeCloner::operator()(const SingletonTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const FunctionTypeVar& t) void TypeCloner::operator()(const FunctionTypeVar& t)
{ {
TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf});

View file

@ -350,6 +350,23 @@ struct TypeVarStringifier
} }
} }
void operator()(TypeId, const SingletonTypeVar& stv)
{
if (const BoolSingleton* bs = Luau::get<BoolSingleton>(&stv))
state.emit(bs->value ? "true" : "false");
else if (const StringSingleton* ss = Luau::get<StringSingleton>(&stv))
{
state.emit("\"");
state.emit(escape(ss->value));
state.emit("\"");
}
else
{
LUAU_ASSERT(!"Unknown singleton type");
throw std::runtime_error("Unknown singleton type");
}
}
void operator()(TypeId, const FunctionTypeVar& ftv) void operator()(TypeId, const FunctionTypeVar& ftv)
{ {
if (state.hasSeen(&ftv)) if (state.hasSeen(&ftv))
@ -359,6 +376,7 @@ struct TypeVarStringifier
return; return;
} }
// We should not be respecting opts.hideNamedFunctionTypeParameters here.
if (ftv.generics.size() > 0 || ftv.genericPacks.size() > 0) if (ftv.generics.size() > 0 || ftv.genericPacks.size() > 0)
{ {
state.emit("<"); state.emit("<");
@ -514,7 +532,14 @@ struct TypeVarStringifier
break; break;
} }
if (isIdentifier(name))
state.emit(name); state.emit(name);
else
{
state.emit("[\"");
state.emit(escape(name));
state.emit("\"]");
}
state.emit(": "); state.emit(": ");
stringify(prop.type); stringify(prop.type);
comma = true; comma = true;
@ -1084,6 +1109,94 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts)
return toString(const_cast<TypePackId>(&tp), std::move(opts)); return toString(const_cast<TypePackId>(&tp), std::move(opts));
} }
std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts)
{
std::string s = prefix;
auto toString_ = [&opts](TypeId ty) -> std::string {
ToStringResult res = toStringDetailed(ty, opts);
opts.nameMap = std::move(res.nameMap);
return res.name;
};
auto toStringPack_ = [&opts](TypePackId ty) -> std::string {
ToStringResult res = toStringDetailed(ty, opts);
opts.nameMap = std::move(res.nameMap);
return res.name;
};
if (!opts.hideNamedFunctionTypeParameters && (!ftv.generics.empty() || !ftv.genericPacks.empty()))
{
s += "<";
bool first = true;
for (TypeId g : ftv.generics)
{
if (!first)
s += ", ";
first = false;
s += toString_(g);
}
for (TypePackId gp : ftv.genericPacks)
{
if (!first)
s += ", ";
first = false;
s += toStringPack_(gp);
}
s += ">";
}
s += "(";
auto argPackIter = begin(ftv.argTypes);
auto argNameIter = ftv.argNames.begin();
bool first = true;
while (argPackIter != end(ftv.argTypes))
{
if (!first)
s += ", ";
first = false;
// argNames is guaranteed to be equal to argTypes iff argNames is not empty.
// We don't currently respect opts.functionTypeArguments. I don't think this function should.
if (!ftv.argNames.empty())
s += (*argNameIter ? (*argNameIter)->name : "_") + ": ";
s += toString_(*argPackIter);
++argPackIter;
if (!ftv.argNames.empty())
{
LUAU_ASSERT(argNameIter != ftv.argNames.end());
++argNameIter;
}
}
if (argPackIter.tail())
{
if (auto vtp = get<VariadicTypePack>(*argPackIter.tail()))
s += ", ...: " + toString_(vtp->ty);
else
s += ", ...: " + toStringPack_(*argPackIter.tail());
}
s += "): ";
size_t retSize = size(ftv.retType);
bool hasTail = !finite(ftv.retType);
if (retSize == 0 && !hasTail)
s += "()";
else if ((retSize == 0 && hasTail) || (retSize == 1 && !hasTail))
s += toStringPack_(ftv.retType);
else
s += "(" + toStringPack_(ftv.retType) + ")";
return s;
}
void dump(TypeId ty) void dump(TypeId ty)
{ {
ToStringOptions opts; ToStringOptions opts;

View file

@ -14,61 +14,6 @@ LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace namespace
{ {
std::string escape(std::string_view s)
{
std::string r;
r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting
for (uint8_t c : s)
{
if (c >= ' ' && c != '\\' && c != '\'' && c != '\"')
r += c;
else
{
r += '\\';
switch (c)
{
case '\a':
r += 'a';
break;
case '\b':
r += 'b';
break;
case '\f':
r += 'f';
break;
case '\n':
r += 'n';
break;
case '\r':
r += 'r';
break;
case '\t':
r += 't';
break;
case '\v':
r += 'v';
break;
case '\'':
r += '\'';
break;
case '\"':
r += '\"';
break;
case '\\':
r += '\\';
break;
default:
Luau::formatAppend(r, "%03u", c);
}
}
}
return r;
}
bool isIdentifierStartChar(char c) bool isIdentifierStartChar(char c)
{ {
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c == '_'; return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c == '_';

View file

@ -96,6 +96,22 @@ public:
return nullptr; return nullptr;
} }
} }
AstType* operator()(const SingletonTypeVar& stv)
{
if (const BoolSingleton* bs = get<BoolSingleton>(&stv))
return allocator->alloc<AstTypeSingletonBool>(Location(), bs->value);
else if (const StringSingleton* ss = get<StringSingleton>(&stv))
{
AstArray<char> value;
value.data = const_cast<char*>(ss->value.c_str());
value.size = strlen(value.data);
return allocator->alloc<AstTypeSingletonString>(Location(), value);
}
else
return nullptr;
}
AstType* operator()(const AnyTypeVar&) AstType* operator()(const AnyTypeVar&)
{ {
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("any")); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("any"));

View file

@ -36,6 +36,9 @@ LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes)
LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false)
LUAU_FASTFLAG(LuauNewRequireTrace2) LUAU_FASTFLAG(LuauNewRequireTrace2)
LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAG(LuauTypeAliasPacks)
LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false)
LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false)
LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false)
namespace Luau namespace Luau
{ {
@ -211,10 +214,8 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan
, booleanType(singletonTypes.booleanType) , booleanType(singletonTypes.booleanType)
, threadType(singletonTypes.threadType) , threadType(singletonTypes.threadType)
, anyType(singletonTypes.anyType) , anyType(singletonTypes.anyType)
, errorType(singletonTypes.errorType)
, optionalNumberType(singletonTypes.optionalNumberType) , optionalNumberType(singletonTypes.optionalNumberType)
, anyTypePack(singletonTypes.anyTypePack) , anyTypePack(singletonTypes.anyTypePack)
, errorTypePack(singletonTypes.errorTypePack)
{ {
globalScope = std::make_shared<Scope>(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); globalScope = std::make_shared<Scope>(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}));
@ -484,7 +485,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std
TypeId type = bindings[name].type; TypeId type = bindings[name].type;
if (get<FreeTypeVar>(follow(type))) if (get<FreeTypeVar>(follow(type)))
{ {
*asMutable(type) = ErrorTypeVar{}; *asMutable(type) = *errorRecoveryType(anyType);
reportError(TypeError{typealias->location, OccursCheckFailed{}}); reportError(TypeError{typealias->location, OccursCheckFailed{}});
} }
} }
@ -719,7 +720,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign)
else if (auto tail = valueIter.tail()) else if (auto tail = valueIter.tail())
{ {
if (get<Unifiable::Error>(*tail)) if (get<Unifiable::Error>(*tail))
right = errorType; right = errorRecoveryType(scope);
else if (auto vtp = get<VariadicTypePack>(*tail)) else if (auto vtp = get<VariadicTypePack>(*tail))
right = vtp->ty; right = vtp->ty;
else if (get<Unifiable::Free>(*tail)) else if (get<Unifiable::Free>(*tail))
@ -961,7 +962,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin)
else if (get<Unifiable::Error>(callRetPack) || !first(callRetPack)) else if (get<Unifiable::Error>(callRetPack) || !first(callRetPack))
{ {
for (TypeId var : varTypes) for (TypeId var : varTypes)
unify(var, errorType, forin.location); unify(var, errorRecoveryType(scope), forin.location);
return check(loopScope, *forin.body); return check(loopScope, *forin.body);
} }
@ -979,7 +980,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin)
const FunctionTypeVar* iterFunc = get<FunctionTypeVar>(iterTy); const FunctionTypeVar* iterFunc = get<FunctionTypeVar>(iterTy);
if (!iterFunc) if (!iterFunc)
{ {
TypeId varTy = get<AnyTypeVar>(iterTy) ? anyType : errorType; TypeId varTy = get<AnyTypeVar>(iterTy) ? anyType : errorRecoveryType(loopScope);
for (TypeId var : varTypes) for (TypeId var : varTypes)
unify(var, varTy, forin.location); unify(var, varTy, forin.location);
@ -1152,9 +1153,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}});
if (FFlag::LuauTypeAliasPacks) if (FFlag::LuauTypeAliasPacks)
bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorType}; bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)};
else else
bindingsMap[name] = TypeFun{binding->typeParams, errorType}; bindingsMap[name] = TypeFun{binding->typeParams, errorRecoveryType(anyType)};
} }
else else
{ {
@ -1398,7 +1399,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr&
if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit)
{ {
reportErrorCodeTooComplex(expr.location); reportErrorCodeTooComplex(expr.location);
return {errorType}; return {errorRecoveryType(scope)};
} }
ExprResult<TypeId> result; ExprResult<TypeId> result;
@ -1407,12 +1408,22 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr&
result = checkExpr(scope, *a->expr); result = checkExpr(scope, *a->expr);
else if (expr.is<AstExprConstantNil>()) else if (expr.is<AstExprConstantNil>())
result = {nilType}; result = {nilType};
else if (expr.is<AstExprConstantBool>()) else if (const AstExprConstantBool* bexpr = expr.as<AstExprConstantBool>())
{
if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType))
result = {singletonType(bexpr->value)};
else
result = {booleanType}; result = {booleanType};
}
else if (const AstExprConstantString* sexpr = expr.as<AstExprConstantString>())
{
if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType))
result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))};
else
result = {stringType};
}
else if (expr.is<AstExprConstantNumber>()) else if (expr.is<AstExprConstantNumber>())
result = {numberType}; result = {numberType};
else if (expr.is<AstExprConstantString>())
result = {stringType};
else if (auto a = expr.as<AstExprLocal>()) else if (auto a = expr.as<AstExprLocal>())
result = checkExpr(scope, *a); result = checkExpr(scope, *a);
else if (auto a = expr.as<AstExprGlobal>()) else if (auto a = expr.as<AstExprGlobal>())
@ -1485,7 +1496,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLo
// TODO: tempting to ice here, but this breaks very often because our toposort doesn't enforce this constraint // TODO: tempting to ice here, but this breaks very often because our toposort doesn't enforce this constraint
// ice("AstExprLocal exists but no binding definition for it?", expr.location); // ice("AstExprLocal exists but no binding definition for it?", expr.location);
reportError(TypeError{expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}}); reportError(TypeError{expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}});
return {errorType}; return {errorRecoveryType(scope)};
} }
ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr)
@ -1497,7 +1508,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGl
return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}}; return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}};
reportError(TypeError{expr.location, UnknownSymbol{expr.name.value, UnknownSymbol::Binding}}); reportError(TypeError{expr.location, UnknownSymbol{expr.name.value, UnknownSymbol::Binding}});
return {errorType}; return {errorRecoveryType(scope)};
} }
ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr)
@ -1517,14 +1528,14 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa
return {head}; return {head};
} }
if (get<ErrorTypeVar>(varargPack)) if (get<ErrorTypeVar>(varargPack))
return {errorType}; return {errorRecoveryType(scope)};
else if (auto vtp = get<VariadicTypePack>(varargPack)) else if (auto vtp = get<VariadicTypePack>(varargPack))
return {vtp->ty}; return {vtp->ty};
else if (get<Unifiable::Generic>(varargPack)) else if (get<Unifiable::Generic>(varargPack))
{ {
// TODO: Better error? // TODO: Better error?
reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"}); reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"});
return {errorType}; return {errorRecoveryType(scope)};
} }
else else
ice("Unknown TypePack type in checkExpr(AstExprVarargs)!"); ice("Unknown TypePack type in checkExpr(AstExprVarargs)!");
@ -1547,7 +1558,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa
return {head, std::move(result.predicates)}; return {head, std::move(result.predicates)};
} }
if (get<Unifiable::Error>(retPack)) if (get<Unifiable::Error>(retPack))
return {errorType, std::move(result.predicates)}; return {errorRecoveryType(scope), std::move(result.predicates)};
else if (auto vtp = get<VariadicTypePack>(retPack)) else if (auto vtp = get<VariadicTypePack>(retPack))
return {vtp->ty, std::move(result.predicates)}; return {vtp->ty, std::move(result.predicates)};
else if (get<Unifiable::Generic>(retPack)) else if (get<Unifiable::Generic>(retPack))
@ -1572,7 +1583,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true)) if (std::optional<TypeId> ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true))
return {*ty}; return {*ty};
return {errorType}; return {errorRecoveryType(scope)};
} }
std::optional<TypeId> TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location) std::optional<TypeId> TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location)
@ -1876,6 +1887,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa
std::vector<std::pair<TypeId, TypeId>> fieldTypes(expr.items.size); std::vector<std::pair<TypeId, TypeId>> fieldTypes(expr.items.size);
const TableTypeVar* expectedTable = nullptr; const TableTypeVar* expectedTable = nullptr;
const UnionTypeVar* expectedUnion = nullptr;
std::optional<TypeId> expectedIndexType; std::optional<TypeId> expectedIndexType;
std::optional<TypeId> expectedIndexResultType; std::optional<TypeId> expectedIndexResultType;
@ -1894,6 +1906,9 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa
} }
} }
} }
else if (FFlag::LuauExpectedTypesOfProperties)
if (const UnionTypeVar* utv = get<UnionTypeVar>(follow(*expectedType)))
expectedUnion = utv;
} }
for (size_t i = 0; i < expr.items.size; ++i) for (size_t i = 0; i < expr.items.size; ++i)
@ -1916,6 +1931,18 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa
if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end())
expectedResultType = prop->second.type; expectedResultType = prop->second.type;
} }
else if (FFlag::LuauExpectedTypesOfProperties && expectedUnion)
{
std::vector<TypeId> expectedResultTypes;
for (TypeId expectedOption : expectedUnion)
if (const TableTypeVar* ttv = get<TableTypeVar>(follow(expectedOption)))
if (auto prop = ttv->props.find(key->value.data); prop != ttv->props.end())
expectedResultTypes.push_back(prop->second.type);
if (expectedResultTypes.size() == 1)
expectedResultType = expectedResultTypes[0];
else if (expectedResultTypes.size() > 1)
expectedResultType = addType(UnionTypeVar{expectedResultTypes});
}
} }
else else
{ {
@ -1958,21 +1985,22 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn
{ {
TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); TypeId actualFunctionType = instantiate(scope, *fnt, expr.location);
TypePackId arguments = addTypePack({operandType}); TypePackId arguments = addTypePack({operandType});
TypePackId retType = freshTypePack(scope); TypePackId retTypePack = freshTypePack(scope);
TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retType)); TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack));
Unifier state = mkUnifier(expr.location); Unifier state = mkUnifier(expr.location);
state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true);
TypeId retType = first(retTypePack).value_or(nilType);
if (!state.errors.empty()) if (!state.errors.empty())
return {errorType}; retType = errorRecoveryType(retType);
return {first(retType).value_or(nilType)}; return {retType};
} }
reportError(expr.location, reportError(expr.location,
GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())}); GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())});
return {errorType}; return {errorRecoveryType(scope)};
} }
reportErrors(tryUnify(numberType, operandType, expr.location)); reportErrors(tryUnify(numberType, operandType, expr.location));
@ -1984,7 +2012,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn
operandType = stripFromNilAndReport(operandType, expr.location); operandType = stripFromNilAndReport(operandType, expr.location);
if (get<ErrorTypeVar>(operandType)) if (get<ErrorTypeVar>(operandType))
return {errorType}; return {errorRecoveryType(scope)};
if (get<AnyTypeVar>(operandType)) if (get<AnyTypeVar>(operandType))
return {numberType}; // Not strictly correct: metatables permit overriding this return {numberType}; // Not strictly correct: metatables permit overriding this
@ -2044,7 +2072,7 @@ TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const Location& location, b
if (unify(a, b, location)) if (unify(a, b, location))
return a; return a;
return errorType; return errorRecoveryType(anyType);
} }
if (*a == *b) if (*a == *b)
@ -2166,11 +2194,13 @@ TypeId TypeChecker::checkRelationalOperation(
std::optional<TypeId> leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType)); std::optional<TypeId> leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType));
std::optional<TypeId> rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType)); std::optional<TypeId> rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType));
// TODO: this check seems odd, the second part is redundant
// is it meant to be if (leftMetatable && rightMetatable && leftMetatable != rightMetatable)
if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable)
{ {
reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable",
toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())});
return errorType; return errorRecoveryType(booleanType);
} }
if (leftMetatable) if (leftMetatable)
@ -2188,7 +2218,7 @@ TypeId TypeChecker::checkRelationalOperation(
if (!state.errors.empty()) if (!state.errors.empty())
{ {
reportError(expr.location, GenericError{format("Metamethod '%s' must return type 'boolean'", metamethodName.c_str())}); reportError(expr.location, GenericError{format("Metamethod '%s' must return type 'boolean'", metamethodName.c_str())});
return errorType; return errorRecoveryType(booleanType);
} }
} }
} }
@ -2206,7 +2236,7 @@ TypeId TypeChecker::checkRelationalOperation(
{ {
reportError( reportError(
expr.location, GenericError{format("Table %s does not offer metamethod %s", toString(lhsType).c_str(), metamethodName.c_str())}); expr.location, GenericError{format("Table %s does not offer metamethod %s", toString(lhsType).c_str(), metamethodName.c_str())});
return errorType; return errorRecoveryType(booleanType);
} }
} }
@ -2214,14 +2244,14 @@ TypeId TypeChecker::checkRelationalOperation(
{ {
auto name = getIdentifierOfBaseVar(expr.left); auto name = getIdentifierOfBaseVar(expr.left);
reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison}); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison});
return errorType; return errorRecoveryType(booleanType);
} }
if (needsMetamethod) if (needsMetamethod)
{ {
reportError(expr.location, GenericError{format("Type %s cannot be compared with %s because it has no metatable", reportError(expr.location, GenericError{format("Type %s cannot be compared with %s because it has no metatable",
toString(lhsType).c_str(), toString(expr.op).c_str())}); toString(lhsType).c_str(), toString(expr.op).c_str())});
return errorType; return errorRecoveryType(booleanType);
} }
return booleanType; return booleanType;
@ -2266,7 +2296,8 @@ TypeId TypeChecker::checkBinaryOperation(
{ {
auto name = getIdentifierOfBaseVar(expr.left); auto name = getIdentifierOfBaseVar(expr.left);
reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation});
return errorType; if (!FFlag::LuauErrorRecoveryType)
return errorRecoveryType(scope);
} }
// If we know nothing at all about the lhs type, we can usually say nothing about the result. // If we know nothing at all about the lhs type, we can usually say nothing about the result.
@ -2296,18 +2327,33 @@ TypeId TypeChecker::checkBinaryOperation(
auto checkMetatableCall = [this, &scope, &expr](TypeId fnt, TypeId lhst, TypeId rhst) -> TypeId { auto checkMetatableCall = [this, &scope, &expr](TypeId fnt, TypeId lhst, TypeId rhst) -> TypeId {
TypeId actualFunctionType = instantiate(scope, fnt, expr.location); TypeId actualFunctionType = instantiate(scope, fnt, expr.location);
TypePackId arguments = addTypePack({lhst, rhst}); TypePackId arguments = addTypePack({lhst, rhst});
TypePackId retType = freshTypePack(scope); TypePackId retTypePack = freshTypePack(scope);
TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retType)); TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack));
Unifier state = mkUnifier(expr.location); Unifier state = mkUnifier(expr.location);
state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true);
reportErrors(state.errors); reportErrors(state.errors);
bool hasErrors = !state.errors.empty();
if (FFlag::LuauErrorRecoveryType && hasErrors)
{
// If there are unification errors, the return type may still be unknown
// so we loosen the argument types to see if that helps.
TypePackId fallbackArguments = freshTypePack(scope);
TypeId fallbackFunctionType = addType(FunctionTypeVar(scope->level, fallbackArguments, retTypePack));
state.log.rollback();
state.errors.clear();
state.tryUnify(fallbackFunctionType, actualFunctionType, /*isFunctionCall*/ true);
if (!state.errors.empty()) if (!state.errors.empty())
return errorType; state.log.rollback();
}
return first(retType).value_or(nilType); TypeId retType = first(retTypePack).value_or(nilType);
if (hasErrors)
retType = errorRecoveryType(retType);
return retType;
}; };
std::string op = opToMetaTableEntry(expr.op); std::string op = opToMetaTableEntry(expr.op);
@ -2321,7 +2367,8 @@ TypeId TypeChecker::checkBinaryOperation(
reportError(expr.location, GenericError{format("Binary operator '%s' not supported by types '%s' and '%s'", toString(expr.op).c_str(), reportError(expr.location, GenericError{format("Binary operator '%s' not supported by types '%s' and '%s'", toString(expr.op).c_str(),
toString(lhsType).c_str(), toString(rhsType).c_str())}); toString(lhsType).c_str(), toString(rhsType).c_str())});
return errorType;
return errorRecoveryType(scope);
} }
switch (expr.op) switch (expr.op)
@ -2414,11 +2461,9 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy
ExprResult<TypeId> result = checkExpr(scope, *expr.expr, annotationType); ExprResult<TypeId> result = checkExpr(scope, *expr.expr, annotationType);
ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); ErrorVec errorVec = canUnify(result.type, annotationType, expr.location);
if (!errorVec.empty())
{
reportErrors(errorVec); reportErrors(errorVec);
return {errorType, std::move(result.predicates)}; if (!errorVec.empty())
} annotationType = errorRecoveryType(annotationType);
return {annotationType, std::move(result.predicates)}; return {annotationType, std::move(result.predicates)};
} }
@ -2434,7 +2479,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr
// any type errors that may arise from it are going to be useless. // any type errors that may arise from it are going to be useless.
currentModule->errors.resize(oldSize); currentModule->errors.resize(oldSize);
return {errorType}; return {errorRecoveryType(scope)};
} }
ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr) ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr)
@ -2476,7 +2521,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
{ {
for (AstExpr* expr : a->expressions) for (AstExpr* expr : a->expressions)
checkExpr(scope, *expr); checkExpr(scope, *expr);
return std::pair(errorType, nullptr); return {errorRecoveryType(scope), nullptr};
} }
else else
ice("Unexpected AST node in checkLValue", expr.location); ice("Unexpected AST node in checkLValue", expr.location);
@ -2488,7 +2533,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
return {*ty, nullptr}; return {*ty, nullptr};
reportError(expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}); reportError(expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding});
return {errorType, nullptr}; return {errorRecoveryType(scope), nullptr};
} }
std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr) std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr)
@ -2545,24 +2590,25 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
{ {
Unifier state = mkUnifier(expr.location); Unifier state = mkUnifier(expr.location);
state.tryUnify(indexer->indexType, stringType); state.tryUnify(indexer->indexType, stringType);
TypeId retType = indexer->indexResultType;
if (!state.errors.empty()) if (!state.errors.empty())
{ {
state.log.rollback(); state.log.rollback();
reportError(expr.location, UnknownProperty{lhs, name}); reportError(expr.location, UnknownProperty{lhs, name});
return std::pair(errorType, nullptr); retType = errorRecoveryType(retType);
} }
return std::pair(indexer->indexResultType, nullptr); return std::pair(retType, nullptr);
} }
else if (lhsTable->state == TableState::Sealed) else if (lhsTable->state == TableState::Sealed)
{ {
reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}});
return std::pair(errorType, nullptr); return std::pair(errorRecoveryType(scope), nullptr);
} }
else else
{ {
reportError(TypeError{expr.location, GenericError{"Internal error: generic tables are not lvalues"}}); reportError(TypeError{expr.location, GenericError{"Internal error: generic tables are not lvalues"}});
return std::pair(errorType, nullptr); return std::pair(errorRecoveryType(scope), nullptr);
} }
} }
else if (const ClassTypeVar* lhsClass = get<ClassTypeVar>(lhs)) else if (const ClassTypeVar* lhsClass = get<ClassTypeVar>(lhs))
@ -2571,7 +2617,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
if (!prop) if (!prop)
{ {
reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); reportError(TypeError{expr.location, UnknownProperty{lhs, name}});
return std::pair(errorType, nullptr); return std::pair(errorRecoveryType(scope), nullptr);
} }
return std::pair(prop->type, nullptr); return std::pair(prop->type, nullptr);
@ -2585,12 +2631,12 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
if (isTableIntersection(lhs)) if (isTableIntersection(lhs))
{ {
reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}});
return std::pair(errorType, nullptr); return std::pair(errorRecoveryType(scope), nullptr);
} }
} }
reportError(TypeError{expr.location, NotATable{lhs}}); reportError(TypeError{expr.location, NotATable{lhs}});
return std::pair(errorType, nullptr); return std::pair(errorRecoveryType(scope), nullptr);
} }
std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr) std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr)
@ -2615,7 +2661,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
if (!prop) if (!prop)
{ {
reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}});
return std::pair(errorType, nullptr); return std::pair(errorRecoveryType(scope), nullptr);
} }
return std::pair(prop->type, nullptr); return std::pair(prop->type, nullptr);
} }
@ -2626,7 +2672,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
if (!exprTable) if (!exprTable)
{ {
reportError(TypeError{expr.expr->location, NotATable{exprType}}); reportError(TypeError{expr.expr->location, NotATable{exprType}});
return std::pair(errorType, nullptr); return std::pair(errorRecoveryType(scope), nullptr);
} }
if (value) if (value)
@ -2678,7 +2724,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName)
if (isNonstrictMode()) if (isNonstrictMode())
return globalScope->bindings[name].typeId; return globalScope->bindings[name].typeId;
return errorType; return errorRecoveryType(scope);
} }
else else
{ {
@ -2705,20 +2751,21 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName)
TableTypeVar* ttv = getMutableTableType(lhsType); TableTypeVar* ttv = getMutableTableType(lhsType);
if (!ttv) if (!ttv)
{ {
if (!isTableIntersection(lhsType)) if (!FFlag::LuauErrorRecoveryType && !isTableIntersection(lhsType))
// This error now gets reported when we check the function body.
reportError(TypeError{funName.location, OnlyTablesCanHaveMethods{lhsType}}); reportError(TypeError{funName.location, OnlyTablesCanHaveMethods{lhsType}});
return errorType; return errorRecoveryType(scope);
} }
// Cannot extend sealed table, but we dont report an error here because it will be reported during AstStatFunction check // Cannot extend sealed table, but we dont report an error here because it will be reported during AstStatFunction check
if (lhsType->persistent || ttv->state == TableState::Sealed) if (lhsType->persistent || ttv->state == TableState::Sealed)
return errorType; return errorRecoveryType(scope);
Name name = indexName->index.value; Name name = indexName->index.value;
if (ttv->props.count(name)) if (ttv->props.count(name))
return errorType; return errorRecoveryType(scope);
Property& property = ttv->props[name]; Property& property = ttv->props[name];
@ -2728,9 +2775,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName)
return property.type; return property.type;
} }
else if (funName.is<AstExprError>()) else if (funName.is<AstExprError>())
{ return errorRecoveryType(scope);
return errorType;
}
else else
{ {
ice("Unexpected AST node type", funName.location); ice("Unexpected AST node type", funName.location);
@ -2991,7 +3036,7 @@ ExprResult<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, const A
else if (expr.is<AstExprVarargs>()) else if (expr.is<AstExprVarargs>())
{ {
if (!scope->varargPack) if (!scope->varargPack)
return {addTypePack({addType(ErrorTypeVar())})}; return {errorRecoveryTypePack(scope)};
return {*scope->varargPack}; return {*scope->varargPack};
} }
@ -3095,10 +3140,9 @@ void TypeChecker::checkArgumentList(
if (get<Unifiable::Error>(tail)) if (get<Unifiable::Error>(tail))
{ {
// Unify remaining parameters so we don't leave any free-types hanging around. // Unify remaining parameters so we don't leave any free-types hanging around.
TypeId argTy = errorType;
while (paramIter != endIter) while (paramIter != endIter)
{ {
state.tryUnify(*paramIter, argTy); state.tryUnify(*paramIter, errorRecoveryType(anyType));
++paramIter; ++paramIter;
} }
return; return;
@ -3157,7 +3201,7 @@ void TypeChecker::checkArgumentList(
{ {
while (argIter != endIter) while (argIter != endIter)
{ {
unify(*argIter, errorType, state.location); unify(*argIter, errorRecoveryType(scope), state.location);
++argIter; ++argIter;
} }
// For this case, we want the error span to cover every errant extra parameter // For this case, we want the error span to cover every errant extra parameter
@ -3246,7 +3290,8 @@ ExprResult<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, const A
// For each overload // For each overload
// Compare parameter and argument types // Compare parameter and argument types
// Report any errors (also speculate dot vs colon warnings!) // Report any errors (also speculate dot vs colon warnings!)
// If there are no errors, return the resulting return type // Return the resulting return type (even if there are errors)
// If there are no matching overloads, unify with (a...) -> (b...) and return b...
TypeId selfType = nullptr; TypeId selfType = nullptr;
TypeId functionType = nullptr; TypeId functionType = nullptr;
@ -3268,8 +3313,8 @@ ExprResult<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, const A
} }
else else
{ {
functionType = errorType; functionType = errorRecoveryType(scope);
actualFunctionType = errorType; actualFunctionType = functionType;
} }
} }
else else
@ -3296,7 +3341,7 @@ ExprResult<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, const A
TypePackId argPack = argListResult.type; TypePackId argPack = argListResult.type;
if (get<Unifiable::Error>(argPack)) if (get<Unifiable::Error>(argPack))
return ExprResult<TypePackId>{errorTypePack}; return {errorRecoveryTypePack(scope)};
TypePack* args = getMutable<TypePack>(argPack); TypePack* args = getMutable<TypePack>(argPack);
LUAU_ASSERT(args != nullptr); LUAU_ASSERT(args != nullptr);
@ -3314,19 +3359,34 @@ ExprResult<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, const A
std::vector<OverloadErrorEntry> errors; // errors encountered for each overload std::vector<OverloadErrorEntry> errors; // errors encountered for each overload
std::vector<TypeId> overloadsThatMatchArgCount; std::vector<TypeId> overloadsThatMatchArgCount;
std::vector<TypeId> overloadsThatDont;
for (TypeId fn : overloads) for (TypeId fn : overloads)
{ {
fn = follow(fn); fn = follow(fn);
if (auto ret = checkCallOverload(scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, errors)) if (auto ret = checkCallOverload(
scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors))
return *ret; return *ret;
} }
if (handleSelfCallMismatch(scope, expr, args, argLocations, errors)) if (handleSelfCallMismatch(scope, expr, args, argLocations, errors))
return {retPack}; return {retPack};
return reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors);
if (FFlag::LuauErrorRecoveryType)
{
const FunctionTypeVar* overload = nullptr;
if (!overloadsThatMatchArgCount.empty())
overload = get<FunctionTypeVar>(overloadsThatMatchArgCount[0]);
if (!overload && !overloadsThatDont.empty())
overload = get<FunctionTypeVar>(overloadsThatDont[0]);
if (overload)
return {errorRecoveryTypePack(overload->retType)};
}
return {errorRecoveryTypePack(retPack)};
} }
std::vector<std::optional<TypeId>> TypeChecker::getExpectedTypesForCall(const std::vector<TypeId>& overloads, size_t argumentCount, bool selfCall) std::vector<std::optional<TypeId>> TypeChecker::getExpectedTypesForCall(const std::vector<TypeId>& overloads, size_t argumentCount, bool selfCall)
@ -3382,7 +3442,7 @@ std::vector<std::optional<TypeId>> TypeChecker::getExpectedTypesForCall(const st
std::optional<ExprResult<TypePackId>> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, std::optional<ExprResult<TypePackId>> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack,
TypePackId argPack, TypePack* args, const std::vector<Location>& argLocations, const ExprResult<TypePackId>& argListResult, TypePackId argPack, TypePack* args, const std::vector<Location>& argLocations, const ExprResult<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<OverloadErrorEntry>& errors) std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<TypeId>& overloadsThatDont, std::vector<OverloadErrorEntry>& errors)
{ {
fn = stripFromNilAndReport(fn, expr.func->location); fn = stripFromNilAndReport(fn, expr.func->location);
@ -3394,7 +3454,7 @@ std::optional<ExprResult<TypePackId>> TypeChecker::checkCallOverload(const Scope
if (get<ErrorTypeVar>(fn)) if (get<ErrorTypeVar>(fn))
{ {
return {{addTypePack(TypePackVar{Unifiable::Error{}})}}; return {{errorRecoveryTypePack(scope)}};
} }
if (get<FreeTypeVar>(fn)) if (get<FreeTypeVar>(fn))
@ -3427,14 +3487,14 @@ std::optional<ExprResult<TypePackId>> TypeChecker::checkCallOverload(const Scope
TypeId fn = *ty; TypeId fn = *ty;
fn = instantiate(scope, fn, expr.func->location); fn = instantiate(scope, fn, expr.func->location);
return checkCallOverload( return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult,
scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, overloadsThatMatchArgCount, errors); overloadsThatMatchArgCount, overloadsThatDont, errors);
} }
} }
reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}});
unify(retPack, errorTypePack, expr.func->location); unify(retPack, errorRecoveryTypePack(scope), expr.func->location);
return {{errorTypePack}}; return {{errorRecoveryTypePack(retPack)}};
} }
// When this function type has magic functions and did return something, we select that overload instead. // When this function type has magic functions and did return something, we select that overload instead.
@ -3476,6 +3536,8 @@ std::optional<ExprResult<TypePackId>> TypeChecker::checkCallOverload(const Scope
if (!argMismatch) if (!argMismatch)
overloadsThatMatchArgCount.push_back(fn); overloadsThatMatchArgCount.push_back(fn);
else if (FFlag::LuauErrorRecoveryType)
overloadsThatDont.push_back(fn);
errors.emplace_back(std::move(state.errors), args->head, ftv); errors.emplace_back(std::move(state.errors), args->head, ftv);
state.log.rollback(); state.log.rollback();
@ -3586,14 +3648,14 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal
return false; return false;
} }
ExprResult<TypePackId> TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack,
TypePackId argPack, const std::vector<Location>& argLocations, const std::vector<TypeId>& overloads, const std::vector<Location>& argLocations, const std::vector<TypeId>& overloads, const std::vector<TypeId>& overloadsThatMatchArgCount,
const std::vector<TypeId>& overloadsThatMatchArgCount, const std::vector<OverloadErrorEntry>& errors) const std::vector<OverloadErrorEntry>& errors)
{ {
if (overloads.size() == 1) if (overloads.size() == 1)
{ {
reportErrors(std::get<0>(errors.front())); reportErrors(std::get<0>(errors.front()));
return {errorTypePack}; return;
} }
std::vector<TypeId> overloadTypes = overloadsThatMatchArgCount; std::vector<TypeId> overloadTypes = overloadsThatMatchArgCount;
@ -3622,7 +3684,7 @@ ExprResult<TypePackId> TypeChecker::reportOverloadResolutionError(const ScopePtr
// If only one overload matched, we don't need this error because we provided the previous errors. // If only one overload matched, we don't need this error because we provided the previous errors.
if (overloadsThatMatchArgCount.size() == 1) if (overloadsThatMatchArgCount.size() == 1)
return {errorTypePack}; return;
} }
std::string s; std::string s;
@ -3655,7 +3717,7 @@ ExprResult<TypePackId> TypeChecker::reportOverloadResolutionError(const ScopePtr
reportError(expr.func->location, ExtraInformation{"Other overloads are also not viable: " + s}); reportError(expr.func->location, ExtraInformation{"Other overloads are also not viable: " + s});
// No viable overload // No viable overload
return {errorTypePack}; return;
} }
ExprResult<TypePackId> TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray<AstExpr*>& exprs, ExprResult<TypePackId> TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray<AstExpr*>& exprs,
@ -3740,7 +3802,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module
if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict)
{ {
reportError(TypeError{location, UnknownRequire{}}); reportError(TypeError{location, UnknownRequire{}});
return errorType; return errorRecoveryType(anyType);
} }
return anyType; return anyType;
@ -3758,14 +3820,14 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module
reportError(TypeError{location, UnknownRequire{reportedModulePath}}); reportError(TypeError{location, UnknownRequire{reportedModulePath}});
} }
return errorType; return errorRecoveryType(scope);
} }
if (module->type != SourceCode::Module) if (module->type != SourceCode::Module)
{ {
std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name);
reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."});
return errorType; return errorRecoveryType(scope);
} }
std::optional<TypeId> moduleType = first(module->getModuleScope()->returnType); std::optional<TypeId> moduleType = first(module->getModuleScope()->returnType);
@ -3773,7 +3835,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module
{ {
std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name);
reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."});
return errorType; return errorRecoveryType(scope);
} }
SeenTypes seenTypes; SeenTypes seenTypes;
@ -4078,7 +4140,7 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location
if (!qty.has_value()) if (!qty.has_value())
{ {
reportError(location, UnificationTooComplex{}); reportError(location, UnificationTooComplex{});
return errorType; return errorRecoveryType(scope);
} }
if (ty == *qty) if (ty == *qty)
@ -4101,7 +4163,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat
else else
{ {
reportError(location, UnificationTooComplex{}); reportError(location, UnificationTooComplex{});
return errorType; return errorRecoveryType(scope);
} }
} }
@ -4116,7 +4178,7 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location)
else else
{ {
reportError(location, UnificationTooComplex{}); reportError(location, UnificationTooComplex{});
return errorType; return errorRecoveryType(anyType);
} }
} }
@ -4131,7 +4193,7 @@ TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location lo
else else
{ {
reportError(location, UnificationTooComplex{}); reportError(location, UnificationTooComplex{});
return errorTypePack; return errorRecoveryTypePack(anyTypePack);
} }
} }
@ -4279,6 +4341,38 @@ TypeId TypeChecker::freshType(TypeLevel level)
return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level))); return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level)));
} }
TypeId TypeChecker::singletonType(bool value)
{
// TODO: cache singleton types
return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BoolSingleton{value})));
}
TypeId TypeChecker::singletonType(std::string value)
{
// TODO: cache singleton types
return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(StringSingleton{std::move(value)})));
}
TypeId TypeChecker::errorRecoveryType(const ScopePtr& scope)
{
return singletonTypes.errorRecoveryType();
}
TypeId TypeChecker::errorRecoveryType(TypeId guess)
{
return singletonTypes.errorRecoveryType(guess);
}
TypePackId TypeChecker::errorRecoveryTypePack(const ScopePtr& scope)
{
return singletonTypes.errorRecoveryTypePack();
}
TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess)
{
return singletonTypes.errorRecoveryTypePack(guess);
}
std::optional<TypeId> TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) std::optional<TypeId> TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate)
{ {
std::vector<TypeId> types = Luau::filterMap(type, predicate); std::vector<TypeId> types = Luau::filterMap(type, predicate);
@ -4350,7 +4444,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
if (lit->parameters.size != 1 || !lit->parameters.data[0].type) if (lit->parameters.size != 1 || !lit->parameters.data[0].type)
{ {
reportError(TypeError{annotation.location, GenericError{"_luau_print requires one generic parameter"}}); reportError(TypeError{annotation.location, GenericError{"_luau_print requires one generic parameter"}});
return addType(ErrorTypeVar{}); return errorRecoveryType(anyType);
} }
ToStringOptions opts; ToStringOptions opts;
@ -4368,7 +4462,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
if (!tf) if (!tf)
{ {
if (lit->name == Parser::errorName) if (lit->name == Parser::errorName)
return addType(ErrorTypeVar{}); return errorRecoveryType(scope);
std::string typeName; std::string typeName;
if (lit->hasPrefix) if (lit->hasPrefix)
@ -4380,7 +4474,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
else else
reportError(TypeError{annotation.location, UnknownSymbol{typeName, UnknownSymbol::Type}}); reportError(TypeError{annotation.location, UnknownSymbol{typeName, UnknownSymbol::Type}});
return addType(ErrorTypeVar{}); return errorRecoveryType(scope);
} }
if (lit->parameters.size == 0 && tf->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf->typePackParams.empty())) if (lit->parameters.size == 0 && tf->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf->typePackParams.empty()))
@ -4390,14 +4484,17 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
else if (!FFlag::LuauTypeAliasPacks && lit->parameters.size != tf->typeParams.size()) else if (!FFlag::LuauTypeAliasPacks && lit->parameters.size != tf->typeParams.size())
{ {
reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->parameters.size, 0}}); reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->parameters.size, 0}});
return addType(ErrorTypeVar{}); if (!FFlag::LuauErrorRecoveryType)
return errorRecoveryType(scope);
} }
else if (FFlag::LuauTypeAliasPacks)
if (FFlag::LuauTypeAliasPacks)
{ {
if (!lit->hasParameterList && !tf->typePackParams.empty()) if (!lit->hasParameterList && !tf->typePackParams.empty())
{ {
reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}});
return addType(ErrorTypeVar{}); if (!FFlag::LuauErrorRecoveryType)
return errorRecoveryType(scope);
} }
std::vector<TypeId> typeParams; std::vector<TypeId> typeParams;
@ -4445,7 +4542,17 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
{ {
reportError( reportError(
TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}});
return addType(ErrorTypeVar{});
if (FFlag::LuauErrorRecoveryType)
{
// Pad the types out with error recovery types
while (typeParams.size() < tf->typeParams.size())
typeParams.push_back(errorRecoveryType(scope));
while (typePackParams.size() < tf->typePackParams.size())
typePackParams.push_back(errorRecoveryTypePack(scope));
}
else
return errorRecoveryType(scope);
} }
if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams)
@ -4464,6 +4571,14 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
for (const auto& param : lit->parameters) for (const auto& param : lit->parameters)
typeParams.push_back(resolveType(scope, *param.type)); typeParams.push_back(resolveType(scope, *param.type));
if (FFlag::LuauErrorRecoveryType)
{
// If there aren't enough type parameters, pad them out with error recovery types
// (we've already reported the error)
while (typeParams.size() < lit->parameters.size)
typeParams.push_back(errorRecoveryType(scope));
}
if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams) if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams)
{ {
// If the generic parameters and the type arguments are the same, we are about to // If the generic parameters and the type arguments are the same, we are about to
@ -4483,8 +4598,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
props[prop.name.value] = {resolveType(scope, *prop.type)}; props[prop.name.value] = {resolveType(scope, *prop.type)};
if (const auto& indexer = table->indexer) if (const auto& indexer = table->indexer)
tableIndexer = TableIndexer( tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType));
resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType));
return addType(TableTypeVar{ return addType(TableTypeVar{
props, tableIndexer, scope->level, props, tableIndexer, scope->level,
@ -4536,14 +4650,20 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
return addType(IntersectionTypeVar{types}); return addType(IntersectionTypeVar{types});
} }
else if (annotation.is<AstTypeError>()) else if (const auto& tsb = annotation.as<AstTypeSingletonBool>())
{ {
return addType(ErrorTypeVar{}); return singletonType(tsb->value);
} }
else if (const auto& tss = annotation.as<AstTypeSingletonString>())
{
return singletonType(std::string(tss->value.data, tss->value.size));
}
else if (annotation.is<AstTypeError>())
return errorRecoveryType(scope);
else else
{ {
reportError(TypeError{annotation.location, GenericError{"Unknown type annotation?"}}); reportError(TypeError{annotation.location, GenericError{"Unknown type annotation?"}});
return addType(ErrorTypeVar{}); return errorRecoveryType(scope);
} }
} }
@ -4584,7 +4704,7 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack
else else
reportError(TypeError{generic->location, UnknownSymbol{genericName, UnknownSymbol::Type}}); reportError(TypeError{generic->location, UnknownSymbol{genericName, UnknownSymbol::Type}});
return addTypePack(TypePackVar{Unifiable::Error{}}); return errorRecoveryTypePack(scope);
} }
return *genericTy; return *genericTy;
@ -4706,12 +4826,12 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf,
if (!maybeInstantiated.has_value()) if (!maybeInstantiated.has_value())
{ {
reportError(location, UnificationTooComplex{}); reportError(location, UnificationTooComplex{});
return errorType; return errorRecoveryType(scope);
} }
if (FFlag::LuauRecursiveTypeParameterRestriction && applyTypeFunction.encounteredForwardedType) if (FFlag::LuauRecursiveTypeParameterRestriction && applyTypeFunction.encounteredForwardedType)
{ {
reportError(TypeError{location, GenericError{"Recursive type being used with different parameters"}}); reportError(TypeError{location, GenericError{"Recursive type being used with different parameters"}});
return errorType; return errorRecoveryType(scope);
} }
TypeId instantiated = *maybeInstantiated; TypeId instantiated = *maybeInstantiated;
@ -4773,8 +4893,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf,
return instantiated; return instantiated;
} }
std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGenericTypes( std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional<TypeLevel> levelOpt,
const ScopePtr& scope, std::optional<TypeLevel> levelOpt, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames) const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames)
{ {
LUAU_ASSERT(scope->parent); LUAU_ASSERT(scope->parent);
@ -5043,7 +5163,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement
addRefinement(refis, isaP.lvalue, *result); addRefinement(refis, isaP.lvalue, *result);
else else
{ {
addRefinement(refis, isaP.lvalue, errorType); addRefinement(refis, isaP.lvalue, errorRecoveryType(scope));
errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}}); errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}});
} }
} }
@ -5107,7 +5227,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec
addRefinement(refis, typeguardP.lvalue, *result); addRefinement(refis, typeguardP.lvalue, *result);
else else
{ {
addRefinement(refis, typeguardP.lvalue, errorType); addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope));
if (sense) if (sense)
errVec.push_back( errVec.push_back(
TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}}); TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}});
@ -5118,7 +5238,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec
auto fail = [&](const TypeErrorData& err) { auto fail = [&](const TypeErrorData& err) {
errVec.push_back(TypeError{typeguardP.location, err}); errVec.push_back(TypeError{typeguardP.location, err});
addRefinement(refis, typeguardP.lvalue, errorType); addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope));
}; };
if (!typeguardP.isTypeof) if (!typeguardP.isTypeof)
@ -5139,28 +5259,6 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec
return resolve(IsAPredicate{std::move(typeguardP.lvalue), typeguardP.location, type}, errVec, refis, scope, sense); return resolve(IsAPredicate{std::move(typeguardP.lvalue), typeguardP.location, type}, errVec, refis, scope, sense);
} }
void TypeChecker::DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense)
{
if (!sense)
return;
static std::vector<std::string> primitives{
"string", "number", "boolean", "nil", "thread",
"table", // no op. Requires special handling.
"function", // no op. Requires special handling.
"userdata", // no op. Requires special handling.
};
if (auto typeFun = globalScope->lookupType(typeguardP.kind);
typeFun && typeFun->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || typeFun->typePackParams.empty()))
{
if (auto it = std::find(primitives.begin(), primitives.end(), typeguardP.kind); it != primitives.end())
addRefinement(refis, typeguardP.lvalue, typeFun->type);
else if (typeguardP.isTypeof)
addRefinement(refis, typeguardP.lvalue, typeFun->type);
}
}
void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense)
{ {
// This refinement will require success typing to do everything correctly. For now, we can get most of the way there. // This refinement will require success typing to do everything correctly. For now, we can get most of the way there.

View file

@ -286,5 +286,4 @@ TypePack* asMutable(const TypePack* tp)
{ {
return const_cast<TypePack*>(tp); return const_cast<TypePack*>(tp);
} }
} // namespace Luau } // namespace Luau

View file

@ -21,6 +21,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500)
LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0)
LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAG(LuauTypeAliasPacks)
LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false)
LUAU_FASTFLAG(LuauErrorRecoveryType)
namespace Luau namespace Luau
{ {
@ -305,6 +306,18 @@ bool maybeGeneric(TypeId ty)
return isGeneric(ty); return isGeneric(ty);
} }
bool maybeSingleton(TypeId ty)
{
ty = follow(ty);
if (get<SingletonTypeVar>(ty))
return true;
if (const UnionTypeVar* utv = get<UnionTypeVar>(ty))
for (TypeId option : utv)
if (get<SingletonTypeVar>(follow(option)))
return true;
return false;
}
FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional<FunctionDefinition> defn, bool hasSelf) FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional<FunctionDefinition> defn, bool hasSelf)
: argTypes(argTypes) : argTypes(argTypes)
, retType(retType) , retType(retType)
@ -562,10 +575,8 @@ SingletonTypes::SingletonTypes()
, booleanType(&booleanType_) , booleanType(&booleanType_)
, threadType(&threadType_) , threadType(&threadType_)
, anyType(&anyType_) , anyType(&anyType_)
, errorType(&errorType_)
, optionalNumberType(&optionalNumberType_) , optionalNumberType(&optionalNumberType_)
, anyTypePack(&anyTypePack_) , anyTypePack(&anyTypePack_)
, errorTypePack(&errorTypePack_)
, arena(new TypeArena) , arena(new TypeArena)
{ {
TypeId stringMetatable = makeStringMetatable(); TypeId stringMetatable = makeStringMetatable();
@ -634,6 +645,32 @@ TypeId SingletonTypes::makeStringMetatable()
return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
} }
TypeId SingletonTypes::errorRecoveryType()
{
return &errorType_;
}
TypePackId SingletonTypes::errorRecoveryTypePack()
{
return &errorTypePack_;
}
TypeId SingletonTypes::errorRecoveryType(TypeId guess)
{
if (FFlag::LuauErrorRecoveryType)
return guess;
else
return &errorType_;
}
TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess)
{
if (FFlag::LuauErrorRecoveryType)
return guess;
else
return &errorTypePack_;
}
SingletonTypes singletonTypes; SingletonTypes singletonTypes;
void persist(TypeId ty) void persist(TypeId ty)
@ -1141,6 +1178,11 @@ struct QVarFinder
return false; return false;
} }
bool operator()(const SingletonTypeVar&) const
{
return false;
}
bool operator()(const FunctionTypeVar& ftv) const bool operator()(const FunctionTypeVar& ftv) const
{ {
if (hasGeneric(ftv.argTypes)) if (hasGeneric(ftv.argTypes))
@ -1412,7 +1454,7 @@ static std::vector<TypeId> parseFormatString(TypeChecker& typechecker, const cha
else if (strchr(options, data[i])) else if (strchr(options, data[i]))
result.push_back(typechecker.numberType); result.push_back(typechecker.numberType);
else else
result.push_back(typechecker.errorType); result.push_back(typechecker.errorRecoveryType(typechecker.anyType));
} }
} }

View file

@ -22,7 +22,9 @@ LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false)
LUAU_FASTFLAG(LuauShareTxnSeen); LUAU_FASTFLAG(LuauShareTxnSeen);
LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false) LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false)
LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false)
LUAU_FASTFLAG(LuauSingletonTypes)
LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false)
LUAU_FASTFLAG(LuauErrorRecoveryType);
namespace Luau namespace Luau
{ {
@ -211,6 +213,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
{ {
occursCheck(subTy, superTy); occursCheck(subTy, superTy);
// The occurrence check might have caused superTy no longer to be a free type
if (!get<ErrorTypeVar>(subTy)) if (!get<ErrorTypeVar>(subTy))
{ {
log(subTy); log(subTy);
@ -221,10 +224,20 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
} }
else if (l && r) else if (l && r)
{ {
if (!FFlag::LuauErrorRecoveryType)
log(superTy); log(superTy);
occursCheck(superTy, subTy); occursCheck(superTy, subTy);
r->level = min(r->level, l->level); r->level = min(r->level, l->level);
// The occurrence check might have caused superTy no longer to be a free type
if (!FFlag::LuauErrorRecoveryType)
*asMutable(superTy) = BoundTypeVar(subTy); *asMutable(superTy) = BoundTypeVar(subTy);
else if (!get<ErrorTypeVar>(superTy))
{
log(superTy);
*asMutable(superTy) = BoundTypeVar(subTy);
}
return; return;
} }
else if (l) else if (l)
@ -240,6 +253,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
return; return;
} }
// The occurrence check might have caused superTy no longer to be a free type
if (!get<ErrorTypeVar>(superTy)) if (!get<ErrorTypeVar>(superTy))
{ {
if (auto rightLevel = getMutableLevel(subTy)) if (auto rightLevel = getMutableLevel(subTy))
@ -251,6 +265,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
log(superTy); log(superTy);
*asMutable(superTy) = BoundTypeVar(subTy); *asMutable(superTy) = BoundTypeVar(subTy);
} }
return; return;
} }
else if (r) else if (r)
@ -512,6 +527,9 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
else if (get<PrimitiveTypeVar>(superTy) && get<PrimitiveTypeVar>(subTy)) else if (get<PrimitiveTypeVar>(superTy) && get<PrimitiveTypeVar>(subTy))
tryUnifyPrimitives(superTy, subTy); tryUnifyPrimitives(superTy, subTy);
else if (FFlag::LuauSingletonTypes && (get<PrimitiveTypeVar>(superTy) || get<SingletonTypeVar>(superTy)) && get<SingletonTypeVar>(subTy))
tryUnifySingletons(superTy, subTy);
else if (get<FunctionTypeVar>(superTy) && get<FunctionTypeVar>(subTy)) else if (get<FunctionTypeVar>(superTy) && get<FunctionTypeVar>(subTy))
tryUnifyFunctions(superTy, subTy, isFunctionCall); tryUnifyFunctions(superTy, subTy, isFunctionCall);
@ -723,17 +741,18 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal
{ {
occursCheck(superTp, subTp); occursCheck(superTp, subTp);
// The occurrence check might have caused superTp no longer to be a free type
if (!get<ErrorTypeVar>(superTp)) if (!get<ErrorTypeVar>(superTp))
{ {
log(superTp); log(superTp);
*asMutable(superTp) = Unifiable::Bound<TypePackId>(subTp); *asMutable(superTp) = Unifiable::Bound<TypePackId>(subTp);
} }
} }
else if (get<Unifiable::Free>(subTp)) else if (get<Unifiable::Free>(subTp))
{ {
occursCheck(subTp, superTp); occursCheck(subTp, superTp);
// The occurrence check might have caused superTp no longer to be a free type
if (!get<ErrorTypeVar>(subTp)) if (!get<ErrorTypeVar>(subTp))
{ {
log(subTp); log(subTp);
@ -874,13 +893,13 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal
while (superIter.good()) while (superIter.good())
{ {
tryUnify_(singletonTypes.errorType, *superIter); tryUnify_(singletonTypes.errorRecoveryType(), *superIter);
superIter.advance(); superIter.advance();
} }
while (subIter.good()) while (subIter.good())
{ {
tryUnify_(singletonTypes.errorType, *subIter); tryUnify_(singletonTypes.errorRecoveryType(), *subIter);
subIter.advance(); subIter.advance();
} }
@ -906,6 +925,27 @@ void Unifier::tryUnifyPrimitives(TypeId superTy, TypeId subTy)
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}});
} }
void Unifier::tryUnifySingletons(TypeId superTy, TypeId subTy)
{
const PrimitiveTypeVar* lp = get<PrimitiveTypeVar>(superTy);
const SingletonTypeVar* ls = get<SingletonTypeVar>(superTy);
const SingletonTypeVar* rs = get<SingletonTypeVar>(subTy);
if ((!lp && !ls) || !rs)
ice("passed non singleton/primitive types to unifySingletons");
if (ls && *ls == *rs)
return;
if (lp && lp->type == PrimitiveTypeVar::Boolean && get<BoolSingleton>(rs) && variance == Covariant)
return;
if (lp && lp->type == PrimitiveTypeVar::String && get<StringSingleton>(rs) && variance == Covariant)
return;
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}});
}
void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall) void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall)
{ {
FunctionTypeVar* lf = getMutable<FunctionTypeVar>(superTy); FunctionTypeVar* lf = getMutable<FunctionTypeVar>(superTy);
@ -1023,7 +1063,8 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
} }
// And vice versa if we're invariant // And vice versa if we're invariant
if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && lt->state != TableState::Free) if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed &&
lt->state != TableState::Free)
{ {
for (const auto& [propName, subProp] : rt->props) for (const auto& [propName, subProp] : rt->props)
{ {
@ -1634,9 +1675,8 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed)
{ {
ok = false; ok = false;
errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); errors.push_back(TypeError{location, UnknownProperty{superTy, propName}});
if (!FFlag::LuauExtendedClassMismatchError) if (!FFlag::LuauExtendedClassMismatchError)
tryUnify_(prop.type, singletonTypes.errorType); tryUnify_(prop.type, singletonTypes.errorRecoveryType());
} }
else else
{ {
@ -1952,7 +1992,7 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty)
{ {
LUAU_ASSERT(get<Unifiable::Error>(any)); LUAU_ASSERT(get<Unifiable::Error>(any));
const TypeId anyTy = singletonTypes.errorType; const TypeId anyTy = singletonTypes.errorRecoveryType();
if (FFlag::LuauTypecheckOpts) if (FFlag::LuauTypecheckOpts)
{ {
@ -2046,7 +2086,7 @@ void Unifier::occursCheck(std::unordered_set<TypeId>& seen_DEPRECATED, DenseHash
{ {
errors.push_back(TypeError{location, OccursCheckFailed{}}); errors.push_back(TypeError{location, OccursCheckFailed{}});
log(needle); log(needle);
*asMutable(needle) = ErrorTypeVar{}; *asMutable(needle) = *singletonTypes.errorRecoveryType();
return; return;
} }
@ -2134,7 +2174,7 @@ void Unifier::occursCheck(std::unordered_set<TypePackId>& seen_DEPRECATED, Dense
{ {
errors.push_back(TypeError{location, OccursCheckFailed{}}); errors.push_back(TypeError{location, OccursCheckFailed{}});
log(needle); log(needle);
*asMutable(needle) = ErrorTypeVar{}; *asMutable(needle) = *singletonTypes.errorRecoveryTypePack();
return; return;
} }

View file

@ -255,6 +255,14 @@ public:
{ {
return visit((class AstType*)node); return visit((class AstType*)node);
} }
virtual bool visit(class AstTypeSingletonBool* node)
{
return visit((class AstType*)node);
}
virtual bool visit(class AstTypeSingletonString* node)
{
return visit((class AstType*)node);
}
virtual bool visit(class AstTypeError* node) virtual bool visit(class AstTypeError* node)
{ {
return visit((class AstType*)node); return visit((class AstType*)node);
@ -1158,6 +1166,30 @@ public:
unsigned messageIndex; unsigned messageIndex;
}; };
class AstTypeSingletonBool : public AstType
{
public:
LUAU_RTTI(AstTypeSingletonBool)
AstTypeSingletonBool(const Location& location, bool value);
void visit(AstVisitor* visitor) override;
bool value;
};
class AstTypeSingletonString : public AstType
{
public:
LUAU_RTTI(AstTypeSingletonString)
AstTypeSingletonString(const Location& location, const AstArray<char>& value);
void visit(AstVisitor* visitor) override;
const AstArray<char> value;
};
class AstTypePack : public AstNode class AstTypePack : public AstNode
{ {
public: public:

View file

@ -286,6 +286,7 @@ private:
// `<' typeAnnotation[, ...] `>' // `<' typeAnnotation[, ...] `>'
AstArray<AstTypeOrPack> parseTypeParams(); AstArray<AstTypeOrPack> parseTypeParams();
std::optional<AstArray<char>> parseCharArray();
AstExpr* parseString(); AstExpr* parseString();
AstLocal* pushLocal(const Binding& binding); AstLocal* pushLocal(const Binding& binding);

View file

@ -34,4 +34,6 @@ bool equalsLower(std::string_view lhs, std::string_view rhs);
size_t hashRange(const char* data, size_t size); size_t hashRange(const char* data, size_t size);
std::string escape(std::string_view s);
bool isIdentifier(std::string_view s);
} // namespace Luau } // namespace Luau

View file

@ -841,6 +841,28 @@ void AstTypeIntersection::visit(AstVisitor* visitor)
} }
} }
AstTypeSingletonBool::AstTypeSingletonBool(const Location& location, bool value)
: AstType(ClassIndex(), location)
, value(value)
{
}
void AstTypeSingletonBool::visit(AstVisitor* visitor)
{
visitor->visit(this);
}
AstTypeSingletonString::AstTypeSingletonString(const Location& location, const AstArray<char>& value)
: AstType(ClassIndex(), location)
, value(value)
{
}
void AstTypeSingletonString::visit(AstVisitor* visitor)
{
visitor->visit(this);
}
AstTypeError::AstTypeError(const Location& location, const AstArray<AstType*>& types, bool isMissing, unsigned messageIndex) AstTypeError::AstTypeError(const Location& location, const AstArray<AstType*>& types, bool isMissing, unsigned messageIndex)
: AstType(ClassIndex(), location) : AstType(ClassIndex(), location)
, types(types) , types(types)

View file

@ -16,6 +16,7 @@ LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false)
LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false) LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false)
LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false) LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false)
LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false)
LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false)
LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false) LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false)
namespace Luau namespace Luau
@ -1278,7 +1279,27 @@ AstType* Parser::parseTableTypeAnnotation()
while (lexer.current().type != '}') while (lexer.current().type != '}')
{ {
if (lexer.current().type == '[') if (FFlag::LuauParseSingletonTypes && lexer.current().type == '[' &&
(lexer.lookahead().type == Lexeme::RawString || lexer.lookahead().type == Lexeme::QuotedString))
{
const Lexeme begin = lexer.current();
nextLexeme(); // [
std::optional<AstArray<char>> chars = parseCharArray();
expectMatchAndConsume(']', begin);
expectAndConsume(':', "table field");
AstType* type = parseTypeAnnotation();
// TODO: since AstName conains a char*, it can't contain null
bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size);
if (chars && !containsNull)
props.push_back({AstName(chars->data), begin.location, type});
else
report(begin.location, "String literal contains malformed escape sequence");
}
else if (lexer.current().type == '[')
{ {
if (indexer) if (indexer)
{ {
@ -1528,6 +1549,32 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack)
nextLexeme(); nextLexeme();
return {allocator.alloc<AstTypeReference>(begin, std::nullopt, nameNil), {}}; return {allocator.alloc<AstTypeReference>(begin, std::nullopt, nameNil), {}};
} }
else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedTrue)
{
nextLexeme();
return {allocator.alloc<AstTypeSingletonBool>(begin, true)};
}
else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedFalse)
{
nextLexeme();
return {allocator.alloc<AstTypeSingletonBool>(begin, false)};
}
else if (FFlag::LuauParseSingletonTypes && (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString))
{
if (std::optional<AstArray<char>> value = parseCharArray())
{
AstArray<char> svalue = *value;
return {allocator.alloc<AstTypeSingletonString>(begin, svalue)};
}
else
return {reportTypeAnnotationError(begin, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")};
}
else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::BrokenString)
{
Location location = lexer.current().location;
nextLexeme();
return {reportTypeAnnotationError(location, {}, /*isMissing*/ false, "Malformed string")};
}
else if (lexer.current().type == Lexeme::Name) else if (lexer.current().type == Lexeme::Name)
{ {
std::optional<AstName> prefix; std::optional<AstName> prefix;
@ -2416,7 +2463,7 @@ AstArray<AstTypeOrPack> Parser::parseTypeParams()
return copy(parameters); return copy(parameters);
} }
AstExpr* Parser::parseString() std::optional<AstArray<char>> Parser::parseCharArray()
{ {
LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString); LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString);
@ -2426,11 +2473,8 @@ AstExpr* Parser::parseString()
{ {
if (!Lexer::fixupQuotedString(scratchData)) if (!Lexer::fixupQuotedString(scratchData))
{ {
Location location = lexer.current().location;
nextLexeme(); nextLexeme();
return std::nullopt;
return reportExprError(location, {}, "String literal contains malformed escape sequence");
} }
} }
else else
@ -2438,12 +2482,18 @@ AstExpr* Parser::parseString()
Lexer::fixupMultilineString(scratchData); Lexer::fixupMultilineString(scratchData);
} }
Location start = lexer.current().location;
AstArray<char> value = copy(scratchData); AstArray<char> value = copy(scratchData);
nextLexeme(); nextLexeme();
return value;
}
return allocator.alloc<AstExprConstantString>(start, value); AstExpr* Parser::parseString()
{
Location location = lexer.current().location;
if (std::optional<AstArray<char>> value = parseCharArray())
return allocator.alloc<AstExprConstantString>(location, *value);
else
return reportExprError(location, {}, "String literal contains malformed escape sequence");
} }
AstLocal* Parser::pushLocal(const Binding& binding) AstLocal* Parser::pushLocal(const Binding& binding)

View file

@ -225,4 +225,62 @@ size_t hashRange(const char* data, size_t size)
return hash; return hash;
} }
bool isIdentifier(std::string_view s)
{
return (s.find_first_not_of("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ01234567890_") == std::string::npos);
}
std::string escape(std::string_view s)
{
std::string r;
r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting
for (uint8_t c : s)
{
if (c >= ' ' && c != '\\' && c != '\'' && c != '\"')
r += c;
else
{
r += '\\';
switch (c)
{
case '\a':
r += 'a';
break;
case '\b':
r += 'b';
break;
case '\f':
r += 'f';
break;
case '\n':
r += 'n';
break;
case '\r':
r += 'r';
break;
case '\t':
r += 't';
break;
case '\v':
r += 'v';
break;
case '\'':
r += '\'';
break;
case '\"':
r += '\"';
break;
case '\\':
r += '\\';
break;
default:
Luau::formatAppend(r, "%03u", c);
}
}
}
return r;
}
} // namespace Luau } // namespace Luau

View file

@ -236,32 +236,12 @@ int main(int argc, char** argv)
Luau::registerBuiltinTypes(frontend.typeChecker); Luau::registerBuiltinTypes(frontend.typeChecker);
Luau::freeze(frontend.typeChecker.globalTypes); Luau::freeze(frontend.typeChecker.globalTypes);
std::vector<std::string> files = getSourceFiles(argc, argv);
int failed = 0; int failed = 0;
for (int i = 1; i < argc; ++i) for (const std::string& path : files)
{ failed += !analyzeFile(frontend, path.c_str(), format, annotate);
if (argv[i][0] == '-')
continue;
if (isDirectory(argv[i]))
{
traverseDirectory(argv[i], [&](const std::string& name) {
// Look for .luau first and if absent, fall back to .lua
if (name.length() > 5 && name.rfind(".luau") == name.length() - 5)
{
failed += !analyzeFile(frontend, name.c_str(), format, annotate);
}
else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4)
{
failed += !analyzeFile(frontend, name.c_str(), format, annotate);
}
});
}
else
{
failed += !analyzeFile(frontend, argv[i], format, annotate);
}
}
if (!configResolver.configErrors.empty()) if (!configResolver.configErrors.empty())
{ {

View file

@ -223,3 +223,40 @@ std::optional<std::string> getParentPath(const std::string& path)
return ""; return "";
} }
static std::string getExtension(const std::string& path)
{
std::string::size_type dot = path.find_last_of(".\\/");
if (dot == std::string::npos || path[dot] != '.')
return "";
return path.substr(dot);
}
std::vector<std::string> getSourceFiles(int argc, char** argv)
{
std::vector<std::string> files;
for (int i = 1; i < argc; ++i)
{
if (argv[i][0] == '-')
continue;
if (isDirectory(argv[i]))
{
traverseDirectory(argv[i], [&](const std::string& name) {
std::string ext = getExtension(name);
if (ext == ".lua" || ext == ".luau")
files.push_back(name);
});
}
else
{
files.push_back(argv[i]);
}
}
return files;
}

View file

@ -4,6 +4,7 @@
#include <optional> #include <optional>
#include <string> #include <string>
#include <functional> #include <functional>
#include <vector>
std::optional<std::string> readFile(const std::string& name); std::optional<std::string> readFile(const std::string& name);
@ -12,3 +13,5 @@ bool traverseDirectory(const std::string& path, const std::function<void(const s
std::string joinPaths(const std::string& lhs, const std::string& rhs); std::string joinPaths(const std::string& lhs, const std::string& rhs);
std::optional<std::string> getParentPath(const std::string& path); std::optional<std::string> getParentPath(const std::string& path);
std::vector<std::string> getSourceFiles(int argc, char** argv);

View file

@ -20,7 +20,7 @@
enum class CompileFormat enum class CompileFormat
{ {
Default, Text,
Binary Binary
}; };
@ -33,7 +33,7 @@ static int lua_loadstring(lua_State* L)
lua_setsafeenv(L, LUA_ENVIRONINDEX, false); lua_setsafeenv(L, LUA_ENVIRONINDEX, false);
std::string bytecode = Luau::compile(std::string(s, l)); std::string bytecode = Luau::compile(std::string(s, l));
if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0) if (luau_load(L, chunkname, bytecode.data(), bytecode.size(), 0) == 0)
return 1; return 1;
lua_pushnil(L); lua_pushnil(L);
@ -80,7 +80,7 @@ static int lua_require(lua_State* L)
// now we can compile & run module on the new thread // now we can compile & run module on the new thread
std::string bytecode = Luau::compile(*source); std::string bytecode = Luau::compile(*source);
if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0)
{ {
int status = lua_resume(ML, L, 0); int status = lua_resume(ML, L, 0);
@ -151,7 +151,7 @@ static std::string runCode(lua_State* L, const std::string& source)
{ {
std::string bytecode = Luau::compile(source); std::string bytecode = Luau::compile(source);
if (luau_load(L, "=stdin", bytecode.data(), bytecode.size()) != 0) if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0)
{ {
size_t len; size_t len;
const char* msg = lua_tolstring(L, -1, &len); const char* msg = lua_tolstring(L, -1, &len);
@ -370,7 +370,7 @@ static bool runFile(const char* name, lua_State* GL)
std::string bytecode = Luau::compile(*source); std::string bytecode = Luau::compile(*source);
int status = 0; int status = 0;
if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0)
{ {
status = lua_resume(L, NULL, 0); status = lua_resume(L, NULL, 0);
} }
@ -379,11 +379,7 @@ static bool runFile(const char* name, lua_State* GL)
status = LUA_ERRSYNTAX; status = LUA_ERRSYNTAX;
} }
if (status == 0) if (status != 0)
{
return true;
}
else
{ {
std::string error; std::string error;
@ -400,8 +396,10 @@ static bool runFile(const char* name, lua_State* GL)
error += lua_debugtrace(L); error += lua_debugtrace(L);
fprintf(stderr, "%s", error.c_str()); fprintf(stderr, "%s", error.c_str());
return false;
} }
lua_pop(GL, 1);
return status == 0;
} }
static void report(const char* name, const Luau::Location& location, const char* type, const char* message) static void report(const char* name, const Luau::Location& location, const char* type, const char* message)
@ -431,14 +429,18 @@ static bool compileFile(const char* name, CompileFormat format)
try try
{ {
Luau::BytecodeBuilder bcb; Luau::BytecodeBuilder bcb;
if (format == CompileFormat::Text)
{
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source); bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source);
bcb.setDumpSource(*source); bcb.setDumpSource(*source);
}
Luau::compileOrThrow(bcb, *source); Luau::compileOrThrow(bcb, *source);
switch (format) switch (format)
{ {
case CompileFormat::Default: case CompileFormat::Text:
printf("%s", bcb.dumpEverything().c_str()); printf("%s", bcb.dumpEverything().c_str());
break; break;
case CompileFormat::Binary: case CompileFormat::Binary:
@ -504,7 +506,7 @@ int main(int argc, char** argv)
if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0)
{ {
CompileFormat format = CompileFormat::Default; CompileFormat format = CompileFormat::Text;
if (strcmp(argv[1], "--compile=binary") == 0) if (strcmp(argv[1], "--compile=binary") == 0)
format = CompileFormat::Binary; format = CompileFormat::Binary;
@ -514,27 +516,12 @@ int main(int argc, char** argv)
_setmode(_fileno(stdout), _O_BINARY); _setmode(_fileno(stdout), _O_BINARY);
#endif #endif
std::vector<std::string> files = getSourceFiles(argc, argv);
int failed = 0; int failed = 0;
for (int i = 2; i < argc; ++i) for (const std::string& path : files)
{ failed += !compileFile(path.c_str(), format);
if (argv[i][0] == '-')
continue;
if (isDirectory(argv[i]))
{
traverseDirectory(argv[i], [&](const std::string& name) {
if (name.length() > 5 && name.rfind(".luau") == name.length() - 5)
failed += !compileFile(name.c_str(), format);
else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4)
failed += !compileFile(name.c_str(), format);
});
}
else
{
failed += !compileFile(argv[i], format);
}
}
return failed; return failed;
} }
@ -548,33 +535,25 @@ int main(int argc, char** argv)
int profile = 0; int profile = 0;
for (int i = 1; i < argc; ++i) for (int i = 1; i < argc; ++i)
{
if (argv[i][0] != '-')
continue;
if (strcmp(argv[i], "--profile") == 0) if (strcmp(argv[i], "--profile") == 0)
profile = 10000; // default to 10 KHz profile = 10000; // default to 10 KHz
else if (strncmp(argv[i], "--profile=", 10) == 0) else if (strncmp(argv[i], "--profile=", 10) == 0)
profile = atoi(argv[i] + 10); profile = atoi(argv[i] + 10);
}
if (profile) if (profile)
profilerStart(L, profile); profilerStart(L, profile);
std::vector<std::string> files = getSourceFiles(argc, argv);
int failed = 0; int failed = 0;
for (int i = 1; i < argc; ++i) for (const std::string& path : files)
{ failed += !runFile(path.c_str(), L);
if (argv[i][0] == '-')
continue;
if (isDirectory(argv[i]))
{
traverseDirectory(argv[i], [&](const std::string& name) {
if (name.length() > 4 && name.rfind(".lua") == name.length() - 4)
failed += !runFile(name.c_str(), L);
});
}
else
{
failed += !runFile(argv[i], L);
}
}
if (profile) if (profile)
{ {

View file

@ -13,11 +13,9 @@ class AstNameTable;
class BytecodeBuilder; class BytecodeBuilder;
class BytecodeEncoder; class BytecodeEncoder;
// Note: this structure is duplicated in luacode.h, don't forget to change these in sync!
struct CompileOptions struct CompileOptions
{ {
// default bytecode version target; can be used to compile code for older clients
int bytecodeVersion = 1;
// 0 - no optimization // 0 - no optimization
// 1 - baseline optimization level that doesn't prevent debuggability // 1 - baseline optimization level that doesn't prevent debuggability
// 2 - includes optimizations that harm debuggability such as inlining // 2 - includes optimizations that harm debuggability such as inlining

View file

@ -0,0 +1,39 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <stddef.h>
/* Can be used to reconfigure visibility/exports for public APIs */
#ifndef LUACODE_API
#define LUACODE_API extern
#endif
typedef struct lua_CompileOptions lua_CompileOptions;
struct lua_CompileOptions
{
// 0 - no optimization
// 1 - baseline optimization level that doesn't prevent debuggability
// 2 - includes optimizations that harm debuggability such as inlining
int optimizationLevel; // default=1
// 0 - no debugging support
// 1 - line info & function names only; sufficient for backtraces
// 2 - full debug info with local & upvalue names; necessary for debugger
int debugLevel; // default=1
// 0 - no code coverage support
// 1 - statement coverage
// 2 - statement and expression coverage (verbose)
int coverageLevel; // default=0
// global builtin to construct vectors; disabled by default
const char* vectorLib;
const char* vectorCtor;
// null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these
const char** mutableGlobals;
};
/* compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy */
LUACODE_API char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize);

View file

@ -11,9 +11,6 @@
#include <math.h> #include <math.h>
LUAU_FASTFLAGVARIABLE(LuauPreloadClosures, false) LUAU_FASTFLAGVARIABLE(LuauPreloadClosures, false)
LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresFenv, false)
LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresUpval, false)
LUAU_FASTFLAGVARIABLE(LuauGenericSpecialGlobals, false)
LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport)
LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false) LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false)
@ -24,9 +21,6 @@ static const uint32_t kMaxRegisterCount = 255;
static const uint32_t kMaxUpvalueCount = 200; static const uint32_t kMaxUpvalueCount = 200;
static const uint32_t kMaxLocalCount = 200; static const uint32_t kMaxLocalCount = 200;
// TODO: Remove with LuauGenericSpecialGlobals
static const char* kSpecialGlobals[] = {"Game", "Workspace", "_G", "game", "plugin", "script", "shared", "workspace"};
CompileError::CompileError(const Location& location, const std::string& message) CompileError::CompileError(const Location& location, const std::string& message)
: location(location) : location(location)
, message(message) , message(message)
@ -466,7 +460,7 @@ struct Compiler
bool shared = false; bool shared = false;
if (FFlag::LuauPreloadClosuresUpval) if (FFlag::LuauPreloadClosures)
{ {
// Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure
// objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it
@ -482,18 +476,6 @@ struct Compiler
} }
} }
} }
// Optimization: when closure has no upvalues, instead of allocating it every time we can share closure objects
// (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it is used)
else if (FFlag::LuauPreloadClosures && options.optimizationLevel >= 1 && f->upvals.empty() && !setfenvUsed)
{
int32_t cid = bytecode.addConstantClosure(f->id);
if (cid >= 0 && cid < 32768)
{
bytecode.emitAD(LOP_DUPCLOSURE, target, cid);
return;
}
}
if (!shared) if (!shared)
bytecode.emitAD(LOP_NEWCLOSURE, target, pid); bytecode.emitAD(LOP_NEWCLOSURE, target, pid);
@ -3298,7 +3280,6 @@ struct Compiler
bool visit(AstStatLocalFunction* node) override bool visit(AstStatLocalFunction* node) override
{ {
// record local->function association for some optimizations // record local->function association for some optimizations
if (FFlag::LuauPreloadClosuresUpval)
self->locals[node->name].func = node->func; self->locals[node->name].func = node->func;
return true; return true;
@ -3711,8 +3692,6 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName
Compiler compiler(bytecode, options); Compiler compiler(bytecode, options);
// since access to some global objects may result in values that change over time, we block imports from non-readonly tables // since access to some global objects may result in values that change over time, we block imports from non-readonly tables
if (FFlag::LuauGenericSpecialGlobals)
{
if (AstName name = names.get("_G"); name.value) if (AstName name = names.get("_G"); name.value)
compiler.globals[name].writable = true; compiler.globals[name].writable = true;
@ -3720,15 +3699,6 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName
for (const char** ptr = options.mutableGlobals; *ptr; ++ptr) for (const char** ptr = options.mutableGlobals; *ptr; ++ptr)
if (AstName name = names.get(*ptr); name.value) if (AstName name = names.get(*ptr); name.value)
compiler.globals[name].writable = true; compiler.globals[name].writable = true;
}
else
{
for (const char* global : kSpecialGlobals)
{
if (AstName name = names.get(global); name.value)
compiler.globals[name].writable = true;
}
}
// this visitor traverses the AST to analyze mutability of locals/globals, filling Local::written and Global::written // this visitor traverses the AST to analyze mutability of locals/globals, filling Local::written and Global::written
Compiler::AssignmentVisitor assignmentVisitor(&compiler); Compiler::AssignmentVisitor assignmentVisitor(&compiler);
@ -3742,7 +3712,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName
} }
// this visitor tracks calls to getfenv/setfenv and disables some optimizations when they are found // this visitor tracks calls to getfenv/setfenv and disables some optimizations when they are found
if (FFlag::LuauPreloadClosuresFenv && options.optimizationLevel >= 1 && (names.get("getfenv").value || names.get("setfenv").value)) if (options.optimizationLevel >= 1 && (names.get("getfenv").value || names.get("setfenv").value))
{ {
Compiler::FenvVisitor fenvVisitor(compiler.getfenvUsed, compiler.setfenvUsed); Compiler::FenvVisitor fenvVisitor(compiler.getfenvUsed, compiler.setfenvUsed);
root->visit(&fenvVisitor); root->visit(&fenvVisitor);

29
Compiler/src/lcode.cpp Normal file
View file

@ -0,0 +1,29 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "luacode.h"
#include "Luau/Compiler.h"
#include <string.h>
char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize)
{
LUAU_ASSERT(outsize);
Luau::CompileOptions opts;
if (options)
{
static_assert(sizeof(lua_CompileOptions) == sizeof(Luau::CompileOptions), "C and C++ interface must match");
memcpy(static_cast<void*>(&opts), options, sizeof(opts));
}
std::string result = compile(std::string(source, size), opts);
char* copy = static_cast<char*>(malloc(result.size()));
if (!copy)
return nullptr;
memcpy(copy, result.data(), result.size());
*outsize = result.size();
return copy;
}

View file

@ -40,8 +40,13 @@ make config=release luau luau-analyze
To integrate Luau into your CMake application projects, at the minimum you'll need to depend on `Luau.Compiler` and `Luau.VM` projects. From there you need to create a new Luau state (using Lua 5.x API such as `lua_newstate`), compile source to bytecode and load it into the VM like this: To integrate Luau into your CMake application projects, at the minimum you'll need to depend on `Luau.Compiler` and `Luau.VM` projects. From there you need to create a new Luau state (using Lua 5.x API such as `lua_newstate`), compile source to bytecode and load it into the VM like this:
```cpp ```cpp
std::string bytecode = Luau::compile(source); // needs Luau/Compiler.h include // needs lua.h and luacode.h
if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0) size_t bytecodeSize = 0;
char* bytecode = luau_compile(source, strlen(source), NULL, &bytecodeSize);
int result = luau_load(L, chunkname, bytecode, bytecodeSize, 0);
free(bytecode);
if (result == 0)
return 1; /* return chunk main function */ return 1; /* return chunk main function */
``` ```

View file

@ -7,7 +7,7 @@ Any source code can not result in memory safety errors or crashes during its com
Note that Luau does not provide termination guarantees - some code may exhaust CPU or RAM resources on the system during compilation or execution. Note that Luau does not provide termination guarantees - some code may exhaust CPU or RAM resources on the system during compilation or execution.
The runtime expects valid bytecode as an input. Feeding bytecode that was not produced by Luau compiler into the VM is not supported and The runtime expects valid bytecode as an input. Feeding bytecode that was not produced by Luau compiler into the VM is not supported and
doesn't come with any security guarantees; make sure to sign the bytecode when it crosses a network or file system boundary to avoid tampering. doesn't come with any security guarantees; make sure to sign and/or encrypt the bytecode when it crosses a network or file system boundary to avoid tampering.
# Reporting a Vulnerability # Reporting a Vulnerability

View file

@ -25,9 +25,11 @@ target_sources(Luau.Compiler PRIVATE
Compiler/include/Luau/Bytecode.h Compiler/include/Luau/Bytecode.h
Compiler/include/Luau/BytecodeBuilder.h Compiler/include/Luau/BytecodeBuilder.h
Compiler/include/Luau/Compiler.h Compiler/include/Luau/Compiler.h
Compiler/include/luacode.h
Compiler/src/BytecodeBuilder.cpp Compiler/src/BytecodeBuilder.cpp
Compiler/src/Compiler.cpp Compiler/src/Compiler.cpp
Compiler/src/lcode.cpp
) )
# Luau.Analysis Sources # Luau.Analysis Sources
@ -204,6 +206,7 @@ if(TARGET Luau.UnitTest)
tests/TypeInfer.intersectionTypes.test.cpp tests/TypeInfer.intersectionTypes.test.cpp
tests/TypeInfer.provisional.test.cpp tests/TypeInfer.provisional.test.cpp
tests/TypeInfer.refinements.test.cpp tests/TypeInfer.refinements.test.cpp
tests/TypeInfer.singletons.test.cpp
tests/TypeInfer.tables.test.cpp tests/TypeInfer.tables.test.cpp
tests/TypeInfer.test.cpp tests/TypeInfer.test.cpp
tests/TypeInfer.tryUnify.test.cpp tests/TypeInfer.tryUnify.test.cpp

View file

@ -102,6 +102,8 @@ LUA_API lua_State* lua_newstate(lua_Alloc f, void* ud);
LUA_API void lua_close(lua_State* L); LUA_API void lua_close(lua_State* L);
LUA_API lua_State* lua_newthread(lua_State* L); LUA_API lua_State* lua_newthread(lua_State* L);
LUA_API lua_State* lua_mainthread(lua_State* L); LUA_API lua_State* lua_mainthread(lua_State* L);
LUA_API void lua_resetthread(lua_State* L);
LUA_API int lua_isthreadreset(lua_State* L);
/* /*
** basic stack manipulation ** basic stack manipulation
@ -162,8 +164,7 @@ LUA_API void lua_pushlstring(lua_State* L, const char* s, size_t l);
LUA_API void lua_pushstring(lua_State* L, const char* s); LUA_API void lua_pushstring(lua_State* L, const char* s);
LUA_API const char* lua_pushvfstring(lua_State* L, const char* fmt, va_list argp); LUA_API const char* lua_pushvfstring(lua_State* L, const char* fmt, va_list argp);
LUA_API LUA_PRINTF_ATTR(2, 3) const char* lua_pushfstringL(lua_State* L, const char* fmt, ...); LUA_API LUA_PRINTF_ATTR(2, 3) const char* lua_pushfstringL(lua_State* L, const char* fmt, ...);
LUA_API void lua_pushcfunction( LUA_API void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont);
lua_State* L, lua_CFunction fn, const char* debugname = NULL, int nup = 0, lua_Continuation cont = NULL);
LUA_API void lua_pushboolean(lua_State* L, int b); LUA_API void lua_pushboolean(lua_State* L, int b);
LUA_API void lua_pushlightuserdata(lua_State* L, void* p); LUA_API void lua_pushlightuserdata(lua_State* L, void* p);
LUA_API int lua_pushthread(lua_State* L); LUA_API int lua_pushthread(lua_State* L);
@ -178,9 +179,9 @@ LUA_API void lua_rawget(lua_State* L, int idx);
LUA_API void lua_rawgeti(lua_State* L, int idx, int n); LUA_API void lua_rawgeti(lua_State* L, int idx, int n);
LUA_API void lua_createtable(lua_State* L, int narr, int nrec); LUA_API void lua_createtable(lua_State* L, int narr, int nrec);
LUA_API void lua_setreadonly(lua_State* L, int idx, bool value); LUA_API void lua_setreadonly(lua_State* L, int idx, int enabled);
LUA_API int lua_getreadonly(lua_State* L, int idx); LUA_API int lua_getreadonly(lua_State* L, int idx);
LUA_API void lua_setsafeenv(lua_State* L, int idx, bool value); LUA_API void lua_setsafeenv(lua_State* L, int idx, int enabled);
LUA_API void* lua_newuserdata(lua_State* L, size_t sz, int tag); LUA_API void* lua_newuserdata(lua_State* L, size_t sz, int tag);
LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)); LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*));
@ -200,7 +201,7 @@ LUA_API int lua_setfenv(lua_State* L, int idx);
/* /*
** `load' and `call' functions (load and run Luau bytecode) ** `load' and `call' functions (load and run Luau bytecode)
*/ */
LUA_API int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env = 0); LUA_API int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env);
LUA_API void lua_call(lua_State* L, int nargs, int nresults); LUA_API void lua_call(lua_State* L, int nargs, int nresults);
LUA_API int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc); LUA_API int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc);
@ -293,6 +294,8 @@ LUA_API void lua_unref(lua_State* L, int ref);
#define lua_isnoneornil(L, n) (lua_type(L, (n)) <= LUA_TNIL) #define lua_isnoneornil(L, n) (lua_type(L, (n)) <= LUA_TNIL)
#define lua_pushliteral(L, s) lua_pushlstring(L, "" s, (sizeof(s) / sizeof(char)) - 1) #define lua_pushliteral(L, s) lua_pushlstring(L, "" s, (sizeof(s) / sizeof(char)) - 1)
#define lua_pushcfunction(L, fn, debugname) lua_pushcclosurek(L, fn, debugname, 0, NULL)
#define lua_pushcclosure(L, fn, debugname, nup) lua_pushcclosurek(L, fn, debugname, nup, NULL)
#define lua_setglobal(L, s) lua_setfield(L, LUA_GLOBALSINDEX, (s)) #define lua_setglobal(L, s) lua_setfield(L, LUA_GLOBALSINDEX, (s))
#define lua_getglobal(L, s) lua_getfield(L, LUA_GLOBALSINDEX, (s)) #define lua_getglobal(L, s) lua_getfield(L, LUA_GLOBALSINDEX, (s))
@ -319,8 +322,8 @@ LUA_API const char* lua_setlocal(lua_State* L, int level, int n);
LUA_API const char* lua_getupvalue(lua_State* L, int funcindex, int n); LUA_API const char* lua_getupvalue(lua_State* L, int funcindex, int n);
LUA_API const char* lua_setupvalue(lua_State* L, int funcindex, int n); LUA_API const char* lua_setupvalue(lua_State* L, int funcindex, int n);
LUA_API void lua_singlestep(lua_State* L, bool singlestep); LUA_API void lua_singlestep(lua_State* L, int enabled);
LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, bool enable); LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled);
/* Warning: this function is not thread-safe since it stores the result in a shared global array! Only use for debugging. */ /* Warning: this function is not thread-safe since it stores the result in a shared global array! Only use for debugging. */
LUA_API const char* lua_debugtrace(lua_State* L); LUA_API const char* lua_debugtrace(lua_State* L);
@ -361,6 +364,7 @@ struct lua_Callbacks
void (*debuginterrupt)(lua_State* L, lua_Debug* ar); /* gets called when thread execution is interrupted by break in another thread */ void (*debuginterrupt)(lua_State* L, lua_Debug* ar); /* gets called when thread execution is interrupted by break in another thread */
void (*debugprotectederror)(lua_State* L); /* gets called when protected call results in an error */ void (*debugprotectederror)(lua_State* L); /* gets called when protected call results in an error */
}; };
typedef struct lua_Callbacks lua_Callbacks;
LUA_API lua_Callbacks* lua_callbacks(lua_State* L); LUA_API lua_Callbacks* lua_callbacks(lua_State* L);

View file

@ -8,11 +8,12 @@
#define luaL_typeerror(L, narg, tname) luaL_typeerrorL(L, narg, tname) #define luaL_typeerror(L, narg, tname) luaL_typeerrorL(L, narg, tname)
#define luaL_argerror(L, narg, extramsg) luaL_argerrorL(L, narg, extramsg) #define luaL_argerror(L, narg, extramsg) luaL_argerrorL(L, narg, extramsg)
typedef struct luaL_Reg struct luaL_Reg
{ {
const char* name; const char* name;
lua_CFunction func; lua_CFunction func;
} luaL_Reg; };
typedef struct luaL_Reg luaL_Reg;
LUALIB_API void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l); LUALIB_API void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l);
LUALIB_API int luaL_getmetafield(lua_State* L, int obj, const char* e); LUALIB_API int luaL_getmetafield(lua_State* L, int obj, const char* e);
@ -78,6 +79,7 @@ struct luaL_Buffer
struct TString* storage; struct TString* storage;
char buffer[LUA_BUFFERSIZE]; char buffer[LUA_BUFFERSIZE];
}; };
typedef struct luaL_Buffer luaL_Buffer;
// when internal buffer storage is exhausted, a mutable string value 'storage' will be placed on the stack // when internal buffer storage is exhausted, a mutable string value 'storage' will be placed on the stack
// in general, functions expect the mutable string buffer to be placed on top of the stack (top-1) // in general, functions expect the mutable string buffer to be placed on top of the stack (top-1)

View file

@ -593,7 +593,7 @@ const char* lua_pushfstringL(lua_State* L, const char* fmt, ...)
return ret; return ret;
} }
void lua_pushcfunction(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont) void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont)
{ {
luaC_checkGC(L); luaC_checkGC(L);
luaC_checkthreadsleep(L); luaC_checkthreadsleep(L);
@ -698,13 +698,13 @@ void lua_createtable(lua_State* L, int narray, int nrec)
return; return;
} }
void lua_setreadonly(lua_State* L, int objindex, bool value) void lua_setreadonly(lua_State* L, int objindex, int enabled)
{ {
const TValue* o = index2adr(L, objindex); const TValue* o = index2adr(L, objindex);
api_check(L, ttistable(o)); api_check(L, ttistable(o));
Table* t = hvalue(o); Table* t = hvalue(o);
api_check(L, t != hvalue(registry(L))); api_check(L, t != hvalue(registry(L)));
t->readonly = value; t->readonly = bool(enabled);
return; return;
} }
@ -717,12 +717,12 @@ int lua_getreadonly(lua_State* L, int objindex)
return res; return res;
} }
void lua_setsafeenv(lua_State* L, int objindex, bool value) void lua_setsafeenv(lua_State* L, int objindex, int enabled)
{ {
const TValue* o = index2adr(L, objindex); const TValue* o = index2adr(L, objindex);
api_check(L, ttistable(o)); api_check(L, ttistable(o));
Table* t = hvalue(o); Table* t = hvalue(o);
t->safeenv = value; t->safeenv = bool(enabled);
return; return;
} }

View file

@ -436,8 +436,8 @@ static const luaL_Reg base_funcs[] = {
static void auxopen(lua_State* L, const char* name, lua_CFunction f, lua_CFunction u) static void auxopen(lua_State* L, const char* name, lua_CFunction f, lua_CFunction u)
{ {
lua_pushcfunction(L, u); lua_pushcfunction(L, u, NULL);
lua_pushcfunction(L, f, name, 1); lua_pushcclosure(L, f, name, 1);
lua_setfield(L, -2, name); lua_setfield(L, -2, name);
} }
@ -456,10 +456,10 @@ LUALIB_API int luaopen_base(lua_State* L)
auxopen(L, "ipairs", luaB_ipairs, luaB_inext); auxopen(L, "ipairs", luaB_ipairs, luaB_inext);
auxopen(L, "pairs", luaB_pairs, luaB_next); auxopen(L, "pairs", luaB_pairs, luaB_next);
lua_pushcfunction(L, luaB_pcally, "pcall", 0, luaB_pcallcont); lua_pushcclosurek(L, luaB_pcally, "pcall", 0, luaB_pcallcont);
lua_setfield(L, -2, "pcall"); lua_setfield(L, -2, "pcall");
lua_pushcfunction(L, luaB_xpcally, "xpcall", 0, luaB_xpcallcont); lua_pushcclosurek(L, luaB_xpcally, "xpcall", 0, luaB_xpcallcont);
lua_setfield(L, -2, "xpcall"); lua_setfield(L, -2, "xpcall");
return 1; return 1;

View file

@ -5,6 +5,8 @@
#include "lstate.h" #include "lstate.h"
#include "lvm.h" #include "lvm.h"
LUAU_FASTFLAGVARIABLE(LuauCoroutineClose, false)
#define CO_RUN 0 /* running */ #define CO_RUN 0 /* running */
#define CO_SUS 1 /* suspended */ #define CO_SUS 1 /* suspended */
#define CO_NOR 2 /* 'normal' (it resumed another coroutine) */ #define CO_NOR 2 /* 'normal' (it resumed another coroutine) */
@ -208,8 +210,7 @@ static int cowrap(lua_State* L)
{ {
cocreate(L); cocreate(L);
lua_pushcfunction(L, auxwrapy, NULL, 1, auxwrapcont); lua_pushcclosurek(L, auxwrapy, NULL, 1, auxwrapcont);
return 1; return 1;
} }
@ -232,6 +233,34 @@ static int coyieldable(lua_State* L)
return 1; return 1;
} }
static int coclose(lua_State* L)
{
if (!FFlag::LuauCoroutineClose)
luaL_error(L, "coroutine.close is not enabled");
lua_State* co = lua_tothread(L, 1);
luaL_argexpected(L, co, 1, "thread");
int status = auxstatus(L, co);
if (status != CO_DEAD && status != CO_SUS)
luaL_error(L, "cannot close %s coroutine", statnames[status]);
if (co->status == LUA_OK || co->status == LUA_YIELD)
{
lua_pushboolean(L, true);
lua_resetthread(co);
return 1;
}
else
{
lua_pushboolean(L, false);
if (lua_gettop(co))
lua_xmove(co, L, 1); /* move error message */
lua_resetthread(co);
return 2;
}
}
static const luaL_Reg co_funcs[] = { static const luaL_Reg co_funcs[] = {
{"create", cocreate}, {"create", cocreate},
{"running", corunning}, {"running", corunning},
@ -239,6 +268,7 @@ static const luaL_Reg co_funcs[] = {
{"wrap", cowrap}, {"wrap", cowrap},
{"yield", coyield}, {"yield", coyield},
{"isyieldable", coyieldable}, {"isyieldable", coyieldable},
{"close", coclose},
{NULL, NULL}, {NULL, NULL},
}; };
@ -246,7 +276,7 @@ LUALIB_API int luaopen_coroutine(lua_State* L)
{ {
luaL_register(L, LUA_COLIBNAME, co_funcs); luaL_register(L, LUA_COLIBNAME, co_funcs);
lua_pushcfunction(L, coresumey, "resume", 0, coresumecont); lua_pushcclosurek(L, coresumey, "resume", 0, coresumecont);
lua_setfield(L, -2, "resume"); lua_setfield(L, -2, "resume");
return 1; return 1;

View file

@ -316,7 +316,7 @@ void luaG_breakpoint(lua_State* L, Proto* p, int line, bool enable)
p->debuginsn[j] = LUAU_INSN_OP(p->code[j]); p->debuginsn[j] = LUAU_INSN_OP(p->code[j]);
} }
uint8_t op = enable ? LOP_BREAK : LUAU_INSN_OP(p->code[i]); uint8_t op = enable ? LOP_BREAK : LUAU_INSN_OP(p->debuginsn[i]);
// patch just the opcode byte, leave arguments alone // patch just the opcode byte, leave arguments alone
p->code[i] &= ~0xff; p->code[i] &= ~0xff;
@ -357,17 +357,17 @@ int luaG_getline(Proto* p, int pc)
return p->abslineinfo[pc >> p->linegaplog2] + p->lineinfo[pc]; return p->abslineinfo[pc >> p->linegaplog2] + p->lineinfo[pc];
} }
void lua_singlestep(lua_State* L, bool singlestep) void lua_singlestep(lua_State* L, int enabled)
{ {
L->singlestep = singlestep; L->singlestep = bool(enabled);
} }
void lua_breakpoint(lua_State* L, int funcindex, int line, bool enable) void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled)
{ {
const TValue* func = luaA_toobject(L, funcindex); const TValue* func = luaA_toobject(L, funcindex);
api_check(L, ttisfunction(func) && !clvalue(func)->isC); api_check(L, ttisfunction(func) && !clvalue(func)->isC);
luaG_breakpoint(L, clvalue(func)->l.p, line, enable); luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled));
} }
static size_t append(char* buf, size_t bufsize, size_t offset, const char* data) static size_t append(char* buf, size_t bufsize, size_t offset, const char* data)

View file

@ -19,6 +19,7 @@
LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false) LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false)
LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false)
LUAU_FASTFLAG(LuauCoroutineClose)
/* /*
** {====================================================== ** {======================================================
@ -300,7 +301,10 @@ static void resume(lua_State* L, void* ud)
if (L->status == 0) if (L->status == 0)
{ {
// start coroutine // start coroutine
LUAU_ASSERT(L->ci == L->base_ci && firstArg > L->base); LUAU_ASSERT(L->ci == L->base_ci && firstArg >= L->base);
if (FFlag::LuauCoroutineClose && firstArg == L->base)
luaG_runerror(L, "cannot resume dead coroutine");
if (luau_precall(L, firstArg - 1, LUA_MULTRET) != PCRLUA) if (luau_precall(L, firstArg - 1, LUA_MULTRET) != PCRLUA)
return; return;

View file

@ -22,7 +22,7 @@ LUALIB_API void luaL_openlibs(lua_State* L)
const luaL_Reg* lib = lualibs; const luaL_Reg* lib = lualibs;
for (; lib->func; lib++) for (; lib->func; lib++)
{ {
lua_pushcfunction(L, lib->func); lua_pushcfunction(L, lib->func, NULL);
lua_pushstring(L, lib->name); lua_pushstring(L, lib->name);
lua_call(L, 1, 0); lua_call(L, 1, 0);
} }

View file

@ -124,6 +124,34 @@ void luaE_freethread(lua_State* L, lua_State* L1)
luaM_free(L, L1, sizeof(lua_State), L1->memcat); luaM_free(L, L1, sizeof(lua_State), L1->memcat);
} }
void lua_resetthread(lua_State* L)
{
/* close upvalues before clearing anything */
luaF_close(L, L->stack);
/* clear call frames */
CallInfo* ci = L->base_ci;
ci->func = L->stack;
ci->base = ci->func + 1;
ci->top = ci->base + LUA_MINSTACK;
setnilvalue(ci->func);
L->ci = ci;
luaD_reallocCI(L, BASIC_CI_SIZE);
/* clear thread state */
L->status = LUA_OK;
L->base = L->ci->base;
L->top = L->ci->base;
L->nCcalls = L->baseCcalls = 0;
/* clear thread stack */
luaD_reallocstack(L, BASIC_STACK_SIZE);
for (int i = 0; i < L->stacksize; i++)
setnilvalue(L->stack + i);
}
int lua_isthreadreset(lua_State* L)
{
return L->ci == L->base_ci && L->base == L->top && L->status == LUA_OK;
}
lua_State* lua_newstate(lua_Alloc f, void* ud) lua_State* lua_newstate(lua_Alloc f, void* ud)
{ {
int i; int i;

View file

@ -748,7 +748,7 @@ static int gmatch(lua_State* L)
luaL_checkstring(L, 2); luaL_checkstring(L, 2);
lua_settop(L, 2); lua_settop(L, 2);
lua_pushinteger(L, 0); lua_pushinteger(L, 0);
lua_pushcfunction(L, gmatch_aux, NULL, 3); lua_pushcclosure(L, gmatch_aux, NULL, 3);
return 1; return 1;
} }

View file

@ -265,7 +265,7 @@ static int iter_aux(lua_State* L)
static int iter_codes(lua_State* L) static int iter_codes(lua_State* L)
{ {
luaL_checkstring(L, 1); luaL_checkstring(L, 1);
lua_pushcfunction(L, iter_aux); lua_pushcfunction(L, iter_aux, NULL);
lua_pushvalue(L, 1); lua_pushvalue(L, 1);
lua_pushinteger(L, 0); lua_pushinteger(L, 0);
return 3; return 3;

View file

@ -25,8 +25,8 @@ try:
import scipy import scipy
from scipy import stats from scipy import stats
except ModuleNotFoundError: except ModuleNotFoundError:
print("scipy package is required") print("Warning: scipy package is not installed, confidence values will not be available")
exit(1) stats = None
scriptdir = os.path.dirname(os.path.realpath(__file__)) scriptdir = os.path.dirname(os.path.realpath(__file__))
defaultVm = 'luau.exe' if os.name == "nt" else './luau' defaultVm = 'luau.exe' if os.name == "nt" else './luau'
@ -200,11 +200,14 @@ def finalizeResult(result):
result.sampleStdDev = math.sqrt(sumOfSquares / (result.count - 1)) result.sampleStdDev = math.sqrt(sumOfSquares / (result.count - 1))
result.unbiasedEst = result.sampleStdDev * result.sampleStdDev result.unbiasedEst = result.sampleStdDev * result.sampleStdDev
if stats:
# Two-tailed distribution with 95% conf. # Two-tailed distribution with 95% conf.
tValue = stats.t.ppf(1 - 0.05 / 2, result.count - 1) tValue = stats.t.ppf(1 - 0.05 / 2, result.count - 1)
# Compute confidence interval # Compute confidence interval
result.sampleConfidenceInterval = tValue * result.sampleStdDev / math.sqrt(result.count) result.sampleConfidenceInterval = tValue * result.sampleStdDev / math.sqrt(result.count)
else:
result.sampleConfidenceInterval = result.sampleStdDev
else: else:
result.sampleStdDev = 0 result.sampleStdDev = 0
result.unbiasedEst = 0 result.unbiasedEst = 0
@ -377,14 +380,19 @@ def analyzeResult(subdir, main, comparisons):
tStat = abs(main.avg - compare.avg) / (pooledStdDev * math.sqrt(2 / main.count)) tStat = abs(main.avg - compare.avg) / (pooledStdDev * math.sqrt(2 / main.count))
degreesOfFreedom = 2 * main.count - 2 degreesOfFreedom = 2 * main.count - 2
if stats:
# Two-tailed distribution with 95% conf. # Two-tailed distribution with 95% conf.
tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom) tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom)
noSignificantDifference = tStat < tCritical noSignificantDifference = tStat < tCritical
pValue = 2 * (1 - stats.t.cdf(tStat, df = degreesOfFreedom)) pValue = 2 * (1 - stats.t.cdf(tStat, df = degreesOfFreedom))
else:
noSignificantDifference = None
pValue = -1
if noSignificantDifference: if noSignificantDifference is None:
verdict = ""
elif noSignificantDifference:
verdict = "likely same" verdict = "likely same"
elif main.avg < compare.avg: elif main.avg < compare.avg:
verdict = "likely worse" verdict = "likely worse"

View file

@ -88,7 +88,7 @@ for i=1,N do
local y=ymin+(j-1)*dy local y=ymin+(j-1)*dy
S = S + level(x,y) S = S + level(x,y)
end end
-- if i % 10 == 0 then print(collectgarbage"count") end -- if i % 10 == 0 then print(collectgarbage("count")) end
end end
print(S) print(S)

View file

@ -88,7 +88,7 @@ for i=1,N do
local y=ymin+(j-1)*dy local y=ymin+(j-1)*dy
S = S + level(x,y) S = S + level(x,y)
end end
-- if i % 10 == 0 then print(collectgarbage"count") end -- if i % 10 == 0 then print(collectgarbage("count")) end
end end
print(S) print(S)

View file

@ -275,7 +275,7 @@ local function memory(s)
local t=os.clock() local t=os.clock()
--local dt=string.format("%f",t-t0) --local dt=string.format("%f",t-t0)
local dt=t-t0 local dt=t-t0
--io.stdout:write(s,"\t",dt," sec\t",t," sec\t",math.floor(collectgarbage"count"/1024),"M\n") --io.stdout:write(s,"\t",dt," sec\t",t," sec\t",math.floor(collectgarbage("count")/1024),"M\n")
t0=t t0=t
end end
@ -286,7 +286,7 @@ local function do_(f,s)
end end
local function julia(l,a,b) local function julia(l,a,b)
memory"begin" memory("begin")
cx=a cy=b cx=a cy=b
root=newcell() root=newcell()
exterior=newcell() exterior.color=white exterior=newcell() exterior.color=white
@ -297,14 +297,14 @@ memory"begin"
do_(update,"update") do_(update,"update")
repeat repeat
N=0 color(root,Rxmin,Rxmax,Rymin,Rymax) --print("color",N) N=0 color(root,Rxmin,Rxmax,Rymin,Rymax) --print("color",N)
until N==0 memory"color" until N==0 memory("color")
repeat repeat
N=0 prewhite(root,Rxmin,Rxmax,Rymin,Rymax) --print("prewhite",N) N=0 prewhite(root,Rxmin,Rxmax,Rymin,Rymax) --print("prewhite",N)
until N==0 memory"prewhite" until N==0 memory("prewhite")
do_(recolor,"recolor") do_(recolor,"recolor")
do_(colorup,"colorup") --print("colorup",N) do_(colorup,"colorup") --print("colorup",N)
local g,b=do_(area,"area") --print("area",g,b,g+b) local g,b=do_(area,"area") --print("area",g,b,g+b)
show(i) memory"output" show(i) memory("output")
--print("edges",nE) --print("edges",nE)
end end
end end

View file

@ -759,7 +759,7 @@ Otherwise, `s` is interpreted as a [date format string](https://www.cplusplus.co
function os.difftime(a: number, b: number): number function os.difftime(a: number, b: number): number
``` ```
Calculates the difference in seconds between `a` and `b`; provided for compatibility. Calculates the difference in seconds between `a` and `b`; provided for compatibility only. Please use `a - b` instead.
``` ```
function os.time(t: table?): number function os.time(t: table?): number

View file

@ -257,7 +257,7 @@ DEFINE_PROTO_FUZZER(const luau::StatBlock& message)
lua_State* L = lua_newthread(globalState); lua_State* L = lua_newthread(globalState);
luaL_sandboxthread(L); luaL_sandboxthread(L);
if (luau_load(L, "=fuzz", bytecode.data(), bytecode.size()) == 0) if (luau_load(L, "=fuzz", bytecode.data(), bytecode.size(), 0) == 0)
{ {
interruptDeadline = std::chrono::system_clock::now() + kInterruptTimeout; interruptDeadline = std::chrono::system_clock::now() + kInterruptTimeout;

View file

@ -0,0 +1,34 @@
# coroutine.close
## Summary
Add `coroutine.close` function from Lua 5.4 that takes a suspended coroutine and makes it "dead" (non-runnable).
## Motivation
When implementing various higher level objects on top of coroutines, such as promises, it can be useful to cancel the coroutine execution externally - when the caller is not
interested in getting the results anymore, execution can be aborted. Since coroutines don't provide a way to do that externally, this requires the framework to implement
cancellation on top of coroutines by keeping extra status/token and checking that token in all places where the coroutine is resumed.
Since coroutine execution can be aborted with an error at any point, coroutines already implement support for "dead" status. If it were possible to externally transition a coroutine
to that status, it would be easier to implement cancellable promises on top of coroutines.
## Design
We implement Lua 5.4 behavior exactly with the exception of to-be-closed variables that we don't support. Quoting Lua 5.4 manual:
> coroutine.close (co)
> Closes coroutine co, that is, puts the coroutine in a dead state. The given coroutine must be dead or suspended. In case of error (either the original error that stopped the coroutine or errors in closing methods), returns false plus the error object; otherwise returns true.
The `co` argument must be a coroutine object (of type `thread`).
After closing the coroutine, it gets transitioned to dead state which means that `coroutine.status` will return `"dead"` and attempts to resume the coroutine will fail. In addition, the coroutine stack (which can be accessed via `debug.traceback` or `debug.info`) will become empty. Calling `coroutine.close` on a closed coroutine will return `true` - after closing, the coroutine transitions into a "dead" state with no error information.
## Drawbacks
None known, as this function doesn't introduce any existing states to coroutines, and is similar to running the coroutine to completion/error.
## Alternatives
Lua's name for this function is likely in part motivated by to-be-closed variables that we don't support. As such, a more appropriate name could be `coroutine.cancel` which also
aligns with use cases better. However, since the semantics is otherwise the same, using the same name as Lua 5.4 reduces library fragmentation.

View file

@ -48,6 +48,18 @@ type Animals = "Dog" | "Cat" | "Bird"
type TrueOrNil = true? type TrueOrNil = true?
``` ```
Adding constant strings as type means that it is now legal to write
`{["foo"]:T}` as a table type. This should be parsed as a property,
not an indexer. For example:
```lua
type T = {
["foo"]: number,
["$$bar"]: string,
baz: boolean,
}
```
The table type `T` is a table with three properties and no indexer.
### Semantics ### Semantics
You are allowed to provide a constant value to the generic primitive type. You are allowed to provide a constant value to the generic primitive type.

View file

@ -91,10 +91,6 @@ struct ACFixture : ACFixtureImpl<Fixture>
{ {
}; };
struct UnfrozenACFixture : ACFixtureImpl<UnfrozenFixture>
{
};
TEST_SUITE_BEGIN("AutocompleteTest"); TEST_SUITE_BEGIN("AutocompleteTest");
TEST_CASE_FIXTURE(ACFixture, "empty_program") TEST_CASE_FIXTURE(ACFixture, "empty_program")
@ -1919,9 +1915,10 @@ local bar: @1= foo
CHECK(!ac.entryMap.count("foo")); CHECK(!ac.entryMap.count("foo"));
} }
// CLI-45692: Remove UnfrozenACFixture here TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis")
TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_function_no_parenthesis")
{ {
ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true);
check(R"( check(R"(
local function target(a: (number) -> number) return a(4) end local function target(a: (number) -> number) return a(4) end
local function bar1(a: number) return -a end local function bar1(a: number) return -a end
@ -1950,9 +1947,10 @@ local fp: @1= f
CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); CHECK(ac.entryMap.count("({ x: number, y: number }) -> number"));
} }
// CLI-45692: Remove UnfrozenACFixture here TEST_CASE_FIXTURE(ACFixture, "type_correct_keywords")
TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_keywords")
{ {
ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true);
check(R"( check(R"(
local function a(x: boolean) end local function a(x: boolean) end
local function b(x: number?) end local function b(x: number?) end
@ -2484,7 +2482,7 @@ local t = {
CHECK(ac.entryMap.count("second")); CHECK(ac.entryMap.count("second"));
} }
TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols") TEST_CASE_FIXTURE(UnfrozenFixture, "autocomplete_documentation_symbols")
{ {
loadDefinition(R"( loadDefinition(R"(
declare y: { declare y: {

View file

@ -11,8 +11,6 @@
#include <string_view> #include <string_view>
LUAU_FASTFLAG(LuauPreloadClosures) LUAU_FASTFLAG(LuauPreloadClosures)
LUAU_FASTFLAG(LuauPreloadClosuresFenv)
LUAU_FASTFLAG(LuauPreloadClosuresUpval)
LUAU_FASTFLAG(LuauGenericSpecialGlobals) LUAU_FASTFLAG(LuauGenericSpecialGlobals)
using namespace Luau; using namespace Luau;
@ -2797,7 +2795,7 @@ CAPTURE UPVAL U1
RETURN R0 1 RETURN R0 1
)"); )");
if (FFlag::LuauPreloadClosuresUpval) if (FFlag::LuauPreloadClosures)
{ {
// recursive capture // recursive capture
CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"(
@ -3479,8 +3477,6 @@ CAPTURE VAL R0
RETURN R1 1 RETURN R1 1
)"); )");
if (FFlag::LuauPreloadClosuresFenv)
{
// if they don't need upvalues but we sense that environment may be modified, we disable this to avoid fenv-related identity confusion // if they don't need upvalues but we sense that environment may be modified, we disable this to avoid fenv-related identity confusion
CHECK_EQ("\n" + compileFunction(R"( CHECK_EQ("\n" + compileFunction(R"(
setfenv(1, {}) setfenv(1, {})
@ -3506,13 +3502,11 @@ return function() print("hi") end
NEWCLOSURE R0 P0 NEWCLOSURE R0 P0
RETURN R0 1 RETURN R0 1
)"); )");
}
} }
TEST_CASE("SharedClosure") TEST_CASE("SharedClosure")
{ {
ScopedFastFlag sff1("LuauPreloadClosures", true); ScopedFastFlag sff1("LuauPreloadClosures", true);
ScopedFastFlag sff2("LuauPreloadClosuresUpval", true);
// closures can be shared even if functions refer to upvalues, as long as upvalues are top-level // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level
CHECK_EQ("\n" + compileFunction(R"( CHECK_EQ("\n" + compileFunction(R"(
@ -3671,7 +3665,7 @@ RETURN R0 0
)"); )");
} }
TEST_CASE("LuauGenericSpecialGlobals") TEST_CASE("MutableGlobals")
{ {
const char* source = R"( const char* source = R"(
print() print()
@ -3685,43 +3679,6 @@ shared.print()
workspace.print() workspace.print()
)"; )";
{
ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", false};
// Check Roblox globals are here
CHECK_EQ("\n" + compileFunction0(source), R"(
GETIMPORT R0 1
CALL R0 0 0
GETIMPORT R1 3
GETTABLEKS R0 R1 K0
CALL R0 0 0
GETIMPORT R1 5
GETTABLEKS R0 R1 K0
CALL R0 0 0
GETIMPORT R1 7
GETTABLEKS R0 R1 K0
CALL R0 0 0
GETIMPORT R1 9
GETTABLEKS R0 R1 K0
CALL R0 0 0
GETIMPORT R1 11
GETTABLEKS R0 R1 K0
CALL R0 0 0
GETIMPORT R1 13
GETTABLEKS R0 R1 K0
CALL R0 0 0
GETIMPORT R1 15
GETTABLEKS R0 R1 K0
CALL R0 0 0
GETIMPORT R1 17
GETTABLEKS R0 R1 K0
CALL R0 0 0
RETURN R0 0
)");
}
ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", true};
// Check Roblox globals are no longer here // Check Roblox globals are no longer here
CHECK_EQ("\n" + compileFunction0(source), R"( CHECK_EQ("\n" + compileFunction0(source), R"(
GETIMPORT R0 1 GETIMPORT R0 1

View file

@ -1,5 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Compiler.h" #include "lua.h"
#include "lualib.h"
#include "luacode.h"
#include "Luau/BuiltinDefinitions.h" #include "Luau/BuiltinDefinitions.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
@ -10,9 +12,6 @@
#include "doctest.h" #include "doctest.h"
#include "ScopedFlags.h" #include "ScopedFlags.h"
#include "lua.h"
#include "lualib.h"
#include <fstream> #include <fstream>
#include <math.h> #include <math.h>
@ -49,8 +48,12 @@ static int lua_loadstring(lua_State* L)
lua_setsafeenv(L, LUA_ENVIRONINDEX, false); lua_setsafeenv(L, LUA_ENVIRONINDEX, false);
std::string bytecode = Luau::compile(std::string(s, l)); size_t bytecodeSize = 0;
if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0) char* bytecode = luau_compile(s, l, nullptr, &bytecodeSize);
int result = luau_load(L, chunkname, bytecode, bytecodeSize, 0);
free(bytecode);
if (result == 0)
return 1; return 1;
lua_pushnil(L); lua_pushnil(L);
@ -179,21 +182,17 @@ static StateRef runConformance(
std::string chunkname = "=" + std::string(name); std::string chunkname = "=" + std::string(name);
Luau::CompileOptions copts; lua_CompileOptions copts = {};
copts.optimizationLevel = 1; // default
copts.debugLevel = 2; // for debugger tests copts.debugLevel = 2; // for debugger tests
copts.vectorCtor = "vector"; // for vector tests copts.vectorCtor = "vector"; // for vector tests
std::string bytecode = Luau::compile(source, copts); size_t bytecodeSize = 0;
int status = 0; char* bytecode = luau_compile(source.data(), source.size(), &copts, &bytecodeSize);
int result = luau_load(L, chunkname.c_str(), bytecode, bytecodeSize, 0);
free(bytecode);
if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) int status = (result == 0) ? lua_resume(L, nullptr, 0) : LUA_ERRSYNTAX;
{
status = lua_resume(L, nullptr, 0);
}
else
{
status = LUA_ERRSYNTAX;
}
while (yield && (status == LUA_YIELD || status == LUA_BREAK)) while (yield && (status == LUA_YIELD || status == LUA_BREAK))
{ {
@ -332,34 +331,42 @@ TEST_CASE("UTF8")
TEST_CASE("Coroutine") TEST_CASE("Coroutine")
{ {
ScopedFastFlag sff("LuauCoroutineClose", true);
runConformance("coroutine.lua"); runConformance("coroutine.lua");
} }
TEST_CASE("PCall") static int cxxthrow(lua_State* L)
{ {
runConformance("pcall.lua", [](lua_State* L) {
lua_pushcfunction(L, [](lua_State* L) -> int {
#if LUA_USE_LONGJMP #if LUA_USE_LONGJMP
luaL_error(L, "oops"); luaL_error(L, "oops");
#else #else
throw std::runtime_error("oops"); throw std::runtime_error("oops");
#endif #endif
}); }
TEST_CASE("PCall")
{
runConformance("pcall.lua", [](lua_State* L) {
lua_pushcfunction(L, cxxthrow, "cxxthrow");
lua_setglobal(L, "cxxthrow"); lua_setglobal(L, "cxxthrow");
lua_pushcfunction(L, [](lua_State* L) -> int { lua_pushcfunction(
L,
[](lua_State* L) -> int {
lua_State* co = lua_tothread(L, 1); lua_State* co = lua_tothread(L, 1);
lua_xmove(L, co, 1); lua_xmove(L, co, 1);
lua_resumeerror(co, L); lua_resumeerror(co, L);
return 0; return 0;
}); },
"resumeerror");
lua_setglobal(L, "resumeerror"); lua_setglobal(L, "resumeerror");
}); });
} }
TEST_CASE("Pack") TEST_CASE("Pack")
{ {
ScopedFastFlag sff{ "LuauStrPackUBCastFix", true }; ScopedFastFlag sff{"LuauStrPackUBCastFix", true};
runConformance("tpack.lua"); runConformance("tpack.lua");
} }
@ -367,18 +374,18 @@ TEST_CASE("Pack")
TEST_CASE("Vector") TEST_CASE("Vector")
{ {
runConformance("vector.lua", [](lua_State* L) { runConformance("vector.lua", [](lua_State* L) {
lua_pushcfunction(L, lua_vector); lua_pushcfunction(L, lua_vector, "vector");
lua_setglobal(L, "vector"); lua_setglobal(L, "vector");
lua_pushvector(L, 0.0f, 0.0f, 0.0f); lua_pushvector(L, 0.0f, 0.0f, 0.0f);
luaL_newmetatable(L, "vector"); luaL_newmetatable(L, "vector");
lua_pushstring(L, "__index"); lua_pushstring(L, "__index");
lua_pushcfunction(L, lua_vector_index); lua_pushcfunction(L, lua_vector_index, nullptr);
lua_settable(L, -3); lua_settable(L, -3);
lua_pushstring(L, "__namecall"); lua_pushstring(L, "__namecall");
lua_pushcfunction(L, lua_vector_namecall); lua_pushcfunction(L, lua_vector_namecall, nullptr);
lua_settable(L, -3); lua_settable(L, -3);
lua_setreadonly(L, -1, true); lua_setreadonly(L, -1, true);
@ -513,15 +520,19 @@ TEST_CASE("Debugger")
}; };
// add breakpoint() function // add breakpoint() function
lua_pushcfunction(L, [](lua_State* L) -> int { lua_pushcfunction(
L,
[](lua_State* L) -> int {
int line = luaL_checkinteger(L, 1); int line = luaL_checkinteger(L, 1);
bool enabled = lua_isboolean(L, 2) ? lua_toboolean(L, 2) : true;
lua_Debug ar = {}; lua_Debug ar = {};
lua_getinfo(L, 1, "f", &ar); lua_getinfo(L, 1, "f", &ar);
lua_breakpoint(L, -1, line, true); lua_breakpoint(L, -1, line, enabled);
return 0; return 0;
}); },
"breakpoint");
lua_setglobal(L, "breakpoint"); lua_setglobal(L, "breakpoint");
}, },
[](lua_State* L) { [](lua_State* L) {
@ -744,7 +755,7 @@ TEST_CASE("ExceptionObject")
if (nsize == 0) if (nsize == 0)
{ {
free(ptr); free(ptr);
return NULL; return nullptr;
} }
else if (nsize > 512 * 1024) else if (nsize > 512 * 1024)
{ {

View file

@ -4,7 +4,8 @@
#include <ostream> #include <ostream>
#include <optional> #include <optional>
namespace std { namespace std
{
inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&) inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&)
{ {

View file

@ -203,6 +203,8 @@ TEST_CASE_FIXTURE(Fixture, "clone_class")
TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types")
{ {
ScopedFastFlag sff{"LuauErrorRecoveryType", true};
TypeVar freeTy(FreeTypeVar{TypeLevel{}}); TypeVar freeTy(FreeTypeVar{TypeLevel{}});
TypePackVar freeTp(FreeTypePack{TypeLevel{}}); TypePackVar freeTp(FreeTypePack{TypeLevel{}});
@ -212,12 +214,12 @@ TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types")
bool encounteredFreeType = false; bool encounteredFreeType = false;
TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, &encounteredFreeType); TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, &encounteredFreeType);
CHECK(Luau::get<ErrorTypeVar>(clonedTy)); CHECK_EQ("any", toString(clonedTy));
CHECK(encounteredFreeType); CHECK(encounteredFreeType);
encounteredFreeType = false; encounteredFreeType = false;
TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, &encounteredFreeType); TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, &encounteredFreeType);
CHECK(Luau::get<Unifiable::Error>(clonedTp)); CHECK_EQ("...any", toString(clonedTp));
CHECK(encounteredFreeType); CHECK(encounteredFreeType);
} }

View file

@ -198,7 +198,8 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table
TypeVar tv{ttv}; TypeVar tv{ttv};
ToStringOptions o{/* exhaustive= */ false, /* useLineBreaks= */ false, /* functionTypeArguments= */ false, /* hideTableKind= */ false, 40}; ToStringOptions o;
o.maxTableLength = 40;
CHECK_EQ(toString(&tv, o), "{| a: number, b: number, c: number, d: number, e: number, ... 5 more ... |}"); CHECK_EQ(toString(&tv, o), "{| a: number, b: number, c: number, d: number, e: number, ... 5 more ... |}");
} }
@ -395,7 +396,7 @@ local function target(callback: nil) return callback(4, "hello") end
)"); )");
LUAU_REQUIRE_ERRORS(result); LUAU_REQUIRE_ERRORS(result);
CHECK_EQ(toString(requireType("target")), "(nil) -> (*unknown*)"); CHECK_EQ("(nil) -> (*unknown*)", toString(requireType("target")));
} }
TEST_CASE_FIXTURE(Fixture, "toStringGenericPack") TEST_CASE_FIXTURE(Fixture, "toStringGenericPack")
@ -469,4 +470,110 @@ TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param")
CHECK_EQ(toString(tableTy), "Table<Table>"); CHECK_EQ(toString(tableTy), "Table<Table>");
} }
TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id")
{
CheckResult result = check(R"(
local function id(x) return x end
)");
TypeId ty = requireType("id");
const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(ty));
CHECK_EQ("id<a>(x: a): a", toStringNamedFunction("id", *ftv));
}
TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map")
{
CheckResult result = check(R"(
local function map(arr, fn)
local t = {}
for i = 0, #arr do
t[i] = fn(arr[i])
end
return t
end
)");
TypeId ty = requireType("map");
const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(ty));
CHECK_EQ("map<a, b>(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv));
}
TEST_CASE("toStringNamedFunction_unit_f")
{
TypePackVar empty{TypePack{}};
FunctionTypeVar ftv{&empty, &empty, {}, false};
CHECK_EQ("f(): ()", toStringNamedFunction("f", ftv));
}
TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics")
{
CheckResult result = check(R"(
local function f<a, b...>(x: a, ...): (a, a, b...)
return x, x, ...
end
)");
TypeId ty = requireType("f");
auto ftv = get<FunctionTypeVar>(follow(ty));
CHECK_EQ("f<a, b...>(x: a, ...: any): (a, a, b...)", toStringNamedFunction("f", *ftv));
}
TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2")
{
CheckResult result = check(R"(
local function f(): ...number
return 1, 2, 3
end
)");
TypeId ty = requireType("f");
auto ftv = get<FunctionTypeVar>(follow(ty));
CHECK_EQ("f(): ...number", toStringNamedFunction("f", *ftv));
}
TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3")
{
CheckResult result = check(R"(
local function f(): (string, ...number)
return 'a', 1, 2, 3
end
)");
TypeId ty = requireType("f");
auto ftv = get<FunctionTypeVar>(follow(ty));
CHECK_EQ("f(): (string, ...number)", toStringNamedFunction("f", *ftv));
}
TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_argnames")
{
CheckResult result = check(R"(
local f: (number, y: number) -> number
)");
TypeId ty = requireType("f");
auto ftv = get<FunctionTypeVar>(follow(ty));
CHECK_EQ("f(_: number, y: number): number", toStringNamedFunction("f", *ftv));
}
TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params")
{
CheckResult result = check(R"(
local function f<T>(x: T, g: <U>(T) -> U)): ()
end
)");
TypeId ty = requireType("f");
auto ftv = get<FunctionTypeVar>(follow(ty));
ToStringOptions opts;
opts.hideNamedFunctionTypeParameters = true;
CHECK_EQ("f(x: T, g: <U>(T) -> U): ()", toStringNamedFunction("f", *ftv, opts));
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -109,7 +109,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_typ
CheckResult result = check(R"( CheckResult result = check(R"(
type A = number type A = number
type A = string -- Redefinition of type 'A', previously defined at line 1 type A = string -- Redefinition of type 'A', previously defined at line 1
local foo: string = 1 -- No "Type 'number' could not be converted into 'string'" local foo: string = 1 -- "Type 'number' could not be converted into 'string'"
)"); )");
LUAU_REQUIRE_ERROR_COUNT(2, result); LUAU_REQUIRE_ERROR_COUNT(2, result);

View file

@ -381,6 +381,8 @@ TEST_CASE_FIXTURE(Fixture, "typeof_expr")
TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop")
{ {
ScopedFastFlag sff{"LuauErrorRecoveryType", true};
CheckResult result = check(R"( CheckResult result = check(R"(
type A = B type A = B
type B = A type B = A
@ -390,7 +392,7 @@ TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop")
)"); )");
TypeId fType = requireType("aa"); TypeId fType = requireType("aa");
const ErrorTypeVar* ftv = get<ErrorTypeVar>(follow(fType)); const AnyTypeVar* ftv = get<AnyTypeVar>(follow(fType));
REQUIRE(ftv != nullptr); REQUIRE(ftv != nullptr);
REQUIRE(!result.errors.empty()); REQUIRE(!result.errors.empty());
} }

View file

@ -289,7 +289,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods")
end end
)"); )");
// TODO: Should typecheck but currently errors CLI-39916 // TODO: Should typecheck but currently errors CLI-39916
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERRORS(result);
} }
TEST_CASE_FIXTURE(Fixture, "infer_generic_property") TEST_CASE_FIXTURE(Fixture, "infer_generic_property")
@ -352,7 +352,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_generic_types")
-- so this assignment should fail -- so this assignment should fail
local b: boolean = f(true) local b: boolean = f(true)
)"); )");
LUAU_REQUIRE_ERROR_COUNT(2, result); LUAU_REQUIRE_ERRORS(result);
} }
TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types")
@ -368,7 +368,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types")
local y: number = id(37) local y: number = id(37)
end end
)"); )");
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERRORS(result);
} }
TEST_CASE_FIXTURE(Fixture, "dont_substitute_bound_types") TEST_CASE_FIXTURE(Fixture, "dont_substitute_bound_types")

View file

@ -704,6 +704,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector")
CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0]));
else else
CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0]));
CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance"
if (FFlag::LuauQuantifyInPlace2) if (FFlag::LuauQuantifyInPlace2)

View file

@ -0,0 +1,377 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Fixture.h"
#include "doctest.h"
#include "Luau/BuiltinDefinitions.h"
using namespace Luau;
TEST_SUITE_BEGIN("TypeSingletons");
TEST_CASE_FIXTURE(Fixture, "bool_singletons")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
local a: true = true
local b: false = false
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "string_singletons")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
local a: "foo" = "foo"
local b: "bar" = "bar"
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "bool_singletons_mismatch")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
local a: true = false
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type 'false' could not be converted into 'true'", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "string_singletons_mismatch")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
local a: "foo" = "bar"
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type '\"bar\"' could not be converted into '\"foo\"'", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "string_singletons_escape_chars")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
local a: "\n" = "\000\r"
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(R"(Type '"\000\r"' could not be converted into '"\n"')", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "bool_singleton_subtype")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
local a: true = true
local b: boolean = a
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
local a: "foo" = "foo"
local b: string = a
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
function f(a: true, b: "foo") end
f(true, "foo")
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons_mismatch")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
function f(a: true, b: "foo") end
f(true, "bar")
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type '\"bar\"' could not be converted into '\"foo\"'", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
function f(a, b) end
local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any)
g(true, "foo")
g(false, 37)
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons_mismatch")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
function f(a, b) end
local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any)
g(true, 37)
)");
LUAU_REQUIRE_ERROR_COUNT(2, result);
CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0]));
CHECK_EQ("Other overloads are also not viable: (false, number) -> ()", toString(result.errors[1]));
}
TEST_CASE_FIXTURE(Fixture, "enums_using_singletons")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
type MyEnum = "foo" | "bar" | "baz"
local a : MyEnum = "foo"
local b : MyEnum = "bar"
local c : MyEnum = "baz"
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
{"LuauExtendedTypeMismatchError", true},
};
CheckResult result = check(R"(
type MyEnum = "foo" | "bar" | "baz"
local a : MyEnum = "bang"
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type '\"bang\"' could not be converted into '\"bar\" | \"baz\" | \"foo\"'; none of the union options are compatible",
toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_subtyping")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
type MyEnum1 = "foo" | "bar"
type MyEnum2 = MyEnum1 | "baz"
local a : MyEnum1 = "foo"
local b : MyEnum2 = a
local c : string = b
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
{"LuauExpectedTypesOfProperties", true},
};
CheckResult result = check(R"(
type Dog = { tag: "Dog", howls: boolean }
type Cat = { tag: "Cat", meows: boolean }
type Animal = Dog | Cat
local a : Dog = { tag = "Dog", howls = true }
local b : Animal = { tag = "Cat", meows = true }
local c : Animal = a
c = b
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons_mismatch")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
type Dog = { tag: "Dog", howls: boolean }
type Cat = { tag: "Cat", meows: boolean }
type Animal = Dog | Cat
local a : Animal = { tag = "Cat", howls = true }
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "tagged_unions_immutable_tag")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
type Dog = { tag: "Dog", howls: boolean }
type Cat = { tag: "Cat", meows: boolean }
type Animal = Dog | Cat
local a : Animal = { tag = "Cat", meows = true }
a.tag = "Dog"
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings")
{
ScopedFastFlag sffs[] = {
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
--!strict
type T = {
["foo"] : number,
["$$bar"] : string,
baz : boolean
}
local t: T = {
["foo"] = 37,
["$$bar"] = "hi",
baz = true
}
local a: number = t.foo
local b: string = t["$$bar"]
local c: boolean = t.baz
t.foo = 5
t["$$bar"] = "lo"
t.baz = false
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings_mismatch")
{
ScopedFastFlag sffs[] = {
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
--!strict
type T = {
["$$bar"] : string,
}
local t: T = {
["$$bar"] = "hi",
}
t["$$bar"] = 5
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer")
{
ScopedFastFlag sffs[] = {
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
--!strict
type S = "bar"
type T = {
[("foo")] : number,
[S] : string,
}
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Syntax error: Cannot have more than one table indexer", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes")
{
ScopedFastFlag sffs[] = {
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
--!strict
local x: { ["<>"] : number }
x = { ["\n"] = 5 }
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(R"(Table type '{| ["\n"]: number |}' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')",
toString(result.errors[0]));
}
TEST_SUITE_END();

View file

@ -362,7 +362,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error")
CHECK_EQ(2, result.errors.size()); CHECK_EQ(2, result.errors.size());
TypeId p = requireType("p"); TypeId p = requireType("p");
CHECK_EQ(*p, *typeChecker.errorType); CHECK_EQ("*unknown*", toString(p));
} }
TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function")
@ -480,7 +480,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(typeChecker.anyType, requireType("a")); CHECK_EQ("any", toString(requireType("a")));
} }
TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any")
@ -496,7 +496,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(typeChecker.anyType, requireType("a")); CHECK_EQ("any", toString(requireType("a")));
} }
TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2")
@ -512,7 +512,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(typeChecker.anyType, requireType("a")); CHECK_EQ("any", toString(requireType("a")));
} }
TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error")
@ -526,7 +526,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error")
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(typeChecker.errorType, requireType("a")); CHECK_EQ("*unknown*", toString(requireType("a")));
} }
TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2")
@ -542,7 +542,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2")
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(typeChecker.errorType, requireType("a")); CHECK_EQ("*unknown*", toString(requireType("a")));
} }
TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator") TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator")
@ -673,7 +673,7 @@ TEST_CASE_FIXTURE(Fixture, "string_index")
REQUIRE(nat); REQUIRE(nat);
CHECK_EQ("string", toString(nat->ty)); CHECK_EQ("string", toString(nat->ty));
CHECK(get<ErrorTypeVar>(requireType("t"))); CHECK_EQ("*unknown*", toString(requireType("t")));
} }
TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error")
@ -1456,7 +1456,7 @@ TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export")
auto hootyType = requireType(bModule, "Hooty"); auto hootyType = requireType(bModule, "Hooty");
CHECK_MESSAGE(get<ErrorTypeVar>(follow(hootyType)) != nullptr, "Should be an error: " << toString(hootyType)); CHECK_EQ("*unknown*", toString(hootyType));
} }
TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property")
@ -2032,7 +2032,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_4")
CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[1]); CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[1]);
} }
TEST_CASE_FIXTURE(Fixture, "error_types_propagate") TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(
local err = (true).x local err = (true).x
@ -2049,10 +2049,10 @@ TEST_CASE_FIXTURE(Fixture, "error_types_propagate")
CHECK_EQ("boolean", toString(err->table)); CHECK_EQ("boolean", toString(err->table));
CHECK_EQ("x", err->key); CHECK_EQ("x", err->key);
CHECK(nullptr != get<ErrorTypeVar>(requireType("c"))); CHECK_EQ("*unknown*", toString(requireType("c")));
CHECK(nullptr != get<ErrorTypeVar>(requireType("d"))); CHECK_EQ("*unknown*", toString(requireType("d")));
CHECK(nullptr != get<ErrorTypeVar>(requireType("e"))); CHECK_EQ("*unknown*", toString(requireType("e")));
CHECK(nullptr != get<ErrorTypeVar>(requireType("f"))); CHECK_EQ("*unknown*", toString(requireType("f")));
} }
TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error")
@ -2068,7 +2068,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error")
CHECK_EQ("unknown", err->name); CHECK_EQ("unknown", err->name);
CHECK(nullptr != get<ErrorTypeVar>(requireType("a"))); CHECK_EQ("*unknown*", toString(requireType("a")));
} }
TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error")
@ -2077,9 +2077,7 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error")
local a = Utility.Create "Foo" {} local a = Utility.Create "Foo" {}
)"); )");
TypeId aType = requireType("a"); CHECK_EQ("*unknown*", toString(requireType("a")));
REQUIRE_MESSAGE(nullptr != get<ErrorTypeVar>(aType), "Not an error: " << toString(aType));
} }
TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable")
@ -2146,6 +2144,8 @@ TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops")
TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection") TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection")
{ {
ScopedFastFlag sff{"LuauErrorRecoveryType", true};
CheckResult result = check(R"( CheckResult result = check(R"(
--!strict --!strict
local Vec3 = {} local Vec3 = {}
@ -2175,11 +2175,13 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio
CHECK_EQ("Vec3", toString(requireType("b"))); CHECK_EQ("Vec3", toString(requireType("b")));
CHECK_EQ("Vec3", toString(requireType("c"))); CHECK_EQ("Vec3", toString(requireType("c")));
CHECK_EQ("Vec3", toString(requireType("d"))); CHECK_EQ("Vec3", toString(requireType("d")));
CHECK(get<ErrorTypeVar>(requireType("e"))); CHECK_EQ("Vec3", toString(requireType("e")));
} }
TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs")
{ {
ScopedFastFlag sff{"LuauErrorRecoveryType", true};
CheckResult result = check(R"( CheckResult result = check(R"(
--!strict --!strict
local Vec3 = {} local Vec3 = {}
@ -2209,7 +2211,7 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio
CHECK_EQ("Vec3", toString(requireType("b"))); CHECK_EQ("Vec3", toString(requireType("b")));
CHECK_EQ("Vec3", toString(requireType("c"))); CHECK_EQ("Vec3", toString(requireType("c")));
CHECK_EQ("Vec3", toString(requireType("d"))); CHECK_EQ("Vec3", toString(requireType("d")));
CHECK(get<ErrorTypeVar>(requireType("e"))); CHECK_EQ("Vec3", toString(requireType("e")));
} }
TEST_CASE_FIXTURE(Fixture, "compare_numbers") TEST_CASE_FIXTURE(Fixture, "compare_numbers")
@ -2901,6 +2903,8 @@ end
TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber")
{ {
ScopedFastFlag sff{"LuauErrorRecoveryType", true};
CheckResult result = check(R"( CheckResult result = check(R"(
local x: number = 9999 local x: number = 9999
function x:y(z: number) function x:y(z: number)
@ -2908,7 +2912,7 @@ function x:y(z: number)
end end
)"); )");
LUAU_REQUIRE_ERROR_COUNT(3, result); LUAU_REQUIRE_ERROR_COUNT(2, result);
} }
TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfError") TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfError")
@ -2920,7 +2924,7 @@ function x:y(z: number)
end end
)"); )");
LUAU_REQUIRE_ERROR_COUNT(2, result); LUAU_REQUIRE_ERRORS(result);
} }
TEST_CASE_FIXTURE(Fixture, "CallOrOfFunctions") TEST_CASE_FIXTURE(Fixture, "CallOrOfFunctions")
@ -3799,7 +3803,7 @@ TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign")
print(a) print(a)
)"); )");
LUAU_REQUIRE_ERROR_COUNT(2, result); LUAU_REQUIRE_ERRORS(result);
CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'");
} }
@ -4215,7 +4219,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying")
std::optional<TypeFun> t0 = getMainModule()->getModuleScope()->lookupType("t0"); std::optional<TypeFun> t0 = getMainModule()->getModuleScope()->lookupType("t0");
REQUIRE(t0); REQUIRE(t0);
CHECK(get<ErrorTypeVar>(t0->type)); CHECK_EQ("*unknown*", toString(t0->type));
auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) {
return get<OccursCheckFailed>(err); return get<OccursCheckFailed>(err);
@ -4238,7 +4242,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional")
std::optional<TypeFun> t0 = getMainModule()->getModuleScope()->lookupType("t0"); std::optional<TypeFun> t0 = getMainModule()->getModuleScope()->lookupType("t0");
REQUIRE(t0); REQUIRE(t0);
CHECK(get<ErrorTypeVar>(t0->type)); CHECK_EQ("*unknown*", toString(t0->type));
auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) {
return get<OccursCheckFailed>(err); return get<OccursCheckFailed>(err);
@ -4394,6 +4398,25 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload")
CHECK_EQ(toString(*it), "(number) -> number"); CHECK_EQ(toString(*it), "(number) -> number");
} }
TEST_CASE_FIXTURE(Fixture, "return_type_by_overload")
{
ScopedFastFlag sff{"LuauErrorRecoveryType", true};
CheckResult result = check(R"(
type Overload = ((string) -> string) & ((number, number) -> number)
local abc: Overload
local x = abc(true)
local y = abc(true,true)
local z = abc(true,true,true)
)");
LUAU_REQUIRE_ERRORS(result);
CHECK_EQ("string", toString(requireType("x")));
CHECK_EQ("number", toString(requireType("y")));
// Should this be string|number?
CHECK_EQ("string", toString(requireType("z")));
}
TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments")
{ {
// Simple direct arg to arg propagation // Simple direct arg to arg propagation
@ -4740,4 +4763,20 @@ TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions3")
} }
} }
TEST_CASE_FIXTURE(Fixture, "type_error_addition")
{
CheckResult result = check(R"(
--!strict
local foo = makesandwich()
local bar = foo.nutrition + 100
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
// We should definitely get this error
CHECK_EQ("Unknown global 'makesandwich'", toString(result.errors[0]));
// We get this error if makesandwich() returns a free type
// CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'foo'", toString(result.errors[1]));
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -121,9 +121,26 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeId bType = requireType("b"); CHECK_EQ("a", toString(requireType("a")));
CHECK_EQ("*unknown*", toString(requireType("b")));
}
CHECK_MESSAGE(get<ErrorTypeVar>(bType), "Should be an error: " << toString(bType)); TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained")
{
ScopedFastFlag sff{"LuauErrorRecoveryType", true};
CheckResult result = check(R"(
function f(arg: number) return arg end
local a
local b
local c = f(a, b)
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("a", toString(requireType("a")));
CHECK_EQ("*unknown*", toString(requireType("b")));
CHECK_EQ("number", toString(requireType("c")));
} }
TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails") TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails")
@ -167,15 +184,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_tails_respect_progress")
CHECK(state.errors.empty()); CHECK(state.errors.empty());
} }
TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_variadic_pack_with_error_should_work")
{
TypePackId variadicPack = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.numberType}});
TypePackId errorPack = arena.addTypePack(TypePack{{typeChecker.numberType}, arena.addTypePack(TypePackVar{Unifiable::Error{}})});
state.tryUnify(variadicPack, errorPack);
REQUIRE_EQ(0, state.errors.size());
}
TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(

View file

@ -200,8 +200,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property")
CHECK_EQ(mup->missing[0], *bTy); CHECK_EQ(mup->missing[0], *bTy);
CHECK_EQ(mup->key, "x"); CHECK_EQ(mup->key, "x");
TypeId r = requireType("r"); CHECK_EQ("*unknown*", toString(requireType("r")));
CHECK_MESSAGE(get<ErrorTypeVar>(r), "Expected error, got " << toString(r));
} }
TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any") TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any")
@ -283,7 +282,7 @@ local c = b:foo(1, 2)
CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0]));
} }
TEST_CASE_FIXTURE(Fixture, "optional_union_follow") TEST_CASE_FIXTURE(UnfrozenFixture, "optional_union_follow")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(
local y: number? = 2 local y: number? = 2

View file

@ -876,4 +876,4 @@ assert(concat(typeof(5), typeof(nil), typeof({}), typeof(newproxy())) == "number
testgetfenv() -- DONT MOVE THIS LINE testgetfenv() -- DONT MOVE THIS LINE
return'OK' return 'OK'

View file

@ -419,11 +419,5 @@ co = coroutine.create(function ()
return loadstring("return a")() return loadstring("return a")()
end) end)
a = {a = 15}
-- debug.setfenv(co, a)
-- assert(debug.getfenv(co) == a)
-- assert(select(2, coroutine.resume(co)) == a)
-- assert(select(2, coroutine.resume(co)) == a.a)
return 'OK'
return'OK'

View file

@ -237,4 +237,4 @@ repeat
i = i+1 i = i+1
until i==c until i==c
return'OK' return 'OK'

View file

@ -319,4 +319,58 @@ for i=0,30 do
assert(#T2 == 1 or T2[#T2] == 42) assert(#T2 == 1 or T2[#T2] == 42)
end end
return'OK' -- test coroutine.close
do
-- ok to close a dead coroutine
local co = coroutine.create(type)
assert(coroutine.resume(co, "testing 'coroutine.close'"))
assert(coroutine.status(co) == "dead")
local st, msg = coroutine.close(co)
assert(st and msg == nil)
-- also ok to close it again
st, msg = coroutine.close(co)
assert(st and msg == nil)
-- cannot close the running coroutine
coroutine.wrap(function()
local st, msg = pcall(coroutine.close, coroutine.running())
assert(not st and string.find(msg, "running"))
end)()
-- cannot close a "normal" coroutine
coroutine.wrap(function()
local co = coroutine.running()
coroutine.wrap(function ()
local st, msg = pcall(coroutine.close, co)
assert(not st and string.find(msg, "normal"))
end)()
end)()
-- closing a coroutine after an error
local co = coroutine.create(error)
local obj = {42}
local st, msg = coroutine.resume(co, obj)
assert(not st and msg == obj)
st, msg = coroutine.close(co)
assert(not st and msg == obj)
-- after closing, no more errors
st, msg = coroutine.close(co)
assert(st and msg == nil)
-- closing a coroutine that has outstanding upvalues
local f
local co = coroutine.create(function()
local a = 42
f = function() return a end
coroutine.yield()
a = 20
end)
coroutine.resume(co)
assert(f() == 42)
st, msg = coroutine.close(co)
assert(st and msg == nil)
assert(f() == 42)
end
return 'OK'

View file

@ -74,4 +74,4 @@ assert(os.difftime(t1,t2) == 60*2-19)
assert(os.time({ year = 1970, day = 1, month = 1, hour = 0}) == 0) assert(os.time({ year = 1970, day = 1, month = 1, hour = 0}) == 0)
return'OK' return 'OK'

View file

@ -98,4 +98,4 @@ assert(quuz(function(...) end) == "0 true")
assert(quuz(function(a, b) end) == "2 false") assert(quuz(function(a, b) end) == "2 false")
assert(quuz(function(a, b, ...) end) == "2 true") assert(quuz(function(a, b, ...) end) == "2 true")
return'OK' return 'OK'

View file

@ -45,4 +45,13 @@ breakpoint(38) -- break inside corobad()
local co = coroutine.create(corobad) local co = coroutine.create(corobad)
assert(coroutine.resume(co) == false) -- this breaks, resumes and dies! assert(coroutine.resume(co) == false) -- this breaks, resumes and dies!
function bar()
print("in bar")
end
breakpoint(49)
breakpoint(49, false) -- validate that disabling breakpoints works
bar()
return 'OK' return 'OK'

View file

@ -34,15 +34,15 @@ assert(doit("error('hi', 0)") == 'hi')
assert(doit("unpack({}, 1, n=2^30)")) assert(doit("unpack({}, 1, n=2^30)"))
assert(doit("a=math.sin()")) assert(doit("a=math.sin()"))
assert(not doit("tostring(1)") and doit("tostring()")) assert(not doit("tostring(1)") and doit("tostring()"))
assert(doit"tonumber()") assert(doit("tonumber()"))
assert(doit"repeat until 1; a") assert(doit("repeat until 1; a"))
checksyntax("break label", "", "label", 1) checksyntax("break label", "", "label", 1)
assert(doit";") assert(doit(";"))
assert(doit"a=1;;") assert(doit("a=1;;"))
assert(doit"return;;") assert(doit("return;;"))
assert(doit"assert(false)") assert(doit("assert(false)"))
assert(doit"assert(nil)") assert(doit("assert(nil)"))
assert(doit"a=math.sin\n(3)") assert(doit("a=math.sin\n(3)"))
assert(doit("function a (... , ...) end")) assert(doit("function a (... , ...) end"))
assert(doit("function a (, ...) end")) assert(doit("function a (, ...) end"))
@ -59,7 +59,7 @@ checkmessage("a=1; local a,bbbb=2,3; a = math.sin(1) and bbbb(3)",
"local 'bbbb'") "local 'bbbb'")
checkmessage("a={}; do local a=1 end a:bbbb(3)", "method 'bbbb'") checkmessage("a={}; do local a=1 end a:bbbb(3)", "method 'bbbb'")
checkmessage("local a={}; a.bbbb(3)", "field 'bbbb'") checkmessage("local a={}; a.bbbb(3)", "field 'bbbb'")
assert(not string.find(doit"a={13}; local bbbb=1; a[bbbb](3)", "'bbbb'")) assert(not string.find(doit("a={13}; local bbbb=1; a[bbbb](3)"), "'bbbb'"))
checkmessage("a={13}; local bbbb=1; a[bbbb](3)", "number") checkmessage("a={13}; local bbbb=1; a[bbbb](3)", "number")
aaa = nil aaa = nil
@ -67,14 +67,14 @@ checkmessage("aaa.bbb:ddd(9)", "global 'aaa'")
checkmessage("local aaa={bbb=1}; aaa.bbb:ddd(9)", "field 'bbb'") checkmessage("local aaa={bbb=1}; aaa.bbb:ddd(9)", "field 'bbb'")
checkmessage("local aaa={bbb={}}; aaa.bbb:ddd(9)", "method 'ddd'") checkmessage("local aaa={bbb={}}; aaa.bbb:ddd(9)", "method 'ddd'")
checkmessage("local a,b,c; (function () a = b+1 end)()", "upvalue 'b'") checkmessage("local a,b,c; (function () a = b+1 end)()", "upvalue 'b'")
assert(not doit"local aaa={bbb={ddd=next}}; aaa.bbb:ddd(nil)") assert(not doit("local aaa={bbb={ddd=next}}; aaa.bbb:ddd(nil)"))
checkmessage("b=1; local aaa='a'; x=aaa+b", "local 'aaa'") checkmessage("b=1; local aaa='a'; x=aaa+b", "local 'aaa'")
checkmessage("aaa={}; x=3/aaa", "global 'aaa'") checkmessage("aaa={}; x=3/aaa", "global 'aaa'")
checkmessage("aaa='2'; b=nil;x=aaa*b", "global 'b'") checkmessage("aaa='2'; b=nil;x=aaa*b", "global 'b'")
checkmessage("aaa={}; x=-aaa", "global 'aaa'") checkmessage("aaa={}; x=-aaa", "global 'aaa'")
assert(not string.find(doit"aaa={}; x=(aaa or aaa)+(aaa and aaa)", "'aaa'")) assert(not string.find(doit("aaa={}; x=(aaa or aaa)+(aaa and aaa)"), "'aaa'"))
assert(not string.find(doit"aaa={}; (aaa or aaa)()", "'aaa'")) assert(not string.find(doit("aaa={}; (aaa or aaa)()"), "'aaa'"))
checkmessage([[aaa=9 checkmessage([[aaa=9
repeat until 3==3 repeat until 3==3
@ -122,10 +122,10 @@ function lineerror (s)
return line and line+0 return line and line+0
end end
assert(lineerror"local a\n for i=1,'a' do \n print(i) \n end" == 2) assert(lineerror("local a\n for i=1,'a' do \n print(i) \n end") == 2)
-- assert(lineerror"\n local a \n for k,v in 3 \n do \n print(k) \n end" == 3) -- assert(lineerror("\n local a \n for k,v in 3 \n do \n print(k) \n end") == 3)
-- assert(lineerror"\n\n for k,v in \n 3 \n do \n print(k) \n end" == 4) -- assert(lineerror("\n\n for k,v in \n 3 \n do \n print(k) \n end") == 4)
assert(lineerror"function a.x.y ()\na=a+1\nend" == 1) assert(lineerror("function a.x.y ()\na=a+1\nend") == 1)
local p = [[ local p = [[
function g() f() end function g() f() end

View file

@ -77,7 +77,7 @@ end
local function dosteps (siz) local function dosteps (siz)
collectgarbage() collectgarbage()
collectgarbage"stop" collectgarbage("stop")
local a = {} local a = {}
for i=1,100 do a[i] = {{}}; local b = {} end for i=1,100 do a[i] = {{}}; local b = {} end
local x = gcinfo() local x = gcinfo()
@ -99,11 +99,11 @@ assert(dosteps(10000) == 1)
do do
local x = gcinfo() local x = gcinfo()
collectgarbage() collectgarbage()
collectgarbage"stop" collectgarbage("stop")
repeat repeat
local a = {} local a = {}
until gcinfo() > 1000 until gcinfo() > 1000
collectgarbage"restart" collectgarbage("restart")
repeat repeat
local a = {} local a = {}
until gcinfo() < 1000 until gcinfo() < 1000
@ -123,7 +123,7 @@ for n in pairs(b) do
end end
b = nil b = nil
collectgarbage() collectgarbage()
for n in pairs(a) do error'cannot be here' end for n in pairs(a) do error("cannot be here") end
for i=1,lim do a[i] = i end for i=1,lim do a[i] = i end
for i=1,lim do assert(a[i] == i) end for i=1,lim do assert(a[i] == i) end

View file

@ -368,9 +368,9 @@ assert(next(a,nil) == 1000 and next(a,1000) == nil)
assert(next({}) == nil) assert(next({}) == nil)
assert(next({}, nil) == nil) assert(next({}, nil) == nil)
for a,b in pairs{} do error"not here" end for a,b in pairs{} do error("not here") end
for i=1,0 do error'not here' end for i=1,0 do error("not here") end
for i=0,1,-1 do error'not here' end for i=0,1,-1 do error("not here") end
a = nil; for i=1,1 do assert(not a); a=1 end; assert(a) a = nil; for i=1,1 do assert(not a); a=1 end; assert(a)
a = nil; for i=1,1,-1 do assert(not a); a=1 end; assert(a) a = nil; for i=1,1,-1 do assert(not a); a=1 end; assert(a)

View file

@ -144,4 +144,4 @@ coroutine.resume(co)
resumeerror(co, "fail") resumeerror(co, "fail")
checkresults({ true, false, "fail" }, coroutine.resume(co)) checkresults({ true, false, "fail" }, coroutine.resume(co))
return'OK' return 'OK'

View file

@ -205,4 +205,4 @@ for p, c in string.gmatch(x, "()(" .. utf8.charpattern .. ")") do
end end
end end
return'OK' return 'OK'