Sync to upstream/release/505 (#216)

- Improve error recovery during type checking
- Initial (not fully complete) implementation for singleton types (RFC RFC: Singleton types #37)
- Implement a C-friendly interface for compiler (luacode.h)
- Remove C++ features from lua.h (removed default arguments from luau_load and lua_pushcfunction)
- Fix lua_breakpoint behavior when enabled=false
- Implement coroutine.close (RFC RFC: coroutine.close #88)

Note, this introduces small breaking changes in lua.h:

- luau_load env argument is now required, pass an extra 0
- lua_pushcfunction now must be called with 3 arguments; if you were calling it with 2 arguments, pass an extra NULL; if you were calling it with 4, use lua_pushcclosure.

These changes are necessary to make sure lua.h can be used from pure C - the future release will make it possible by adding an option to luaconf.h to change function name mangling to be C-compatible. We don't anticipate breaking the FFI interface in the future, but this change was necessary to restore C compatibility.

Closes #121
Fixes #213
This commit is contained in:
Arseny Kapoulkine 2021-11-19 08:10:07 -08:00 committed by GitHub
parent 4265e58ad1
commit 3f1508c83a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
62 changed files with 1785 additions and 566 deletions

View file

@ -12,10 +12,17 @@ namespace Luau
struct FunctionDocumentation;
struct TableDocumentation;
struct OverloadedFunctionDocumentation;
struct BasicDocumentation;
using Documentation = Luau::Variant<std::string, FunctionDocumentation, TableDocumentation, OverloadedFunctionDocumentation>;
using Documentation = Luau::Variant<BasicDocumentation, FunctionDocumentation, TableDocumentation, OverloadedFunctionDocumentation>;
using DocumentationSymbol = std::string;
struct BasicDocumentation
{
std::string documentation;
std::string learnMoreLink;
};
struct FunctionParameterDocumentation
{
std::string name;
@ -29,6 +36,7 @@ struct FunctionDocumentation
std::string documentation;
std::vector<FunctionParameterDocumentation> parameters;
std::vector<DocumentationSymbol> returns;
std::string learnMoreLink;
};
struct OverloadedFunctionDocumentation
@ -43,6 +51,7 @@ struct TableDocumentation
{
std::string documentation;
Luau::DenseHashMap<std::string, DocumentationSymbol> keys;
std::string learnMoreLink;
};
using DocumentationDatabase = Luau::DenseHashMap<DocumentationSymbol, Documentation>;

View file

@ -27,6 +27,7 @@ struct ToStringOptions
bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable.
bool functionTypeArguments = false; // If true, output function type argument names when they are available
bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}'
bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level.
size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars
size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength);
std::optional<ToStringNameMap> nameMap;
@ -64,6 +65,8 @@ inline std::string toString(TypePackId ty)
std::string toString(const TypeVar& tv, const ToStringOptions& opts = {});
std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {});
std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts = {});
// It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class
// These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression
void dump(TypeId ty);

View file

@ -175,10 +175,10 @@ struct TypeChecker
std::vector<std::optional<TypeId>> getExpectedTypesForCall(const std::vector<TypeId>& overloads, size_t argumentCount, bool selfCall);
std::optional<ExprResult<TypePackId>> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack,
TypePackId argPack, TypePack* args, const std::vector<Location>& argLocations, const ExprResult<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<OverloadErrorEntry>& errors);
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<TypeId>& overloadsThatDont, std::vector<OverloadErrorEntry>& errors);
bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector<Location>& argLocations,
const std::vector<OverloadErrorEntry>& errors);
ExprResult<TypePackId> reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack,
void reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack,
const std::vector<Location>& argLocations, const std::vector<TypeId>& overloads, const std::vector<TypeId>& overloadsThatMatchArgCount,
const std::vector<OverloadErrorEntry>& errors);
@ -282,6 +282,14 @@ public:
// Wrapper for merge(l, r, toUnion) but without the lambda junk.
void merge(RefinementMap& l, const RefinementMap& r);
// Produce an "emergency backup type" for recovery from type errors.
// This comes in two flavours, depening on whether or not we can make a good guess
// for an error recovery type.
TypeId errorRecoveryType(TypeId guess);
TypePackId errorRecoveryTypePack(TypePackId guess);
TypeId errorRecoveryType(const ScopePtr& scope);
TypePackId errorRecoveryTypePack(const ScopePtr& scope);
private:
void prepareErrorsForDisplay(ErrorVec& errVec);
void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data);
@ -297,6 +305,10 @@ private:
TypeId freshType(const ScopePtr& scope);
TypeId freshType(TypeLevel level);
// Produce a new singleton type var.
TypeId singletonType(bool value);
TypeId singletonType(std::string value);
// Returns nullopt if the predicate filters down the TypeId to 0 options.
std::optional<TypeId> filterMap(TypeId type, TypeIdPredicate predicate);
@ -330,8 +342,8 @@ private:
const std::vector<TypePackId>& typePackParams, const Location& location);
// Note: `scope` must be a fresh scope.
std::pair<std::vector<TypeId>, std::vector<TypePackId>> createGenericTypes(
const ScopePtr& scope, std::optional<TypeLevel> levelOpt, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames);
std::pair<std::vector<TypeId>, std::vector<TypePackId>> createGenericTypes(const ScopePtr& scope, std::optional<TypeLevel> levelOpt,
const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames);
public:
ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense);
@ -347,7 +359,6 @@ private:
void resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
void DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
bool isNonstrictMode() const;
@ -387,12 +398,9 @@ public:
const TypeId booleanType;
const TypeId threadType;
const TypeId anyType;
const TypeId errorType;
const TypeId optionalNumberType;
const TypePackId anyTypePack;
const TypePackId errorTypePack;
private:
int checkRecursionCount = 0;

View file

@ -108,6 +108,79 @@ struct PrimitiveTypeVar
}
};
// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md
// Types for true and false
struct BoolSingleton
{
bool value;
bool operator==(const BoolSingleton& rhs) const
{
return value == rhs.value;
}
bool operator!=(const BoolSingleton& rhs) const
{
return !(*this == rhs);
}
};
// Types for "foo", "bar" etc.
struct StringSingleton
{
std::string value;
bool operator==(const StringSingleton& rhs) const
{
return value == rhs.value;
}
bool operator!=(const StringSingleton& rhs) const
{
return !(*this == rhs);
}
};
// No type for float singletons, partly because === isn't any equalivalence on floats
// (NaN != NaN).
using SingletonVariant = Luau::Variant<BoolSingleton, StringSingleton>;
struct SingletonTypeVar
{
explicit SingletonTypeVar(const SingletonVariant& variant)
: variant(variant)
{
}
explicit SingletonTypeVar(SingletonVariant&& variant)
: variant(std::move(variant))
{
}
// Default operator== is C++20.
bool operator==(const SingletonTypeVar& rhs) const
{
return variant == rhs.variant;
}
bool operator!=(const SingletonTypeVar& rhs) const
{
return !(*this == rhs);
}
SingletonVariant variant;
};
template<typename T>
const T* get(const SingletonTypeVar* stv)
{
if (stv)
return get_if<T>(&stv->variant);
else
return nullptr;
}
struct FunctionArgument
{
Name name;
@ -332,8 +405,8 @@ struct LazyTypeVar
using ErrorTypeVar = Unifiable::Error;
using TypeVariant = Unifiable::Variant<TypeId, PrimitiveTypeVar, FunctionTypeVar, TableTypeVar, MetatableTypeVar, ClassTypeVar, AnyTypeVar,
UnionTypeVar, IntersectionTypeVar, LazyTypeVar>;
using TypeVariant = Unifiable::Variant<TypeId, PrimitiveTypeVar, SingletonTypeVar, FunctionTypeVar, TableTypeVar, MetatableTypeVar, ClassTypeVar,
AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar>;
struct TypeVar final
{
@ -410,6 +483,9 @@ bool isGeneric(const TypeId ty);
// Checks if a type may be instantiated to one containing generic type binders
bool maybeGeneric(const TypeId ty);
// Checks if a type is of the form T1|...|Tn where one of the Ti is a singleton
bool maybeSingleton(TypeId ty);
struct SingletonTypes
{
const TypeId nilType;
@ -418,16 +494,19 @@ struct SingletonTypes
const TypeId booleanType;
const TypeId threadType;
const TypeId anyType;
const TypeId errorType;
const TypeId optionalNumberType;
const TypePackId anyTypePack;
const TypePackId errorTypePack;
SingletonTypes();
SingletonTypes(const SingletonTypes&) = delete;
void operator=(const SingletonTypes&) = delete;
TypeId errorRecoveryType(TypeId guess);
TypePackId errorRecoveryTypePack(TypePackId guess);
TypeId errorRecoveryType();
TypePackId errorRecoveryTypePack();
private:
std::unique_ptr<struct TypeArena> arena;
TypeId makeStringMetatable();

View file

@ -105,6 +105,8 @@ private:
struct Error
{
// This constructor has to be public, since it's used in TypeVar and TypePack,
// but shouldn't be called directly. Please use errorRecoveryType() instead.
Error();
int index;

View file

@ -65,6 +65,7 @@ struct Unifier
private:
void tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false);
void tryUnifyPrimitives(TypeId superTy, TypeId subTy);
void tryUnifySingletons(TypeId superTy, TypeId subTy);
void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false);
void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false);
void DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false);

View file

@ -14,6 +14,7 @@
LUAU_FASTFLAGVARIABLE(ElseElseIfCompletionImprovements, false);
LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false);
static const std::unordered_set<std::string> kStatementStartingKeywords = {
"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"};
@ -198,11 +199,24 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
UnifierSharedState unifierState(&iceReporter);
Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState);
if (FFlag::LuauAutocompleteAvoidMutation)
{
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
expectedType = clone(expectedType, *typeArena, seenTypes, seenTypePacks, nullptr);
actualType = clone(actualType, *typeArena, seenTypes, seenTypePacks, nullptr);
auto errors = unifier.canUnify(expectedType, actualType);
return errors.empty();
}
else
{
unifier.tryUnify(expectedType, actualType);
bool ok = unifier.errors.empty();
unifier.log.rollback();
return ok;
}
};
auto expr = node->asExpr();
@ -1496,10 +1510,8 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName
if (!sourceModule)
return {};
TypeChecker& typeChecker =
(frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker);
ModulePtr module =
(frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName)
TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker);
ModulePtr module = (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName)
: frontend.moduleResolver.getModule(moduleName));
if (!module)
@ -1527,8 +1539,7 @@ OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view
sourceModule->mode = Mode::Strict;
sourceModule->commentLocations = std::move(result.commentLocations);
TypeChecker& typeChecker =
(frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker);
TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker);
ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict);

View file

@ -153,6 +153,7 @@ declare function gcinfo(): number
wrap: <A..., R...>((A...) -> R...) -> any,
yield: <A..., R...>(A...) -> R...,
isyieldable: () -> boolean,
close: (thread) -> (boolean, any?)
}
declare table: {

View file

@ -180,13 +180,13 @@ struct ErrorConverter
switch (e.context)
{
case CountMismatch::Return:
return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " +
std::to_string(e.actual) + " " + actualVerb + " returned here";
return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " +
actualVerb + " returned here";
case CountMismatch::Result:
// It is alright if right hand side produces more values than the
// left hand side accepts. In this context consider only the opposite case.
return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " +
std::to_string(e.actual) + " are required here";
return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + std::to_string(e.actual) +
" are required here";
case CountMismatch::Arg:
if (FFlag::LuauTypeAliasPacks)
return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual);

View file

@ -22,7 +22,6 @@ LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false)
LUAU_FASTFLAG(LuauTraceRequireLookupChild)
LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false)
LUAU_FASTFLAG(LuauNewRequireTrace2)
LUAU_FASTFLAGVARIABLE(LuauClearScopes, false)
namespace Luau
{
@ -458,7 +457,6 @@ CheckResult Frontend::check(const ModuleName& name)
module->astTypes.clear();
module->astExpectedTypes.clear();
module->astOriginalCallTypes.clear();
if (FFlag::LuauClearScopes)
module->scopes.resize(1);
}

View file

@ -161,6 +161,7 @@ struct TypeCloner
void operator()(const Unifiable::Bound<TypeId>& t);
void operator()(const Unifiable::Error& t);
void operator()(const PrimitiveTypeVar& t);
void operator()(const SingletonTypeVar& t);
void operator()(const FunctionTypeVar& t);
void operator()(const TableTypeVar& t);
void operator()(const MetatableTypeVar& t);
@ -199,7 +200,9 @@ struct TypePackCloner
if (encounteredFreeType)
*encounteredFreeType = true;
seenTypePacks[typePackId] = dest.addTypePack(TypePackVar{Unifiable::Error{}});
TypePackId err = singletonTypes.errorRecoveryTypePack(singletonTypes.anyTypePack);
TypePackId cloned = dest.addTypePack(*err);
seenTypePacks[typePackId] = cloned;
}
void operator()(const Unifiable::Generic& t)
@ -251,8 +254,9 @@ void TypeCloner::operator()(const Unifiable::Free& t)
{
if (encounteredFreeType)
*encounteredFreeType = true;
seenTypes[typeId] = dest.addType(ErrorTypeVar{});
TypeId err = singletonTypes.errorRecoveryType(singletonTypes.anyType);
TypeId cloned = dest.addType(*err);
seenTypes[typeId] = cloned;
}
void TypeCloner::operator()(const Unifiable::Generic& t)
@ -270,11 +274,17 @@ void TypeCloner::operator()(const Unifiable::Error& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const PrimitiveTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const SingletonTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const FunctionTypeVar& t)
{
TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf});

View file

@ -350,6 +350,23 @@ struct TypeVarStringifier
}
}
void operator()(TypeId, const SingletonTypeVar& stv)
{
if (const BoolSingleton* bs = Luau::get<BoolSingleton>(&stv))
state.emit(bs->value ? "true" : "false");
else if (const StringSingleton* ss = Luau::get<StringSingleton>(&stv))
{
state.emit("\"");
state.emit(escape(ss->value));
state.emit("\"");
}
else
{
LUAU_ASSERT(!"Unknown singleton type");
throw std::runtime_error("Unknown singleton type");
}
}
void operator()(TypeId, const FunctionTypeVar& ftv)
{
if (state.hasSeen(&ftv))
@ -359,6 +376,7 @@ struct TypeVarStringifier
return;
}
// We should not be respecting opts.hideNamedFunctionTypeParameters here.
if (ftv.generics.size() > 0 || ftv.genericPacks.size() > 0)
{
state.emit("<");
@ -514,7 +532,14 @@ struct TypeVarStringifier
break;
}
if (isIdentifier(name))
state.emit(name);
else
{
state.emit("[\"");
state.emit(escape(name));
state.emit("\"]");
}
state.emit(": ");
stringify(prop.type);
comma = true;
@ -1084,6 +1109,94 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts)
return toString(const_cast<TypePackId>(&tp), std::move(opts));
}
std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts)
{
std::string s = prefix;
auto toString_ = [&opts](TypeId ty) -> std::string {
ToStringResult res = toStringDetailed(ty, opts);
opts.nameMap = std::move(res.nameMap);
return res.name;
};
auto toStringPack_ = [&opts](TypePackId ty) -> std::string {
ToStringResult res = toStringDetailed(ty, opts);
opts.nameMap = std::move(res.nameMap);
return res.name;
};
if (!opts.hideNamedFunctionTypeParameters && (!ftv.generics.empty() || !ftv.genericPacks.empty()))
{
s += "<";
bool first = true;
for (TypeId g : ftv.generics)
{
if (!first)
s += ", ";
first = false;
s += toString_(g);
}
for (TypePackId gp : ftv.genericPacks)
{
if (!first)
s += ", ";
first = false;
s += toStringPack_(gp);
}
s += ">";
}
s += "(";
auto argPackIter = begin(ftv.argTypes);
auto argNameIter = ftv.argNames.begin();
bool first = true;
while (argPackIter != end(ftv.argTypes))
{
if (!first)
s += ", ";
first = false;
// argNames is guaranteed to be equal to argTypes iff argNames is not empty.
// We don't currently respect opts.functionTypeArguments. I don't think this function should.
if (!ftv.argNames.empty())
s += (*argNameIter ? (*argNameIter)->name : "_") + ": ";
s += toString_(*argPackIter);
++argPackIter;
if (!ftv.argNames.empty())
{
LUAU_ASSERT(argNameIter != ftv.argNames.end());
++argNameIter;
}
}
if (argPackIter.tail())
{
if (auto vtp = get<VariadicTypePack>(*argPackIter.tail()))
s += ", ...: " + toString_(vtp->ty);
else
s += ", ...: " + toStringPack_(*argPackIter.tail());
}
s += "): ";
size_t retSize = size(ftv.retType);
bool hasTail = !finite(ftv.retType);
if (retSize == 0 && !hasTail)
s += "()";
else if ((retSize == 0 && hasTail) || (retSize == 1 && !hasTail))
s += toStringPack_(ftv.retType);
else
s += "(" + toStringPack_(ftv.retType) + ")";
return s;
}
void dump(TypeId ty)
{
ToStringOptions opts;

View file

@ -14,61 +14,6 @@ LUAU_FASTFLAG(LuauTypeAliasPacks)
namespace
{
std::string escape(std::string_view s)
{
std::string r;
r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting
for (uint8_t c : s)
{
if (c >= ' ' && c != '\\' && c != '\'' && c != '\"')
r += c;
else
{
r += '\\';
switch (c)
{
case '\a':
r += 'a';
break;
case '\b':
r += 'b';
break;
case '\f':
r += 'f';
break;
case '\n':
r += 'n';
break;
case '\r':
r += 'r';
break;
case '\t':
r += 't';
break;
case '\v':
r += 'v';
break;
case '\'':
r += '\'';
break;
case '\"':
r += '\"';
break;
case '\\':
r += '\\';
break;
default:
Luau::formatAppend(r, "%03u", c);
}
}
}
return r;
}
bool isIdentifierStartChar(char c)
{
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c == '_';

View file

@ -96,6 +96,22 @@ public:
return nullptr;
}
}
AstType* operator()(const SingletonTypeVar& stv)
{
if (const BoolSingleton* bs = get<BoolSingleton>(&stv))
return allocator->alloc<AstTypeSingletonBool>(Location(), bs->value);
else if (const StringSingleton* ss = get<StringSingleton>(&stv))
{
AstArray<char> value;
value.data = const_cast<char*>(ss->value.c_str());
value.size = strlen(value.data);
return allocator->alloc<AstTypeSingletonString>(Location(), value);
}
else
return nullptr;
}
AstType* operator()(const AnyTypeVar&)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("any"));

View file

@ -36,6 +36,9 @@ LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes)
LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false)
LUAU_FASTFLAG(LuauNewRequireTrace2)
LUAU_FASTFLAG(LuauTypeAliasPacks)
LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false)
LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false)
LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false)
namespace Luau
{
@ -211,10 +214,8 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan
, booleanType(singletonTypes.booleanType)
, threadType(singletonTypes.threadType)
, anyType(singletonTypes.anyType)
, errorType(singletonTypes.errorType)
, optionalNumberType(singletonTypes.optionalNumberType)
, anyTypePack(singletonTypes.anyTypePack)
, errorTypePack(singletonTypes.errorTypePack)
{
globalScope = std::make_shared<Scope>(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}));
@ -484,7 +485,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std
TypeId type = bindings[name].type;
if (get<FreeTypeVar>(follow(type)))
{
*asMutable(type) = ErrorTypeVar{};
*asMutable(type) = *errorRecoveryType(anyType);
reportError(TypeError{typealias->location, OccursCheckFailed{}});
}
}
@ -719,7 +720,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign)
else if (auto tail = valueIter.tail())
{
if (get<Unifiable::Error>(*tail))
right = errorType;
right = errorRecoveryType(scope);
else if (auto vtp = get<VariadicTypePack>(*tail))
right = vtp->ty;
else if (get<Unifiable::Free>(*tail))
@ -961,7 +962,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin)
else if (get<Unifiable::Error>(callRetPack) || !first(callRetPack))
{
for (TypeId var : varTypes)
unify(var, errorType, forin.location);
unify(var, errorRecoveryType(scope), forin.location);
return check(loopScope, *forin.body);
}
@ -979,7 +980,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin)
const FunctionTypeVar* iterFunc = get<FunctionTypeVar>(iterTy);
if (!iterFunc)
{
TypeId varTy = get<AnyTypeVar>(iterTy) ? anyType : errorType;
TypeId varTy = get<AnyTypeVar>(iterTy) ? anyType : errorRecoveryType(loopScope);
for (TypeId var : varTypes)
unify(var, varTy, forin.location);
@ -1152,9 +1153,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}});
if (FFlag::LuauTypeAliasPacks)
bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorType};
bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)};
else
bindingsMap[name] = TypeFun{binding->typeParams, errorType};
bindingsMap[name] = TypeFun{binding->typeParams, errorRecoveryType(anyType)};
}
else
{
@ -1398,7 +1399,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr&
if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit)
{
reportErrorCodeTooComplex(expr.location);
return {errorType};
return {errorRecoveryType(scope)};
}
ExprResult<TypeId> result;
@ -1407,12 +1408,22 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr&
result = checkExpr(scope, *a->expr);
else if (expr.is<AstExprConstantNil>())
result = {nilType};
else if (expr.is<AstExprConstantBool>())
else if (const AstExprConstantBool* bexpr = expr.as<AstExprConstantBool>())
{
if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType))
result = {singletonType(bexpr->value)};
else
result = {booleanType};
}
else if (const AstExprConstantString* sexpr = expr.as<AstExprConstantString>())
{
if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType))
result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))};
else
result = {stringType};
}
else if (expr.is<AstExprConstantNumber>())
result = {numberType};
else if (expr.is<AstExprConstantString>())
result = {stringType};
else if (auto a = expr.as<AstExprLocal>())
result = checkExpr(scope, *a);
else if (auto a = expr.as<AstExprGlobal>())
@ -1485,7 +1496,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLo
// TODO: tempting to ice here, but this breaks very often because our toposort doesn't enforce this constraint
// ice("AstExprLocal exists but no binding definition for it?", expr.location);
reportError(TypeError{expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}});
return {errorType};
return {errorRecoveryType(scope)};
}
ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr)
@ -1497,7 +1508,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGl
return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}};
reportError(TypeError{expr.location, UnknownSymbol{expr.name.value, UnknownSymbol::Binding}});
return {errorType};
return {errorRecoveryType(scope)};
}
ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr)
@ -1517,14 +1528,14 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa
return {head};
}
if (get<ErrorTypeVar>(varargPack))
return {errorType};
return {errorRecoveryType(scope)};
else if (auto vtp = get<VariadicTypePack>(varargPack))
return {vtp->ty};
else if (get<Unifiable::Generic>(varargPack))
{
// TODO: Better error?
reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"});
return {errorType};
return {errorRecoveryType(scope)};
}
else
ice("Unknown TypePack type in checkExpr(AstExprVarargs)!");
@ -1547,7 +1558,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa
return {head, std::move(result.predicates)};
}
if (get<Unifiable::Error>(retPack))
return {errorType, std::move(result.predicates)};
return {errorRecoveryType(scope), std::move(result.predicates)};
else if (auto vtp = get<VariadicTypePack>(retPack))
return {vtp->ty, std::move(result.predicates)};
else if (get<Unifiable::Generic>(retPack))
@ -1572,7 +1583,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true))
return {*ty};
return {errorType};
return {errorRecoveryType(scope)};
}
std::optional<TypeId> TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location)
@ -1876,6 +1887,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa
std::vector<std::pair<TypeId, TypeId>> fieldTypes(expr.items.size);
const TableTypeVar* expectedTable = nullptr;
const UnionTypeVar* expectedUnion = nullptr;
std::optional<TypeId> expectedIndexType;
std::optional<TypeId> expectedIndexResultType;
@ -1894,6 +1906,9 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa
}
}
}
else if (FFlag::LuauExpectedTypesOfProperties)
if (const UnionTypeVar* utv = get<UnionTypeVar>(follow(*expectedType)))
expectedUnion = utv;
}
for (size_t i = 0; i < expr.items.size; ++i)
@ -1916,6 +1931,18 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa
if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end())
expectedResultType = prop->second.type;
}
else if (FFlag::LuauExpectedTypesOfProperties && expectedUnion)
{
std::vector<TypeId> expectedResultTypes;
for (TypeId expectedOption : expectedUnion)
if (const TableTypeVar* ttv = get<TableTypeVar>(follow(expectedOption)))
if (auto prop = ttv->props.find(key->value.data); prop != ttv->props.end())
expectedResultTypes.push_back(prop->second.type);
if (expectedResultTypes.size() == 1)
expectedResultType = expectedResultTypes[0];
else if (expectedResultTypes.size() > 1)
expectedResultType = addType(UnionTypeVar{expectedResultTypes});
}
}
else
{
@ -1958,21 +1985,22 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn
{
TypeId actualFunctionType = instantiate(scope, *fnt, expr.location);
TypePackId arguments = addTypePack({operandType});
TypePackId retType = freshTypePack(scope);
TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retType));
TypePackId retTypePack = freshTypePack(scope);
TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack));
Unifier state = mkUnifier(expr.location);
state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true);
TypeId retType = first(retTypePack).value_or(nilType);
if (!state.errors.empty())
return {errorType};
retType = errorRecoveryType(retType);
return {first(retType).value_or(nilType)};
return {retType};
}
reportError(expr.location,
GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())});
return {errorType};
return {errorRecoveryType(scope)};
}
reportErrors(tryUnify(numberType, operandType, expr.location));
@ -1984,7 +2012,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn
operandType = stripFromNilAndReport(operandType, expr.location);
if (get<ErrorTypeVar>(operandType))
return {errorType};
return {errorRecoveryType(scope)};
if (get<AnyTypeVar>(operandType))
return {numberType}; // Not strictly correct: metatables permit overriding this
@ -2044,7 +2072,7 @@ TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const Location& location, b
if (unify(a, b, location))
return a;
return errorType;
return errorRecoveryType(anyType);
}
if (*a == *b)
@ -2166,11 +2194,13 @@ TypeId TypeChecker::checkRelationalOperation(
std::optional<TypeId> leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType));
std::optional<TypeId> rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType));
// TODO: this check seems odd, the second part is redundant
// is it meant to be if (leftMetatable && rightMetatable && leftMetatable != rightMetatable)
if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable)
{
reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable",
toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())});
return errorType;
return errorRecoveryType(booleanType);
}
if (leftMetatable)
@ -2188,7 +2218,7 @@ TypeId TypeChecker::checkRelationalOperation(
if (!state.errors.empty())
{
reportError(expr.location, GenericError{format("Metamethod '%s' must return type 'boolean'", metamethodName.c_str())});
return errorType;
return errorRecoveryType(booleanType);
}
}
}
@ -2206,7 +2236,7 @@ TypeId TypeChecker::checkRelationalOperation(
{
reportError(
expr.location, GenericError{format("Table %s does not offer metamethod %s", toString(lhsType).c_str(), metamethodName.c_str())});
return errorType;
return errorRecoveryType(booleanType);
}
}
@ -2214,14 +2244,14 @@ TypeId TypeChecker::checkRelationalOperation(
{
auto name = getIdentifierOfBaseVar(expr.left);
reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison});
return errorType;
return errorRecoveryType(booleanType);
}
if (needsMetamethod)
{
reportError(expr.location, GenericError{format("Type %s cannot be compared with %s because it has no metatable",
toString(lhsType).c_str(), toString(expr.op).c_str())});
return errorType;
return errorRecoveryType(booleanType);
}
return booleanType;
@ -2266,7 +2296,8 @@ TypeId TypeChecker::checkBinaryOperation(
{
auto name = getIdentifierOfBaseVar(expr.left);
reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation});
return errorType;
if (!FFlag::LuauErrorRecoveryType)
return errorRecoveryType(scope);
}
// If we know nothing at all about the lhs type, we can usually say nothing about the result.
@ -2296,18 +2327,33 @@ TypeId TypeChecker::checkBinaryOperation(
auto checkMetatableCall = [this, &scope, &expr](TypeId fnt, TypeId lhst, TypeId rhst) -> TypeId {
TypeId actualFunctionType = instantiate(scope, fnt, expr.location);
TypePackId arguments = addTypePack({lhst, rhst});
TypePackId retType = freshTypePack(scope);
TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retType));
TypePackId retTypePack = freshTypePack(scope);
TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack));
Unifier state = mkUnifier(expr.location);
state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true);
reportErrors(state.errors);
bool hasErrors = !state.errors.empty();
if (FFlag::LuauErrorRecoveryType && hasErrors)
{
// If there are unification errors, the return type may still be unknown
// so we loosen the argument types to see if that helps.
TypePackId fallbackArguments = freshTypePack(scope);
TypeId fallbackFunctionType = addType(FunctionTypeVar(scope->level, fallbackArguments, retTypePack));
state.log.rollback();
state.errors.clear();
state.tryUnify(fallbackFunctionType, actualFunctionType, /*isFunctionCall*/ true);
if (!state.errors.empty())
return errorType;
state.log.rollback();
}
return first(retType).value_or(nilType);
TypeId retType = first(retTypePack).value_or(nilType);
if (hasErrors)
retType = errorRecoveryType(retType);
return retType;
};
std::string op = opToMetaTableEntry(expr.op);
@ -2321,7 +2367,8 @@ TypeId TypeChecker::checkBinaryOperation(
reportError(expr.location, GenericError{format("Binary operator '%s' not supported by types '%s' and '%s'", toString(expr.op).c_str(),
toString(lhsType).c_str(), toString(rhsType).c_str())});
return errorType;
return errorRecoveryType(scope);
}
switch (expr.op)
@ -2414,11 +2461,9 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy
ExprResult<TypeId> result = checkExpr(scope, *expr.expr, annotationType);
ErrorVec errorVec = canUnify(result.type, annotationType, expr.location);
if (!errorVec.empty())
{
reportErrors(errorVec);
return {errorType, std::move(result.predicates)};
}
if (!errorVec.empty())
annotationType = errorRecoveryType(annotationType);
return {annotationType, std::move(result.predicates)};
}
@ -2434,7 +2479,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr
// any type errors that may arise from it are going to be useless.
currentModule->errors.resize(oldSize);
return {errorType};
return {errorRecoveryType(scope)};
}
ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr)
@ -2476,7 +2521,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
{
for (AstExpr* expr : a->expressions)
checkExpr(scope, *expr);
return std::pair(errorType, nullptr);
return {errorRecoveryType(scope), nullptr};
}
else
ice("Unexpected AST node in checkLValue", expr.location);
@ -2488,7 +2533,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
return {*ty, nullptr};
reportError(expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding});
return {errorType, nullptr};
return {errorRecoveryType(scope), nullptr};
}
std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr)
@ -2545,24 +2590,25 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
{
Unifier state = mkUnifier(expr.location);
state.tryUnify(indexer->indexType, stringType);
TypeId retType = indexer->indexResultType;
if (!state.errors.empty())
{
state.log.rollback();
reportError(expr.location, UnknownProperty{lhs, name});
return std::pair(errorType, nullptr);
retType = errorRecoveryType(retType);
}
return std::pair(indexer->indexResultType, nullptr);
return std::pair(retType, nullptr);
}
else if (lhsTable->state == TableState::Sealed)
{
reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}});
return std::pair(errorType, nullptr);
return std::pair(errorRecoveryType(scope), nullptr);
}
else
{
reportError(TypeError{expr.location, GenericError{"Internal error: generic tables are not lvalues"}});
return std::pair(errorType, nullptr);
return std::pair(errorRecoveryType(scope), nullptr);
}
}
else if (const ClassTypeVar* lhsClass = get<ClassTypeVar>(lhs))
@ -2571,7 +2617,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
if (!prop)
{
reportError(TypeError{expr.location, UnknownProperty{lhs, name}});
return std::pair(errorType, nullptr);
return std::pair(errorRecoveryType(scope), nullptr);
}
return std::pair(prop->type, nullptr);
@ -2585,12 +2631,12 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
if (isTableIntersection(lhs))
{
reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}});
return std::pair(errorType, nullptr);
return std::pair(errorRecoveryType(scope), nullptr);
}
}
reportError(TypeError{expr.location, NotATable{lhs}});
return std::pair(errorType, nullptr);
return std::pair(errorRecoveryType(scope), nullptr);
}
std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr)
@ -2615,7 +2661,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
if (!prop)
{
reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}});
return std::pair(errorType, nullptr);
return std::pair(errorRecoveryType(scope), nullptr);
}
return std::pair(prop->type, nullptr);
}
@ -2626,7 +2672,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
if (!exprTable)
{
reportError(TypeError{expr.expr->location, NotATable{exprType}});
return std::pair(errorType, nullptr);
return std::pair(errorRecoveryType(scope), nullptr);
}
if (value)
@ -2678,7 +2724,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName)
if (isNonstrictMode())
return globalScope->bindings[name].typeId;
return errorType;
return errorRecoveryType(scope);
}
else
{
@ -2705,20 +2751,21 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName)
TableTypeVar* ttv = getMutableTableType(lhsType);
if (!ttv)
{
if (!isTableIntersection(lhsType))
if (!FFlag::LuauErrorRecoveryType && !isTableIntersection(lhsType))
// This error now gets reported when we check the function body.
reportError(TypeError{funName.location, OnlyTablesCanHaveMethods{lhsType}});
return errorType;
return errorRecoveryType(scope);
}
// Cannot extend sealed table, but we dont report an error here because it will be reported during AstStatFunction check
if (lhsType->persistent || ttv->state == TableState::Sealed)
return errorType;
return errorRecoveryType(scope);
Name name = indexName->index.value;
if (ttv->props.count(name))
return errorType;
return errorRecoveryType(scope);
Property& property = ttv->props[name];
@ -2728,9 +2775,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName)
return property.type;
}
else if (funName.is<AstExprError>())
{
return errorType;
}
return errorRecoveryType(scope);
else
{
ice("Unexpected AST node type", funName.location);
@ -2991,7 +3036,7 @@ ExprResult<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, const A
else if (expr.is<AstExprVarargs>())
{
if (!scope->varargPack)
return {addTypePack({addType(ErrorTypeVar())})};
return {errorRecoveryTypePack(scope)};
return {*scope->varargPack};
}
@ -3095,10 +3140,9 @@ void TypeChecker::checkArgumentList(
if (get<Unifiable::Error>(tail))
{
// Unify remaining parameters so we don't leave any free-types hanging around.
TypeId argTy = errorType;
while (paramIter != endIter)
{
state.tryUnify(*paramIter, argTy);
state.tryUnify(*paramIter, errorRecoveryType(anyType));
++paramIter;
}
return;
@ -3157,7 +3201,7 @@ void TypeChecker::checkArgumentList(
{
while (argIter != endIter)
{
unify(*argIter, errorType, state.location);
unify(*argIter, errorRecoveryType(scope), state.location);
++argIter;
}
// For this case, we want the error span to cover every errant extra parameter
@ -3246,7 +3290,8 @@ ExprResult<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, const A
// For each overload
// Compare parameter and argument types
// Report any errors (also speculate dot vs colon warnings!)
// If there are no errors, return the resulting return type
// Return the resulting return type (even if there are errors)
// If there are no matching overloads, unify with (a...) -> (b...) and return b...
TypeId selfType = nullptr;
TypeId functionType = nullptr;
@ -3268,8 +3313,8 @@ ExprResult<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, const A
}
else
{
functionType = errorType;
actualFunctionType = errorType;
functionType = errorRecoveryType(scope);
actualFunctionType = functionType;
}
}
else
@ -3296,7 +3341,7 @@ ExprResult<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, const A
TypePackId argPack = argListResult.type;
if (get<Unifiable::Error>(argPack))
return ExprResult<TypePackId>{errorTypePack};
return {errorRecoveryTypePack(scope)};
TypePack* args = getMutable<TypePack>(argPack);
LUAU_ASSERT(args != nullptr);
@ -3314,19 +3359,34 @@ ExprResult<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, const A
std::vector<OverloadErrorEntry> errors; // errors encountered for each overload
std::vector<TypeId> overloadsThatMatchArgCount;
std::vector<TypeId> overloadsThatDont;
for (TypeId fn : overloads)
{
fn = follow(fn);
if (auto ret = checkCallOverload(scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, errors))
if (auto ret = checkCallOverload(
scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors))
return *ret;
}
if (handleSelfCallMismatch(scope, expr, args, argLocations, errors))
return {retPack};
return reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors);
reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors);
if (FFlag::LuauErrorRecoveryType)
{
const FunctionTypeVar* overload = nullptr;
if (!overloadsThatMatchArgCount.empty())
overload = get<FunctionTypeVar>(overloadsThatMatchArgCount[0]);
if (!overload && !overloadsThatDont.empty())
overload = get<FunctionTypeVar>(overloadsThatDont[0]);
if (overload)
return {errorRecoveryTypePack(overload->retType)};
}
return {errorRecoveryTypePack(retPack)};
}
std::vector<std::optional<TypeId>> TypeChecker::getExpectedTypesForCall(const std::vector<TypeId>& overloads, size_t argumentCount, bool selfCall)
@ -3382,7 +3442,7 @@ std::vector<std::optional<TypeId>> TypeChecker::getExpectedTypesForCall(const st
std::optional<ExprResult<TypePackId>> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack,
TypePackId argPack, TypePack* args, const std::vector<Location>& argLocations, const ExprResult<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<OverloadErrorEntry>& errors)
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<TypeId>& overloadsThatDont, std::vector<OverloadErrorEntry>& errors)
{
fn = stripFromNilAndReport(fn, expr.func->location);
@ -3394,7 +3454,7 @@ std::optional<ExprResult<TypePackId>> TypeChecker::checkCallOverload(const Scope
if (get<ErrorTypeVar>(fn))
{
return {{addTypePack(TypePackVar{Unifiable::Error{}})}};
return {{errorRecoveryTypePack(scope)}};
}
if (get<FreeTypeVar>(fn))
@ -3427,14 +3487,14 @@ std::optional<ExprResult<TypePackId>> TypeChecker::checkCallOverload(const Scope
TypeId fn = *ty;
fn = instantiate(scope, fn, expr.func->location);
return checkCallOverload(
scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, overloadsThatMatchArgCount, errors);
return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult,
overloadsThatMatchArgCount, overloadsThatDont, errors);
}
}
reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}});
unify(retPack, errorTypePack, expr.func->location);
return {{errorTypePack}};
unify(retPack, errorRecoveryTypePack(scope), expr.func->location);
return {{errorRecoveryTypePack(retPack)}};
}
// When this function type has magic functions and did return something, we select that overload instead.
@ -3476,6 +3536,8 @@ std::optional<ExprResult<TypePackId>> TypeChecker::checkCallOverload(const Scope
if (!argMismatch)
overloadsThatMatchArgCount.push_back(fn);
else if (FFlag::LuauErrorRecoveryType)
overloadsThatDont.push_back(fn);
errors.emplace_back(std::move(state.errors), args->head, ftv);
state.log.rollback();
@ -3586,14 +3648,14 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal
return false;
}
ExprResult<TypePackId> TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack,
TypePackId argPack, const std::vector<Location>& argLocations, const std::vector<TypeId>& overloads,
const std::vector<TypeId>& overloadsThatMatchArgCount, const std::vector<OverloadErrorEntry>& errors)
void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack,
const std::vector<Location>& argLocations, const std::vector<TypeId>& overloads, const std::vector<TypeId>& overloadsThatMatchArgCount,
const std::vector<OverloadErrorEntry>& errors)
{
if (overloads.size() == 1)
{
reportErrors(std::get<0>(errors.front()));
return {errorTypePack};
return;
}
std::vector<TypeId> overloadTypes = overloadsThatMatchArgCount;
@ -3622,7 +3684,7 @@ ExprResult<TypePackId> TypeChecker::reportOverloadResolutionError(const ScopePtr
// If only one overload matched, we don't need this error because we provided the previous errors.
if (overloadsThatMatchArgCount.size() == 1)
return {errorTypePack};
return;
}
std::string s;
@ -3655,7 +3717,7 @@ ExprResult<TypePackId> TypeChecker::reportOverloadResolutionError(const ScopePtr
reportError(expr.func->location, ExtraInformation{"Other overloads are also not viable: " + s});
// No viable overload
return {errorTypePack};
return;
}
ExprResult<TypePackId> TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray<AstExpr*>& exprs,
@ -3740,7 +3802,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module
if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict)
{
reportError(TypeError{location, UnknownRequire{}});
return errorType;
return errorRecoveryType(anyType);
}
return anyType;
@ -3758,14 +3820,14 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module
reportError(TypeError{location, UnknownRequire{reportedModulePath}});
}
return errorType;
return errorRecoveryType(scope);
}
if (module->type != SourceCode::Module)
{
std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name);
reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."});
return errorType;
return errorRecoveryType(scope);
}
std::optional<TypeId> moduleType = first(module->getModuleScope()->returnType);
@ -3773,7 +3835,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module
{
std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name);
reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."});
return errorType;
return errorRecoveryType(scope);
}
SeenTypes seenTypes;
@ -4078,7 +4140,7 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location
if (!qty.has_value())
{
reportError(location, UnificationTooComplex{});
return errorType;
return errorRecoveryType(scope);
}
if (ty == *qty)
@ -4101,7 +4163,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat
else
{
reportError(location, UnificationTooComplex{});
return errorType;
return errorRecoveryType(scope);
}
}
@ -4116,7 +4178,7 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location)
else
{
reportError(location, UnificationTooComplex{});
return errorType;
return errorRecoveryType(anyType);
}
}
@ -4131,7 +4193,7 @@ TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location lo
else
{
reportError(location, UnificationTooComplex{});
return errorTypePack;
return errorRecoveryTypePack(anyTypePack);
}
}
@ -4279,6 +4341,38 @@ TypeId TypeChecker::freshType(TypeLevel level)
return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level)));
}
TypeId TypeChecker::singletonType(bool value)
{
// TODO: cache singleton types
return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BoolSingleton{value})));
}
TypeId TypeChecker::singletonType(std::string value)
{
// TODO: cache singleton types
return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(StringSingleton{std::move(value)})));
}
TypeId TypeChecker::errorRecoveryType(const ScopePtr& scope)
{
return singletonTypes.errorRecoveryType();
}
TypeId TypeChecker::errorRecoveryType(TypeId guess)
{
return singletonTypes.errorRecoveryType(guess);
}
TypePackId TypeChecker::errorRecoveryTypePack(const ScopePtr& scope)
{
return singletonTypes.errorRecoveryTypePack();
}
TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess)
{
return singletonTypes.errorRecoveryTypePack(guess);
}
std::optional<TypeId> TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate)
{
std::vector<TypeId> types = Luau::filterMap(type, predicate);
@ -4350,7 +4444,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
if (lit->parameters.size != 1 || !lit->parameters.data[0].type)
{
reportError(TypeError{annotation.location, GenericError{"_luau_print requires one generic parameter"}});
return addType(ErrorTypeVar{});
return errorRecoveryType(anyType);
}
ToStringOptions opts;
@ -4368,7 +4462,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
if (!tf)
{
if (lit->name == Parser::errorName)
return addType(ErrorTypeVar{});
return errorRecoveryType(scope);
std::string typeName;
if (lit->hasPrefix)
@ -4380,7 +4474,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
else
reportError(TypeError{annotation.location, UnknownSymbol{typeName, UnknownSymbol::Type}});
return addType(ErrorTypeVar{});
return errorRecoveryType(scope);
}
if (lit->parameters.size == 0 && tf->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf->typePackParams.empty()))
@ -4390,14 +4484,17 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
else if (!FFlag::LuauTypeAliasPacks && lit->parameters.size != tf->typeParams.size())
{
reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->parameters.size, 0}});
return addType(ErrorTypeVar{});
if (!FFlag::LuauErrorRecoveryType)
return errorRecoveryType(scope);
}
else if (FFlag::LuauTypeAliasPacks)
if (FFlag::LuauTypeAliasPacks)
{
if (!lit->hasParameterList && !tf->typePackParams.empty())
{
reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}});
return addType(ErrorTypeVar{});
if (!FFlag::LuauErrorRecoveryType)
return errorRecoveryType(scope);
}
std::vector<TypeId> typeParams;
@ -4445,7 +4542,17 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
{
reportError(
TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}});
return addType(ErrorTypeVar{});
if (FFlag::LuauErrorRecoveryType)
{
// Pad the types out with error recovery types
while (typeParams.size() < tf->typeParams.size())
typeParams.push_back(errorRecoveryType(scope));
while (typePackParams.size() < tf->typePackParams.size())
typePackParams.push_back(errorRecoveryTypePack(scope));
}
else
return errorRecoveryType(scope);
}
if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams)
@ -4464,6 +4571,14 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
for (const auto& param : lit->parameters)
typeParams.push_back(resolveType(scope, *param.type));
if (FFlag::LuauErrorRecoveryType)
{
// If there aren't enough type parameters, pad them out with error recovery types
// (we've already reported the error)
while (typeParams.size() < lit->parameters.size)
typeParams.push_back(errorRecoveryType(scope));
}
if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams)
{
// If the generic parameters and the type arguments are the same, we are about to
@ -4483,8 +4598,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
props[prop.name.value] = {resolveType(scope, *prop.type)};
if (const auto& indexer = table->indexer)
tableIndexer = TableIndexer(
resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType));
tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType));
return addType(TableTypeVar{
props, tableIndexer, scope->level,
@ -4536,14 +4650,20 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
return addType(IntersectionTypeVar{types});
}
else if (annotation.is<AstTypeError>())
else if (const auto& tsb = annotation.as<AstTypeSingletonBool>())
{
return addType(ErrorTypeVar{});
return singletonType(tsb->value);
}
else if (const auto& tss = annotation.as<AstTypeSingletonString>())
{
return singletonType(std::string(tss->value.data, tss->value.size));
}
else if (annotation.is<AstTypeError>())
return errorRecoveryType(scope);
else
{
reportError(TypeError{annotation.location, GenericError{"Unknown type annotation?"}});
return addType(ErrorTypeVar{});
return errorRecoveryType(scope);
}
}
@ -4584,7 +4704,7 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack
else
reportError(TypeError{generic->location, UnknownSymbol{genericName, UnknownSymbol::Type}});
return addTypePack(TypePackVar{Unifiable::Error{}});
return errorRecoveryTypePack(scope);
}
return *genericTy;
@ -4706,12 +4826,12 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf,
if (!maybeInstantiated.has_value())
{
reportError(location, UnificationTooComplex{});
return errorType;
return errorRecoveryType(scope);
}
if (FFlag::LuauRecursiveTypeParameterRestriction && applyTypeFunction.encounteredForwardedType)
{
reportError(TypeError{location, GenericError{"Recursive type being used with different parameters"}});
return errorType;
return errorRecoveryType(scope);
}
TypeId instantiated = *maybeInstantiated;
@ -4773,8 +4893,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf,
return instantiated;
}
std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGenericTypes(
const ScopePtr& scope, std::optional<TypeLevel> levelOpt, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames)
std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional<TypeLevel> levelOpt,
const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames)
{
LUAU_ASSERT(scope->parent);
@ -5043,7 +5163,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement
addRefinement(refis, isaP.lvalue, *result);
else
{
addRefinement(refis, isaP.lvalue, errorType);
addRefinement(refis, isaP.lvalue, errorRecoveryType(scope));
errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}});
}
}
@ -5107,7 +5227,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec
addRefinement(refis, typeguardP.lvalue, *result);
else
{
addRefinement(refis, typeguardP.lvalue, errorType);
addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope));
if (sense)
errVec.push_back(
TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}});
@ -5118,7 +5238,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec
auto fail = [&](const TypeErrorData& err) {
errVec.push_back(TypeError{typeguardP.location, err});
addRefinement(refis, typeguardP.lvalue, errorType);
addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope));
};
if (!typeguardP.isTypeof)
@ -5139,28 +5259,6 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec
return resolve(IsAPredicate{std::move(typeguardP.lvalue), typeguardP.location, type}, errVec, refis, scope, sense);
}
void TypeChecker::DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense)
{
if (!sense)
return;
static std::vector<std::string> primitives{
"string", "number", "boolean", "nil", "thread",
"table", // no op. Requires special handling.
"function", // no op. Requires special handling.
"userdata", // no op. Requires special handling.
};
if (auto typeFun = globalScope->lookupType(typeguardP.kind);
typeFun && typeFun->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || typeFun->typePackParams.empty()))
{
if (auto it = std::find(primitives.begin(), primitives.end(), typeguardP.kind); it != primitives.end())
addRefinement(refis, typeguardP.lvalue, typeFun->type);
else if (typeguardP.isTypeof)
addRefinement(refis, typeguardP.lvalue, typeFun->type);
}
}
void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense)
{
// This refinement will require success typing to do everything correctly. For now, we can get most of the way there.

View file

@ -286,5 +286,4 @@ TypePack* asMutable(const TypePack* tp)
{
return const_cast<TypePack*>(tp);
}
} // namespace Luau

View file

@ -21,6 +21,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500)
LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0)
LUAU_FASTFLAG(LuauTypeAliasPacks)
LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false)
LUAU_FASTFLAG(LuauErrorRecoveryType)
namespace Luau
{
@ -305,6 +306,18 @@ bool maybeGeneric(TypeId ty)
return isGeneric(ty);
}
bool maybeSingleton(TypeId ty)
{
ty = follow(ty);
if (get<SingletonTypeVar>(ty))
return true;
if (const UnionTypeVar* utv = get<UnionTypeVar>(ty))
for (TypeId option : utv)
if (get<SingletonTypeVar>(follow(option)))
return true;
return false;
}
FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional<FunctionDefinition> defn, bool hasSelf)
: argTypes(argTypes)
, retType(retType)
@ -562,10 +575,8 @@ SingletonTypes::SingletonTypes()
, booleanType(&booleanType_)
, threadType(&threadType_)
, anyType(&anyType_)
, errorType(&errorType_)
, optionalNumberType(&optionalNumberType_)
, anyTypePack(&anyTypePack_)
, errorTypePack(&errorTypePack_)
, arena(new TypeArena)
{
TypeId stringMetatable = makeStringMetatable();
@ -634,6 +645,32 @@ TypeId SingletonTypes::makeStringMetatable()
return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
}
TypeId SingletonTypes::errorRecoveryType()
{
return &errorType_;
}
TypePackId SingletonTypes::errorRecoveryTypePack()
{
return &errorTypePack_;
}
TypeId SingletonTypes::errorRecoveryType(TypeId guess)
{
if (FFlag::LuauErrorRecoveryType)
return guess;
else
return &errorType_;
}
TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess)
{
if (FFlag::LuauErrorRecoveryType)
return guess;
else
return &errorTypePack_;
}
SingletonTypes singletonTypes;
void persist(TypeId ty)
@ -1141,6 +1178,11 @@ struct QVarFinder
return false;
}
bool operator()(const SingletonTypeVar&) const
{
return false;
}
bool operator()(const FunctionTypeVar& ftv) const
{
if (hasGeneric(ftv.argTypes))
@ -1412,7 +1454,7 @@ static std::vector<TypeId> parseFormatString(TypeChecker& typechecker, const cha
else if (strchr(options, data[i]))
result.push_back(typechecker.numberType);
else
result.push_back(typechecker.errorType);
result.push_back(typechecker.errorRecoveryType(typechecker.anyType));
}
}

View file

@ -22,7 +22,9 @@ LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false)
LUAU_FASTFLAG(LuauShareTxnSeen);
LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false)
LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false)
LUAU_FASTFLAG(LuauSingletonTypes)
LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false)
LUAU_FASTFLAG(LuauErrorRecoveryType);
namespace Luau
{
@ -211,6 +213,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
{
occursCheck(subTy, superTy);
// The occurrence check might have caused superTy no longer to be a free type
if (!get<ErrorTypeVar>(subTy))
{
log(subTy);
@ -221,10 +224,20 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
}
else if (l && r)
{
if (!FFlag::LuauErrorRecoveryType)
log(superTy);
occursCheck(superTy, subTy);
r->level = min(r->level, l->level);
// The occurrence check might have caused superTy no longer to be a free type
if (!FFlag::LuauErrorRecoveryType)
*asMutable(superTy) = BoundTypeVar(subTy);
else if (!get<ErrorTypeVar>(superTy))
{
log(superTy);
*asMutable(superTy) = BoundTypeVar(subTy);
}
return;
}
else if (l)
@ -240,6 +253,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
return;
}
// The occurrence check might have caused superTy no longer to be a free type
if (!get<ErrorTypeVar>(superTy))
{
if (auto rightLevel = getMutableLevel(subTy))
@ -251,6 +265,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
log(superTy);
*asMutable(superTy) = BoundTypeVar(subTy);
}
return;
}
else if (r)
@ -512,6 +527,9 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
else if (get<PrimitiveTypeVar>(superTy) && get<PrimitiveTypeVar>(subTy))
tryUnifyPrimitives(superTy, subTy);
else if (FFlag::LuauSingletonTypes && (get<PrimitiveTypeVar>(superTy) || get<SingletonTypeVar>(superTy)) && get<SingletonTypeVar>(subTy))
tryUnifySingletons(superTy, subTy);
else if (get<FunctionTypeVar>(superTy) && get<FunctionTypeVar>(subTy))
tryUnifyFunctions(superTy, subTy, isFunctionCall);
@ -723,17 +741,18 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal
{
occursCheck(superTp, subTp);
// The occurrence check might have caused superTp no longer to be a free type
if (!get<ErrorTypeVar>(superTp))
{
log(superTp);
*asMutable(superTp) = Unifiable::Bound<TypePackId>(subTp);
}
}
else if (get<Unifiable::Free>(subTp))
{
occursCheck(subTp, superTp);
// The occurrence check might have caused superTp no longer to be a free type
if (!get<ErrorTypeVar>(subTp))
{
log(subTp);
@ -874,13 +893,13 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal
while (superIter.good())
{
tryUnify_(singletonTypes.errorType, *superIter);
tryUnify_(singletonTypes.errorRecoveryType(), *superIter);
superIter.advance();
}
while (subIter.good())
{
tryUnify_(singletonTypes.errorType, *subIter);
tryUnify_(singletonTypes.errorRecoveryType(), *subIter);
subIter.advance();
}
@ -906,6 +925,27 @@ void Unifier::tryUnifyPrimitives(TypeId superTy, TypeId subTy)
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}});
}
void Unifier::tryUnifySingletons(TypeId superTy, TypeId subTy)
{
const PrimitiveTypeVar* lp = get<PrimitiveTypeVar>(superTy);
const SingletonTypeVar* ls = get<SingletonTypeVar>(superTy);
const SingletonTypeVar* rs = get<SingletonTypeVar>(subTy);
if ((!lp && !ls) || !rs)
ice("passed non singleton/primitive types to unifySingletons");
if (ls && *ls == *rs)
return;
if (lp && lp->type == PrimitiveTypeVar::Boolean && get<BoolSingleton>(rs) && variance == Covariant)
return;
if (lp && lp->type == PrimitiveTypeVar::String && get<StringSingleton>(rs) && variance == Covariant)
return;
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}});
}
void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall)
{
FunctionTypeVar* lf = getMutable<FunctionTypeVar>(superTy);
@ -1023,7 +1063,8 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
}
// And vice versa if we're invariant
if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && lt->state != TableState::Free)
if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed &&
lt->state != TableState::Free)
{
for (const auto& [propName, subProp] : rt->props)
{
@ -1634,9 +1675,8 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed)
{
ok = false;
errors.push_back(TypeError{location, UnknownProperty{superTy, propName}});
if (!FFlag::LuauExtendedClassMismatchError)
tryUnify_(prop.type, singletonTypes.errorType);
tryUnify_(prop.type, singletonTypes.errorRecoveryType());
}
else
{
@ -1952,7 +1992,7 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty)
{
LUAU_ASSERT(get<Unifiable::Error>(any));
const TypeId anyTy = singletonTypes.errorType;
const TypeId anyTy = singletonTypes.errorRecoveryType();
if (FFlag::LuauTypecheckOpts)
{
@ -2046,7 +2086,7 @@ void Unifier::occursCheck(std::unordered_set<TypeId>& seen_DEPRECATED, DenseHash
{
errors.push_back(TypeError{location, OccursCheckFailed{}});
log(needle);
*asMutable(needle) = ErrorTypeVar{};
*asMutable(needle) = *singletonTypes.errorRecoveryType();
return;
}
@ -2134,7 +2174,7 @@ void Unifier::occursCheck(std::unordered_set<TypePackId>& seen_DEPRECATED, Dense
{
errors.push_back(TypeError{location, OccursCheckFailed{}});
log(needle);
*asMutable(needle) = ErrorTypeVar{};
*asMutable(needle) = *singletonTypes.errorRecoveryTypePack();
return;
}

View file

@ -255,6 +255,14 @@ public:
{
return visit((class AstType*)node);
}
virtual bool visit(class AstTypeSingletonBool* node)
{
return visit((class AstType*)node);
}
virtual bool visit(class AstTypeSingletonString* node)
{
return visit((class AstType*)node);
}
virtual bool visit(class AstTypeError* node)
{
return visit((class AstType*)node);
@ -1158,6 +1166,30 @@ public:
unsigned messageIndex;
};
class AstTypeSingletonBool : public AstType
{
public:
LUAU_RTTI(AstTypeSingletonBool)
AstTypeSingletonBool(const Location& location, bool value);
void visit(AstVisitor* visitor) override;
bool value;
};
class AstTypeSingletonString : public AstType
{
public:
LUAU_RTTI(AstTypeSingletonString)
AstTypeSingletonString(const Location& location, const AstArray<char>& value);
void visit(AstVisitor* visitor) override;
const AstArray<char> value;
};
class AstTypePack : public AstNode
{
public:

View file

@ -286,6 +286,7 @@ private:
// `<' typeAnnotation[, ...] `>'
AstArray<AstTypeOrPack> parseTypeParams();
std::optional<AstArray<char>> parseCharArray();
AstExpr* parseString();
AstLocal* pushLocal(const Binding& binding);

View file

@ -34,4 +34,6 @@ bool equalsLower(std::string_view lhs, std::string_view rhs);
size_t hashRange(const char* data, size_t size);
std::string escape(std::string_view s);
bool isIdentifier(std::string_view s);
} // namespace Luau

View file

@ -841,6 +841,28 @@ void AstTypeIntersection::visit(AstVisitor* visitor)
}
}
AstTypeSingletonBool::AstTypeSingletonBool(const Location& location, bool value)
: AstType(ClassIndex(), location)
, value(value)
{
}
void AstTypeSingletonBool::visit(AstVisitor* visitor)
{
visitor->visit(this);
}
AstTypeSingletonString::AstTypeSingletonString(const Location& location, const AstArray<char>& value)
: AstType(ClassIndex(), location)
, value(value)
{
}
void AstTypeSingletonString::visit(AstVisitor* visitor)
{
visitor->visit(this);
}
AstTypeError::AstTypeError(const Location& location, const AstArray<AstType*>& types, bool isMissing, unsigned messageIndex)
: AstType(ClassIndex(), location)
, types(types)

View file

@ -16,6 +16,7 @@ LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false)
LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false)
LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false)
LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false)
LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false)
LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false)
namespace Luau
@ -1278,7 +1279,27 @@ AstType* Parser::parseTableTypeAnnotation()
while (lexer.current().type != '}')
{
if (lexer.current().type == '[')
if (FFlag::LuauParseSingletonTypes && lexer.current().type == '[' &&
(lexer.lookahead().type == Lexeme::RawString || lexer.lookahead().type == Lexeme::QuotedString))
{
const Lexeme begin = lexer.current();
nextLexeme(); // [
std::optional<AstArray<char>> chars = parseCharArray();
expectMatchAndConsume(']', begin);
expectAndConsume(':', "table field");
AstType* type = parseTypeAnnotation();
// TODO: since AstName conains a char*, it can't contain null
bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size);
if (chars && !containsNull)
props.push_back({AstName(chars->data), begin.location, type});
else
report(begin.location, "String literal contains malformed escape sequence");
}
else if (lexer.current().type == '[')
{
if (indexer)
{
@ -1528,6 +1549,32 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack)
nextLexeme();
return {allocator.alloc<AstTypeReference>(begin, std::nullopt, nameNil), {}};
}
else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedTrue)
{
nextLexeme();
return {allocator.alloc<AstTypeSingletonBool>(begin, true)};
}
else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedFalse)
{
nextLexeme();
return {allocator.alloc<AstTypeSingletonBool>(begin, false)};
}
else if (FFlag::LuauParseSingletonTypes && (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString))
{
if (std::optional<AstArray<char>> value = parseCharArray())
{
AstArray<char> svalue = *value;
return {allocator.alloc<AstTypeSingletonString>(begin, svalue)};
}
else
return {reportTypeAnnotationError(begin, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")};
}
else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::BrokenString)
{
Location location = lexer.current().location;
nextLexeme();
return {reportTypeAnnotationError(location, {}, /*isMissing*/ false, "Malformed string")};
}
else if (lexer.current().type == Lexeme::Name)
{
std::optional<AstName> prefix;
@ -2416,7 +2463,7 @@ AstArray<AstTypeOrPack> Parser::parseTypeParams()
return copy(parameters);
}
AstExpr* Parser::parseString()
std::optional<AstArray<char>> Parser::parseCharArray()
{
LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString);
@ -2426,11 +2473,8 @@ AstExpr* Parser::parseString()
{
if (!Lexer::fixupQuotedString(scratchData))
{
Location location = lexer.current().location;
nextLexeme();
return reportExprError(location, {}, "String literal contains malformed escape sequence");
return std::nullopt;
}
}
else
@ -2438,12 +2482,18 @@ AstExpr* Parser::parseString()
Lexer::fixupMultilineString(scratchData);
}
Location start = lexer.current().location;
AstArray<char> value = copy(scratchData);
nextLexeme();
return value;
}
return allocator.alloc<AstExprConstantString>(start, value);
AstExpr* Parser::parseString()
{
Location location = lexer.current().location;
if (std::optional<AstArray<char>> value = parseCharArray())
return allocator.alloc<AstExprConstantString>(location, *value);
else
return reportExprError(location, {}, "String literal contains malformed escape sequence");
}
AstLocal* Parser::pushLocal(const Binding& binding)

View file

@ -225,4 +225,62 @@ size_t hashRange(const char* data, size_t size)
return hash;
}
bool isIdentifier(std::string_view s)
{
return (s.find_first_not_of("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ01234567890_") == std::string::npos);
}
std::string escape(std::string_view s)
{
std::string r;
r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting
for (uint8_t c : s)
{
if (c >= ' ' && c != '\\' && c != '\'' && c != '\"')
r += c;
else
{
r += '\\';
switch (c)
{
case '\a':
r += 'a';
break;
case '\b':
r += 'b';
break;
case '\f':
r += 'f';
break;
case '\n':
r += 'n';
break;
case '\r':
r += 'r';
break;
case '\t':
r += 't';
break;
case '\v':
r += 'v';
break;
case '\'':
r += '\'';
break;
case '\"':
r += '\"';
break;
case '\\':
r += '\\';
break;
default:
Luau::formatAppend(r, "%03u", c);
}
}
}
return r;
}
} // namespace Luau

View file

@ -236,32 +236,12 @@ int main(int argc, char** argv)
Luau::registerBuiltinTypes(frontend.typeChecker);
Luau::freeze(frontend.typeChecker.globalTypes);
std::vector<std::string> files = getSourceFiles(argc, argv);
int failed = 0;
for (int i = 1; i < argc; ++i)
{
if (argv[i][0] == '-')
continue;
if (isDirectory(argv[i]))
{
traverseDirectory(argv[i], [&](const std::string& name) {
// Look for .luau first and if absent, fall back to .lua
if (name.length() > 5 && name.rfind(".luau") == name.length() - 5)
{
failed += !analyzeFile(frontend, name.c_str(), format, annotate);
}
else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4)
{
failed += !analyzeFile(frontend, name.c_str(), format, annotate);
}
});
}
else
{
failed += !analyzeFile(frontend, argv[i], format, annotate);
}
}
for (const std::string& path : files)
failed += !analyzeFile(frontend, path.c_str(), format, annotate);
if (!configResolver.configErrors.empty())
{

View file

@ -223,3 +223,40 @@ std::optional<std::string> getParentPath(const std::string& path)
return "";
}
static std::string getExtension(const std::string& path)
{
std::string::size_type dot = path.find_last_of(".\\/");
if (dot == std::string::npos || path[dot] != '.')
return "";
return path.substr(dot);
}
std::vector<std::string> getSourceFiles(int argc, char** argv)
{
std::vector<std::string> files;
for (int i = 1; i < argc; ++i)
{
if (argv[i][0] == '-')
continue;
if (isDirectory(argv[i]))
{
traverseDirectory(argv[i], [&](const std::string& name) {
std::string ext = getExtension(name);
if (ext == ".lua" || ext == ".luau")
files.push_back(name);
});
}
else
{
files.push_back(argv[i]);
}
}
return files;
}

View file

@ -4,6 +4,7 @@
#include <optional>
#include <string>
#include <functional>
#include <vector>
std::optional<std::string> readFile(const std::string& name);
@ -12,3 +13,5 @@ bool traverseDirectory(const std::string& path, const std::function<void(const s
std::string joinPaths(const std::string& lhs, const std::string& rhs);
std::optional<std::string> getParentPath(const std::string& path);
std::vector<std::string> getSourceFiles(int argc, char** argv);

View file

@ -20,7 +20,7 @@
enum class CompileFormat
{
Default,
Text,
Binary
};
@ -33,7 +33,7 @@ static int lua_loadstring(lua_State* L)
lua_setsafeenv(L, LUA_ENVIRONINDEX, false);
std::string bytecode = Luau::compile(std::string(s, l));
if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0)
if (luau_load(L, chunkname, bytecode.data(), bytecode.size(), 0) == 0)
return 1;
lua_pushnil(L);
@ -80,7 +80,7 @@ static int lua_require(lua_State* L)
// now we can compile & run module on the new thread
std::string bytecode = Luau::compile(*source);
if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0)
if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0)
{
int status = lua_resume(ML, L, 0);
@ -151,7 +151,7 @@ static std::string runCode(lua_State* L, const std::string& source)
{
std::string bytecode = Luau::compile(source);
if (luau_load(L, "=stdin", bytecode.data(), bytecode.size()) != 0)
if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0)
{
size_t len;
const char* msg = lua_tolstring(L, -1, &len);
@ -370,7 +370,7 @@ static bool runFile(const char* name, lua_State* GL)
std::string bytecode = Luau::compile(*source);
int status = 0;
if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0)
if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0)
{
status = lua_resume(L, NULL, 0);
}
@ -379,11 +379,7 @@ static bool runFile(const char* name, lua_State* GL)
status = LUA_ERRSYNTAX;
}
if (status == 0)
{
return true;
}
else
if (status != 0)
{
std::string error;
@ -400,8 +396,10 @@ static bool runFile(const char* name, lua_State* GL)
error += lua_debugtrace(L);
fprintf(stderr, "%s", error.c_str());
return false;
}
lua_pop(GL, 1);
return status == 0;
}
static void report(const char* name, const Luau::Location& location, const char* type, const char* message)
@ -431,14 +429,18 @@ static bool compileFile(const char* name, CompileFormat format)
try
{
Luau::BytecodeBuilder bcb;
if (format == CompileFormat::Text)
{
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source);
bcb.setDumpSource(*source);
}
Luau::compileOrThrow(bcb, *source);
switch (format)
{
case CompileFormat::Default:
case CompileFormat::Text:
printf("%s", bcb.dumpEverything().c_str());
break;
case CompileFormat::Binary:
@ -504,7 +506,7 @@ int main(int argc, char** argv)
if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0)
{
CompileFormat format = CompileFormat::Default;
CompileFormat format = CompileFormat::Text;
if (strcmp(argv[1], "--compile=binary") == 0)
format = CompileFormat::Binary;
@ -514,27 +516,12 @@ int main(int argc, char** argv)
_setmode(_fileno(stdout), _O_BINARY);
#endif
std::vector<std::string> files = getSourceFiles(argc, argv);
int failed = 0;
for (int i = 2; i < argc; ++i)
{
if (argv[i][0] == '-')
continue;
if (isDirectory(argv[i]))
{
traverseDirectory(argv[i], [&](const std::string& name) {
if (name.length() > 5 && name.rfind(".luau") == name.length() - 5)
failed += !compileFile(name.c_str(), format);
else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4)
failed += !compileFile(name.c_str(), format);
});
}
else
{
failed += !compileFile(argv[i], format);
}
}
for (const std::string& path : files)
failed += !compileFile(path.c_str(), format);
return failed;
}
@ -548,33 +535,25 @@ int main(int argc, char** argv)
int profile = 0;
for (int i = 1; i < argc; ++i)
{
if (argv[i][0] != '-')
continue;
if (strcmp(argv[i], "--profile") == 0)
profile = 10000; // default to 10 KHz
else if (strncmp(argv[i], "--profile=", 10) == 0)
profile = atoi(argv[i] + 10);
}
if (profile)
profilerStart(L, profile);
std::vector<std::string> files = getSourceFiles(argc, argv);
int failed = 0;
for (int i = 1; i < argc; ++i)
{
if (argv[i][0] == '-')
continue;
if (isDirectory(argv[i]))
{
traverseDirectory(argv[i], [&](const std::string& name) {
if (name.length() > 4 && name.rfind(".lua") == name.length() - 4)
failed += !runFile(name.c_str(), L);
});
}
else
{
failed += !runFile(argv[i], L);
}
}
for (const std::string& path : files)
failed += !runFile(path.c_str(), L);
if (profile)
{

View file

@ -13,11 +13,9 @@ class AstNameTable;
class BytecodeBuilder;
class BytecodeEncoder;
// Note: this structure is duplicated in luacode.h, don't forget to change these in sync!
struct CompileOptions
{
// default bytecode version target; can be used to compile code for older clients
int bytecodeVersion = 1;
// 0 - no optimization
// 1 - baseline optimization level that doesn't prevent debuggability
// 2 - includes optimizations that harm debuggability such as inlining

View file

@ -0,0 +1,39 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <stddef.h>
/* Can be used to reconfigure visibility/exports for public APIs */
#ifndef LUACODE_API
#define LUACODE_API extern
#endif
typedef struct lua_CompileOptions lua_CompileOptions;
struct lua_CompileOptions
{
// 0 - no optimization
// 1 - baseline optimization level that doesn't prevent debuggability
// 2 - includes optimizations that harm debuggability such as inlining
int optimizationLevel; // default=1
// 0 - no debugging support
// 1 - line info & function names only; sufficient for backtraces
// 2 - full debug info with local & upvalue names; necessary for debugger
int debugLevel; // default=1
// 0 - no code coverage support
// 1 - statement coverage
// 2 - statement and expression coverage (verbose)
int coverageLevel; // default=0
// global builtin to construct vectors; disabled by default
const char* vectorLib;
const char* vectorCtor;
// null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these
const char** mutableGlobals;
};
/* compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy */
LUACODE_API char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize);

View file

@ -11,9 +11,6 @@
#include <math.h>
LUAU_FASTFLAGVARIABLE(LuauPreloadClosures, false)
LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresFenv, false)
LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresUpval, false)
LUAU_FASTFLAGVARIABLE(LuauGenericSpecialGlobals, false)
LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport)
LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false)
@ -24,9 +21,6 @@ static const uint32_t kMaxRegisterCount = 255;
static const uint32_t kMaxUpvalueCount = 200;
static const uint32_t kMaxLocalCount = 200;
// TODO: Remove with LuauGenericSpecialGlobals
static const char* kSpecialGlobals[] = {"Game", "Workspace", "_G", "game", "plugin", "script", "shared", "workspace"};
CompileError::CompileError(const Location& location, const std::string& message)
: location(location)
, message(message)
@ -466,7 +460,7 @@ struct Compiler
bool shared = false;
if (FFlag::LuauPreloadClosuresUpval)
if (FFlag::LuauPreloadClosures)
{
// Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure
// objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it
@ -482,18 +476,6 @@ struct Compiler
}
}
}
// Optimization: when closure has no upvalues, instead of allocating it every time we can share closure objects
// (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it is used)
else if (FFlag::LuauPreloadClosures && options.optimizationLevel >= 1 && f->upvals.empty() && !setfenvUsed)
{
int32_t cid = bytecode.addConstantClosure(f->id);
if (cid >= 0 && cid < 32768)
{
bytecode.emitAD(LOP_DUPCLOSURE, target, cid);
return;
}
}
if (!shared)
bytecode.emitAD(LOP_NEWCLOSURE, target, pid);
@ -3298,7 +3280,6 @@ struct Compiler
bool visit(AstStatLocalFunction* node) override
{
// record local->function association for some optimizations
if (FFlag::LuauPreloadClosuresUpval)
self->locals[node->name].func = node->func;
return true;
@ -3711,8 +3692,6 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName
Compiler compiler(bytecode, options);
// since access to some global objects may result in values that change over time, we block imports from non-readonly tables
if (FFlag::LuauGenericSpecialGlobals)
{
if (AstName name = names.get("_G"); name.value)
compiler.globals[name].writable = true;
@ -3720,15 +3699,6 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName
for (const char** ptr = options.mutableGlobals; *ptr; ++ptr)
if (AstName name = names.get(*ptr); name.value)
compiler.globals[name].writable = true;
}
else
{
for (const char* global : kSpecialGlobals)
{
if (AstName name = names.get(global); name.value)
compiler.globals[name].writable = true;
}
}
// this visitor traverses the AST to analyze mutability of locals/globals, filling Local::written and Global::written
Compiler::AssignmentVisitor assignmentVisitor(&compiler);
@ -3742,7 +3712,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName
}
// this visitor tracks calls to getfenv/setfenv and disables some optimizations when they are found
if (FFlag::LuauPreloadClosuresFenv && options.optimizationLevel >= 1 && (names.get("getfenv").value || names.get("setfenv").value))
if (options.optimizationLevel >= 1 && (names.get("getfenv").value || names.get("setfenv").value))
{
Compiler::FenvVisitor fenvVisitor(compiler.getfenvUsed, compiler.setfenvUsed);
root->visit(&fenvVisitor);

29
Compiler/src/lcode.cpp Normal file
View file

@ -0,0 +1,29 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "luacode.h"
#include "Luau/Compiler.h"
#include <string.h>
char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize)
{
LUAU_ASSERT(outsize);
Luau::CompileOptions opts;
if (options)
{
static_assert(sizeof(lua_CompileOptions) == sizeof(Luau::CompileOptions), "C and C++ interface must match");
memcpy(static_cast<void*>(&opts), options, sizeof(opts));
}
std::string result = compile(std::string(source, size), opts);
char* copy = static_cast<char*>(malloc(result.size()));
if (!copy)
return nullptr;
memcpy(copy, result.data(), result.size());
*outsize = result.size();
return copy;
}

View file

@ -25,9 +25,11 @@ target_sources(Luau.Compiler PRIVATE
Compiler/include/Luau/Bytecode.h
Compiler/include/Luau/BytecodeBuilder.h
Compiler/include/Luau/Compiler.h
Compiler/include/luacode.h
Compiler/src/BytecodeBuilder.cpp
Compiler/src/Compiler.cpp
Compiler/src/lcode.cpp
)
# Luau.Analysis Sources
@ -204,6 +206,7 @@ if(TARGET Luau.UnitTest)
tests/TypeInfer.intersectionTypes.test.cpp
tests/TypeInfer.provisional.test.cpp
tests/TypeInfer.refinements.test.cpp
tests/TypeInfer.singletons.test.cpp
tests/TypeInfer.tables.test.cpp
tests/TypeInfer.test.cpp
tests/TypeInfer.tryUnify.test.cpp

View file

@ -102,6 +102,8 @@ LUA_API lua_State* lua_newstate(lua_Alloc f, void* ud);
LUA_API void lua_close(lua_State* L);
LUA_API lua_State* lua_newthread(lua_State* L);
LUA_API lua_State* lua_mainthread(lua_State* L);
LUA_API void lua_resetthread(lua_State* L);
LUA_API int lua_isthreadreset(lua_State* L);
/*
** basic stack manipulation
@ -162,8 +164,7 @@ LUA_API void lua_pushlstring(lua_State* L, const char* s, size_t l);
LUA_API void lua_pushstring(lua_State* L, const char* s);
LUA_API const char* lua_pushvfstring(lua_State* L, const char* fmt, va_list argp);
LUA_API LUA_PRINTF_ATTR(2, 3) const char* lua_pushfstringL(lua_State* L, const char* fmt, ...);
LUA_API void lua_pushcfunction(
lua_State* L, lua_CFunction fn, const char* debugname = NULL, int nup = 0, lua_Continuation cont = NULL);
LUA_API void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont);
LUA_API void lua_pushboolean(lua_State* L, int b);
LUA_API void lua_pushlightuserdata(lua_State* L, void* p);
LUA_API int lua_pushthread(lua_State* L);
@ -178,9 +179,9 @@ LUA_API void lua_rawget(lua_State* L, int idx);
LUA_API void lua_rawgeti(lua_State* L, int idx, int n);
LUA_API void lua_createtable(lua_State* L, int narr, int nrec);
LUA_API void lua_setreadonly(lua_State* L, int idx, bool value);
LUA_API void lua_setreadonly(lua_State* L, int idx, int enabled);
LUA_API int lua_getreadonly(lua_State* L, int idx);
LUA_API void lua_setsafeenv(lua_State* L, int idx, bool value);
LUA_API void lua_setsafeenv(lua_State* L, int idx, int enabled);
LUA_API void* lua_newuserdata(lua_State* L, size_t sz, int tag);
LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*));
@ -200,7 +201,7 @@ LUA_API int lua_setfenv(lua_State* L, int idx);
/*
** `load' and `call' functions (load and run Luau bytecode)
*/
LUA_API int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env = 0);
LUA_API int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env);
LUA_API void lua_call(lua_State* L, int nargs, int nresults);
LUA_API int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc);
@ -293,6 +294,8 @@ LUA_API void lua_unref(lua_State* L, int ref);
#define lua_isnoneornil(L, n) (lua_type(L, (n)) <= LUA_TNIL)
#define lua_pushliteral(L, s) lua_pushlstring(L, "" s, (sizeof(s) / sizeof(char)) - 1)
#define lua_pushcfunction(L, fn, debugname) lua_pushcclosurek(L, fn, debugname, 0, NULL)
#define lua_pushcclosure(L, fn, debugname, nup) lua_pushcclosurek(L, fn, debugname, nup, NULL)
#define lua_setglobal(L, s) lua_setfield(L, LUA_GLOBALSINDEX, (s))
#define lua_getglobal(L, s) lua_getfield(L, LUA_GLOBALSINDEX, (s))
@ -319,8 +322,8 @@ LUA_API const char* lua_setlocal(lua_State* L, int level, int n);
LUA_API const char* lua_getupvalue(lua_State* L, int funcindex, int n);
LUA_API const char* lua_setupvalue(lua_State* L, int funcindex, int n);
LUA_API void lua_singlestep(lua_State* L, bool singlestep);
LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, bool enable);
LUA_API void lua_singlestep(lua_State* L, int enabled);
LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled);
/* Warning: this function is not thread-safe since it stores the result in a shared global array! Only use for debugging. */
LUA_API const char* lua_debugtrace(lua_State* L);
@ -361,6 +364,7 @@ struct lua_Callbacks
void (*debuginterrupt)(lua_State* L, lua_Debug* ar); /* gets called when thread execution is interrupted by break in another thread */
void (*debugprotectederror)(lua_State* L); /* gets called when protected call results in an error */
};
typedef struct lua_Callbacks lua_Callbacks;
LUA_API lua_Callbacks* lua_callbacks(lua_State* L);

View file

@ -8,11 +8,12 @@
#define luaL_typeerror(L, narg, tname) luaL_typeerrorL(L, narg, tname)
#define luaL_argerror(L, narg, extramsg) luaL_argerrorL(L, narg, extramsg)
typedef struct luaL_Reg
struct luaL_Reg
{
const char* name;
lua_CFunction func;
} luaL_Reg;
};
typedef struct luaL_Reg luaL_Reg;
LUALIB_API void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l);
LUALIB_API int luaL_getmetafield(lua_State* L, int obj, const char* e);
@ -75,6 +76,7 @@ struct luaL_Buffer
struct TString* storage;
char buffer[LUA_BUFFERSIZE];
};
typedef struct luaL_Buffer luaL_Buffer;
// when internal buffer storage is exhausted, a mutable string value 'storage' will be placed on the stack
// in general, functions expect the mutable string buffer to be placed on top of the stack (top-1)

View file

@ -593,7 +593,7 @@ const char* lua_pushfstringL(lua_State* L, const char* fmt, ...)
return ret;
}
void lua_pushcfunction(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont)
void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont)
{
luaC_checkGC(L);
luaC_checkthreadsleep(L);
@ -698,13 +698,13 @@ void lua_createtable(lua_State* L, int narray, int nrec)
return;
}
void lua_setreadonly(lua_State* L, int objindex, bool value)
void lua_setreadonly(lua_State* L, int objindex, int enabled)
{
const TValue* o = index2adr(L, objindex);
api_check(L, ttistable(o));
Table* t = hvalue(o);
api_check(L, t != hvalue(registry(L)));
t->readonly = value;
t->readonly = bool(enabled);
return;
}
@ -717,12 +717,12 @@ int lua_getreadonly(lua_State* L, int objindex)
return res;
}
void lua_setsafeenv(lua_State* L, int objindex, bool value)
void lua_setsafeenv(lua_State* L, int objindex, int enabled)
{
const TValue* o = index2adr(L, objindex);
api_check(L, ttistable(o));
Table* t = hvalue(o);
t->safeenv = value;
t->safeenv = bool(enabled);
return;
}

View file

@ -436,8 +436,8 @@ static const luaL_Reg base_funcs[] = {
static void auxopen(lua_State* L, const char* name, lua_CFunction f, lua_CFunction u)
{
lua_pushcfunction(L, u);
lua_pushcfunction(L, f, name, 1);
lua_pushcfunction(L, u, NULL);
lua_pushcclosure(L, f, name, 1);
lua_setfield(L, -2, name);
}
@ -456,10 +456,10 @@ LUALIB_API int luaopen_base(lua_State* L)
auxopen(L, "ipairs", luaB_ipairs, luaB_inext);
auxopen(L, "pairs", luaB_pairs, luaB_next);
lua_pushcfunction(L, luaB_pcally, "pcall", 0, luaB_pcallcont);
lua_pushcclosurek(L, luaB_pcally, "pcall", 0, luaB_pcallcont);
lua_setfield(L, -2, "pcall");
lua_pushcfunction(L, luaB_xpcally, "xpcall", 0, luaB_xpcallcont);
lua_pushcclosurek(L, luaB_xpcally, "xpcall", 0, luaB_xpcallcont);
lua_setfield(L, -2, "xpcall");
return 1;

View file

@ -5,6 +5,8 @@
#include "lstate.h"
#include "lvm.h"
LUAU_FASTFLAGVARIABLE(LuauCoroutineClose, false)
#define CO_RUN 0 /* running */
#define CO_SUS 1 /* suspended */
#define CO_NOR 2 /* 'normal' (it resumed another coroutine) */
@ -208,8 +210,7 @@ static int cowrap(lua_State* L)
{
cocreate(L);
lua_pushcfunction(L, auxwrapy, NULL, 1, auxwrapcont);
lua_pushcclosurek(L, auxwrapy, NULL, 1, auxwrapcont);
return 1;
}
@ -232,6 +233,34 @@ static int coyieldable(lua_State* L)
return 1;
}
static int coclose(lua_State* L)
{
if (!FFlag::LuauCoroutineClose)
luaL_error(L, "coroutine.close is not enabled");
lua_State* co = lua_tothread(L, 1);
luaL_argexpected(L, co, 1, "thread");
int status = auxstatus(L, co);
if (status != CO_DEAD && status != CO_SUS)
luaL_error(L, "cannot close %s coroutine", statnames[status]);
if (co->status == LUA_OK || co->status == LUA_YIELD)
{
lua_pushboolean(L, true);
lua_resetthread(co);
return 1;
}
else
{
lua_pushboolean(L, false);
if (lua_gettop(co))
lua_xmove(co, L, 1); /* move error message */
lua_resetthread(co);
return 2;
}
}
static const luaL_Reg co_funcs[] = {
{"create", cocreate},
{"running", corunning},
@ -239,6 +268,7 @@ static const luaL_Reg co_funcs[] = {
{"wrap", cowrap},
{"yield", coyield},
{"isyieldable", coyieldable},
{"close", coclose},
{NULL, NULL},
};
@ -246,7 +276,7 @@ LUALIB_API int luaopen_coroutine(lua_State* L)
{
luaL_register(L, LUA_COLIBNAME, co_funcs);
lua_pushcfunction(L, coresumey, "resume", 0, coresumecont);
lua_pushcclosurek(L, coresumey, "resume", 0, coresumecont);
lua_setfield(L, -2, "resume");
return 1;

View file

@ -316,7 +316,7 @@ void luaG_breakpoint(lua_State* L, Proto* p, int line, bool enable)
p->debuginsn[j] = LUAU_INSN_OP(p->code[j]);
}
uint8_t op = enable ? LOP_BREAK : LUAU_INSN_OP(p->code[i]);
uint8_t op = enable ? LOP_BREAK : LUAU_INSN_OP(p->debuginsn[i]);
// patch just the opcode byte, leave arguments alone
p->code[i] &= ~0xff;
@ -357,17 +357,17 @@ int luaG_getline(Proto* p, int pc)
return p->abslineinfo[pc >> p->linegaplog2] + p->lineinfo[pc];
}
void lua_singlestep(lua_State* L, bool singlestep)
void lua_singlestep(lua_State* L, int enabled)
{
L->singlestep = singlestep;
L->singlestep = bool(enabled);
}
void lua_breakpoint(lua_State* L, int funcindex, int line, bool enable)
void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled)
{
const TValue* func = luaA_toobject(L, funcindex);
api_check(L, ttisfunction(func) && !clvalue(func)->isC);
luaG_breakpoint(L, clvalue(func)->l.p, line, enable);
luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled));
}
static size_t append(char* buf, size_t bufsize, size_t offset, const char* data)

View file

@ -19,6 +19,7 @@
LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false)
LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false)
LUAU_FASTFLAG(LuauCoroutineClose)
/*
** {======================================================
@ -300,7 +301,10 @@ static void resume(lua_State* L, void* ud)
if (L->status == 0)
{
// start coroutine
LUAU_ASSERT(L->ci == L->base_ci && firstArg > L->base);
LUAU_ASSERT(L->ci == L->base_ci && firstArg >= L->base);
if (FFlag::LuauCoroutineClose && firstArg == L->base)
luaG_runerror(L, "cannot resume dead coroutine");
if (luau_precall(L, firstArg - 1, LUA_MULTRET) != PCRLUA)
return;

View file

@ -22,7 +22,7 @@ LUALIB_API void luaL_openlibs(lua_State* L)
const luaL_Reg* lib = lualibs;
for (; lib->func; lib++)
{
lua_pushcfunction(L, lib->func);
lua_pushcfunction(L, lib->func, NULL);
lua_pushstring(L, lib->name);
lua_call(L, 1, 0);
}

View file

@ -124,6 +124,34 @@ void luaE_freethread(lua_State* L, lua_State* L1)
luaM_free(L, L1, sizeof(lua_State), L1->memcat);
}
void lua_resetthread(lua_State* L)
{
/* close upvalues before clearing anything */
luaF_close(L, L->stack);
/* clear call frames */
CallInfo* ci = L->base_ci;
ci->func = L->stack;
ci->base = ci->func + 1;
ci->top = ci->base + LUA_MINSTACK;
setnilvalue(ci->func);
L->ci = ci;
luaD_reallocCI(L, BASIC_CI_SIZE);
/* clear thread state */
L->status = LUA_OK;
L->base = L->ci->base;
L->top = L->ci->base;
L->nCcalls = L->baseCcalls = 0;
/* clear thread stack */
luaD_reallocstack(L, BASIC_STACK_SIZE);
for (int i = 0; i < L->stacksize; i++)
setnilvalue(L->stack + i);
}
int lua_isthreadreset(lua_State* L)
{
return L->ci == L->base_ci && L->base == L->top && L->status == LUA_OK;
}
lua_State* lua_newstate(lua_Alloc f, void* ud)
{
int i;

View file

@ -748,7 +748,7 @@ static int gmatch(lua_State* L)
luaL_checkstring(L, 2);
lua_settop(L, 2);
lua_pushinteger(L, 0);
lua_pushcfunction(L, gmatch_aux, NULL, 3);
lua_pushcclosure(L, gmatch_aux, NULL, 3);
return 1;
}

View file

@ -265,7 +265,7 @@ static int iter_aux(lua_State* L)
static int iter_codes(lua_State* L)
{
luaL_checkstring(L, 1);
lua_pushcfunction(L, iter_aux);
lua_pushcfunction(L, iter_aux, NULL);
lua_pushvalue(L, 1);
lua_pushinteger(L, 0);
return 3;

View file

@ -25,8 +25,8 @@ try:
import scipy
from scipy import stats
except ModuleNotFoundError:
print("scipy package is required")
exit(1)
print("Warning: scipy package is not installed, confidence values will not be available")
stats = None
scriptdir = os.path.dirname(os.path.realpath(__file__))
defaultVm = 'luau.exe' if os.name == "nt" else './luau'
@ -200,11 +200,14 @@ def finalizeResult(result):
result.sampleStdDev = math.sqrt(sumOfSquares / (result.count - 1))
result.unbiasedEst = result.sampleStdDev * result.sampleStdDev
if stats:
# Two-tailed distribution with 95% conf.
tValue = stats.t.ppf(1 - 0.05 / 2, result.count - 1)
# Compute confidence interval
result.sampleConfidenceInterval = tValue * result.sampleStdDev / math.sqrt(result.count)
else:
result.sampleConfidenceInterval = result.sampleStdDev
else:
result.sampleStdDev = 0
result.unbiasedEst = 0
@ -377,14 +380,19 @@ def analyzeResult(subdir, main, comparisons):
tStat = abs(main.avg - compare.avg) / (pooledStdDev * math.sqrt(2 / main.count))
degreesOfFreedom = 2 * main.count - 2
if stats:
# Two-tailed distribution with 95% conf.
tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom)
noSignificantDifference = tStat < tCritical
pValue = 2 * (1 - stats.t.cdf(tStat, df = degreesOfFreedom))
else:
noSignificantDifference = None
pValue = -1
if noSignificantDifference:
if noSignificantDifference is None:
verdict = ""
elif noSignificantDifference:
verdict = "likely same"
elif main.avg < compare.avg:
verdict = "likely worse"

View file

@ -257,7 +257,7 @@ DEFINE_PROTO_FUZZER(const luau::StatBlock& message)
lua_State* L = lua_newthread(globalState);
luaL_sandboxthread(L);
if (luau_load(L, "=fuzz", bytecode.data(), bytecode.size()) == 0)
if (luau_load(L, "=fuzz", bytecode.data(), bytecode.size(), 0) == 0)
{
interruptDeadline = std::chrono::system_clock::now() + kInterruptTimeout;

View file

@ -91,10 +91,6 @@ struct ACFixture : ACFixtureImpl<Fixture>
{
};
struct UnfrozenACFixture : ACFixtureImpl<UnfrozenFixture>
{
};
TEST_SUITE_BEGIN("AutocompleteTest");
TEST_CASE_FIXTURE(ACFixture, "empty_program")
@ -1919,9 +1915,10 @@ local bar: @1= foo
CHECK(!ac.entryMap.count("foo"));
}
// CLI-45692: Remove UnfrozenACFixture here
TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_function_no_parenthesis")
TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis")
{
ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true);
check(R"(
local function target(a: (number) -> number) return a(4) end
local function bar1(a: number) return -a end
@ -1950,9 +1947,10 @@ local fp: @1= f
CHECK(ac.entryMap.count("({ x: number, y: number }) -> number"));
}
// CLI-45692: Remove UnfrozenACFixture here
TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_keywords")
TEST_CASE_FIXTURE(ACFixture, "type_correct_keywords")
{
ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true);
check(R"(
local function a(x: boolean) end
local function b(x: number?) end
@ -2484,7 +2482,7 @@ local t = {
CHECK(ac.entryMap.count("second"));
}
TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols")
TEST_CASE_FIXTURE(UnfrozenFixture, "autocomplete_documentation_symbols")
{
loadDefinition(R"(
declare y: {

View file

@ -11,8 +11,6 @@
#include <string_view>
LUAU_FASTFLAG(LuauPreloadClosures)
LUAU_FASTFLAG(LuauPreloadClosuresFenv)
LUAU_FASTFLAG(LuauPreloadClosuresUpval)
LUAU_FASTFLAG(LuauGenericSpecialGlobals)
using namespace Luau;
@ -2797,7 +2795,7 @@ CAPTURE UPVAL U1
RETURN R0 1
)");
if (FFlag::LuauPreloadClosuresUpval)
if (FFlag::LuauPreloadClosures)
{
// recursive capture
CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"(
@ -3479,8 +3477,6 @@ CAPTURE VAL R0
RETURN R1 1
)");
if (FFlag::LuauPreloadClosuresFenv)
{
// if they don't need upvalues but we sense that environment may be modified, we disable this to avoid fenv-related identity confusion
CHECK_EQ("\n" + compileFunction(R"(
setfenv(1, {})
@ -3507,12 +3503,10 @@ NEWCLOSURE R0 P0
RETURN R0 1
)");
}
}
TEST_CASE("SharedClosure")
{
ScopedFastFlag sff1("LuauPreloadClosures", true);
ScopedFastFlag sff2("LuauPreloadClosuresUpval", true);
// closures can be shared even if functions refer to upvalues, as long as upvalues are top-level
CHECK_EQ("\n" + compileFunction(R"(
@ -3671,7 +3665,7 @@ RETURN R0 0
)");
}
TEST_CASE("LuauGenericSpecialGlobals")
TEST_CASE("MutableGlobals")
{
const char* source = R"(
print()
@ -3685,43 +3679,6 @@ shared.print()
workspace.print()
)";
{
ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", false};
// Check Roblox globals are here
CHECK_EQ("\n" + compileFunction0(source), R"(
GETIMPORT R0 1
CALL R0 0 0
GETIMPORT R1 3
GETTABLEKS R0 R1 K0
CALL R0 0 0
GETIMPORT R1 5
GETTABLEKS R0 R1 K0
CALL R0 0 0
GETIMPORT R1 7
GETTABLEKS R0 R1 K0
CALL R0 0 0
GETIMPORT R1 9
GETTABLEKS R0 R1 K0
CALL R0 0 0
GETIMPORT R1 11
GETTABLEKS R0 R1 K0
CALL R0 0 0
GETIMPORT R1 13
GETTABLEKS R0 R1 K0
CALL R0 0 0
GETIMPORT R1 15
GETTABLEKS R0 R1 K0
CALL R0 0 0
GETIMPORT R1 17
GETTABLEKS R0 R1 K0
CALL R0 0 0
RETURN R0 0
)");
}
ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", true};
// Check Roblox globals are no longer here
CHECK_EQ("\n" + compileFunction0(source), R"(
GETIMPORT R0 1

View file

@ -1,5 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Compiler.h"
#include "lua.h"
#include "lualib.h"
#include "luacode.h"
#include "Luau/BuiltinDefinitions.h"
#include "Luau/ModuleResolver.h"
@ -10,9 +12,6 @@
#include "doctest.h"
#include "ScopedFlags.h"
#include "lua.h"
#include "lualib.h"
#include <fstream>
#include <math.h>
@ -49,8 +48,12 @@ static int lua_loadstring(lua_State* L)
lua_setsafeenv(L, LUA_ENVIRONINDEX, false);
std::string bytecode = Luau::compile(std::string(s, l));
if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0)
size_t bytecodeSize = 0;
char* bytecode = luau_compile(s, l, nullptr, &bytecodeSize);
int result = luau_load(L, chunkname, bytecode, bytecodeSize, 0);
free(bytecode);
if (result == 0)
return 1;
lua_pushnil(L);
@ -179,21 +182,17 @@ static StateRef runConformance(
std::string chunkname = "=" + std::string(name);
Luau::CompileOptions copts;
lua_CompileOptions copts = {};
copts.optimizationLevel = 1; // default
copts.debugLevel = 2; // for debugger tests
copts.vectorCtor = "vector"; // for vector tests
std::string bytecode = Luau::compile(source, copts);
int status = 0;
size_t bytecodeSize = 0;
char* bytecode = luau_compile(source.data(), source.size(), &copts, &bytecodeSize);
int result = luau_load(L, chunkname.c_str(), bytecode, bytecodeSize, 0);
free(bytecode);
if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0)
{
status = lua_resume(L, nullptr, 0);
}
else
{
status = LUA_ERRSYNTAX;
}
int status = (result == 0) ? lua_resume(L, nullptr, 0) : LUA_ERRSYNTAX;
while (yield && (status == LUA_YIELD || status == LUA_BREAK))
{
@ -332,27 +331,35 @@ TEST_CASE("UTF8")
TEST_CASE("Coroutine")
{
ScopedFastFlag sff("LuauCoroutineClose", true);
runConformance("coroutine.lua");
}
TEST_CASE("PCall")
static int cxxthrow(lua_State* L)
{
runConformance("pcall.lua", [](lua_State* L) {
lua_pushcfunction(L, [](lua_State* L) -> int {
#if LUA_USE_LONGJMP
luaL_error(L, "oops");
#else
throw std::runtime_error("oops");
#endif
});
}
TEST_CASE("PCall")
{
runConformance("pcall.lua", [](lua_State* L) {
lua_pushcfunction(L, cxxthrow, "cxxthrow");
lua_setglobal(L, "cxxthrow");
lua_pushcfunction(L, [](lua_State* L) -> int {
lua_pushcfunction(
L,
[](lua_State* L) -> int {
lua_State* co = lua_tothread(L, 1);
lua_xmove(L, co, 1);
lua_resumeerror(co, L);
return 0;
});
},
"resumeerror");
lua_setglobal(L, "resumeerror");
});
}
@ -367,18 +374,18 @@ TEST_CASE("Pack")
TEST_CASE("Vector")
{
runConformance("vector.lua", [](lua_State* L) {
lua_pushcfunction(L, lua_vector);
lua_pushcfunction(L, lua_vector, "vector");
lua_setglobal(L, "vector");
lua_pushvector(L, 0.0f, 0.0f, 0.0f);
luaL_newmetatable(L, "vector");
lua_pushstring(L, "__index");
lua_pushcfunction(L, lua_vector_index);
lua_pushcfunction(L, lua_vector_index, nullptr);
lua_settable(L, -3);
lua_pushstring(L, "__namecall");
lua_pushcfunction(L, lua_vector_namecall);
lua_pushcfunction(L, lua_vector_namecall, nullptr);
lua_settable(L, -3);
lua_setreadonly(L, -1, true);
@ -513,15 +520,19 @@ TEST_CASE("Debugger")
};
// add breakpoint() function
lua_pushcfunction(L, [](lua_State* L) -> int {
lua_pushcfunction(
L,
[](lua_State* L) -> int {
int line = luaL_checkinteger(L, 1);
bool enabled = lua_isboolean(L, 2) ? lua_toboolean(L, 2) : true;
lua_Debug ar = {};
lua_getinfo(L, 1, "f", &ar);
lua_breakpoint(L, -1, line, true);
lua_breakpoint(L, -1, line, enabled);
return 0;
});
},
"breakpoint");
lua_setglobal(L, "breakpoint");
},
[](lua_State* L) {
@ -744,7 +755,7 @@ TEST_CASE("ExceptionObject")
if (nsize == 0)
{
free(ptr);
return NULL;
return nullptr;
}
else if (nsize > 512 * 1024)
{

View file

@ -4,7 +4,8 @@
#include <ostream>
#include <optional>
namespace std {
namespace std
{
inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&)
{

View file

@ -203,6 +203,8 @@ TEST_CASE_FIXTURE(Fixture, "clone_class")
TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types")
{
ScopedFastFlag sff{"LuauErrorRecoveryType", true};
TypeVar freeTy(FreeTypeVar{TypeLevel{}});
TypePackVar freeTp(FreeTypePack{TypeLevel{}});
@ -212,12 +214,12 @@ TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types")
bool encounteredFreeType = false;
TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, &encounteredFreeType);
CHECK(Luau::get<ErrorTypeVar>(clonedTy));
CHECK_EQ("any", toString(clonedTy));
CHECK(encounteredFreeType);
encounteredFreeType = false;
TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, &encounteredFreeType);
CHECK(Luau::get<Unifiable::Error>(clonedTp));
CHECK_EQ("...any", toString(clonedTp));
CHECK(encounteredFreeType);
}

View file

@ -198,7 +198,8 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table
TypeVar tv{ttv};
ToStringOptions o{/* exhaustive= */ false, /* useLineBreaks= */ false, /* functionTypeArguments= */ false, /* hideTableKind= */ false, 40};
ToStringOptions o;
o.maxTableLength = 40;
CHECK_EQ(toString(&tv, o), "{| a: number, b: number, c: number, d: number, e: number, ... 5 more ... |}");
}
@ -395,7 +396,7 @@ local function target(callback: nil) return callback(4, "hello") end
)");
LUAU_REQUIRE_ERRORS(result);
CHECK_EQ(toString(requireType("target")), "(nil) -> (*unknown*)");
CHECK_EQ("(nil) -> (*unknown*)", toString(requireType("target")));
}
TEST_CASE_FIXTURE(Fixture, "toStringGenericPack")
@ -469,4 +470,110 @@ TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param")
CHECK_EQ(toString(tableTy), "Table<Table>");
}
TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id")
{
CheckResult result = check(R"(
local function id(x) return x end
)");
TypeId ty = requireType("id");
const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(ty));
CHECK_EQ("id<a>(x: a): a", toStringNamedFunction("id", *ftv));
}
TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map")
{
CheckResult result = check(R"(
local function map(arr, fn)
local t = {}
for i = 0, #arr do
t[i] = fn(arr[i])
end
return t
end
)");
TypeId ty = requireType("map");
const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(ty));
CHECK_EQ("map<a, b>(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv));
}
TEST_CASE("toStringNamedFunction_unit_f")
{
TypePackVar empty{TypePack{}};
FunctionTypeVar ftv{&empty, &empty, {}, false};
CHECK_EQ("f(): ()", toStringNamedFunction("f", ftv));
}
TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics")
{
CheckResult result = check(R"(
local function f<a, b...>(x: a, ...): (a, a, b...)
return x, x, ...
end
)");
TypeId ty = requireType("f");
auto ftv = get<FunctionTypeVar>(follow(ty));
CHECK_EQ("f<a, b...>(x: a, ...: any): (a, a, b...)", toStringNamedFunction("f", *ftv));
}
TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2")
{
CheckResult result = check(R"(
local function f(): ...number
return 1, 2, 3
end
)");
TypeId ty = requireType("f");
auto ftv = get<FunctionTypeVar>(follow(ty));
CHECK_EQ("f(): ...number", toStringNamedFunction("f", *ftv));
}
TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3")
{
CheckResult result = check(R"(
local function f(): (string, ...number)
return 'a', 1, 2, 3
end
)");
TypeId ty = requireType("f");
auto ftv = get<FunctionTypeVar>(follow(ty));
CHECK_EQ("f(): (string, ...number)", toStringNamedFunction("f", *ftv));
}
TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_argnames")
{
CheckResult result = check(R"(
local f: (number, y: number) -> number
)");
TypeId ty = requireType("f");
auto ftv = get<FunctionTypeVar>(follow(ty));
CHECK_EQ("f(_: number, y: number): number", toStringNamedFunction("f", *ftv));
}
TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params")
{
CheckResult result = check(R"(
local function f<T>(x: T, g: <U>(T) -> U)): ()
end
)");
TypeId ty = requireType("f");
auto ftv = get<FunctionTypeVar>(follow(ty));
ToStringOptions opts;
opts.hideNamedFunctionTypeParameters = true;
CHECK_EQ("f(x: T, g: <U>(T) -> U): ()", toStringNamedFunction("f", *ftv, opts));
}
TEST_SUITE_END();

View file

@ -109,7 +109,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_typ
CheckResult result = check(R"(
type A = number
type A = string -- Redefinition of type 'A', previously defined at line 1
local foo: string = 1 -- No "Type 'number' could not be converted into 'string'"
local foo: string = 1 -- "Type 'number' could not be converted into 'string'"
)");
LUAU_REQUIRE_ERROR_COUNT(2, result);

View file

@ -381,6 +381,8 @@ TEST_CASE_FIXTURE(Fixture, "typeof_expr")
TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop")
{
ScopedFastFlag sff{"LuauErrorRecoveryType", true};
CheckResult result = check(R"(
type A = B
type B = A
@ -390,7 +392,7 @@ TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop")
)");
TypeId fType = requireType("aa");
const ErrorTypeVar* ftv = get<ErrorTypeVar>(follow(fType));
const AnyTypeVar* ftv = get<AnyTypeVar>(follow(fType));
REQUIRE(ftv != nullptr);
REQUIRE(!result.errors.empty());
}

View file

@ -289,7 +289,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods")
end
)");
// TODO: Should typecheck but currently errors CLI-39916
LUAU_REQUIRE_ERROR_COUNT(1, result);
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "infer_generic_property")
@ -352,7 +352,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_generic_types")
-- so this assignment should fail
local b: boolean = f(true)
)");
LUAU_REQUIRE_ERROR_COUNT(2, result);
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types")
@ -368,7 +368,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types")
local y: number = id(37)
end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "dont_substitute_bound_types")

View file

@ -704,6 +704,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector")
CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0]));
else
CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0]));
CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance"
if (FFlag::LuauQuantifyInPlace2)

View file

@ -0,0 +1,377 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Fixture.h"
#include "doctest.h"
#include "Luau/BuiltinDefinitions.h"
using namespace Luau;
TEST_SUITE_BEGIN("TypeSingletons");
TEST_CASE_FIXTURE(Fixture, "bool_singletons")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
local a: true = true
local b: false = false
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "string_singletons")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
local a: "foo" = "foo"
local b: "bar" = "bar"
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "bool_singletons_mismatch")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
local a: true = false
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type 'false' could not be converted into 'true'", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "string_singletons_mismatch")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
local a: "foo" = "bar"
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type '\"bar\"' could not be converted into '\"foo\"'", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "string_singletons_escape_chars")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
local a: "\n" = "\000\r"
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(R"(Type '"\000\r"' could not be converted into '"\n"')", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "bool_singleton_subtype")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
local a: true = true
local b: boolean = a
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
local a: "foo" = "foo"
local b: string = a
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
function f(a: true, b: "foo") end
f(true, "foo")
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons_mismatch")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
function f(a: true, b: "foo") end
f(true, "bar")
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type '\"bar\"' could not be converted into '\"foo\"'", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
function f(a, b) end
local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any)
g(true, "foo")
g(false, 37)
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons_mismatch")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
function f(a, b) end
local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any)
g(true, 37)
)");
LUAU_REQUIRE_ERROR_COUNT(2, result);
CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0]));
CHECK_EQ("Other overloads are also not viable: (false, number) -> ()", toString(result.errors[1]));
}
TEST_CASE_FIXTURE(Fixture, "enums_using_singletons")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
type MyEnum = "foo" | "bar" | "baz"
local a : MyEnum = "foo"
local b : MyEnum = "bar"
local c : MyEnum = "baz"
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
{"LuauExtendedTypeMismatchError", true},
};
CheckResult result = check(R"(
type MyEnum = "foo" | "bar" | "baz"
local a : MyEnum = "bang"
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type '\"bang\"' could not be converted into '\"bar\" | \"baz\" | \"foo\"'; none of the union options are compatible",
toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_subtyping")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
type MyEnum1 = "foo" | "bar"
type MyEnum2 = MyEnum1 | "baz"
local a : MyEnum1 = "foo"
local b : MyEnum2 = a
local c : string = b
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
{"LuauExpectedTypesOfProperties", true},
};
CheckResult result = check(R"(
type Dog = { tag: "Dog", howls: boolean }
type Cat = { tag: "Cat", meows: boolean }
type Animal = Dog | Cat
local a : Dog = { tag = "Dog", howls = true }
local b : Animal = { tag = "Cat", meows = true }
local c : Animal = a
c = b
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons_mismatch")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
type Dog = { tag: "Dog", howls: boolean }
type Cat = { tag: "Cat", meows: boolean }
type Animal = Dog | Cat
local a : Animal = { tag = "Cat", howls = true }
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "tagged_unions_immutable_tag")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
type Dog = { tag: "Dog", howls: boolean }
type Cat = { tag: "Cat", meows: boolean }
type Animal = Dog | Cat
local a : Animal = { tag = "Cat", meows = true }
a.tag = "Dog"
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings")
{
ScopedFastFlag sffs[] = {
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
--!strict
type T = {
["foo"] : number,
["$$bar"] : string,
baz : boolean
}
local t: T = {
["foo"] = 37,
["$$bar"] = "hi",
baz = true
}
local a: number = t.foo
local b: string = t["$$bar"]
local c: boolean = t.baz
t.foo = 5
t["$$bar"] = "lo"
t.baz = false
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings_mismatch")
{
ScopedFastFlag sffs[] = {
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
--!strict
type T = {
["$$bar"] : string,
}
local t: T = {
["$$bar"] = "hi",
}
t["$$bar"] = 5
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer")
{
ScopedFastFlag sffs[] = {
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
--!strict
type S = "bar"
type T = {
[("foo")] : number,
[S] : string,
}
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Syntax error: Cannot have more than one table indexer", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes")
{
ScopedFastFlag sffs[] = {
{"LuauParseSingletonTypes", true},
};
CheckResult result = check(R"(
--!strict
local x: { ["<>"] : number }
x = { ["\n"] = 5 }
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(R"(Table type '{| ["\n"]: number |}' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')",
toString(result.errors[0]));
}
TEST_SUITE_END();

View file

@ -362,7 +362,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error")
CHECK_EQ(2, result.errors.size());
TypeId p = requireType("p");
CHECK_EQ(*p, *typeChecker.errorType);
CHECK_EQ("*unknown*", toString(p));
}
TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function")
@ -480,7 +480,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2")
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(typeChecker.anyType, requireType("a"));
CHECK_EQ("any", toString(requireType("a")));
}
TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any")
@ -496,7 +496,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any")
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(typeChecker.anyType, requireType("a"));
CHECK_EQ("any", toString(requireType("a")));
}
TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2")
@ -512,7 +512,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2")
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(typeChecker.anyType, requireType("a"));
CHECK_EQ("any", toString(requireType("a")));
}
TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error")
@ -526,7 +526,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error")
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(typeChecker.errorType, requireType("a"));
CHECK_EQ("*unknown*", toString(requireType("a")));
}
TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2")
@ -542,7 +542,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2")
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(typeChecker.errorType, requireType("a"));
CHECK_EQ("*unknown*", toString(requireType("a")));
}
TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator")
@ -673,7 +673,7 @@ TEST_CASE_FIXTURE(Fixture, "string_index")
REQUIRE(nat);
CHECK_EQ("string", toString(nat->ty));
CHECK(get<ErrorTypeVar>(requireType("t")));
CHECK_EQ("*unknown*", toString(requireType("t")));
}
TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error")
@ -1456,7 +1456,7 @@ TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export")
auto hootyType = requireType(bModule, "Hooty");
CHECK_MESSAGE(get<ErrorTypeVar>(follow(hootyType)) != nullptr, "Should be an error: " << toString(hootyType));
CHECK_EQ("*unknown*", toString(hootyType));
}
TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property")
@ -2032,7 +2032,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_4")
CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[1]);
}
TEST_CASE_FIXTURE(Fixture, "error_types_propagate")
TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types")
{
CheckResult result = check(R"(
local err = (true).x
@ -2049,10 +2049,10 @@ TEST_CASE_FIXTURE(Fixture, "error_types_propagate")
CHECK_EQ("boolean", toString(err->table));
CHECK_EQ("x", err->key);
CHECK(nullptr != get<ErrorTypeVar>(requireType("c")));
CHECK(nullptr != get<ErrorTypeVar>(requireType("d")));
CHECK(nullptr != get<ErrorTypeVar>(requireType("e")));
CHECK(nullptr != get<ErrorTypeVar>(requireType("f")));
CHECK_EQ("*unknown*", toString(requireType("c")));
CHECK_EQ("*unknown*", toString(requireType("d")));
CHECK_EQ("*unknown*", toString(requireType("e")));
CHECK_EQ("*unknown*", toString(requireType("f")));
}
TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error")
@ -2068,7 +2068,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error")
CHECK_EQ("unknown", err->name);
CHECK(nullptr != get<ErrorTypeVar>(requireType("a")));
CHECK_EQ("*unknown*", toString(requireType("a")));
}
TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error")
@ -2077,9 +2077,7 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error")
local a = Utility.Create "Foo" {}
)");
TypeId aType = requireType("a");
REQUIRE_MESSAGE(nullptr != get<ErrorTypeVar>(aType), "Not an error: " << toString(aType));
CHECK_EQ("*unknown*", toString(requireType("a")));
}
TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable")
@ -2146,6 +2144,8 @@ TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops")
TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection")
{
ScopedFastFlag sff{"LuauErrorRecoveryType", true};
CheckResult result = check(R"(
--!strict
local Vec3 = {}
@ -2175,11 +2175,13 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio
CHECK_EQ("Vec3", toString(requireType("b")));
CHECK_EQ("Vec3", toString(requireType("c")));
CHECK_EQ("Vec3", toString(requireType("d")));
CHECK(get<ErrorTypeVar>(requireType("e")));
CHECK_EQ("Vec3", toString(requireType("e")));
}
TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs")
{
ScopedFastFlag sff{"LuauErrorRecoveryType", true};
CheckResult result = check(R"(
--!strict
local Vec3 = {}
@ -2209,7 +2211,7 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio
CHECK_EQ("Vec3", toString(requireType("b")));
CHECK_EQ("Vec3", toString(requireType("c")));
CHECK_EQ("Vec3", toString(requireType("d")));
CHECK(get<ErrorTypeVar>(requireType("e")));
CHECK_EQ("Vec3", toString(requireType("e")));
}
TEST_CASE_FIXTURE(Fixture, "compare_numbers")
@ -2901,6 +2903,8 @@ end
TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber")
{
ScopedFastFlag sff{"LuauErrorRecoveryType", true};
CheckResult result = check(R"(
local x: number = 9999
function x:y(z: number)
@ -2908,7 +2912,7 @@ function x:y(z: number)
end
)");
LUAU_REQUIRE_ERROR_COUNT(3, result);
LUAU_REQUIRE_ERROR_COUNT(2, result);
}
TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfError")
@ -2920,7 +2924,7 @@ function x:y(z: number)
end
)");
LUAU_REQUIRE_ERROR_COUNT(2, result);
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "CallOrOfFunctions")
@ -3799,7 +3803,7 @@ TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign")
print(a)
)");
LUAU_REQUIRE_ERROR_COUNT(2, result);
LUAU_REQUIRE_ERRORS(result);
CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'");
}
@ -4215,7 +4219,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying")
std::optional<TypeFun> t0 = getMainModule()->getModuleScope()->lookupType("t0");
REQUIRE(t0);
CHECK(get<ErrorTypeVar>(t0->type));
CHECK_EQ("*unknown*", toString(t0->type));
auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) {
return get<OccursCheckFailed>(err);
@ -4238,7 +4242,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional")
std::optional<TypeFun> t0 = getMainModule()->getModuleScope()->lookupType("t0");
REQUIRE(t0);
CHECK(get<ErrorTypeVar>(t0->type));
CHECK_EQ("*unknown*", toString(t0->type));
auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) {
return get<OccursCheckFailed>(err);
@ -4394,6 +4398,25 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload")
CHECK_EQ(toString(*it), "(number) -> number");
}
TEST_CASE_FIXTURE(Fixture, "return_type_by_overload")
{
ScopedFastFlag sff{"LuauErrorRecoveryType", true};
CheckResult result = check(R"(
type Overload = ((string) -> string) & ((number, number) -> number)
local abc: Overload
local x = abc(true)
local y = abc(true,true)
local z = abc(true,true,true)
)");
LUAU_REQUIRE_ERRORS(result);
CHECK_EQ("string", toString(requireType("x")));
CHECK_EQ("number", toString(requireType("y")));
// Should this be string|number?
CHECK_EQ("string", toString(requireType("z")));
}
TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments")
{
// Simple direct arg to arg propagation
@ -4740,4 +4763,20 @@ TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions3")
}
}
TEST_CASE_FIXTURE(Fixture, "type_error_addition")
{
CheckResult result = check(R"(
--!strict
local foo = makesandwich()
local bar = foo.nutrition + 100
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
// We should definitely get this error
CHECK_EQ("Unknown global 'makesandwich'", toString(result.errors[0]));
// We get this error if makesandwich() returns a free type
// CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'foo'", toString(result.errors[1]));
}
TEST_SUITE_END();

View file

@ -121,9 +121,26 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeId bType = requireType("b");
CHECK_EQ("a", toString(requireType("a")));
CHECK_EQ("*unknown*", toString(requireType("b")));
}
CHECK_MESSAGE(get<ErrorTypeVar>(bType), "Should be an error: " << toString(bType));
TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained")
{
ScopedFastFlag sff{"LuauErrorRecoveryType", true};
CheckResult result = check(R"(
function f(arg: number) return arg end
local a
local b
local c = f(a, b)
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("a", toString(requireType("a")));
CHECK_EQ("*unknown*", toString(requireType("b")));
CHECK_EQ("number", toString(requireType("c")));
}
TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails")
@ -167,15 +184,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_tails_respect_progress")
CHECK(state.errors.empty());
}
TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_variadic_pack_with_error_should_work")
{
TypePackId variadicPack = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.numberType}});
TypePackId errorPack = arena.addTypePack(TypePack{{typeChecker.numberType}, arena.addTypePack(TypePackVar{Unifiable::Error{}})});
state.tryUnify(variadicPack, errorPack);
REQUIRE_EQ(0, state.errors.size());
}
TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly")
{
CheckResult result = check(R"(

View file

@ -200,8 +200,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property")
CHECK_EQ(mup->missing[0], *bTy);
CHECK_EQ(mup->key, "x");
TypeId r = requireType("r");
CHECK_MESSAGE(get<ErrorTypeVar>(r), "Expected error, got " << toString(r));
CHECK_EQ("*unknown*", toString(requireType("r")));
}
TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any")
@ -283,7 +282,7 @@ local c = b:foo(1, 2)
CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "optional_union_follow")
TEST_CASE_FIXTURE(UnfrozenFixture, "optional_union_follow")
{
CheckResult result = check(R"(
local y: number? = 2

View file

@ -319,4 +319,58 @@ for i=0,30 do
assert(#T2 == 1 or T2[#T2] == 42)
end
-- test coroutine.close
do
-- ok to close a dead coroutine
local co = coroutine.create(type)
assert(coroutine.resume(co, "testing 'coroutine.close'"))
assert(coroutine.status(co) == "dead")
local st, msg = coroutine.close(co)
assert(st and msg == nil)
-- also ok to close it again
st, msg = coroutine.close(co)
assert(st and msg == nil)
-- cannot close the running coroutine
coroutine.wrap(function()
local st, msg = pcall(coroutine.close, coroutine.running())
assert(not st and string.find(msg, "running"))
end)()
-- cannot close a "normal" coroutine
coroutine.wrap(function()
local co = coroutine.running()
coroutine.wrap(function ()
local st, msg = pcall(coroutine.close, co)
assert(not st and string.find(msg, "normal"))
end)()
end)()
-- closing a coroutine after an error
local co = coroutine.create(error)
local obj = {42}
local st, msg = coroutine.resume(co, obj)
assert(not st and msg == obj)
st, msg = coroutine.close(co)
assert(not st and msg == obj)
-- after closing, no more errors
st, msg = coroutine.close(co)
assert(st and msg == nil)
-- closing a coroutine that has outstanding upvalues
local f
local co = coroutine.create(function()
local a = 42
f = function() return a end
coroutine.yield()
a = 20
end)
coroutine.resume(co)
assert(f() == 42)
st, msg = coroutine.close(co)
assert(st and msg == nil)
assert(f() == 42)
end
return'OK'

View file

@ -45,4 +45,13 @@ breakpoint(38) -- break inside corobad()
local co = coroutine.create(corobad)
assert(coroutine.resume(co) == false) -- this breaks, resumes and dies!
function bar()
print("in bar")
end
breakpoint(49)
breakpoint(49, false) -- validate that disabling breakpoints works
bar()
return 'OK'