Merge branch 'upstream' into merge

This commit is contained in:
Andy Friesen 2023-05-05 12:57:26 -07:00
commit 95f0a04e06
56 changed files with 1918 additions and 1058 deletions

View file

@ -107,6 +107,11 @@ struct FunctionCallConstraint
TypePackId result; TypePackId result;
class AstExprCall* callSite; class AstExprCall* callSite;
std::vector<std::optional<TypeId>> discriminantTypes; std::vector<std::optional<TypeId>> discriminantTypes;
// When we dispatch this constraint, we update the key at this map to record
// the overload that we selected.
DenseHashMap<const AstNode*, TypeId>* astOriginalCallTypes;
DenseHashMap<const AstNode*, TypeId>* astOverloadResolvedTypes;
}; };
// result ~ prim ExpectedType SomeSingletonType MultitonType // result ~ prim ExpectedType SomeSingletonType MultitonType

View file

@ -28,6 +28,7 @@ struct FileResolver;
struct ModuleResolver; struct ModuleResolver;
struct ParseResult; struct ParseResult;
struct HotComment; struct HotComment;
struct BuildQueueItem;
struct LoadDefinitionFileResult struct LoadDefinitionFileResult
{ {
@ -171,7 +172,18 @@ struct Frontend
LoadDefinitionFileResult loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName, LoadDefinitionFileResult loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName,
bool captureComments, bool typeCheckForAutocomplete = false); bool captureComments, bool typeCheckForAutocomplete = false);
// Batch module checking. Queue modules and check them together, retrieve results with 'getCheckResult'
// If provided, 'executeTask' function is allowed to call the 'task' function on any thread and return without waiting for 'task' to complete
void queueModuleCheck(const std::vector<ModuleName>& names);
void queueModuleCheck(const ModuleName& name);
std::vector<ModuleName> checkQueuedModules(std::optional<FrontendOptions> optionOverride = {},
std::function<void(std::function<void()> task)> executeTask = {}, std::function<void(size_t done, size_t total)> progress = {});
std::optional<CheckResult> getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false);
private: private:
CheckResult check_DEPRECATED(const ModuleName& name, std::optional<FrontendOptions> optionOverride = {});
struct TypeCheckLimits struct TypeCheckLimits
{ {
std::optional<double> finishTime; std::optional<double> finishTime;
@ -185,7 +197,14 @@ private:
std::pair<SourceNode*, SourceModule*> getSourceNode(const ModuleName& name); std::pair<SourceNode*, SourceModule*> getSourceNode(const ModuleName& name);
SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions);
bool parseGraph(std::vector<ModuleName>& buildQueue, const ModuleName& root, bool forAutocomplete); bool parseGraph(
std::vector<ModuleName>& buildQueue, const ModuleName& root, bool forAutocomplete, std::function<bool(const ModuleName&)> canSkip = {});
void addBuildQueueItems(std::vector<BuildQueueItem>& items, std::vector<ModuleName>& buildQueue, bool cycleDetected,
std::unordered_set<Luau::ModuleName>& seen, const FrontendOptions& frontendOptions);
void checkBuildQueueItem(BuildQueueItem& item);
void checkBuildQueueItems(std::vector<BuildQueueItem>& items);
void recordItemResult(const BuildQueueItem& item);
static LintResult classifyLints(const std::vector<LintWarning>& warnings, const Config& config); static LintResult classifyLints(const std::vector<LintWarning>& warnings, const Config& config);
@ -212,11 +231,13 @@ public:
InternalErrorReporter iceHandler; InternalErrorReporter iceHandler;
std::function<void(const ModuleName& name, const ScopePtr& scope, bool forAutocomplete)> prepareModuleScope; std::function<void(const ModuleName& name, const ScopePtr& scope, bool forAutocomplete)> prepareModuleScope;
std::unordered_map<ModuleName, SourceNode> sourceNodes; std::unordered_map<ModuleName, std::shared_ptr<SourceNode>> sourceNodes;
std::unordered_map<ModuleName, SourceModule> sourceModules; std::unordered_map<ModuleName, std::shared_ptr<SourceModule>> sourceModules;
std::unordered_map<ModuleName, RequireTraceResult> requireTrace; std::unordered_map<ModuleName, RequireTraceResult> requireTrace;
Stats stats = {}; Stats stats = {};
std::vector<ModuleName> moduleQueue;
}; };
ModulePtr check(const SourceModule& sourceModule, const std::vector<RequireCycle>& requireCycles, NotNull<BuiltinTypes> builtinTypes, ModulePtr check(const SourceModule& sourceModule, const std::vector<RequireCycle>& requireCycles, NotNull<BuiltinTypes> builtinTypes,

View file

@ -226,10 +226,6 @@ struct NormalizedType
NormalizedClassType classes; NormalizedClassType classes;
// The class part of the type.
// Each element of this set is a class, and none of the classes are subclasses of each other.
TypeIds DEPRECATED_classes;
// The error part of the type. // The error part of the type.
// This type is either never or the error type. // This type is either never or the error type.
TypeId errors; TypeId errors;
@ -333,8 +329,6 @@ public:
// ------- Normalizing intersections // ------- Normalizing intersections
TypeId intersectionOfTops(TypeId here, TypeId there); TypeId intersectionOfTops(TypeId here, TypeId there);
TypeId intersectionOfBools(TypeId here, TypeId there); TypeId intersectionOfBools(TypeId here, TypeId there);
void DEPRECATED_intersectClasses(TypeIds& heres, const TypeIds& theres);
void DEPRECATED_intersectClassesWithClass(TypeIds& heres, TypeId there);
void intersectClasses(NormalizedClassType& heres, const NormalizedClassType& theres); void intersectClasses(NormalizedClassType& heres, const NormalizedClassType& theres);
void intersectClassesWithClass(NormalizedClassType& heres, TypeId there); void intersectClassesWithClass(NormalizedClassType& heres, TypeId there);
void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there); void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there);

View file

@ -694,7 +694,7 @@ bool areEqual(SeenSet& seen, const Type& lhs, const Type& rhs);
// Follow BoundTypes until we get to something real // Follow BoundTypes until we get to something real
TypeId follow(TypeId t); TypeId follow(TypeId t);
TypeId follow(TypeId t, std::function<TypeId(TypeId)> mapper); TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId));
std::vector<TypeId> flattenIntersection(TypeId ty); std::vector<TypeId> flattenIntersection(TypeId ty);

View file

@ -169,7 +169,7 @@ using SeenSet = std::set<std::pair<const void*, const void*>>;
bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs);
TypePackId follow(TypePackId tp); TypePackId follow(TypePackId tp);
TypePackId follow(TypePackId tp, std::function<TypePackId(TypePackId)> mapper); TypePackId follow(TypePackId t, const void* context, TypePackId (*mapper)(const void*, TypePackId));
size_t size(TypePackId tp, TxnLog* log = nullptr); size_t size(TypePackId tp, TxnLog* log = nullptr);
bool finite(TypePackId tp, TxnLog* log = nullptr); bool finite(TypePackId tp, TxnLog* log = nullptr);

View file

@ -163,5 +163,6 @@ private:
void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp); void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp);
std::optional<TypeError> hasUnificationTooComplex(const ErrorVec& errors); std::optional<TypeError> hasUnificationTooComplex(const ErrorVec& errors);
std::optional<TypeError> hasCountMismatch(const ErrorVec& errors);
} // namespace Luau } // namespace Luau

View file

@ -18,7 +18,6 @@
LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTINT(LuauCheckRecursionLimit);
LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(DebugLuauMagicTypes);
LUAU_FASTFLAG(LuauNegatedClassTypes);
namespace Luau namespace Luau
{ {
@ -1016,7 +1015,7 @@ static bool isMetamethod(const Name& name)
ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass)
{ {
std::optional<TypeId> superTy = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; std::optional<TypeId> superTy = std::make_optional(builtinTypes->classType);
if (declaredClass->superName) if (declaredClass->superName)
{ {
Name superName = Name(declaredClass->superName->value); Name superName = Name(declaredClass->superName->value);
@ -1420,6 +1419,8 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa
rets, rets,
call, call,
std::move(discriminantTypes), std::move(discriminantTypes),
&module->astOriginalCallTypes,
&module->astOverloadResolvedTypes,
}); });
// We force constraints produced by checking function arguments to wait // We force constraints produced by checking function arguments to wait
@ -1772,7 +1773,7 @@ std::tuple<TypeId, TypeId, RefinementId> ConstraintGraphBuilder::checkBinary(
TypeId ty = follow(typeFun->type); TypeId ty = follow(typeFun->type);
// We're only interested in the root class of any classes. // We're only interested in the root class of any classes.
if (auto ctv = get<ClassType>(ty); !ctv || (FFlag::LuauNegatedClassTypes ? (ctv->parent == builtinTypes->classType) : !ctv->parent)) if (auto ctv = get<ClassType>(ty); !ctv || ctv->parent == builtinTypes->classType)
discriminantTy = ty; discriminantTy = ty;
} }
@ -1786,8 +1787,10 @@ std::tuple<TypeId, TypeId, RefinementId> ConstraintGraphBuilder::checkBinary(
} }
else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe) else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe)
{ {
TypeId leftType = check(scope, binary->left, ValueContext::RValue, expectedType, true).ty; // We are checking a binary expression of the form a op b
TypeId rightType = check(scope, binary->right, ValueContext::RValue, expectedType, true).ty; // Just because a op b is epxected to return a bool, doesn't mean a, b are expected to be bools too
TypeId leftType = check(scope, binary->left, ValueContext::RValue, {}, true).ty;
TypeId rightType = check(scope, binary->right, ValueContext::RValue, {}, true).ty;
RefinementId leftRefinement = nullptr; RefinementId leftRefinement = nullptr;
if (auto bc = dfg->getBreadcrumb(binary->left)) if (auto bc = dfg->getBreadcrumb(binary->left))

View file

@ -1172,6 +1172,9 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
else if (auto it = get<IntersectionType>(fn)) else if (auto it = get<IntersectionType>(fn))
fn = collapse(it).value_or(fn); fn = collapse(it).value_or(fn);
if (c.callSite)
(*c.astOriginalCallTypes)[c.callSite] = fn;
// We don't support magic __call metamethods. // We don't support magic __call metamethods.
if (std::optional<TypeId> callMm = findMetatableEntry(builtinTypes, errors, fn, "__call", constraint->location)) if (std::optional<TypeId> callMm = findMetatableEntry(builtinTypes, errors, fn, "__call", constraint->location))
{ {
@ -1219,10 +1222,22 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
TypeId inferredTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result}); TypeId inferredTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result});
std::vector<TypeId> overloads = flattenIntersection(fn); const NormalizedType* normFn = normalizer->normalize(fn);
if (!normFn)
{
reportError(UnificationTooComplex{}, constraint->location);
return true;
}
// TODO: It would be nice to not need to convert the normalized type back to
// an intersection and flatten it.
TypeId normFnTy = normalizer->typeFromNormal(*normFn);
std::vector<TypeId> overloads = flattenIntersection(normFnTy);
Instantiation inst(TxnLog::empty(), arena, TypeLevel{}, constraint->scope); Instantiation inst(TxnLog::empty(), arena, TypeLevel{}, constraint->scope);
std::vector<TypeId> arityMatchingOverloads;
for (TypeId overload : overloads) for (TypeId overload : overloads)
{ {
overload = follow(overload); overload = follow(overload);
@ -1247,8 +1262,17 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
if (const auto& e = hasUnificationTooComplex(u.errors)) if (const auto& e = hasUnificationTooComplex(u.errors))
reportError(*e); reportError(*e);
if (const auto& e = hasCountMismatch(u.errors);
(!e || get<CountMismatch>(*e)->context != CountMismatch::Context::Arg) && get<FunctionType>(*instantiated))
{
arityMatchingOverloads.push_back(*instantiated);
}
if (u.errors.empty()) if (u.errors.empty())
{ {
if (c.callSite)
(*c.astOverloadResolvedTypes)[c.callSite] = *instantiated;
// We found a matching overload. // We found a matching overload.
const auto [changedTypes, changedPacks] = u.log.getChanges(); const auto [changedTypes, changedPacks] = u.log.getChanges();
u.log.commit(); u.log.commit();
@ -1260,6 +1284,15 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
} }
} }
if (arityMatchingOverloads.size() == 1 && c.callSite)
{
// In the name of better error messages in the type checker, we provide
// it with an instantiated function signature that matched arity, but
// not the requisite subtyping requirements. This makes errors better in
// cases where only one overload fit from an arity perspective.
(*c.astOverloadResolvedTypes)[c.callSite] = arityMatchingOverloads.at(0);
}
// We found no matching overloads. // We found no matching overloads.
Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant}; Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant};
u.useScopes = true; u.useScopes = true;
@ -1267,8 +1300,6 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
u.tryUnify(inferredTy, builtinTypes->anyType); u.tryUnify(inferredTy, builtinTypes->anyType);
u.tryUnify(fn, builtinTypes->anyType); u.tryUnify(fn, builtinTypes->anyType);
LUAU_ASSERT(u.errors.empty()); // unifying with any should never fail
const auto [changedTypes, changedPacks] = u.log.getChanges(); const auto [changedTypes, changedPacks] = u.log.getChanges();
u.log.commit(); u.log.commit();
@ -2166,13 +2197,24 @@ void ConstraintSolver::unblock(NotNull<const Constraint> progressed)
void ConstraintSolver::unblock(TypeId progressed) void ConstraintSolver::unblock(TypeId progressed)
{ {
DenseHashSet<TypeId> seen{nullptr};
while (true)
{
if (seen.find(progressed))
iceReporter.ice("ConstraintSolver::unblock encountered a self-bound type!");
seen.insert(progressed);
if (logger) if (logger)
logger->popBlock(progressed); logger->popBlock(progressed);
unblock_(progressed); unblock_(progressed);
if (auto bt = get<BoundType>(progressed)) if (auto bt = get<BoundType>(progressed))
unblock(bt->boundTo); progressed = bt->boundTo;
else
break;
}
} }
void ConstraintSolver::unblock(TypePackId progressed) void ConstraintSolver::unblock(TypePackId progressed)

View file

@ -21,6 +21,9 @@
#include <algorithm> #include <algorithm>
#include <chrono> #include <chrono>
#include <condition_variable>
#include <exception>
#include <mutex>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
@ -34,10 +37,36 @@ LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false)
LUAU_FASTFLAG(LuauRequirePathTrueModuleName) LUAU_FASTFLAG(LuauRequirePathTrueModuleName)
LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false) LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false)
LUAU_FASTFLAGVARIABLE(LuauSplitFrontendProcessing, false)
namespace Luau namespace Luau
{ {
struct BuildQueueItem
{
ModuleName name;
ModuleName humanReadableName;
// Parameters
std::shared_ptr<SourceNode> sourceNode;
std::shared_ptr<SourceModule> sourceModule;
Config config;
ScopePtr environmentScope;
std::vector<RequireCycle> requireCycles;
FrontendOptions options;
bool recordJsonLog = false;
// Queue state
std::vector<size_t> reverseDeps;
int dirtyDependencies = 0;
bool processing = false;
// Result
std::exception_ptr exception;
ModulePtr module;
Frontend::Stats stats;
};
std::optional<Mode> parseMode(const std::vector<HotComment>& hotcomments) std::optional<Mode> parseMode(const std::vector<HotComment>& hotcomments)
{ {
for (const HotComment& hc : hotcomments) for (const HotComment& hc : hotcomments)
@ -220,7 +249,7 @@ namespace
{ {
static ErrorVec accumulateErrors( static ErrorVec accumulateErrors(
const std::unordered_map<ModuleName, SourceNode>& sourceNodes, ModuleResolver& moduleResolver, const ModuleName& name) const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes, ModuleResolver& moduleResolver, const ModuleName& name)
{ {
std::unordered_set<ModuleName> seen; std::unordered_set<ModuleName> seen;
std::vector<ModuleName> queue{name}; std::vector<ModuleName> queue{name};
@ -240,7 +269,7 @@ static ErrorVec accumulateErrors(
if (it == sourceNodes.end()) if (it == sourceNodes.end())
continue; continue;
const SourceNode& sourceNode = it->second; const SourceNode& sourceNode = *it->second;
queue.insert(queue.end(), sourceNode.requireSet.begin(), sourceNode.requireSet.end()); queue.insert(queue.end(), sourceNode.requireSet.begin(), sourceNode.requireSet.end());
// FIXME: If a module has a syntax error, we won't be able to re-report it here. // FIXME: If a module has a syntax error, we won't be able to re-report it here.
@ -285,8 +314,8 @@ static void filterLintOptions(LintOptions& lintOptions, const std::vector<HotCom
// For each such path, record the full path and the location of the require in the starting module. // For each such path, record the full path and the location of the require in the starting module.
// Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V) // Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V)
// However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true) // However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true)
std::vector<RequireCycle> getRequireCycles( std::vector<RequireCycle> getRequireCycles(const FileResolver* resolver,
const FileResolver* resolver, const std::unordered_map<ModuleName, SourceNode>& sourceNodes, const SourceNode* start, bool stopAtFirst = false) const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes, const SourceNode* start, bool stopAtFirst = false)
{ {
std::vector<RequireCycle> result; std::vector<RequireCycle> result;
@ -302,7 +331,7 @@ std::vector<RequireCycle> getRequireCycles(
if (dit == sourceNodes.end()) if (dit == sourceNodes.end())
continue; continue;
stack.push_back(&dit->second); stack.push_back(dit->second.get());
while (!stack.empty()) while (!stack.empty())
{ {
@ -343,7 +372,7 @@ std::vector<RequireCycle> getRequireCycles(
auto rit = sourceNodes.find(reqName); auto rit = sourceNodes.find(reqName);
if (rit != sourceNodes.end()) if (rit != sourceNodes.end())
stack.push_back(&rit->second); stack.push_back(rit->second.get());
} }
} }
} }
@ -389,6 +418,52 @@ Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, c
} }
CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOptions> optionOverride) CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOptions> optionOverride)
{
if (!FFlag::LuauSplitFrontendProcessing)
return check_DEPRECATED(name, optionOverride);
LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
FrontendOptions frontendOptions = optionOverride.value_or(options);
if (std::optional<CheckResult> result = getCheckResult(name, true, frontendOptions.forAutocomplete))
return std::move(*result);
std::vector<ModuleName> buildQueue;
bool cycleDetected = parseGraph(buildQueue, name, frontendOptions.forAutocomplete);
std::unordered_set<Luau::ModuleName> seen;
std::vector<BuildQueueItem> buildQueueItems;
addBuildQueueItems(buildQueueItems, buildQueue, cycleDetected, seen, frontendOptions);
LUAU_ASSERT(!buildQueueItems.empty());
if (FFlag::DebugLuauLogSolverToJson)
{
LUAU_ASSERT(buildQueueItems.back().name == name);
buildQueueItems.back().recordJsonLog = true;
}
checkBuildQueueItems(buildQueueItems);
// Collect results only for checked modules, 'getCheckResult' produces a different result
CheckResult checkResult;
for (const BuildQueueItem& item : buildQueueItems)
{
if (item.module->timeout)
checkResult.timeoutHits.push_back(item.name);
checkResult.errors.insert(checkResult.errors.end(), item.module->errors.begin(), item.module->errors.end());
if (item.name == name)
checkResult.lintResult = item.module->lintResult;
}
return checkResult;
}
CheckResult Frontend::check_DEPRECATED(const ModuleName& name, std::optional<FrontendOptions> optionOverride)
{ {
LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
@ -399,7 +474,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
FrontendModuleResolver& resolver = frontendOptions.forAutocomplete ? moduleResolverForAutocomplete : moduleResolver; FrontendModuleResolver& resolver = frontendOptions.forAutocomplete ? moduleResolverForAutocomplete : moduleResolver;
auto it = sourceNodes.find(name); auto it = sourceNodes.find(name);
if (it != sourceNodes.end() && !it->second.hasDirtyModule(frontendOptions.forAutocomplete)) if (it != sourceNodes.end() && !it->second->hasDirtyModule(frontendOptions.forAutocomplete))
{ {
// No recheck required. // No recheck required.
ModulePtr module = resolver.getModule(name); ModulePtr module = resolver.getModule(name);
@ -421,13 +496,13 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
for (const ModuleName& moduleName : buildQueue) for (const ModuleName& moduleName : buildQueue)
{ {
LUAU_ASSERT(sourceNodes.count(moduleName)); LUAU_ASSERT(sourceNodes.count(moduleName));
SourceNode& sourceNode = sourceNodes[moduleName]; SourceNode& sourceNode = *sourceNodes[moduleName];
if (!sourceNode.hasDirtyModule(frontendOptions.forAutocomplete)) if (!sourceNode.hasDirtyModule(frontendOptions.forAutocomplete))
continue; continue;
LUAU_ASSERT(sourceModules.count(moduleName)); LUAU_ASSERT(sourceModules.count(moduleName));
SourceModule& sourceModule = sourceModules[moduleName]; SourceModule& sourceModule = *sourceModules[moduleName];
const Config& config = configResolver->getConfig(moduleName); const Config& config = configResolver->getConfig(moduleName);
@ -583,7 +658,241 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
return checkResult; return checkResult;
} }
bool Frontend::parseGraph(std::vector<ModuleName>& buildQueue, const ModuleName& root, bool forAutocomplete) void Frontend::queueModuleCheck(const std::vector<ModuleName>& names)
{
moduleQueue.insert(moduleQueue.end(), names.begin(), names.end());
}
void Frontend::queueModuleCheck(const ModuleName& name)
{
moduleQueue.push_back(name);
}
std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptions> optionOverride,
std::function<void(std::function<void()> task)> executeTask, std::function<void(size_t done, size_t total)> progress)
{
FrontendOptions frontendOptions = optionOverride.value_or(options);
// By taking data into locals, we make sure queue is cleared at the end, even if an ICE or a different exception is thrown
std::vector<ModuleName> currModuleQueue;
std::swap(currModuleQueue, moduleQueue);
std::unordered_set<Luau::ModuleName> seen;
std::vector<BuildQueueItem> buildQueueItems;
for (const ModuleName& name : currModuleQueue)
{
if (seen.count(name))
continue;
if (!isDirty(name, frontendOptions.forAutocomplete))
{
seen.insert(name);
continue;
}
std::vector<ModuleName> queue;
bool cycleDetected = parseGraph(queue, name, frontendOptions.forAutocomplete, [&seen](const ModuleName& name) {
return seen.count(name);
});
addBuildQueueItems(buildQueueItems, queue, cycleDetected, seen, frontendOptions);
}
if (buildQueueItems.empty())
return {};
// We need a mapping from modules to build queue slots
std::unordered_map<ModuleName, size_t> moduleNameToQueue;
for (size_t i = 0; i < buildQueueItems.size(); i++)
{
BuildQueueItem& item = buildQueueItems[i];
moduleNameToQueue[item.name] = i;
}
// Default task execution is single-threaded and immediate
if (!executeTask)
{
executeTask = [](std::function<void()> task) {
task();
};
}
std::mutex mtx;
std::condition_variable cv;
std::vector<size_t> readyQueueItems;
size_t processing = 0;
size_t remaining = buildQueueItems.size();
auto itemTask = [&](size_t i) {
BuildQueueItem& item = buildQueueItems[i];
try
{
checkBuildQueueItem(item);
}
catch (...)
{
item.exception = std::current_exception();
}
{
std::unique_lock guard(mtx);
readyQueueItems.push_back(i);
}
cv.notify_one();
};
auto sendItemTask = [&](size_t i) {
BuildQueueItem& item = buildQueueItems[i];
item.processing = true;
processing++;
executeTask([&itemTask, i]() {
itemTask(i);
});
};
auto sendCycleItemTask = [&] {
for (size_t i = 0; i < buildQueueItems.size(); i++)
{
BuildQueueItem& item = buildQueueItems[i];
if (!item.processing)
{
sendItemTask(i);
break;
}
}
};
// In a first pass, check modules that have no dependencies and record info of those modules that wait
for (size_t i = 0; i < buildQueueItems.size(); i++)
{
BuildQueueItem& item = buildQueueItems[i];
for (const ModuleName& dep : item.sourceNode->requireSet)
{
if (auto it = sourceNodes.find(dep); it != sourceNodes.end())
{
if (it->second->hasDirtyModule(frontendOptions.forAutocomplete))
{
item.dirtyDependencies++;
buildQueueItems[moduleNameToQueue[dep]].reverseDeps.push_back(i);
}
}
}
if (item.dirtyDependencies == 0)
sendItemTask(i);
}
// Not a single item was found, a cycle in the graph was hit
if (processing == 0)
sendCycleItemTask();
std::vector<size_t> nextItems;
while (remaining != 0)
{
{
std::unique_lock guard(mtx);
// If nothing is ready yet, wait
if (readyQueueItems.empty())
{
cv.wait(guard, [&readyQueueItems] {
return !readyQueueItems.empty();
});
}
// Handle checked items
for (size_t i : readyQueueItems)
{
const BuildQueueItem& item = buildQueueItems[i];
recordItemResult(item);
// Notify items that were waiting for this dependency
for (size_t reverseDep : item.reverseDeps)
{
BuildQueueItem& reverseDepItem = buildQueueItems[reverseDep];
LUAU_ASSERT(reverseDepItem.dirtyDependencies != 0);
reverseDepItem.dirtyDependencies--;
// In case of a module cycle earlier, check if unlocked an item that was already processed
if (!reverseDepItem.processing && reverseDepItem.dirtyDependencies == 0)
nextItems.push_back(reverseDep);
}
}
LUAU_ASSERT(processing >= readyQueueItems.size());
processing -= readyQueueItems.size();
LUAU_ASSERT(remaining >= readyQueueItems.size());
remaining -= readyQueueItems.size();
readyQueueItems.clear();
}
if (progress)
progress(buildQueueItems.size() - remaining, buildQueueItems.size());
// Items cannot be submitted while holding the lock
for (size_t i : nextItems)
sendItemTask(i);
nextItems.clear();
// If we aren't done, but don't have anything processing, we hit a cycle
if (remaining != 0 && processing == 0)
sendCycleItemTask();
}
std::vector<ModuleName> checkedModules;
checkedModules.reserve(buildQueueItems.size());
for (size_t i = 0; i < buildQueueItems.size(); i++)
checkedModules.push_back(std::move(buildQueueItems[i].name));
return checkedModules;
}
std::optional<CheckResult> Frontend::getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete)
{
auto it = sourceNodes.find(name);
if (it == sourceNodes.end() || it->second->hasDirtyModule(forAutocomplete))
return std::nullopt;
auto& resolver = forAutocomplete ? moduleResolverForAutocomplete : moduleResolver;
ModulePtr module = resolver.getModule(name);
if (module == nullptr)
throw InternalCompilerError("Frontend does not have module: " + name, name);
CheckResult checkResult;
if (module->timeout)
checkResult.timeoutHits.push_back(name);
if (accumulateNested)
checkResult.errors = accumulateErrors(sourceNodes, resolver, name);
else
checkResult.errors.insert(checkResult.errors.end(), module->errors.begin(), module->errors.end());
// Get lint result only for top checked module
checkResult.lintResult = module->lintResult;
return checkResult;
}
bool Frontend::parseGraph(
std::vector<ModuleName>& buildQueue, const ModuleName& root, bool forAutocomplete, std::function<bool(const ModuleName&)> canSkip)
{ {
LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend"); LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend");
LUAU_TIMETRACE_ARGUMENT("root", root.c_str()); LUAU_TIMETRACE_ARGUMENT("root", root.c_str());
@ -654,14 +963,18 @@ bool Frontend::parseGraph(std::vector<ModuleName>& buildQueue, const ModuleName&
// this relies on the fact that markDirty marks reverse-dependencies dirty as well // this relies on the fact that markDirty marks reverse-dependencies dirty as well
// thus if a node is not dirty, all its transitive deps aren't dirty, which means that they won't ever need // thus if a node is not dirty, all its transitive deps aren't dirty, which means that they won't ever need
// to be built, *and* can't form a cycle with any nodes we did process. // to be built, *and* can't form a cycle with any nodes we did process.
if (!it->second.hasDirtyModule(forAutocomplete)) if (!it->second->hasDirtyModule(forAutocomplete))
continue;
// This module might already be in the outside build queue
if (canSkip && canSkip(dep))
continue; continue;
// note: this check is technically redundant *except* that getSourceNode has somewhat broken memoization // note: this check is technically redundant *except* that getSourceNode has somewhat broken memoization
// calling getSourceNode twice in succession will reparse the file, since getSourceNode leaves dirty flag set // calling getSourceNode twice in succession will reparse the file, since getSourceNode leaves dirty flag set
if (seen.contains(&it->second)) if (seen.contains(it->second.get()))
{ {
stack.push_back(&it->second); stack.push_back(it->second.get());
continue; continue;
} }
} }
@ -681,6 +994,210 @@ bool Frontend::parseGraph(std::vector<ModuleName>& buildQueue, const ModuleName&
return cyclic; return cyclic;
} }
void Frontend::addBuildQueueItems(std::vector<BuildQueueItem>& items, std::vector<ModuleName>& buildQueue, bool cycleDetected,
std::unordered_set<Luau::ModuleName>& seen, const FrontendOptions& frontendOptions)
{
LUAU_ASSERT(FFlag::LuauSplitFrontendProcessing);
for (const ModuleName& moduleName : buildQueue)
{
if (seen.count(moduleName))
continue;
seen.insert(moduleName);
LUAU_ASSERT(sourceNodes.count(moduleName));
std::shared_ptr<SourceNode>& sourceNode = sourceNodes[moduleName];
if (!sourceNode->hasDirtyModule(frontendOptions.forAutocomplete))
continue;
LUAU_ASSERT(sourceModules.count(moduleName));
std::shared_ptr<SourceModule>& sourceModule = sourceModules[moduleName];
BuildQueueItem data{moduleName, fileResolver->getHumanReadableModuleName(moduleName), sourceNode, sourceModule};
data.config = configResolver->getConfig(moduleName);
data.environmentScope = getModuleEnvironment(*sourceModule, data.config, frontendOptions.forAutocomplete);
Mode mode = sourceModule->mode.value_or(data.config.mode);
// in NoCheck mode we only need to compute the value of .cyclic for typeck
// in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself
// however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term
// all correct programs must be acyclic so this code triggers rarely
if (cycleDetected)
data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), mode == Mode::NoCheck);
data.options = frontendOptions;
// This is used by the type checker to replace the resulting type of cyclic modules with any
sourceModule->cyclic = !data.requireCycles.empty();
items.push_back(std::move(data));
}
}
void Frontend::checkBuildQueueItem(BuildQueueItem& item)
{
LUAU_ASSERT(FFlag::LuauSplitFrontendProcessing);
SourceNode& sourceNode = *item.sourceNode;
const SourceModule& sourceModule = *item.sourceModule;
const Config& config = item.config;
Mode mode = sourceModule.mode.value_or(config.mode);
ScopePtr environmentScope = item.environmentScope;
double timestamp = getTimestamp();
const std::vector<RequireCycle>& requireCycles = item.requireCycles;
if (item.options.forAutocomplete)
{
double autocompleteTimeLimit = FInt::LuauAutocompleteCheckTimeoutMs / 1000.0;
// The autocomplete typecheck is always in strict mode with DM awareness
// to provide better type information for IDE features
TypeCheckLimits typeCheckLimits;
if (autocompleteTimeLimit != 0.0)
typeCheckLimits.finishTime = TimeTrace::getClock() + autocompleteTimeLimit;
else
typeCheckLimits.finishTime = std::nullopt;
// TODO: This is a dirty ad hoc solution for autocomplete timeouts
// We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit
// so that we'll have type information for the whole file at lower quality instead of a full abort in the middle
if (FInt::LuauTarjanChildLimit > 0)
typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult));
else
typeCheckLimits.instantiationChildLimit = std::nullopt;
if (FInt::LuauTypeInferIterationLimit > 0)
typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult));
else
typeCheckLimits.unifierIterationLimit = std::nullopt;
ModulePtr moduleForAutocomplete = check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true,
/*recordJsonLog*/ false, typeCheckLimits);
double duration = getTimestamp() - timestamp;
if (moduleForAutocomplete->timeout)
sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0;
else if (duration < autocompleteTimeLimit / 2.0)
sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0);
item.stats.timeCheck += duration;
item.stats.filesStrict += 1;
item.module = moduleForAutocomplete;
return;
}
ModulePtr module = check(sourceModule, mode, requireCycles, environmentScope, /*forAutocomplete*/ false, item.recordJsonLog, {});
item.stats.timeCheck += getTimestamp() - timestamp;
item.stats.filesStrict += mode == Mode::Strict;
item.stats.filesNonstrict += mode == Mode::Nonstrict;
if (module == nullptr)
throw InternalCompilerError("Frontend::check produced a nullptr module for " + item.name, item.name);
if (FFlag::DebugLuauDeferredConstraintResolution && mode == Mode::NoCheck)
module->errors.clear();
if (item.options.runLintChecks)
{
LUAU_TIMETRACE_SCOPE("lint", "Frontend");
LintOptions lintOptions = item.options.enabledLintWarnings.value_or(config.enabledLint);
filterLintOptions(lintOptions, sourceModule.hotcomments, mode);
double timestamp = getTimestamp();
std::vector<LintWarning> warnings =
Luau::lint(sourceModule.root, *sourceModule.names, environmentScope, module.get(), sourceModule.hotcomments, lintOptions);
item.stats.timeLint += getTimestamp() - timestamp;
module->lintResult = classifyLints(warnings, config);
}
if (!item.options.retainFullTypeGraphs)
{
// copyErrors needs to allocate into interfaceTypes as it copies
// types out of internalTypes, so we unfreeze it here.
unfreeze(module->interfaceTypes);
copyErrors(module->errors, module->interfaceTypes);
freeze(module->interfaceTypes);
module->internalTypes.clear();
module->astTypes.clear();
module->astTypePacks.clear();
module->astExpectedTypes.clear();
module->astOriginalCallTypes.clear();
module->astOverloadResolvedTypes.clear();
module->astResolvedTypes.clear();
module->astOriginalResolvedTypes.clear();
module->astResolvedTypePacks.clear();
module->astScopes.clear();
module->scopes.clear();
}
if (mode != Mode::NoCheck)
{
for (const RequireCycle& cyc : requireCycles)
{
TypeError te{cyc.location, item.name, ModuleHasCyclicDependency{cyc.path}};
module->errors.push_back(te);
}
}
ErrorVec parseErrors;
for (const ParseError& pe : sourceModule.parseErrors)
parseErrors.push_back(TypeError{pe.getLocation(), item.name, SyntaxError{pe.what()}});
module->errors.insert(module->errors.begin(), parseErrors.begin(), parseErrors.end());
item.module = module;
}
void Frontend::checkBuildQueueItems(std::vector<BuildQueueItem>& items)
{
LUAU_ASSERT(FFlag::LuauSplitFrontendProcessing);
for (BuildQueueItem& item : items)
{
checkBuildQueueItem(item);
recordItemResult(item);
}
}
void Frontend::recordItemResult(const BuildQueueItem& item)
{
if (item.exception)
std::rethrow_exception(item.exception);
if (item.options.forAutocomplete)
{
moduleResolverForAutocomplete.setModule(item.name, item.module);
item.sourceNode->dirtyModuleForAutocomplete = false;
}
else
{
moduleResolver.setModule(item.name, item.module);
item.sourceNode->dirtyModule = false;
}
stats.timeCheck += item.stats.timeCheck;
stats.timeLint += item.stats.timeLint;
stats.filesStrict += item.stats.filesStrict;
stats.filesNonstrict += item.stats.filesNonstrict;
}
ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) const ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) const
{ {
ScopePtr result; ScopePtr result;
@ -711,7 +1228,7 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config
bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const
{ {
auto it = sourceNodes.find(name); auto it = sourceNodes.find(name);
return it == sourceNodes.end() || it->second.hasDirtyModule(forAutocomplete); return it == sourceNodes.end() || it->second->hasDirtyModule(forAutocomplete);
} }
/* /*
@ -728,7 +1245,7 @@ void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* marked
std::unordered_map<ModuleName, std::vector<ModuleName>> reverseDeps; std::unordered_map<ModuleName, std::vector<ModuleName>> reverseDeps;
for (const auto& module : sourceNodes) for (const auto& module : sourceNodes)
{ {
for (const auto& dep : module.second.requireSet) for (const auto& dep : module.second->requireSet)
reverseDeps[dep].push_back(module.first); reverseDeps[dep].push_back(module.first);
} }
@ -740,7 +1257,7 @@ void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* marked
queue.pop_back(); queue.pop_back();
LUAU_ASSERT(sourceNodes.count(next) > 0); LUAU_ASSERT(sourceNodes.count(next) > 0);
SourceNode& sourceNode = sourceNodes[next]; SourceNode& sourceNode = *sourceNodes[next];
if (markedDirty) if (markedDirty)
markedDirty->push_back(next); markedDirty->push_back(next);
@ -766,7 +1283,7 @@ SourceModule* Frontend::getSourceModule(const ModuleName& moduleName)
{ {
auto it = sourceModules.find(moduleName); auto it = sourceModules.find(moduleName);
if (it != sourceModules.end()) if (it != sourceModules.end())
return &it->second; return it->second.get();
else else
return nullptr; return nullptr;
} }
@ -901,22 +1418,22 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vect
// Read AST into sourceModules if necessary. Trace require()s. Report parse errors. // Read AST into sourceModules if necessary. Trace require()s. Report parse errors.
std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(const ModuleName& name) std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(const ModuleName& name)
{ {
LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
auto it = sourceNodes.find(name); auto it = sourceNodes.find(name);
if (it != sourceNodes.end() && !it->second.hasDirtySourceModule()) if (it != sourceNodes.end() && !it->second->hasDirtySourceModule())
{ {
auto moduleIt = sourceModules.find(name); auto moduleIt = sourceModules.find(name);
if (moduleIt != sourceModules.end()) if (moduleIt != sourceModules.end())
return {&it->second, &moduleIt->second}; return {it->second.get(), moduleIt->second.get()};
else else
{ {
LUAU_ASSERT(!"Everything in sourceNodes should also be in sourceModules"); LUAU_ASSERT(!"Everything in sourceNodes should also be in sourceModules");
return {&it->second, nullptr}; return {it->second.get(), nullptr};
} }
} }
LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
double timestamp = getTimestamp(); double timestamp = getTimestamp();
std::optional<SourceCode> source = fileResolver->readSource(name); std::optional<SourceCode> source = fileResolver->readSource(name);
@ -939,30 +1456,37 @@ std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(const ModuleName&
RequireTraceResult& require = requireTrace[name]; RequireTraceResult& require = requireTrace[name];
require = traceRequires(fileResolver, result.root, name); require = traceRequires(fileResolver, result.root, name);
SourceNode& sourceNode = sourceNodes[name]; std::shared_ptr<SourceNode>& sourceNode = sourceNodes[name];
SourceModule& sourceModule = sourceModules[name];
sourceModule = std::move(result); if (!sourceNode)
sourceModule.environmentName = environmentName; sourceNode = std::make_shared<SourceNode>();
sourceNode.name = sourceModule.name; std::shared_ptr<SourceModule>& sourceModule = sourceModules[name];
sourceNode.humanReadableName = sourceModule.humanReadableName;
sourceNode.requireSet.clear(); if (!sourceModule)
sourceNode.requireLocations.clear(); sourceModule = std::make_shared<SourceModule>();
sourceNode.dirtySourceModule = false;
*sourceModule = std::move(result);
sourceModule->environmentName = environmentName;
sourceNode->name = sourceModule->name;
sourceNode->humanReadableName = sourceModule->humanReadableName;
sourceNode->requireSet.clear();
sourceNode->requireLocations.clear();
sourceNode->dirtySourceModule = false;
if (it == sourceNodes.end()) if (it == sourceNodes.end())
{ {
sourceNode.dirtyModule = true; sourceNode->dirtyModule = true;
sourceNode.dirtyModuleForAutocomplete = true; sourceNode->dirtyModuleForAutocomplete = true;
} }
for (const auto& [moduleName, location] : require.requireList) for (const auto& [moduleName, location] : require.requireList)
sourceNode.requireSet.insert(moduleName); sourceNode->requireSet.insert(moduleName);
sourceNode.requireLocations = require.requireList; sourceNode->requireLocations = require.requireList;
return {&sourceNode, &sourceModule}; return {sourceNode.get(), sourceModule.get()};
} }
/** Try to parse a source file into a SourceModule. /** Try to parse a source file into a SourceModule.

View file

@ -17,8 +17,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false)
// This could theoretically be 2000 on amd64, but x86 requires this. // This could theoretically be 2000 on amd64, but x86 requires this.
LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200);
LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000);
LUAU_FASTFLAGVARIABLE(LuauNegatedClassTypes, false);
LUAU_FASTFLAGVARIABLE(LuauNegatedTableTypes, false);
LUAU_FASTFLAGVARIABLE(LuauNormalizeBlockedTypes, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeBlockedTypes, false);
LUAU_FASTFLAGVARIABLE(LuauNormalizeMetatableFixes, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeMetatableFixes, false);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
@ -232,15 +230,8 @@ NormalizedType::NormalizedType(NotNull<BuiltinTypes> builtinTypes)
static bool isShallowInhabited(const NormalizedType& norm) static bool isShallowInhabited(const NormalizedType& norm)
{ {
bool inhabitedClasses;
if (FFlag::LuauNegatedClassTypes)
inhabitedClasses = !norm.classes.isNever();
else
inhabitedClasses = !norm.DEPRECATED_classes.empty();
// This test is just a shallow check, for example it returns `true` for `{ p : never }` // This test is just a shallow check, for example it returns `true` for `{ p : never }`
return !get<NeverType>(norm.tops) || !get<NeverType>(norm.booleans) || inhabitedClasses || !get<NeverType>(norm.errors) || return !get<NeverType>(norm.tops) || !get<NeverType>(norm.booleans) || !norm.classes.isNever() || !get<NeverType>(norm.errors) ||
!get<NeverType>(norm.nils) || !get<NeverType>(norm.numbers) || !norm.strings.isNever() || !get<NeverType>(norm.threads) || !get<NeverType>(norm.nils) || !get<NeverType>(norm.numbers) || !norm.strings.isNever() || !get<NeverType>(norm.threads) ||
!norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty(); !norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty();
} }
@ -257,14 +248,8 @@ bool Normalizer::isInhabited(const NormalizedType* norm, std::unordered_set<Type
if (!norm) if (!norm)
return true; return true;
bool inhabitedClasses;
if (FFlag::LuauNegatedClassTypes)
inhabitedClasses = !norm->classes.isNever();
else
inhabitedClasses = !norm->DEPRECATED_classes.empty();
if (!get<NeverType>(norm->tops) || !get<NeverType>(norm->booleans) || !get<NeverType>(norm->errors) || !get<NeverType>(norm->nils) || if (!get<NeverType>(norm->tops) || !get<NeverType>(norm->booleans) || !get<NeverType>(norm->errors) || !get<NeverType>(norm->nils) ||
!get<NeverType>(norm->numbers) || !get<NeverType>(norm->threads) || inhabitedClasses || !norm->strings.isNever() || !get<NeverType>(norm->numbers) || !get<NeverType>(norm->threads) || !norm->classes.isNever() || !norm->strings.isNever() ||
!norm->functions.isNever()) !norm->functions.isNever())
return true; return true;
@ -466,7 +451,7 @@ static bool areNormalizedTables(const TypeIds& tys)
if (!pt) if (!pt)
return false; return false;
if (pt->type == PrimitiveType::Table && FFlag::LuauNegatedTableTypes) if (pt->type == PrimitiveType::Table)
continue; continue;
return false; return false;
@ -475,14 +460,6 @@ static bool areNormalizedTables(const TypeIds& tys)
return true; return true;
} }
static bool areNormalizedClasses(const TypeIds& tys)
{
for (TypeId ty : tys)
if (!get<ClassType>(ty))
return false;
return true;
}
static bool areNormalizedClasses(const NormalizedClassType& tys) static bool areNormalizedClasses(const NormalizedClassType& tys)
{ {
for (const auto& [ty, negations] : tys.classes) for (const auto& [ty, negations] : tys.classes)
@ -567,7 +544,6 @@ static void assertInvariant(const NormalizedType& norm)
LUAU_ASSERT(isNormalizedTop(norm.tops)); LUAU_ASSERT(isNormalizedTop(norm.tops));
LUAU_ASSERT(isNormalizedBoolean(norm.booleans)); LUAU_ASSERT(isNormalizedBoolean(norm.booleans));
LUAU_ASSERT(areNormalizedClasses(norm.DEPRECATED_classes));
LUAU_ASSERT(areNormalizedClasses(norm.classes)); LUAU_ASSERT(areNormalizedClasses(norm.classes));
LUAU_ASSERT(isNormalizedError(norm.errors)); LUAU_ASSERT(isNormalizedError(norm.errors));
LUAU_ASSERT(isNormalizedNil(norm.nils)); LUAU_ASSERT(isNormalizedNil(norm.nils));
@ -629,7 +605,6 @@ void Normalizer::clearNormal(NormalizedType& norm)
norm.tops = builtinTypes->neverType; norm.tops = builtinTypes->neverType;
norm.booleans = builtinTypes->neverType; norm.booleans = builtinTypes->neverType;
norm.classes.resetToNever(); norm.classes.resetToNever();
norm.DEPRECATED_classes.clear();
norm.errors = builtinTypes->neverType; norm.errors = builtinTypes->neverType;
norm.nils = builtinTypes->neverType; norm.nils = builtinTypes->neverType;
norm.numbers = builtinTypes->neverType; norm.numbers = builtinTypes->neverType;
@ -1252,8 +1227,6 @@ void Normalizer::unionTablesWithTable(TypeIds& heres, TypeId there)
void Normalizer::unionTables(TypeIds& heres, const TypeIds& theres) void Normalizer::unionTables(TypeIds& heres, const TypeIds& theres)
{ {
for (TypeId there : theres) for (TypeId there : theres)
{
if (FFlag::LuauNegatedTableTypes)
{ {
if (there == builtinTypes->tableType) if (there == builtinTypes->tableType)
{ {
@ -1266,11 +1239,6 @@ void Normalizer::unionTables(TypeIds& heres, const TypeIds& theres)
unionTablesWithTable(heres, there); unionTablesWithTable(heres, there);
} }
} }
else
{
unionTablesWithTable(heres, there);
}
}
} }
// So why `ignoreSmallerTyvars`? // So why `ignoreSmallerTyvars`?
@ -1320,10 +1288,7 @@ bool Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there,
} }
here.booleans = unionOfBools(here.booleans, there.booleans); here.booleans = unionOfBools(here.booleans, there.booleans);
if (FFlag::LuauNegatedClassTypes)
unionClasses(here.classes, there.classes); unionClasses(here.classes, there.classes);
else
unionClasses(here.DEPRECATED_classes, there.DEPRECATED_classes);
here.errors = (get<NeverType>(there.errors) ? here.errors : there.errors); here.errors = (get<NeverType>(there.errors) ? here.errors : there.errors);
here.nils = (get<NeverType>(there.nils) ? here.nils : there.nils); here.nils = (get<NeverType>(there.nils) ? here.nils : there.nils);
@ -1414,16 +1379,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor
else if (get<TableType>(there) || get<MetatableType>(there)) else if (get<TableType>(there) || get<MetatableType>(there))
unionTablesWithTable(here.tables, there); unionTablesWithTable(here.tables, there);
else if (get<ClassType>(there)) else if (get<ClassType>(there))
{
if (FFlag::LuauNegatedClassTypes)
{
unionClassesWithClass(here.classes, there); unionClassesWithClass(here.classes, there);
}
else
{
unionClassesWithClass(here.DEPRECATED_classes, there);
}
}
else if (get<ErrorType>(there)) else if (get<ErrorType>(there))
here.errors = there; here.errors = there;
else if (const PrimitiveType* ptv = get<PrimitiveType>(there)) else if (const PrimitiveType* ptv = get<PrimitiveType>(there))
@ -1442,7 +1398,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor
{ {
here.functions.resetToTop(); here.functions.resetToTop();
} }
else if (ptv->type == PrimitiveType::Table && FFlag::LuauNegatedTableTypes) else if (ptv->type == PrimitiveType::Table)
{ {
here.tables.clear(); here.tables.clear();
here.tables.insert(there); here.tables.insert(there);
@ -1527,8 +1483,6 @@ std::optional<NormalizedType> Normalizer::negateNormal(const NormalizedType& her
result.booleans = builtinTypes->trueType; result.booleans = builtinTypes->trueType;
} }
if (FFlag::LuauNegatedClassTypes)
{
if (here.classes.isNever()) if (here.classes.isNever())
{ {
resetToTop(builtinTypes, result.classes); resetToTop(builtinTypes, result.classes);
@ -1553,11 +1507,6 @@ std::optional<NormalizedType> Normalizer::negateNormal(const NormalizedType& her
if (!rootNegations.empty()) if (!rootNegations.empty())
result.classes.pushPair(builtinTypes->classType, rootNegations); result.classes.pushPair(builtinTypes->classType, rootNegations);
} }
}
else
{
result.DEPRECATED_classes = negateAll(here.DEPRECATED_classes);
}
result.nils = get<NeverType>(here.nils) ? builtinTypes->nilType : builtinTypes->neverType; result.nils = get<NeverType>(here.nils) ? builtinTypes->nilType : builtinTypes->neverType;
result.numbers = get<NeverType>(here.numbers) ? builtinTypes->numberType : builtinTypes->neverType; result.numbers = get<NeverType>(here.numbers) ? builtinTypes->numberType : builtinTypes->neverType;
@ -1584,15 +1533,12 @@ std::optional<NormalizedType> Normalizer::negateNormal(const NormalizedType& her
* types are not runtime-testable. Thus, we prohibit negation of anything * types are not runtime-testable. Thus, we prohibit negation of anything
* other than `table` and `never`. * other than `table` and `never`.
*/ */
if (FFlag::LuauNegatedTableTypes)
{
if (here.tables.empty()) if (here.tables.empty())
result.tables.insert(builtinTypes->tableType); result.tables.insert(builtinTypes->tableType);
else if (here.tables.size() == 1 && here.tables.front() == builtinTypes->tableType) else if (here.tables.size() == 1 && here.tables.front() == builtinTypes->tableType)
result.tables.clear(); result.tables.clear();
else else
return std::nullopt; return std::nullopt;
}
// TODO: negating tables // TODO: negating tables
// TODO: negating tyvars? // TODO: negating tyvars?
@ -1662,7 +1608,6 @@ void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty)
here.functions.resetToNever(); here.functions.resetToNever();
break; break;
case PrimitiveType::Table: case PrimitiveType::Table:
LUAU_ASSERT(FFlag::LuauNegatedTableTypes);
here.tables.clear(); here.tables.clear();
break; break;
} }
@ -1734,64 +1679,6 @@ TypeId Normalizer::intersectionOfBools(TypeId here, TypeId there)
return there; return there;
} }
void Normalizer::DEPRECATED_intersectClasses(TypeIds& heres, const TypeIds& theres)
{
TypeIds tmp;
for (auto it = heres.begin(); it != heres.end();)
{
const ClassType* hctv = get<ClassType>(*it);
LUAU_ASSERT(hctv);
bool keep = false;
for (TypeId there : theres)
{
const ClassType* tctv = get<ClassType>(there);
LUAU_ASSERT(tctv);
if (isSubclass(hctv, tctv))
{
keep = true;
break;
}
else if (isSubclass(tctv, hctv))
{
keep = false;
tmp.insert(there);
break;
}
}
if (keep)
it++;
else
it = heres.erase(it);
}
heres.insert(tmp.begin(), tmp.end());
}
void Normalizer::DEPRECATED_intersectClassesWithClass(TypeIds& heres, TypeId there)
{
bool foundSuper = false;
const ClassType* tctv = get<ClassType>(there);
LUAU_ASSERT(tctv);
for (auto it = heres.begin(); it != heres.end();)
{
const ClassType* hctv = get<ClassType>(*it);
LUAU_ASSERT(hctv);
if (isSubclass(hctv, tctv))
it++;
else if (isSubclass(tctv, hctv))
{
foundSuper = true;
break;
}
else
it = heres.erase(it);
}
if (foundSuper)
{
heres.clear();
heres.insert(there);
}
}
void Normalizer::intersectClasses(NormalizedClassType& heres, const NormalizedClassType& theres) void Normalizer::intersectClasses(NormalizedClassType& heres, const NormalizedClassType& theres)
{ {
if (theres.isNever()) if (theres.isNever())
@ -2504,15 +2391,7 @@ bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& th
here.booleans = intersectionOfBools(here.booleans, there.booleans); here.booleans = intersectionOfBools(here.booleans, there.booleans);
if (FFlag::LuauNegatedClassTypes)
{
intersectClasses(here.classes, there.classes); intersectClasses(here.classes, there.classes);
}
else
{
DEPRECATED_intersectClasses(here.DEPRECATED_classes, there.DEPRECATED_classes);
}
here.errors = (get<NeverType>(there.errors) ? there.errors : here.errors); here.errors = (get<NeverType>(there.errors) ? there.errors : here.errors);
here.nils = (get<NeverType>(there.nils) ? there.nils : here.nils); here.nils = (get<NeverType>(there.nils) ? there.nils : here.nils);
here.numbers = (get<NeverType>(there.numbers) ? there.numbers : here.numbers); here.numbers = (get<NeverType>(there.numbers) ? there.numbers : here.numbers);
@ -2618,22 +2497,12 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there)
here.tables = std::move(tables); here.tables = std::move(tables);
} }
else if (get<ClassType>(there)) else if (get<ClassType>(there))
{
if (FFlag::LuauNegatedClassTypes)
{ {
NormalizedClassType nct = std::move(here.classes); NormalizedClassType nct = std::move(here.classes);
clearNormal(here); clearNormal(here);
intersectClassesWithClass(nct, there); intersectClassesWithClass(nct, there);
here.classes = std::move(nct); here.classes = std::move(nct);
} }
else
{
TypeIds classes = std::move(here.DEPRECATED_classes);
clearNormal(here);
DEPRECATED_intersectClassesWithClass(classes, there);
here.DEPRECATED_classes = std::move(classes);
}
}
else if (get<ErrorType>(there)) else if (get<ErrorType>(there))
{ {
TypeId errors = here.errors; TypeId errors = here.errors;
@ -2665,10 +2534,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there)
else if (ptv->type == PrimitiveType::Function) else if (ptv->type == PrimitiveType::Function)
here.functions = std::move(functions); here.functions = std::move(functions);
else if (ptv->type == PrimitiveType::Table) else if (ptv->type == PrimitiveType::Table)
{
LUAU_ASSERT(FFlag::LuauNegatedTableTypes);
here.tables = std::move(tables); here.tables = std::move(tables);
}
else else
LUAU_ASSERT(!"Unreachable"); LUAU_ASSERT(!"Unreachable");
} }
@ -2696,7 +2562,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there)
subtractPrimitive(here, ntv->ty); subtractPrimitive(here, ntv->ty);
else if (const SingletonType* stv = get<SingletonType>(t)) else if (const SingletonType* stv = get<SingletonType>(t))
subtractSingleton(here, follow(ntv->ty)); subtractSingleton(here, follow(ntv->ty));
else if (get<ClassType>(t) && FFlag::LuauNegatedClassTypes) else if (get<ClassType>(t))
{ {
const NormalizedType* normal = normalize(t); const NormalizedType* normal = normalize(t);
std::optional<NormalizedType> negated = negateNormal(*normal); std::optional<NormalizedType> negated = negateNormal(*normal);
@ -2730,7 +2596,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there)
LUAU_ASSERT(!"Unimplemented"); LUAU_ASSERT(!"Unimplemented");
} }
} }
else if (get<NeverType>(there) && FFlag::LuauNegatedClassTypes) else if (get<NeverType>(there))
{ {
here.classes.resetToNever(); here.classes.resetToNever();
} }
@ -2756,8 +2622,6 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm)
if (!get<NeverType>(norm.booleans)) if (!get<NeverType>(norm.booleans))
result.push_back(norm.booleans); result.push_back(norm.booleans);
if (FFlag::LuauNegatedClassTypes)
{
if (isTop(builtinTypes, norm.classes)) if (isTop(builtinTypes, norm.classes))
{ {
result.push_back(builtinTypes->classType); result.push_back(builtinTypes->classType);
@ -2799,11 +2663,6 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm)
result.push_back(arena->addType(UnionType{std::move(parts)})); result.push_back(arena->addType(UnionType{std::move(parts)}));
} }
} }
}
else
{
result.insert(result.end(), norm.DEPRECATED_classes.begin(), norm.DEPRECATED_classes.end());
}
if (!get<NeverType>(norm.errors)) if (!get<NeverType>(norm.errors))
result.push_back(norm.errors); result.push_back(norm.errors);

View file

@ -382,8 +382,9 @@ std::optional<TypeLevel> TxnLog::getLevel(TypeId ty) const
TypeId TxnLog::follow(TypeId ty) const TypeId TxnLog::follow(TypeId ty) const
{ {
return Luau::follow(ty, [this](TypeId ty) { return Luau::follow(ty, this, [](const void* ctx, TypeId ty) -> TypeId {
PendingType* state = this->pending(ty); const TxnLog* self = static_cast<const TxnLog*>(ctx);
PendingType* state = self->pending(ty);
if (state == nullptr) if (state == nullptr)
return ty; return ty;
@ -397,8 +398,9 @@ TypeId TxnLog::follow(TypeId ty) const
TypePackId TxnLog::follow(TypePackId tp) const TypePackId TxnLog::follow(TypePackId tp) const
{ {
return Luau::follow(tp, [this](TypePackId tp) { return Luau::follow(tp, this, [](const void* ctx, TypePackId tp) -> TypePackId {
PendingTypePack* state = this->pending(tp); const TxnLog* self = static_cast<const TxnLog*>(ctx);
PendingTypePack* state = self->pending(tp);
if (state == nullptr) if (state == nullptr)
return tp; return tp;

View file

@ -48,28 +48,9 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate); TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
static bool dcrMagicFunctionFind(MagicFunctionCallContext context); static bool dcrMagicFunctionFind(MagicFunctionCallContext context);
TypeId follow(TypeId t) // LUAU_NOINLINE prevents unwrapLazy from being inlined into advance below; advance is important to keep inlineable
static LUAU_NOINLINE TypeId unwrapLazy(LazyType* ltv)
{ {
return follow(t, [](TypeId t) {
return t;
});
}
TypeId follow(TypeId t, std::function<TypeId(TypeId)> mapper)
{
auto advance = [&mapper](TypeId ty) -> std::optional<TypeId> {
if (FFlag::LuauBoundLazyTypes2)
{
TypeId mapped = mapper(ty);
if (auto btv = get<Unifiable::Bound<TypeId>>(mapped))
return btv->boundTo;
if (auto ttv = get<TableType>(mapped))
return ttv->boundTo;
if (auto ltv = getMutable<LazyType>(mapped))
{
TypeId unwrapped = ltv->unwrapped.load(); TypeId unwrapped = ltv->unwrapped.load();
if (unwrapped) if (unwrapped)
@ -85,23 +66,48 @@ TypeId follow(TypeId t, std::function<TypeId(TypeId)> mapper)
throw InternalCompilerError("Lazy Type cannot resolve to another Lazy Type"); throw InternalCompilerError("Lazy Type cannot resolve to another Lazy Type");
return unwrapped; return unwrapped;
} }
TypeId follow(TypeId t)
{
return follow(t, nullptr, [](const void*, TypeId t) -> TypeId {
return t;
});
}
TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId))
{
auto advance = [context, mapper](TypeId ty) -> std::optional<TypeId> {
if (FFlag::LuauBoundLazyTypes2)
{
TypeId mapped = mapper(context, ty);
if (auto btv = get<Unifiable::Bound<TypeId>>(mapped))
return btv->boundTo;
if (auto ttv = get<TableType>(mapped))
return ttv->boundTo;
if (auto ltv = getMutable<LazyType>(mapped))
return unwrapLazy(ltv);
return std::nullopt; return std::nullopt;
} }
else else
{ {
if (auto btv = get<Unifiable::Bound<TypeId>>(mapper(ty))) if (auto btv = get<Unifiable::Bound<TypeId>>(mapper(context, ty)))
return btv->boundTo; return btv->boundTo;
else if (auto ttv = get<TableType>(mapper(ty))) else if (auto ttv = get<TableType>(mapper(context, ty)))
return ttv->boundTo; return ttv->boundTo;
else else
return std::nullopt; return std::nullopt;
} }
}; };
auto force = [&mapper](TypeId ty) { auto force = [context, mapper](TypeId ty) {
if (auto ltv = get_if<LazyType>(&mapper(ty)->ty)) TypeId mapped = mapper(context, ty);
if (auto ltv = get_if<LazyType>(&mapped->ty))
{ {
TypeId res = ltv->thunk_DEPRECATED(); TypeId res = ltv->thunk_DEPRECATED();
if (get<LazyType>(res)) if (get<LazyType>(res))
@ -120,6 +126,12 @@ TypeId follow(TypeId t, std::function<TypeId(TypeId)> mapper)
else else
return t; return t;
if (FFlag::LuauBoundLazyTypes2)
{
if (!advance(cycleTester)) // Short circuit traversal for the rather common case when advance(advance(t)) == null
return cycleTester;
}
while (true) while (true)
{ {
if (!FFlag::LuauBoundLazyTypes2) if (!FFlag::LuauBoundLazyTypes2)

View file

@ -22,8 +22,6 @@
LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_FASTFLAG(DebugLuauDontReduceTypes) LUAU_FASTFLAG(DebugLuauDontReduceTypes)
LUAU_FASTFLAG(LuauNegatedClassTypes)
namespace Luau namespace Luau
{ {
@ -519,18 +517,39 @@ struct TypeChecker2
auto [minCount, maxCount] = getParameterExtents(TxnLog::empty(), iterFtv->argTypes, /*includeHiddenVariadics*/ true); auto [minCount, maxCount] = getParameterExtents(TxnLog::empty(), iterFtv->argTypes, /*includeHiddenVariadics*/ true);
if (minCount > 2) if (minCount > 2)
reportError(CountMismatch{2, std::nullopt, minCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); {
if (isMm)
reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values));
else
reportError(GenericError{"for..in loops must be passed (next[, table[, state]])"}, getLocation(forInStatement->values));
}
if (maxCount && *maxCount < 2) if (maxCount && *maxCount < 2)
reportError(CountMismatch{2, std::nullopt, *maxCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); {
if (isMm)
reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values));
else
reportError(GenericError{"for..in loops must be passed (next[, table[, state]])"}, getLocation(forInStatement->values));
}
TypePack flattenedArgTypes = extendTypePack(arena, builtinTypes, iterFtv->argTypes, 2); TypePack flattenedArgTypes = extendTypePack(arena, builtinTypes, iterFtv->argTypes, 2);
size_t firstIterationArgCount = iterTys.empty() ? 0 : iterTys.size() - 1; size_t firstIterationArgCount = iterTys.empty() ? 0 : iterTys.size() - 1;
size_t actualArgCount = expectedVariableTypes.head.size(); size_t actualArgCount = expectedVariableTypes.head.size();
if (firstIterationArgCount < minCount) if (firstIterationArgCount < minCount)
{
if (isMm)
reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values));
else
reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location);
}
else if (actualArgCount < minCount) else if (actualArgCount < minCount)
reportError(CountMismatch{2, std::nullopt, actualArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); {
if (isMm)
reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values));
else
reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location);
}
if (iterTys.size() >= 2 && flattenedArgTypes.head.size() > 0) if (iterTys.size() >= 2 && flattenedArgTypes.head.size() > 0)
{ {
@ -841,125 +860,31 @@ struct TypeChecker2
// TODO! // TODO!
} }
ErrorVec visitOverload(AstExprCall* call, NotNull<const FunctionType> overloadFunctionType, const std::vector<Location>& argLocs,
TypePackId expectedArgTypes, TypePackId expectedRetType)
{
ErrorVec overloadErrors =
tryUnify(stack.back(), call->location, overloadFunctionType->retTypes, expectedRetType, CountMismatch::FunctionResult);
size_t argIndex = 0;
auto inferredArgIt = begin(overloadFunctionType->argTypes);
auto expectedArgIt = begin(expectedArgTypes);
while (inferredArgIt != end(overloadFunctionType->argTypes) && expectedArgIt != end(expectedArgTypes))
{
Location argLoc = (argIndex >= argLocs.size()) ? argLocs.back() : argLocs[argIndex];
ErrorVec argErrors = tryUnify(stack.back(), argLoc, *expectedArgIt, *inferredArgIt);
for (TypeError e : argErrors)
overloadErrors.emplace_back(e);
++argIndex;
++inferredArgIt;
++expectedArgIt;
}
// piggyback on the unifier for arity checking, but we can't do this for checking the actual arguments since the locations would be bad
ErrorVec argumentErrors = tryUnify(stack.back(), call->location, expectedArgTypes, overloadFunctionType->argTypes);
for (TypeError e : argumentErrors)
if (get<CountMismatch>(e) != nullptr)
overloadErrors.emplace_back(std::move(e));
return overloadErrors;
}
void reportOverloadResolutionErrors(AstExprCall* call, std::vector<TypeId> overloads, TypePackId expectedArgTypes,
const std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<std::pair<ErrorVec, TypeId>> overloadsErrors)
{
if (overloads.size() == 1)
{
reportErrors(std::get<0>(overloadsErrors.front()));
return;
}
std::vector<TypeId> overloadTypes = overloadsThatMatchArgCount;
if (overloadsThatMatchArgCount.size() == 0)
{
reportError(GenericError{"No overload for function accepts " + std::to_string(size(expectedArgTypes)) + " arguments."}, call->location);
// If no overloads match argument count, just list all overloads.
overloadTypes = overloads;
}
else
{
// Report errors of the first argument-count-matching, but failing overload
TypeId overload = overloadsThatMatchArgCount[0];
// Remove the overload we are reporting errors about from the list of alternatives
overloadTypes.erase(std::remove(overloadTypes.begin(), overloadTypes.end(), overload), overloadTypes.end());
const FunctionType* ftv = get<FunctionType>(overload);
LUAU_ASSERT(ftv); // overload must be a function type here
auto error = std::find_if(overloadsErrors.begin(), overloadsErrors.end(), [overload](const std::pair<ErrorVec, TypeId>& e) {
return overload == e.second;
});
LUAU_ASSERT(error != overloadsErrors.end());
reportErrors(std::get<0>(*error));
// If only one overload matched, we don't need this error because we provided the previous errors.
if (overloadsThatMatchArgCount.size() == 1)
return;
}
std::string s;
for (size_t i = 0; i < overloadTypes.size(); ++i)
{
TypeId overload = follow(overloadTypes[i]);
if (i > 0)
s += "; ";
if (i > 0 && i == overloadTypes.size() - 1)
s += "and ";
s += toString(overload);
}
if (overloadsThatMatchArgCount.size() == 0)
reportError(ExtraInformation{"Available overloads: " + s}, call->func->location);
else
reportError(ExtraInformation{"Other overloads are also not viable: " + s}, call->func->location);
}
// Note: this is intentionally separated from `visit(AstExprCall*)` for stack allocation purposes. // Note: this is intentionally separated from `visit(AstExprCall*)` for stack allocation purposes.
void visitCall(AstExprCall* call) void visitCall(AstExprCall* call)
{ {
TypeArena* arena = &testArena; TypePackId expectedRetType = lookupExpectedPack(call, testArena);
Instantiation instantiation{TxnLog::empty(), arena, TypeLevel{}, stack.back()};
TypePackId expectedRetType = lookupExpectedPack(call, *arena);
TypeId functionType = lookupType(call->func);
TypeId testFunctionType = functionType;
TypePack args; TypePack args;
std::vector<Location> argLocs; std::vector<Location> argLocs;
argLocs.reserve(call->args.size + 1); argLocs.reserve(call->args.size + 1);
if (get<AnyType>(functionType) || get<ErrorType>(functionType) || get<NeverType>(functionType)) TypeId* maybeOriginalCallTy = module->astOriginalCallTypes.find(call);
TypeId* maybeSelectedOverload = module->astOverloadResolvedTypes.find(call);
if (!maybeOriginalCallTy)
return; return;
else if (std::optional<TypeId> callMm = findMetatableEntry(builtinTypes, module->errors, functionType, "__call", call->func->location))
TypeId originalCallTy = follow(*maybeOriginalCallTy);
std::vector<TypeId> overloads = flattenIntersection(originalCallTy);
if (get<AnyType>(originalCallTy) || get<ErrorType>(originalCallTy) || get<NeverType>(originalCallTy))
return;
else if (std::optional<TypeId> callMm = findMetatableEntry(builtinTypes, module->errors, originalCallTy, "__call", call->func->location))
{ {
if (get<FunctionType>(follow(*callMm))) if (get<FunctionType>(follow(*callMm)))
{ {
if (std::optional<TypeId> instantiatedCallMm = instantiation.substitute(*callMm)) args.head.push_back(originalCallTy);
{
args.head.push_back(functionType);
argLocs.push_back(call->func->location); argLocs.push_back(call->func->location);
testFunctionType = follow(*instantiatedCallMm);
}
else
{
reportError(UnificationTooComplex{}, call->func->location);
return;
}
} }
else else
{ {
@ -969,29 +894,16 @@ struct TypeChecker2
return; return;
} }
} }
else if (get<FunctionType>(functionType)) else if (get<FunctionType>(originalCallTy) || get<IntersectionType>(originalCallTy))
{ {
if (std::optional<TypeId> instantiatedFunctionType = instantiation.substitute(functionType))
{
testFunctionType = *instantiatedFunctionType;
} }
else else if (auto utv = get<UnionType>(originalCallTy))
{
reportError(UnificationTooComplex{}, call->func->location);
return;
}
}
else if (auto itv = get<IntersectionType>(functionType))
{
// We do nothing here because we'll flatten the intersection later, but we don't want to report it as a non-function.
}
else if (auto utv = get<UnionType>(functionType))
{ {
// Sometimes it's okay to call a union of functions, but only if all of the functions are the same. // Sometimes it's okay to call a union of functions, but only if all of the functions are the same.
// Another scenario we might run into it is if the union has a nil member. In this case, we want to throw an error // Another scenario we might run into it is if the union has a nil member. In this case, we want to throw an error
if (isOptional(functionType)) if (isOptional(originalCallTy))
{ {
reportError(OptionalValueAccess{functionType}, call->location); reportError(OptionalValueAccess{originalCallTy}, call->location);
return; return;
} }
std::optional<TypeId> fst; std::optional<TypeId> fst;
@ -1001,7 +913,7 @@ struct TypeChecker2
fst = follow(ty); fst = follow(ty);
else if (fst != follow(ty)) else if (fst != follow(ty))
{ {
reportError(CannotCallNonFunction{functionType}, call->func->location); reportError(CannotCallNonFunction{originalCallTy}, call->func->location);
return; return;
} }
} }
@ -1009,19 +921,16 @@ struct TypeChecker2
if (!fst) if (!fst)
ice->ice("UnionType had no elements, so fst is nullopt?"); ice->ice("UnionType had no elements, so fst is nullopt?");
if (std::optional<TypeId> instantiatedFunctionType = instantiation.substitute(*fst)) originalCallTy = follow(*fst);
if (!get<FunctionType>(originalCallTy))
{ {
testFunctionType = *instantiatedFunctionType; reportError(CannotCallNonFunction{originalCallTy}, call->func->location);
}
else
{
reportError(UnificationTooComplex{}, call->func->location);
return; return;
} }
} }
else else
{ {
reportError(CannotCallNonFunction{functionType}, call->func->location); reportError(CannotCallNonFunction{originalCallTy}, call->func->location);
return; return;
} }
@ -1054,63 +963,134 @@ struct TypeChecker2
args.head.push_back(builtinTypes->anyType); args.head.push_back(builtinTypes->anyType);
} }
TypePackId expectedArgTypes = arena->addTypePack(args); TypePackId expectedArgTypes = testArena.addTypePack(args);
std::vector<TypeId> overloads = flattenIntersection(testFunctionType); if (maybeSelectedOverload)
std::vector<std::pair<ErrorVec, TypeId>> overloadsErrors; {
overloadsErrors.reserve(overloads.size()); // This overload might not work still: the constraint solver will
// pass the type checker an instantiated function type that matches
// in arity, but not in subtyping, in order to allow the type
// checker to report better error messages.
std::vector<TypeId> overloadsThatMatchArgCount; TypeId selectedOverload = follow(*maybeSelectedOverload);
const FunctionType* ftv;
if (get<AnyType>(selectedOverload) || get<ErrorType>(selectedOverload) || get<NeverType>(selectedOverload))
{
return;
}
else if (const FunctionType* overloadFtv = get<FunctionType>(selectedOverload))
{
ftv = overloadFtv;
}
else
{
reportError(CannotCallNonFunction{selectedOverload}, call->func->location);
return;
}
LUAU_ASSERT(ftv);
reportErrors(tryUnify(stack.back(), call->location, ftv->retTypes, expectedRetType, CountMismatch::Context::Return));
auto it = begin(expectedArgTypes);
size_t i = 0;
std::vector<TypeId> slice;
for (TypeId arg : ftv->argTypes)
{
if (it == end(expectedArgTypes))
{
slice.push_back(arg);
continue;
}
TypeId expectedArg = *it;
Location argLoc = argLocs.at(i >= argLocs.size() ? argLocs.size() - 1 : i);
reportErrors(tryUnify(stack.back(), argLoc, expectedArg, arg));
++it;
++i;
}
if (slice.size() > 0 && it == end(expectedArgTypes))
{
if (auto tail = it.tail())
{
TypePackId remainingArgs = testArena.addTypePack(TypePack{std::move(slice), std::nullopt});
reportErrors(tryUnify(stack.back(), argLocs.back(), *tail, remainingArgs));
}
}
// We do not need to do an arity test because this overload was
// selected based on its arity already matching.
}
else
{
// No overload worked, even when instantiated. We need to filter the
// set of overloads to those that match the arity of the incoming
// argument set, and then report only those as not matching.
std::vector<TypeId> arityMatchingOverloads;
ErrorVec empty;
for (TypeId overload : overloads) for (TypeId overload : overloads)
{ {
overload = follow(overload); overload = follow(overload);
if (const FunctionType* ftv = get<FunctionType>(overload))
const FunctionType* overloadFn = get<FunctionType>(overload);
if (!overloadFn)
{ {
reportError(CannotCallNonFunction{overload}, call->func->location); if (size(ftv->argTypes) == size(expectedArgTypes))
return; {
arityMatchingOverloads.push_back(overload);
}
}
else if (const std::optional<TypeId> callMm = findMetatableEntry(builtinTypes, empty, overload, "__call", call->location))
{
if (const FunctionType* ftv = get<FunctionType>(follow(*callMm)))
{
if (size(ftv->argTypes) == size(expectedArgTypes))
{
arityMatchingOverloads.push_back(overload);
}
} }
else else
{ {
// We may have to instantiate the overload in order for it to typecheck. reportError(CannotCallNonFunction{}, call->location);
if (std::optional<TypeId> instantiatedFunctionType = instantiation.substitute(overload)) }
}
}
if (arityMatchingOverloads.size() == 0)
{ {
overloadFn = get<FunctionType>(*instantiatedFunctionType); reportError(
GenericError{"No overload for function accepts " + std::to_string(size(expectedArgTypes)) + " arguments."}, call->location);
} }
else else
{ {
overloadsErrors.emplace_back(std::vector{TypeError{call->func->location, UnificationTooComplex{}}}, overload); // We have handled the case of a singular arity-matching
return; // overload above, in the case where an overload was selected.
} // LUAU_ASSERT(arityMatchingOverloads.size() > 1);
reportError(GenericError{"None of the overloads for function that accept " + std::to_string(size(expectedArgTypes)) +
" arguments are compatible."},
call->location);
} }
ErrorVec overloadErrors = visitOverload(call, NotNull{overloadFn}, argLocs, expectedArgTypes, expectedRetType); std::string s;
if (overloadErrors.empty()) std::vector<TypeId>& stringifyOverloads = arityMatchingOverloads.size() == 0 ? overloads : arityMatchingOverloads;
return; for (size_t i = 0; i < stringifyOverloads.size(); ++i)
bool argMismatch = false;
for (auto error : overloadErrors)
{ {
CountMismatch* cm = get<CountMismatch>(error); TypeId overload = follow(stringifyOverloads[i]);
if (!cm)
continue;
if (cm->context == CountMismatch::Arg) if (i > 0)
{ s += "; ";
argMismatch = true;
break; if (i > 0 && i == stringifyOverloads.size() - 1)
} s += "and ";
s += toString(overload);
} }
if (!argMismatch) reportError(ExtraInformation{"Available overloads: " + s}, call->func->location);
overloadsThatMatchArgCount.push_back(overload);
overloadsErrors.emplace_back(std::move(overloadErrors), overload);
} }
reportOverloadResolutionErrors(call, overloads, expectedArgTypes, overloadsThatMatchArgCount, overloadsErrors);
} }
void visit(AstExprCall* call) void visit(AstExprCall* call)
@ -2077,18 +2057,10 @@ struct TypeChecker2
fetch(norm->tops); fetch(norm->tops);
fetch(norm->booleans); fetch(norm->booleans);
if (FFlag::LuauNegatedClassTypes)
{
for (const auto& [ty, _negations] : norm->classes.classes) for (const auto& [ty, _negations] : norm->classes.classes)
{ {
fetch(ty); fetch(ty);
} }
}
else
{
for (TypeId ty : norm->DEPRECATED_classes)
fetch(ty);
}
fetch(norm->errors); fetch(norm->errors);
fetch(norm->nils); fetch(norm->nils);
fetch(norm->numbers); fetch(norm->numbers);

View file

@ -35,7 +35,6 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false)
LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false)
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAG(LuauNegatedClassTypes)
LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false)
LUAU_FASTFLAG(LuauUninhabitedSubAnything2) LUAU_FASTFLAG(LuauUninhabitedSubAnything2)
LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure) LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure)
@ -1701,7 +1700,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea
void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass)
{ {
std::optional<TypeId> superTy = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; std::optional<TypeId> superTy = std::make_optional(builtinTypes->classType);
if (declaredClass.superName) if (declaredClass.superName)
{ {
Name superName = Name(declaredClass.superName->value); Name superName = Name(declaredClass.superName->value);
@ -5968,17 +5967,13 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r
TypeId type = follow(typeFun->type); TypeId type = follow(typeFun->type);
// You cannot refine to the top class type. // You cannot refine to the top class type.
if (FFlag::LuauNegatedClassTypes)
{
if (type == builtinTypes->classType) if (type == builtinTypes->classType)
{ {
return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope));
} }
}
// We're only interested in the root class of any classes. // We're only interested in the root class of any classes.
if (auto ctv = get<ClassType>(type); if (auto ctv = get<ClassType>(type); !ctv || ctv->parent != builtinTypes->classType)
!ctv || (FFlag::LuauNegatedClassTypes ? (ctv->parent != builtinTypes->classType) : (ctv->parent != std::nullopt)))
return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope));
// This probably hints at breaking out type filtering functions from the predicate solver so that typeof is not tightly coupled with IsA. // This probably hints at breaking out type filtering functions from the predicate solver so that typeof is not tightly coupled with IsA.

View file

@ -255,15 +255,17 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs)
TypePackId follow(TypePackId tp) TypePackId follow(TypePackId tp)
{ {
return follow(tp, [](TypePackId t) { return follow(tp, nullptr, [](const void*, TypePackId t) {
return t; return t;
}); });
} }
TypePackId follow(TypePackId tp, std::function<TypePackId(TypePackId)> mapper) TypePackId follow(TypePackId tp, const void* context, TypePackId (*mapper)(const void*, TypePackId))
{ {
auto advance = [&mapper](TypePackId ty) -> std::optional<TypePackId> { auto advance = [context, mapper](TypePackId ty) -> std::optional<TypePackId> {
if (const Unifiable::Bound<TypePackId>* btv = get<Unifiable::Bound<TypePackId>>(mapper(ty))) TypePackId mapped = mapper(context, ty);
if (const Unifiable::Bound<TypePackId>* btv = get<Unifiable::Bound<TypePackId>>(mapped))
return btv->boundTo; return btv->boundTo;
else else
return std::nullopt; return std::nullopt;
@ -275,6 +277,9 @@ TypePackId follow(TypePackId tp, std::function<TypePackId(TypePackId)> mapper)
else else
return tp; return tp;
if (!advance(cycleTester)) // Short circuit traversal for the rather common case when advance(advance(t)) == null
return cycleTester;
while (true) while (true)
{ {
auto a1 = advance(tp); auto a1 = advance(tp);

View file

@ -26,8 +26,6 @@ LUAU_FASTFLAGVARIABLE(LuauOccursIsntAlwaysFailure, false)
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauNormalizeBlockedTypes) LUAU_FASTFLAG(LuauNormalizeBlockedTypes)
LUAU_FASTFLAG(LuauNegatedClassTypes)
LUAU_FASTFLAG(LuauNegatedTableTypes)
namespace Luau namespace Luau
{ {
@ -344,6 +342,19 @@ std::optional<TypeError> hasUnificationTooComplex(const ErrorVec& errors)
return *it; return *it;
} }
std::optional<TypeError> hasCountMismatch(const ErrorVec& errors)
{
auto isCountMismatch = [](const TypeError& te) {
return nullptr != get<CountMismatch>(te);
};
auto it = std::find_if(errors.begin(), errors.end(), isCountMismatch);
if (it == errors.end())
return std::nullopt;
else
return *it;
}
// Used for tagged union matching heuristic, returns first singleton type field // Used for tagged union matching heuristic, returns first singleton type field
static std::optional<std::pair<Luau::Name, const SingletonType*>> getTableMatchTag(TypeId type) static std::optional<std::pair<Luau::Name, const SingletonType*>> getTableMatchTag(TypeId type)
{ {
@ -620,7 +631,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
// Ok. Do nothing. forall functions F, F <: function // Ok. Do nothing. forall functions F, F <: function
} }
else if (FFlag::LuauNegatedTableTypes && isPrim(superTy, PrimitiveType::Table) && (get<TableType>(subTy) || get<MetatableType>(subTy))) else if (isPrim(superTy, PrimitiveType::Table) && (get<TableType>(subTy) || get<MetatableType>(subTy)))
{ {
// Ok, do nothing: forall tables T, T <: table // Ok, do nothing: forall tables T, T <: table
} }
@ -1183,8 +1194,6 @@ void Unifier::tryUnifyNormalizedTypes(
if (!get<PrimitiveType>(superNorm.errors)) if (!get<PrimitiveType>(superNorm.errors))
return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()});
if (FFlag::LuauNegatedClassTypes)
{
for (const auto& [subClass, _] : subNorm.classes.classes) for (const auto& [subClass, _] : subNorm.classes.classes)
{ {
bool found = false; bool found = false;
@ -1240,33 +1249,13 @@ void Unifier::tryUnifyNormalizedTypes(
return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()});
} }
} }
}
else
{
for (TypeId subClass : subNorm.DEPRECATED_classes)
{
bool found = false;
const ClassType* subCtv = get<ClassType>(subClass);
for (TypeId superClass : superNorm.DEPRECATED_classes)
{
const ClassType* superCtv = get<ClassType>(superClass);
if (isSubclass(subCtv, superCtv))
{
found = true;
break;
}
}
if (!found)
return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()});
}
}
for (TypeId subTable : subNorm.tables) for (TypeId subTable : subNorm.tables)
{ {
bool found = false; bool found = false;
for (TypeId superTable : superNorm.tables) for (TypeId superTable : superNorm.tables)
{ {
if (FFlag::LuauNegatedTableTypes && isPrim(superTable, PrimitiveType::Table)) if (isPrim(superTable, PrimitiveType::Table))
{ {
found = true; found = true;
break; break;

View file

@ -14,8 +14,6 @@ enum class Mode
struct ParseOptions struct ParseOptions
{ {
bool allowTypeAnnotations = true;
bool supportContinueStatement = true;
bool allowDeclarationSyntax = false; bool allowDeclarationSyntax = false;
bool captureComments = false; bool captureComments = false;
}; };

View file

@ -14,8 +14,6 @@
LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000)
LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
LUAU_FASTFLAGVARIABLE(LuauParserErrorsOnMissingDefaultTypePackArgument, false)
#define ERROR_INVALID_INTERP_DOUBLE_BRACE "Double braces are not permitted within interpolated strings. Did you mean '\\{'?" #define ERROR_INVALID_INTERP_DOUBLE_BRACE "Double braces are not permitted within interpolated strings. Did you mean '\\{'?"
namespace Luau namespace Luau
@ -327,8 +325,6 @@ AstStat* Parser::parseStat()
// we know this isn't a call or an assignment; therefore it must be a context-sensitive keyword such as `type` or `continue` // we know this isn't a call or an assignment; therefore it must be a context-sensitive keyword such as `type` or `continue`
AstName ident = getIdentifier(expr); AstName ident = getIdentifier(expr);
if (options.allowTypeAnnotations)
{
if (ident == "type") if (ident == "type")
return parseTypeAlias(expr->location, /* exported= */ false); return parseTypeAlias(expr->location, /* exported= */ false);
@ -337,12 +333,11 @@ AstStat* Parser::parseStat()
nextLexeme(); nextLexeme();
return parseTypeAlias(expr->location, /* exported= */ true); return parseTypeAlias(expr->location, /* exported= */ true);
} }
}
if (options.supportContinueStatement && ident == "continue") if (ident == "continue")
return parseContinue(expr->location); return parseContinue(expr->location);
if (options.allowTypeAnnotations && options.allowDeclarationSyntax) if (options.allowDeclarationSyntax)
{ {
if (ident == "declare") if (ident == "declare")
return parseDeclaration(expr->location); return parseDeclaration(expr->location);
@ -1123,7 +1118,7 @@ std::tuple<bool, Location, AstTypePack*> Parser::parseBindingList(TempVector<Bin
AstType* Parser::parseOptionalType() AstType* Parser::parseOptionalType()
{ {
if (options.allowTypeAnnotations && lexer.current().type == ':') if (lexer.current().type == ':')
{ {
nextLexeme(); nextLexeme();
return parseType(); return parseType();
@ -1175,7 +1170,7 @@ AstTypePack* Parser::parseTypeList(TempVector<AstType*>& result, TempVector<std:
std::optional<AstTypeList> Parser::parseOptionalReturnType() std::optional<AstTypeList> Parser::parseOptionalReturnType()
{ {
if (options.allowTypeAnnotations && (lexer.current().type == ':' || lexer.current().type == Lexeme::SkinnyArrow)) if (lexer.current().type == ':' || lexer.current().type == Lexeme::SkinnyArrow)
{ {
if (lexer.current().type == Lexeme::SkinnyArrow) if (lexer.current().type == Lexeme::SkinnyArrow)
report(lexer.current().location, "Function return type annotations are written after ':' instead of '->'"); report(lexer.current().location, "Function return type annotations are written after ':' instead of '->'");
@ -2056,7 +2051,7 @@ AstExpr* Parser::parseAssertionExpr()
Location start = lexer.current().location; Location start = lexer.current().location;
AstExpr* expr = parseSimpleExpr(); AstExpr* expr = parseSimpleExpr();
if (options.allowTypeAnnotations && lexer.current().type == Lexeme::DoubleColon) if (lexer.current().type == Lexeme::DoubleColon)
{ {
nextLexeme(); nextLexeme();
AstType* annotation = parseType(); AstType* annotation = parseType();
@ -2449,24 +2444,13 @@ std::pair<AstArray<AstGenericType>, AstArray<AstGenericTypePack>> Parser::parseG
seenDefault = true; seenDefault = true;
nextLexeme(); nextLexeme();
Lexeme packBegin = lexer.current();
if (shouldParseTypePack(lexer)) if (shouldParseTypePack(lexer))
{ {
AstTypePack* typePack = parseTypePack(); AstTypePack* typePack = parseTypePack();
namePacks.push_back({name, nameLocation, typePack}); namePacks.push_back({name, nameLocation, typePack});
} }
else if (!FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument && lexer.current().type == '(') else
{
auto [type, typePack] = parseTypeOrPack();
if (type)
report(Location(packBegin.location.begin, lexer.previousLocation().end), "Expected type pack after '=', got type");
namePacks.push_back({name, nameLocation, typePack});
}
else if (FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument)
{ {
auto [type, typePack] = parseTypeOrPack(); auto [type, typePack] = parseTypeOrPack();

View file

@ -9,6 +9,13 @@
#include "FileUtils.h" #include "FileUtils.h"
#include "Flags.h" #include "Flags.h"
#include <condition_variable>
#include <functional>
#include <mutex>
#include <queue>
#include <thread>
#include <utility>
#ifdef CALLGRIND #ifdef CALLGRIND
#include <valgrind/callgrind.h> #include <valgrind/callgrind.h>
#endif #endif
@ -64,26 +71,29 @@ static void reportWarning(ReportFormat format, const char* name, const Luau::Lin
report(format, name, warning.location, Luau::LintWarning::getName(warning.code), warning.text.c_str()); report(format, name, warning.location, Luau::LintWarning::getName(warning.code), warning.text.c_str());
} }
static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat format, bool annotate) static bool reportModuleResult(Luau::Frontend& frontend, const Luau::ModuleName& name, ReportFormat format, bool annotate)
{ {
Luau::CheckResult cr; std::optional<Luau::CheckResult> cr = frontend.getCheckResult(name, false);
if (frontend.isDirty(name)) if (!cr)
cr = frontend.check(name);
if (!frontend.getSourceModule(name))
{ {
fprintf(stderr, "Error opening %s\n", name); fprintf(stderr, "Failed to find result for %s\n", name.c_str());
return false; return false;
} }
for (auto& error : cr.errors) if (!frontend.getSourceModule(name))
{
fprintf(stderr, "Error opening %s\n", name.c_str());
return false;
}
for (auto& error : cr->errors)
reportError(frontend, format, error); reportError(frontend, format, error);
std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(name); std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(name);
for (auto& error : cr.lintResult.errors) for (auto& error : cr->lintResult.errors)
reportWarning(format, humanReadableName.c_str(), error); reportWarning(format, humanReadableName.c_str(), error);
for (auto& warning : cr.lintResult.warnings) for (auto& warning : cr->lintResult.warnings)
reportWarning(format, humanReadableName.c_str(), warning); reportWarning(format, humanReadableName.c_str(), warning);
if (annotate) if (annotate)
@ -98,7 +108,7 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat
printf("%s", annotated.c_str()); printf("%s", annotated.c_str());
} }
return cr.errors.empty() && cr.lintResult.errors.empty(); return cr->errors.empty() && cr->lintResult.errors.empty();
} }
static void displayHelp(const char* argv0) static void displayHelp(const char* argv0)
@ -216,6 +226,70 @@ struct CliConfigResolver : Luau::ConfigResolver
} }
}; };
struct TaskScheduler
{
TaskScheduler(unsigned threadCount)
: threadCount(threadCount)
{
for (unsigned i = 0; i < threadCount; i++)
{
workers.emplace_back([this] {
workerFunction();
});
}
}
~TaskScheduler()
{
for (unsigned i = 0; i < threadCount; i++)
push({});
for (std::thread& worker : workers)
worker.join();
}
std::function<void()> pop()
{
std::unique_lock guard(mtx);
cv.wait(guard, [this] {
return !tasks.empty();
});
std::function<void()> task = tasks.front();
tasks.pop();
return task;
}
void push(std::function<void()> task)
{
{
std::unique_lock guard(mtx);
tasks.push(std::move(task));
}
cv.notify_one();
}
static unsigned getThreadCount()
{
return std::max(std::thread::hardware_concurrency(), 1u);
}
private:
void workerFunction()
{
while (std::function<void()> task = pop())
task();
}
unsigned threadCount = 1;
std::mutex mtx;
std::condition_variable cv;
std::vector<std::thread> workers;
std::queue<std::function<void()>> tasks;
};
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
Luau::assertHandler() = assertionHandler; Luau::assertHandler() = assertionHandler;
@ -231,6 +305,7 @@ int main(int argc, char** argv)
ReportFormat format = ReportFormat::Default; ReportFormat format = ReportFormat::Default;
Luau::Mode mode = Luau::Mode::Nonstrict; Luau::Mode mode = Luau::Mode::Nonstrict;
bool annotate = false; bool annotate = false;
int threadCount = 0;
for (int i = 1; i < argc; ++i) for (int i = 1; i < argc; ++i)
{ {
@ -249,6 +324,8 @@ int main(int argc, char** argv)
FFlag::DebugLuauTimeTracing.value = true; FFlag::DebugLuauTimeTracing.value = true;
else if (strncmp(argv[i], "--fflags=", 9) == 0) else if (strncmp(argv[i], "--fflags=", 9) == 0)
setLuauFlags(argv[i] + 9); setLuauFlags(argv[i] + 9);
else if (strncmp(argv[i], "-j", 2) == 0)
threadCount = strtol(argv[i] + 2, nullptr, 10);
} }
#if !defined(LUAU_ENABLE_TIME_TRACE) #if !defined(LUAU_ENABLE_TIME_TRACE)
@ -276,10 +353,28 @@ int main(int argc, char** argv)
std::vector<std::string> files = getSourceFiles(argc, argv); std::vector<std::string> files = getSourceFiles(argc, argv);
for (const std::string& path : files)
frontend.queueModuleCheck(path);
std::vector<Luau::ModuleName> checkedModules;
// If thread count is not set, try to use HW thread count, but with an upper limit
// When we improve scalability of typechecking, upper limit can be adjusted/removed
if (threadCount <= 0)
threadCount = std::min(TaskScheduler::getThreadCount(), 8u);
{
TaskScheduler scheduler(threadCount);
checkedModules = frontend.checkQueuedModules(std::nullopt, [&](std::function<void()> f) {
scheduler.push(std::move(f));
});
}
int failed = 0; int failed = 0;
for (const std::string& path : files) for (const Luau::ModuleName& name : checkedModules)
failed += !analyzeFile(frontend, path.c_str(), format, annotate); failed += !reportModuleResult(frontend, name, format, annotate);
if (!configResolver.configErrors.empty()) if (!configResolver.configErrors.empty())
{ {

View file

@ -64,8 +64,6 @@ int main(int argc, char** argv)
Luau::ParseOptions options; Luau::ParseOptions options;
options.captureComments = true; options.captureComments = true;
options.supportContinueStatement = true;
options.allowTypeAnnotations = true;
options.allowDeclarationSyntax = true; options.allowDeclarationSyntax = true;
Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options); Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options);

View file

@ -35,6 +35,8 @@ struct RegisterSet
uint8_t varargStart = 0; uint8_t varargStart = 0;
}; };
void requireVariadicSequence(RegisterSet& sourceRs, const RegisterSet& defRs, uint8_t varargStart);
struct CfgInfo struct CfgInfo
{ {
std::vector<uint32_t> predecessors; std::vector<uint32_t> predecessors;
@ -43,10 +45,22 @@ struct CfgInfo
std::vector<uint32_t> successors; std::vector<uint32_t> successors;
std::vector<uint32_t> successorsOffsets; std::vector<uint32_t> successorsOffsets;
// VM registers that are live when the block is entered
// Additionally, an active variadic sequence can exist at the entry of the block
std::vector<RegisterSet> in; std::vector<RegisterSet> in;
// VM registers that are defined inside the block
// It can also contain a variadic sequence definition if that hasn't been consumed inside the block
// Note that this means that checking 'def' set might not be enough to say that register has not been written to
std::vector<RegisterSet> def; std::vector<RegisterSet> def;
// VM registers that are coming out from the block
// These might be registers that are defined inside the block or have been defined at the entry of the block
// Additionally, an active variadic sequence can exist at the exit of the block
std::vector<RegisterSet> out; std::vector<RegisterSet> out;
// VM registers captured by nested closures
// This set can never have an active variadic sequence
RegisterSet captured; RegisterSet captured;
}; };

View file

@ -575,7 +575,7 @@ enum class IrCmd : uint8_t
// Calls native libm function with 1 or 2 arguments // Calls native libm function with 1 or 2 arguments
// A: builtin function ID // A: builtin function ID
// B: double // B: double
// C: double (optional, 2nd argument) // C: double/int (optional, 2nd argument)
INVOKE_LIBM, INVOKE_LIBM,
}; };

View file

@ -30,7 +30,7 @@ void toString(IrToStringContext& ctx, IrOp op);
void toString(std::string& result, IrConst constant); void toString(std::string& result, IrConst constant);
void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index, bool includeUseInfo); void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, const IrInst& inst, uint32_t instIdx, bool includeUseInfo);
void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index, bool includeUseInfo); // Block title void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index, bool includeUseInfo); // Block title
std::string toString(const IrFunction& function, bool includeUseInfo); std::string toString(const IrFunction& function, bool includeUseInfo);

View file

@ -114,6 +114,28 @@ inline bool isBlockTerminator(IrCmd cmd)
return false; return false;
} }
inline bool isNonTerminatingJump(IrCmd cmd)
{
switch (cmd)
{
case IrCmd::TRY_NUM_TO_INDEX:
case IrCmd::TRY_CALL_FASTGETTM:
case IrCmd::CHECK_FASTCALL_RES:
case IrCmd::CHECK_TAG:
case IrCmd::CHECK_READONLY:
case IrCmd::CHECK_NO_METATABLE:
case IrCmd::CHECK_SAFE_ENV:
case IrCmd::CHECK_ARRAY_SIZE:
case IrCmd::CHECK_SLOT_MATCH:
case IrCmd::CHECK_NODE_NO_NEXT:
return true;
default:
break;
}
return false;
}
inline bool hasResult(IrCmd cmd) inline bool hasResult(IrCmd cmd)
{ {
switch (cmd) switch (cmd)

View file

@ -1,8 +1,11 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/RegisterA64.h"
#include "Luau/RegisterX64.h" #include "Luau/RegisterX64.h"
#include <initializer_list>
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
@ -17,22 +20,36 @@ static uint32_t kFullBlockFuncton = ~0u;
class UnwindBuilder class UnwindBuilder
{ {
public: public:
enum Arch
{
X64,
A64
};
virtual ~UnwindBuilder() = default; virtual ~UnwindBuilder() = default;
virtual void setBeginOffset(size_t beginOffset) = 0; virtual void setBeginOffset(size_t beginOffset) = 0;
virtual size_t getBeginOffset() const = 0; virtual size_t getBeginOffset() const = 0;
virtual void startInfo() = 0; virtual void startInfo(Arch arch) = 0;
virtual void startFunction() = 0; virtual void startFunction() = 0;
virtual void spill(int espOffset, X64::RegisterX64 reg) = 0;
virtual void save(X64::RegisterX64 reg) = 0;
virtual void allocStack(int size) = 0;
virtual void setupFrameReg(X64::RegisterX64 reg, int espOffset) = 0;
virtual void finishFunction(uint32_t beginOffset, uint32_t endOffset) = 0; virtual void finishFunction(uint32_t beginOffset, uint32_t endOffset) = 0;
virtual void finishInfo() = 0; virtual void finishInfo() = 0;
// A64-specific; prologue must look like this:
// sub sp, sp, stackSize
// store sequence that saves regs to [sp..sp+regs.size*8) in the order specified in regs; regs should start with x29, x30 (fp, lr)
// mov x29, sp
virtual void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list<A64::RegisterA64> regs) = 0;
// X64-specific; prologue must look like this:
// optional, indicated by setupFrame:
// push rbp
// mov rbp, rsp
// push reg in the order specified in regs
// sub rsp, stackSize
virtual void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> regs) = 0;
virtual size_t getSize() const = 0; virtual size_t getSize() const = 0;
virtual size_t getFunctionCount() const = 0; virtual size_t getFunctionCount() const = 0;

View file

@ -24,17 +24,14 @@ public:
void setBeginOffset(size_t beginOffset) override; void setBeginOffset(size_t beginOffset) override;
size_t getBeginOffset() const override; size_t getBeginOffset() const override;
void startInfo() override; void startInfo(Arch arch) override;
void startFunction() override; void startFunction() override;
void spill(int espOffset, X64::RegisterX64 reg) override;
void save(X64::RegisterX64 reg) override;
void allocStack(int size) override;
void setupFrameReg(X64::RegisterX64 reg, int espOffset) override;
void finishFunction(uint32_t beginOffset, uint32_t endOffset) override; void finishFunction(uint32_t beginOffset, uint32_t endOffset) override;
void finishInfo() override; void finishInfo() override;
void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list<A64::RegisterA64> regs) override;
void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> regs) override;
size_t getSize() const override; size_t getSize() const override;
size_t getFunctionCount() const override; size_t getFunctionCount() const override;
@ -49,8 +46,6 @@ private:
uint8_t rawData[kRawDataLimit]; uint8_t rawData[kRawDataLimit];
uint8_t* pos = rawData; uint8_t* pos = rawData;
uint32_t stackOffset = 0;
// We will remember the FDE location to write some of the fields like entry length, function start and size later // We will remember the FDE location to write some of the fields like entry length, function start and size later
uint8_t* fdeEntryStart = nullptr; uint8_t* fdeEntryStart = nullptr;
}; };

View file

@ -44,17 +44,14 @@ public:
void setBeginOffset(size_t beginOffset) override; void setBeginOffset(size_t beginOffset) override;
size_t getBeginOffset() const override; size_t getBeginOffset() const override;
void startInfo() override; void startInfo(Arch arch) override;
void startFunction() override; void startFunction() override;
void spill(int espOffset, X64::RegisterX64 reg) override;
void save(X64::RegisterX64 reg) override;
void allocStack(int size) override;
void setupFrameReg(X64::RegisterX64 reg, int espOffset) override;
void finishFunction(uint32_t beginOffset, uint32_t endOffset) override; void finishFunction(uint32_t beginOffset, uint32_t endOffset) override;
void finishInfo() override; void finishInfo() override;
void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list<A64::RegisterA64> regs) override;
void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> regs) override;
size_t getSize() const override; size_t getSize() const override;
size_t getFunctionCount() const override; size_t getFunctionCount() const override;
@ -75,7 +72,6 @@ private:
uint8_t prologSize = 0; uint8_t prologSize = 0;
X64::RegisterX64 frameReg = X64::noreg; X64::RegisterX64 frameReg = X64::noreg;
uint8_t frameRegOffset = 0; uint8_t frameRegOffset = 0;
uint32_t stackOffset = 0;
}; };
} // namespace CodeGen } // namespace CodeGen

View file

@ -22,12 +22,25 @@
extern "C" void __register_frame(const void*); extern "C" void __register_frame(const void*);
extern "C" void __deregister_frame(const void*); extern "C" void __deregister_frame(const void*);
extern "C" void __unw_add_dynamic_fde() __attribute__((weak));
#endif #endif
#if defined(__APPLE__) namespace Luau
// On Mac, each FDE inside eh_frame section has to be handled separately {
namespace CodeGen
{
#if !defined(_WIN32)
static void visitFdeEntries(char* pos, void (*cb)(const void*)) static void visitFdeEntries(char* pos, void (*cb)(const void*))
{ {
// When using glibc++ unwinder, we need to call __register_frame/__deregister_frame on the entire .eh_frame data
// When using libc++ unwinder (libunwind), each FDE has to be handled separately
// libc++ unwinder is the macOS unwinder, but on Linux the unwinder depends on the library the executable is linked with
// __unw_add_dynamic_fde is specific to libc++ unwinder, as such we determine the library based on its existence
if (__unw_add_dynamic_fde == nullptr)
return cb(pos);
for (;;) for (;;)
{ {
unsigned partLength; unsigned partLength;
@ -47,11 +60,6 @@ static void visitFdeEntries(char* pos, void (*cb)(const void*))
} }
#endif #endif
namespace Luau
{
namespace CodeGen
{
void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, size_t& beginOffset) void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, size_t& beginOffset)
{ {
UnwindBuilder* unwind = (UnwindBuilder*)context; UnwindBuilder* unwind = (UnwindBuilder*)context;
@ -70,10 +78,8 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz
LUAU_ASSERT(!"failed to allocate function table"); LUAU_ASSERT(!"failed to allocate function table");
return nullptr; return nullptr;
} }
#elif defined(__APPLE__)
visitFdeEntries(unwindData, __register_frame);
#elif !defined(_WIN32) #elif !defined(_WIN32)
__register_frame(unwindData); visitFdeEntries(unwindData, __register_frame);
#endif #endif
beginOffset = unwindSize + unwind->getBeginOffset(); beginOffset = unwindSize + unwind->getBeginOffset();
@ -85,10 +91,8 @@ void destroyBlockUnwindInfo(void* context, void* unwindData)
#if defined(_WIN32) && defined(_M_X64) #if defined(_WIN32) && defined(_M_X64)
if (!RtlDeleteFunctionTable((RUNTIME_FUNCTION*)unwindData)) if (!RtlDeleteFunctionTable((RUNTIME_FUNCTION*)unwindData))
LUAU_ASSERT(!"failed to deallocate function table"); LUAU_ASSERT(!"failed to deallocate function table");
#elif defined(__APPLE__)
visitFdeEntries((char*)unwindData, __deregister_frame);
#elif !defined(_WIN32) #elif !defined(_WIN32)
__deregister_frame(unwindData); visitFdeEntries((char*)unwindData, __deregister_frame);
#endif #endif
} }

View file

@ -134,7 +134,6 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction&
for (size_t i = 0; i < sortedBlocks.size(); ++i) for (size_t i = 0; i < sortedBlocks.size(); ++i)
{ {
uint32_t blockIndex = sortedBlocks[i]; uint32_t blockIndex = sortedBlocks[i];
IrBlock& block = function.blocks[blockIndex]; IrBlock& block = function.blocks[blockIndex];
if (block.kind == IrBlockKind::Dead) if (block.kind == IrBlockKind::Dead)
@ -191,10 +190,13 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction&
continue; continue;
} }
// Either instruction result value is not referenced or the use count is not zero
LUAU_ASSERT(inst.lastUse == 0 || inst.useCount != 0);
if (options.includeIr) if (options.includeIr)
{ {
build.logAppend("# "); build.logAppend("# ");
toStringDetailed(ctx, inst, index, /* includeUseInfo */ true); toStringDetailed(ctx, block, blockIndex, inst, index, /* includeUseInfo */ true);
} }
IrBlock& next = i + 1 < sortedBlocks.size() ? function.blocks[sortedBlocks[i + 1]] : dummy; IrBlock& next = i + 1 < sortedBlocks.size() ? function.blocks[sortedBlocks[i + 1]] : dummy;
@ -409,9 +411,11 @@ bool isSupported()
if (sizeof(LuaNode) != 32) if (sizeof(LuaNode) != 32)
return false; return false;
// TODO: A64 codegen does not generate correct unwind info at the moment so it requires longjmp instead of C++ exceptions #ifdef _WIN32
// Unwind info is not supported for Windows-on-ARM yet
if (!LUA_USE_LONGJMP) if (!LUA_USE_LONGJMP)
return false; return false;
#endif
return true; return true;
#else #else

View file

@ -123,9 +123,6 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde
// Arguments: x0 = lua_State*, x1 = Proto*, x2 = native code pointer to jump to, x3 = NativeContext* // Arguments: x0 = lua_State*, x1 = Proto*, x2 = native code pointer to jump to, x3 = NativeContext*
locations.start = build.setLabel(); locations.start = build.setLabel();
unwind.startFunction();
unwind.allocStack(8); // TODO: this is just a hack to make UnwindBuilder assertions cooperate
// prologue // prologue
build.sub(sp, sp, kStackSize); build.sub(sp, sp, kStackSize);
@ -140,6 +137,8 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde
locations.prologueEnd = build.setLabel(); locations.prologueEnd = build.setLabel();
uint32_t prologueSize = build.getLabelOffset(locations.prologueEnd) - build.getLabelOffset(locations.start);
// Setup native execution environment // Setup native execution environment
build.mov(rState, x0); build.mov(rState, x0);
build.mov(rNativeContext, x3); build.mov(rNativeContext, x3);
@ -168,6 +167,8 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde
build.ret(); build.ret();
// Our entry function is special, it spans the whole remaining code area // Our entry function is special, it spans the whole remaining code area
unwind.startFunction();
unwind.prologueA64(prologueSize, kStackSize, {x29, x30, x19, x20, x21, x22, x23, x24});
unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFuncton); unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFuncton);
return locations; return locations;
@ -178,7 +179,7 @@ bool initHeaderFunctions(NativeState& data)
AssemblyBuilderA64 build(/* logText= */ false); AssemblyBuilderA64 build(/* logText= */ false);
UnwindBuilder& unwind = *data.unwindBuilder.get(); UnwindBuilder& unwind = *data.unwindBuilder.get();
unwind.startInfo(); unwind.startInfo(UnwindBuilder::A64);
EntryLocations entryLocations = buildEntryFunction(build, unwind); EntryLocations entryLocations = buildEntryFunction(build, unwind);

View file

@ -58,43 +58,44 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde
unwind.startFunction(); unwind.startFunction();
// Save common non-volatile registers // Save common non-volatile registers
build.push(rbp);
unwind.save(rbp);
if (build.abi == ABIX64::SystemV) if (build.abi == ABIX64::SystemV)
{ {
// We need to use a standard rbp-based frame setup for debuggers to work with JIT code
build.push(rbp);
build.mov(rbp, rsp); build.mov(rbp, rsp);
unwind.setupFrameReg(rbp, 0);
} }
build.push(rbx); build.push(rbx);
unwind.save(rbx);
build.push(r12); build.push(r12);
unwind.save(r12);
build.push(r13); build.push(r13);
unwind.save(r13);
build.push(r14); build.push(r14);
unwind.save(r14);
build.push(r15); build.push(r15);
unwind.save(r15);
if (build.abi == ABIX64::Windows) if (build.abi == ABIX64::Windows)
{ {
// Save non-volatile registers that are specific to Windows x64 ABI // Save non-volatile registers that are specific to Windows x64 ABI
build.push(rdi); build.push(rdi);
unwind.save(rdi);
build.push(rsi); build.push(rsi);
unwind.save(rsi);
// On Windows, rbp is available as a general-purpose non-volatile register; we currently don't use it, but we need to push an even number
// of registers for stack alignment...
build.push(rbp);
// TODO: once we start using non-volatile SIMD registers on Windows, we will save those here // TODO: once we start using non-volatile SIMD registers on Windows, we will save those here
} }
// Allocate stack space (reg home area + local data) // Allocate stack space (reg home area + local data)
build.sub(rsp, kStackSize + kLocalsSize); build.sub(rsp, kStackSize + kLocalsSize);
unwind.allocStack(kStackSize + kLocalsSize);
locations.prologueEnd = build.setLabel(); locations.prologueEnd = build.setLabel();
uint32_t prologueSize = build.getLabelOffset(locations.prologueEnd) - build.getLabelOffset(locations.start);
if (build.abi == ABIX64::SystemV)
unwind.prologueX64(prologueSize, kStackSize + kLocalsSize, /* setupFrame= */ true, {rbx, r12, r13, r14, r15});
else if (build.abi == ABIX64::Windows)
unwind.prologueX64(prologueSize, kStackSize + kLocalsSize, /* setupFrame= */ false, {rbx, r12, r13, r14, r15, rdi, rsi, rbp});
// Setup native execution environment // Setup native execution environment
build.mov(rState, rArg1); build.mov(rState, rArg1);
build.mov(rNativeContext, rArg4); build.mov(rNativeContext, rArg4);
@ -118,6 +119,7 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde
if (build.abi == ABIX64::Windows) if (build.abi == ABIX64::Windows)
{ {
build.pop(rbp);
build.pop(rsi); build.pop(rsi);
build.pop(rdi); build.pop(rdi);
} }
@ -127,7 +129,10 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde
build.pop(r13); build.pop(r13);
build.pop(r12); build.pop(r12);
build.pop(rbx); build.pop(rbx);
if (build.abi == ABIX64::SystemV)
build.pop(rbp); build.pop(rbp);
build.ret(); build.ret();
// Our entry function is special, it spans the whole remaining code area // Our entry function is special, it spans the whole remaining code area
@ -141,7 +146,7 @@ bool initHeaderFunctions(NativeState& data)
AssemblyBuilderX64 build(/* logText= */ false); AssemblyBuilderX64 build(/* logText= */ false);
UnwindBuilder& unwind = *data.unwindBuilder.get(); UnwindBuilder& unwind = *data.unwindBuilder.get();
unwind.startInfo(); unwind.startInfo(UnwindBuilder::X64);
EntryLocations entryLocations = buildEntryFunction(build, unwind); EntryLocations entryLocations = buildEntryFunction(build, unwind);

View file

@ -18,19 +18,6 @@ namespace CodeGen
namespace X64 namespace X64
{ {
static void emitBuiltinMathLdexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg, OperandX64 arg2)
{
ScopedRegX64 tmp{regs, SizeX64::qword};
build.vcvttsd2si(tmp.reg, arg2);
IrCallWrapperX64 callWrap(regs, build);
callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg));
callWrap.addArgument(SizeX64::qword, tmp);
callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_ldexp)]);
build.vmovsd(luauRegValue(ra), xmm0);
}
static void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg, int nresults) static void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg, int nresults)
{ {
IrCallWrapperX64 callWrap(regs, build); IrCallWrapperX64 callWrap(regs, build);
@ -115,9 +102,6 @@ void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int r
{ {
switch (bfid) switch (bfid)
{ {
case LBF_MATH_LDEXP:
LUAU_ASSERT(nparams == 2 && nresults == 1);
return emitBuiltinMathLdexp(regs, build, ra, arg, arg2);
case LBF_MATH_FREXP: case LBF_MATH_FREXP:
LUAU_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); LUAU_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2));
return emitBuiltinMathFrexp(regs, build, ra, arg, nresults); return emitBuiltinMathFrexp(regs, build, ra, arg, nresults);

View file

@ -162,7 +162,7 @@ uint32_t getLiveOutValueCount(IrFunction& function, IrBlock& block)
return getLiveInOutValueCount(function, block).second; return getLiveInOutValueCount(function, block).second;
} }
static void requireVariadicSequence(RegisterSet& sourceRs, const RegisterSet& defRs, uint8_t varargStart) void requireVariadicSequence(RegisterSet& sourceRs, const RegisterSet& defRs, uint8_t varargStart)
{ {
if (!defRs.varargSeq) if (!defRs.varargSeq)
{ {

View file

@ -62,6 +62,7 @@ static const char* getTagName(uint8_t tag)
case LUA_TTHREAD: case LUA_TTHREAD:
return "tthread"; return "tthread";
default: default:
LUAU_ASSERT(!"Unknown type tag");
LUAU_UNREACHABLE(); LUAU_UNREACHABLE();
} }
} }
@ -410,27 +411,6 @@ void toString(std::string& result, IrConst constant)
} }
} }
void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index, bool includeUseInfo)
{
size_t start = ctx.result.size();
toString(ctx, inst, index);
if (includeUseInfo)
{
padToDetailColumn(ctx.result, start);
if (inst.useCount == 0 && hasSideEffects(inst.cmd))
append(ctx.result, "; %%%u, has side-effects\n", index);
else
append(ctx.result, "; useCount: %d, lastUse: %%%u\n", inst.useCount, inst.lastUse);
}
else
{
ctx.result.append("\n");
}
}
static void appendBlockSet(IrToStringContext& ctx, BlockIteratorWrapper blocks) static void appendBlockSet(IrToStringContext& ctx, BlockIteratorWrapper blocks)
{ {
bool comma = false; bool comma = false;
@ -470,6 +450,86 @@ static void appendRegisterSet(IrToStringContext& ctx, const RegisterSet& rs, con
} }
} }
static RegisterSet getJumpTargetExtraLiveIn(IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, const IrInst& inst)
{
RegisterSet extraRs;
if (blockIdx >= ctx.cfg.in.size())
return extraRs;
const RegisterSet& defRs = ctx.cfg.in[blockIdx];
// Find first block argument, for guard instructions (isNonTerminatingJump), that's the first and only one
LUAU_ASSERT(isNonTerminatingJump(inst.cmd));
IrOp op = inst.a;
if (inst.b.kind == IrOpKind::Block)
op = inst.b;
else if (inst.c.kind == IrOpKind::Block)
op = inst.c;
else if (inst.d.kind == IrOpKind::Block)
op = inst.d;
else if (inst.e.kind == IrOpKind::Block)
op = inst.e;
else if (inst.f.kind == IrOpKind::Block)
op = inst.f;
if (op.kind == IrOpKind::Block && op.index < ctx.cfg.in.size())
{
const RegisterSet& inRs = ctx.cfg.in[op.index];
extraRs.regs = inRs.regs & ~defRs.regs;
if (inRs.varargSeq)
requireVariadicSequence(extraRs, defRs, inRs.varargStart);
}
return extraRs;
}
void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, const IrInst& inst, uint32_t instIdx, bool includeUseInfo)
{
size_t start = ctx.result.size();
toString(ctx, inst, instIdx);
if (includeUseInfo)
{
padToDetailColumn(ctx.result, start);
if (inst.useCount == 0 && hasSideEffects(inst.cmd))
{
if (isNonTerminatingJump(inst.cmd))
{
RegisterSet extraRs = getJumpTargetExtraLiveIn(ctx, block, blockIdx, inst);
if (extraRs.regs.any() || extraRs.varargSeq)
{
append(ctx.result, "; %%%u, extra in: ", instIdx);
appendRegisterSet(ctx, extraRs, ", ");
ctx.result.append("\n");
}
else
{
append(ctx.result, "; %%%u\n", instIdx);
}
}
else
{
append(ctx.result, "; %%%u\n", instIdx);
}
}
else
{
append(ctx.result, "; useCount: %d, lastUse: %%%u\n", inst.useCount, inst.lastUse);
}
}
else
{
ctx.result.append("\n");
}
}
void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index, bool includeUseInfo) void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index, bool includeUseInfo)
{ {
// Report captured registers for entry block // Report captured registers for entry block
@ -581,7 +641,7 @@ std::string toString(const IrFunction& function, bool includeUseInfo)
continue; continue;
append(ctx.result, " "); append(ctx.result, " ");
toStringDetailed(ctx, inst, index, includeUseInfo); toStringDetailed(ctx, block, uint32_t(i), inst, index, includeUseInfo);
} }
append(ctx.result, "\n"); append(ctx.result, "\n");

View file

@ -122,42 +122,6 @@ static bool emitBuiltin(
{ {
switch (bfid) switch (bfid)
{ {
case LBF_MATH_LDEXP:
LUAU_ASSERT(nparams == 2 && nresults == 1);
if (args.kind == IrOpKind::VmReg)
{
build.ldr(d1, mem(rBase, args.index * sizeof(TValue) + offsetof(TValue, value.n)));
build.fcvtzs(w0, d1);
}
else if (args.kind == IrOpKind::VmConst)
{
size_t constantOffset = args.index * sizeof(TValue) + offsetof(TValue, value.n);
// Note: cumulative offset is guaranteed to be divisible by 8 (since we're loading a double); we can use that to expand the useful range
// that doesn't require temporaries
if (constantOffset / 8 <= AddressA64::kMaxOffset)
{
build.ldr(d1, mem(rConstants, int(constantOffset)));
}
else
{
emitAddOffset(build, x0, rConstants, constantOffset);
build.ldr(d1, x0);
}
build.fcvtzs(w0, d1);
}
else if (args.kind == IrOpKind::Constant)
build.mov(w0, int(function.doubleOp(args)));
else if (args.kind != IrOpKind::Undef)
LUAU_ASSERT(!"Unsupported instruction form");
build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n)));
build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, libm_ldexp)));
build.blr(x1);
build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n)));
return true;
case LBF_MATH_FREXP: case LBF_MATH_FREXP:
LUAU_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); LUAU_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2));
emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg); emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg);
@ -1610,12 +1574,20 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
{ {
if (inst.c.kind != IrOpKind::None) if (inst.c.kind != IrOpKind::None)
{ {
bool isInt = (inst.c.kind == IrOpKind::Constant) ? constOp(inst.c).kind == IrConstKind::Int
: getCmdValueKind(function.instOp(inst.c).cmd) == IrValueKind::Int;
RegisterA64 temp1 = tempDouble(inst.b); RegisterA64 temp1 = tempDouble(inst.b);
RegisterA64 temp2 = tempDouble(inst.c); RegisterA64 temp2 = isInt ? tempInt(inst.c) : tempDouble(inst.c);
RegisterA64 temp3 = regs.allocTemp(KindA64::d); // note: spill() frees all registers so we need to avoid alloc after spill RegisterA64 temp3 = isInt ? noreg : regs.allocTemp(KindA64::d); // note: spill() frees all registers so we need to avoid alloc after spill
regs.spill(build, index, {temp1, temp2}); regs.spill(build, index, {temp1, temp2});
if (d0 != temp2) if (isInt)
{
build.fmov(d0, temp1);
build.mov(w0, temp2);
}
else if (d0 != temp2)
{ {
build.fmov(d0, temp1); build.fmov(d0, temp1);
build.fmov(d1, temp2); build.fmov(d1, temp2);
@ -1634,8 +1606,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.fmov(d0, temp1); build.fmov(d0, temp1);
} }
build.ldr(x0, mem(rNativeContext, getNativeContextOffset(uintOp(inst.a)))); build.ldr(x1, mem(rNativeContext, getNativeContextOffset(uintOp(inst.a))));
build.blr(x0); build.blr(x1);
inst.regA64 = regs.takeReg(d0, index); inst.regA64 = regs.takeReg(d0, index);
break; break;
} }

View file

@ -1304,7 +1304,15 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.b), inst.b); callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.b), inst.b);
if (inst.c.kind != IrOpKind::None) if (inst.c.kind != IrOpKind::None)
{
bool isInt = (inst.c.kind == IrOpKind::Constant) ? constOp(inst.c).kind == IrConstKind::Int
: getCmdValueKind(function.instOp(inst.c).cmd) == IrValueKind::Int;
if (isInt)
callWrap.addArgument(SizeX64::dword, memRegUintOp(inst.c), inst.c);
else
callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.c), inst.c); callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.c), inst.c);
}
callWrap.call(qword[rNativeContext + getNativeContextOffset(uintOp(inst.a))]); callWrap.call(qword[rNativeContext + getNativeContextOffset(uintOp(inst.a))]);
inst.regX64 = regs.takeReg(xmm0, index); inst.regX64 = regs.takeReg(xmm0, index);

View file

@ -71,23 +71,6 @@ static BuiltinImplResult translateBuiltinNumberToNumberLibm(
return {BuiltinImplType::UsesFallback, 1}; return {BuiltinImplType::UsesFallback, 1};
} }
// (number, number, ...) -> number
static BuiltinImplResult translateBuiltin2NumberToNumber(
IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback)
{
if (nparams < 2 || nresults > 1)
return {BuiltinImplType::None, -1};
builtinCheckDouble(build, build.vmReg(arg), fallback);
builtinCheckDouble(build, args, fallback);
build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(2), build.constInt(1));
if (ra != arg)
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER));
return {BuiltinImplType::UsesFallback, 1};
}
static BuiltinImplResult translateBuiltin2NumberToNumberLibm( static BuiltinImplResult translateBuiltin2NumberToNumberLibm(
IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback)
{ {
@ -110,6 +93,30 @@ static BuiltinImplResult translateBuiltin2NumberToNumberLibm(
return {BuiltinImplType::UsesFallback, 1}; return {BuiltinImplType::UsesFallback, 1};
} }
static BuiltinImplResult translateBuiltinMathLdexp(
IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback)
{
if (nparams < 2 || nresults > 1)
return {BuiltinImplType::None, -1};
builtinCheckDouble(build, build.vmReg(arg), fallback);
builtinCheckDouble(build, args, fallback);
IrOp va = builtinLoadDouble(build, build.vmReg(arg));
IrOp vb = builtinLoadDouble(build, args);
IrOp vbi = build.inst(IrCmd::NUM_TO_INT, vb);
IrOp res = build.inst(IrCmd::INVOKE_LIBM, build.constUint(bfid), va, vbi);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), res);
if (ra != arg)
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER));
return {BuiltinImplType::UsesFallback, 1};
}
// (number, ...) -> (number, number) // (number, ...) -> (number, number)
static BuiltinImplResult translateBuiltinNumberTo2Number( static BuiltinImplResult translateBuiltinNumberTo2Number(
IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback)
@ -778,7 +785,7 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg,
case LBF_MATH_ATAN2: case LBF_MATH_ATAN2:
return translateBuiltin2NumberToNumberLibm(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); return translateBuiltin2NumberToNumberLibm(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback);
case LBF_MATH_LDEXP: case LBF_MATH_LDEXP:
return translateBuiltin2NumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); return translateBuiltinMathLdexp(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback);
case LBF_MATH_FREXP: case LBF_MATH_FREXP:
case LBF_MATH_MODF: case LBF_MATH_MODF:
return translateBuiltinNumberTo2Number(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); return translateBuiltinNumberTo2Number(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback);

View file

@ -299,6 +299,9 @@ void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst repl
removeUse(function, inst.e); removeUse(function, inst.e);
removeUse(function, inst.f); removeUse(function, inst.f);
// Inherit existing use count (last use is skipped as it will be defined later)
replacement.useCount = inst.useCount;
inst = replacement; inst = replacement;
// Removing the earlier extra reference, this might leave the block without users without marking it as dead // Removing the earlier extra reference, this might leave the block without users without marking it as dead
@ -775,6 +778,8 @@ uint32_t getNativeContextOffset(int bfid)
return offsetof(NativeContext, libm_pow); return offsetof(NativeContext, libm_pow);
case LBF_IR_MATH_LOG2: case LBF_IR_MATH_LOG2:
return offsetof(NativeContext, libm_log2); return offsetof(NativeContext, libm_log2);
case LBF_MATH_LDEXP:
return offsetof(NativeContext, libm_ldexp);
default: default:
LUAU_ASSERT(!"Unsupported bfid"); LUAU_ASSERT(!"Unsupported bfid");
} }

View file

@ -290,6 +290,20 @@ struct ConstPropState
valueMap[versionedVmRegLoad(loadCmd, storeInst.a)] = storeInst.b.index; valueMap[versionedVmRegLoad(loadCmd, storeInst.a)] = storeInst.b.index;
} }
void clear()
{
for (int i = 0; i <= maxReg; ++i)
regs[i] = RegisterInfo();
maxReg = 0;
inSafeEnv = false;
checkedGc = false;
instLink.clear();
valueMap.clear();
}
IrFunction& function; IrFunction& function;
bool useValueNumbering = false; bool useValueNumbering = false;
@ -854,12 +868,11 @@ static void constPropInBlock(IrBuilder& build, IrBlock& block, ConstPropState& s
state.valueMap.clear(); state.valueMap.clear();
} }
static void constPropInBlockChain(IrBuilder& build, std::vector<uint8_t>& visited, IrBlock* block, bool useValueNumbering) static void constPropInBlockChain(IrBuilder& build, std::vector<uint8_t>& visited, IrBlock* block, ConstPropState& state)
{ {
IrFunction& function = build.function; IrFunction& function = build.function;
ConstPropState state{function}; state.clear();
state.useValueNumbering = useValueNumbering;
while (block) while (block)
{ {
@ -936,7 +949,7 @@ static std::vector<uint32_t> collectDirectBlockJumpPath(IrFunction& function, st
return path; return path;
} }
static void tryCreateLinearBlock(IrBuilder& build, std::vector<uint8_t>& visited, IrBlock& startingBlock, bool useValueNumbering) static void tryCreateLinearBlock(IrBuilder& build, std::vector<uint8_t>& visited, IrBlock& startingBlock, ConstPropState& state)
{ {
IrFunction& function = build.function; IrFunction& function = build.function;
@ -965,8 +978,9 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector<uint8_t>& visited
return; return;
// Initialize state with the knowledge of our current block // Initialize state with the knowledge of our current block
ConstPropState state{function}; state.clear();
state.useValueNumbering = useValueNumbering;
// TODO: using values from the first block can cause 'live out' of the linear block predecessor to not have all required registers
constPropInBlock(build, startingBlock, state); constPropInBlock(build, startingBlock, state);
// Veryfy that target hasn't changed // Veryfy that target hasn't changed
@ -981,10 +995,43 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector<uint8_t>& visited
replace(function, termInst.a, newBlock); replace(function, termInst.a, newBlock);
// Clone the collected path int our fresh block // Clone the collected path into our fresh block
for (uint32_t pathBlockIdx : path) for (uint32_t pathBlockIdx : path)
build.clone(function.blocks[pathBlockIdx], /* removeCurrentTerminator */ true); build.clone(function.blocks[pathBlockIdx], /* removeCurrentTerminator */ true);
// If all live in/out data is defined aside from the new block, generate it
// Note that liveness information is not strictly correct after optimization passes and may need to be recomputed before next passes
// The information generated here is consistent with current state that could be outdated, but still useful in IR inspection
if (function.cfg.in.size() == newBlock.index)
{
LUAU_ASSERT(function.cfg.in.size() == function.cfg.out.size());
LUAU_ASSERT(function.cfg.in.size() == function.cfg.def.size());
// Live in is the same as the input of the original first block
function.cfg.in.push_back(function.cfg.in[path.front()]);
// Live out is the same as the result of the original last block
function.cfg.out.push_back(function.cfg.out[path.back()]);
// Defs are tricky, registers are joined together, but variadic sequences can be consumed inside the block
function.cfg.def.push_back({});
RegisterSet& def = function.cfg.def.back();
for (uint32_t pathBlockIdx : path)
{
const RegisterSet& pathDef = function.cfg.def[pathBlockIdx];
def.regs |= pathDef.regs;
// Taking only the last defined variadic sequence if it's not consumed before before the end
if (pathDef.varargSeq && function.cfg.out.back().varargSeq)
{
def.varargSeq = true;
def.varargStart = pathDef.varargStart;
}
}
}
// Optimize our linear block // Optimize our linear block
IrBlock& linearBlock = function.blockOp(newBlock); IrBlock& linearBlock = function.blockOp(newBlock);
constPropInBlock(build, linearBlock, state); constPropInBlock(build, linearBlock, state);
@ -994,6 +1041,9 @@ void constPropInBlockChains(IrBuilder& build, bool useValueNumbering)
{ {
IrFunction& function = build.function; IrFunction& function = build.function;
ConstPropState state{function};
state.useValueNumbering = useValueNumbering;
std::vector<uint8_t> visited(function.blocks.size(), false); std::vector<uint8_t> visited(function.blocks.size(), false);
for (IrBlock& block : function.blocks) for (IrBlock& block : function.blocks)
@ -1004,7 +1054,7 @@ void constPropInBlockChains(IrBuilder& build, bool useValueNumbering)
if (visited[function.getBlockIndex(block)]) if (visited[function.getBlockIndex(block)])
continue; continue;
constPropInBlockChain(build, visited, &block, useValueNumbering); constPropInBlockChain(build, visited, &block, state);
} }
} }
@ -1015,6 +1065,9 @@ void createLinearBlocks(IrBuilder& build, bool useValueNumbering)
// new 'block' will only be reachable from a single one and all gathered information can be preserved. // new 'block' will only be reachable from a single one and all gathered information can be preserved.
IrFunction& function = build.function; IrFunction& function = build.function;
ConstPropState state{function};
state.useValueNumbering = useValueNumbering;
std::vector<uint8_t> visited(function.blocks.size(), false); std::vector<uint8_t> visited(function.blocks.size(), false);
// This loop can create new 'linear' blocks, so index-based loop has to be used (and it intentionally won't reach those new blocks) // This loop can create new 'linear' blocks, so index-based loop has to be used (and it intentionally won't reach those new blocks)
@ -1029,7 +1082,7 @@ void createLinearBlocks(IrBuilder& build, bool useValueNumbering)
if (visited[function.getBlockIndex(block)]) if (visited[function.getBlockIndex(block)])
continue; continue;
tryCreateLinearBlock(build, visited, block, useValueNumbering); tryCreateLinearBlock(build, visited, block, state);
} }
} }

View file

@ -36,27 +36,25 @@
#define DW_CFA_lo_user 0x1c #define DW_CFA_lo_user 0x1c
#define DW_CFA_hi_user 0x3f #define DW_CFA_hi_user 0x3f
// Register numbers for x64 (System V ABI, page 57, ch. 3.7, figure 3.36) // Register numbers for X64 (System V ABI, page 57, ch. 3.7, figure 3.36)
#define DW_REG_RAX 0 #define DW_REG_X64_RAX 0
#define DW_REG_RDX 1 #define DW_REG_X64_RDX 1
#define DW_REG_RCX 2 #define DW_REG_X64_RCX 2
#define DW_REG_RBX 3 #define DW_REG_X64_RBX 3
#define DW_REG_RSI 4 #define DW_REG_X64_RSI 4
#define DW_REG_RDI 5 #define DW_REG_X64_RDI 5
#define DW_REG_RBP 6 #define DW_REG_X64_RBP 6
#define DW_REG_RSP 7 #define DW_REG_X64_RSP 7
#define DW_REG_R8 8 #define DW_REG_X64_RA 16
#define DW_REG_R9 9
#define DW_REG_R10 10
#define DW_REG_R11 11
#define DW_REG_R12 12
#define DW_REG_R13 13
#define DW_REG_R14 14
#define DW_REG_R15 15
#define DW_REG_RA 16
const int regIndexToDwRegX64[16] = {DW_REG_RAX, DW_REG_RCX, DW_REG_RDX, DW_REG_RBX, DW_REG_RSP, DW_REG_RBP, DW_REG_RSI, DW_REG_RDI, DW_REG_R8, // Register numbers for A64 (DWARF for the Arm 64-bit Architecture, ch. 4.1)
DW_REG_R9, DW_REG_R10, DW_REG_R11, DW_REG_R12, DW_REG_R13, DW_REG_R14, DW_REG_R15}; #define DW_REG_A64_FP 29
#define DW_REG_A64_LR 30
#define DW_REG_A64_SP 31
// X64 register mapping from real register index to DWARF2 (r8..r15 are mapped 1-1, but named registers aren't)
const int regIndexToDwRegX64[16] = {DW_REG_X64_RAX, DW_REG_X64_RCX, DW_REG_X64_RDX, DW_REG_X64_RBX, DW_REG_X64_RSP, DW_REG_X64_RBP, DW_REG_X64_RSI,
DW_REG_X64_RDI, 8, 9, 10, 11, 12, 13, 14, 15};
const int kCodeAlignFactor = 1; const int kCodeAlignFactor = 1;
const int kDataAlignFactor = 8; const int kDataAlignFactor = 8;
@ -85,7 +83,7 @@ static uint8_t* defineSavedRegisterLocation(uint8_t* pos, int dwReg, uint32_t st
{ {
LUAU_ASSERT(stackOffset % kDataAlignFactor == 0 && "stack offsets have to be measured in kDataAlignFactor units"); LUAU_ASSERT(stackOffset % kDataAlignFactor == 0 && "stack offsets have to be measured in kDataAlignFactor units");
if (dwReg <= 15) if (dwReg <= 0x3f)
{ {
pos = writeu8(pos, DW_CFA_offset + dwReg); pos = writeu8(pos, DW_CFA_offset + dwReg);
} }
@ -99,8 +97,9 @@ static uint8_t* defineSavedRegisterLocation(uint8_t* pos, int dwReg, uint32_t st
return pos; return pos;
} }
static uint8_t* advanceLocation(uint8_t* pos, uint8_t offset) static uint8_t* advanceLocation(uint8_t* pos, unsigned int offset)
{ {
LUAU_ASSERT(offset < 256);
pos = writeu8(pos, DW_CFA_advance_loc1); pos = writeu8(pos, DW_CFA_advance_loc1);
pos = writeu8(pos, offset); pos = writeu8(pos, offset);
return pos; return pos;
@ -132,8 +131,10 @@ size_t UnwindBuilderDwarf2::getBeginOffset() const
return beginOffset; return beginOffset;
} }
void UnwindBuilderDwarf2::startInfo() void UnwindBuilderDwarf2::startInfo(Arch arch)
{ {
LUAU_ASSERT(arch == A64 || arch == X64);
uint8_t* cieLength = pos; uint8_t* cieLength = pos;
pos = writeu32(pos, 0); // Length (to be filled later) pos = writeu32(pos, 0); // Length (to be filled later)
@ -142,15 +143,24 @@ void UnwindBuilderDwarf2::startInfo()
pos = writeu8(pos, 0); // CIE augmentation String "" pos = writeu8(pos, 0); // CIE augmentation String ""
int ra = arch == A64 ? DW_REG_A64_LR : DW_REG_X64_RA;
pos = writeuleb128(pos, kCodeAlignFactor); // Code align factor pos = writeuleb128(pos, kCodeAlignFactor); // Code align factor
pos = writeuleb128(pos, -kDataAlignFactor & 0x7f); // Data align factor of (as signed LEB128) pos = writeuleb128(pos, -kDataAlignFactor & 0x7f); // Data align factor of (as signed LEB128)
pos = writeu8(pos, DW_REG_RA); // Return address register pos = writeu8(pos, ra); // Return address register
// Optional CIE augmentation section (not present) // Optional CIE augmentation section (not present)
// Call frame instructions (common for all FDEs, of which we have 1) // Call frame instructions (common for all FDEs)
pos = defineCfaExpression(pos, DW_REG_RSP, 8); // Define CFA to be the rsp + 8 if (arch == A64)
pos = defineSavedRegisterLocation(pos, DW_REG_RA, 8); // Define return address register (RA) to be located at CFA - 8 {
pos = defineCfaExpression(pos, DW_REG_A64_SP, 0); // Define CFA to be the sp
}
else
{
pos = defineCfaExpression(pos, DW_REG_X64_RSP, 8); // Define CFA to be the rsp + 8
pos = defineSavedRegisterLocation(pos, DW_REG_X64_RA, 8); // Define return address register (RA) to be located at CFA - 8
}
pos = alignPosition(cieLength, pos); pos = alignPosition(cieLength, pos);
writeu32(cieLength, unsigned(pos - cieLength - 4)); // Length field itself is excluded from length writeu32(cieLength, unsigned(pos - cieLength - 4)); // Length field itself is excluded from length
@ -165,8 +175,6 @@ void UnwindBuilderDwarf2::startFunction()
func.fdeEntryStartPos = uint32_t(pos - rawData); func.fdeEntryStartPos = uint32_t(pos - rawData);
unwindFunctions.push_back(func); unwindFunctions.push_back(func);
stackOffset = 8; // Return address was pushed by calling the function
fdeEntryStart = pos; // Will be written at the end fdeEntryStart = pos; // Will be written at the end
pos = writeu32(pos, 0); // Length (to be filled later) pos = writeu32(pos, 0); // Length (to be filled later)
pos = writeu32(pos, unsigned(pos - rawData)); // CIE pointer pos = writeu32(pos, unsigned(pos - rawData)); // CIE pointer
@ -178,42 +186,11 @@ void UnwindBuilderDwarf2::startFunction()
// Function call frame instructions to follow // Function call frame instructions to follow
} }
void UnwindBuilderDwarf2::spill(int espOffset, X64::RegisterX64 reg)
{
pos = advanceLocation(pos, 5); // REX.W mov [rsp + imm8], reg
}
void UnwindBuilderDwarf2::save(X64::RegisterX64 reg)
{
stackOffset += 8;
pos = advanceLocation(pos, 2); // REX.W push reg
pos = defineCfaExpressionOffset(pos, stackOffset);
pos = defineSavedRegisterLocation(pos, regIndexToDwRegX64[reg.index], stackOffset);
}
void UnwindBuilderDwarf2::allocStack(int size)
{
stackOffset += size;
pos = advanceLocation(pos, 4); // REX.W sub rsp, imm8
pos = defineCfaExpressionOffset(pos, stackOffset);
}
void UnwindBuilderDwarf2::setupFrameReg(X64::RegisterX64 reg, int espOffset)
{
if (espOffset != 0)
pos = advanceLocation(pos, 5); // REX.W lea rbp, [rsp + imm8]
else
pos = advanceLocation(pos, 3); // REX.W mov rbp, rsp
// Cfa is based on rsp, so no additonal commands are required
}
void UnwindBuilderDwarf2::finishFunction(uint32_t beginOffset, uint32_t endOffset) void UnwindBuilderDwarf2::finishFunction(uint32_t beginOffset, uint32_t endOffset)
{ {
unwindFunctions.back().beginOffset = beginOffset; unwindFunctions.back().beginOffset = beginOffset;
unwindFunctions.back().endOffset = endOffset; unwindFunctions.back().endOffset = endOffset;
LUAU_ASSERT(stackOffset % 16 == 0 && "stack has to be aligned to 16 bytes after prologue");
LUAU_ASSERT(fdeEntryStart != nullptr); LUAU_ASSERT(fdeEntryStart != nullptr);
pos = alignPosition(fdeEntryStart, pos); pos = alignPosition(fdeEntryStart, pos);
@ -228,6 +205,69 @@ void UnwindBuilderDwarf2::finishInfo()
LUAU_ASSERT(getSize() <= kRawDataLimit); LUAU_ASSERT(getSize() <= kRawDataLimit);
} }
void UnwindBuilderDwarf2::prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list<A64::RegisterA64> regs)
{
LUAU_ASSERT(stackSize % 16 == 0);
LUAU_ASSERT(regs.size() >= 2 && regs.begin()[0] == A64::x29 && regs.begin()[1] == A64::x30);
LUAU_ASSERT(regs.size() * 8 <= stackSize);
// sub sp, sp, stackSize
pos = advanceLocation(pos, 4);
pos = defineCfaExpressionOffset(pos, stackSize);
// stp/str to store each register to stack in order
pos = advanceLocation(pos, prologueSize - 4);
for (size_t i = 0; i < regs.size(); ++i)
{
LUAU_ASSERT(regs.begin()[i].kind == A64::KindA64::x);
pos = defineSavedRegisterLocation(pos, regs.begin()[i].index, stackSize - unsigned(i * 8));
}
}
void UnwindBuilderDwarf2::prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> regs)
{
LUAU_ASSERT(stackSize > 0 && stackSize <= 128 && stackSize % 8 == 0);
unsigned int stackOffset = 8; // Return address was pushed by calling the function
unsigned int prologueOffset = 0;
if (setupFrame)
{
// push rbp
stackOffset += 8;
prologueOffset += 2;
pos = advanceLocation(pos, 2);
pos = defineCfaExpressionOffset(pos, stackOffset);
pos = defineSavedRegisterLocation(pos, DW_REG_X64_RBP, stackOffset);
// mov rbp, rsp
prologueOffset += 3;
pos = advanceLocation(pos, 3);
}
// push reg
for (X64::RegisterX64 reg : regs)
{
LUAU_ASSERT(reg.size == X64::SizeX64::qword);
stackOffset += 8;
prologueOffset += 2;
pos = advanceLocation(pos, 2);
pos = defineCfaExpressionOffset(pos, stackOffset);
pos = defineSavedRegisterLocation(pos, regIndexToDwRegX64[reg.index], stackOffset);
}
// sub rsp, stackSize
stackOffset += stackSize;
prologueOffset += 4;
pos = advanceLocation(pos, 4);
pos = defineCfaExpressionOffset(pos, stackOffset);
LUAU_ASSERT(stackOffset % 16 == 0);
LUAU_ASSERT(prologueOffset == prologueSize);
}
size_t UnwindBuilderDwarf2::getSize() const size_t UnwindBuilderDwarf2::getSize() const
{ {
return size_t(pos - rawData); return size_t(pos - rawData);
@ -244,14 +284,14 @@ void UnwindBuilderDwarf2::finalize(char* target, size_t offset, void* funcAddres
for (const UnwindFunctionDwarf2& func : unwindFunctions) for (const UnwindFunctionDwarf2& func : unwindFunctions)
{ {
uint8_t* fdeEntryStart = (uint8_t*)target + func.fdeEntryStartPos; uint8_t* fdeEntry = (uint8_t*)target + func.fdeEntryStartPos;
writeu64(fdeEntryStart + kFdeInitialLocationOffset, uintptr_t(funcAddress) + offset + func.beginOffset); writeu64(fdeEntry + kFdeInitialLocationOffset, uintptr_t(funcAddress) + offset + func.beginOffset);
if (func.endOffset == kFullBlockFuncton) if (func.endOffset == kFullBlockFuncton)
writeu64(fdeEntryStart + kFdeAddressRangeOffset, funcSize - offset); writeu64(fdeEntry + kFdeAddressRangeOffset, funcSize - offset);
else else
writeu64(fdeEntryStart + kFdeAddressRangeOffset, func.endOffset - func.beginOffset); writeu64(fdeEntry + kFdeAddressRangeOffset, func.endOffset - func.beginOffset);
} }
} }

View file

@ -31,7 +31,10 @@ size_t UnwindBuilderWin::getBeginOffset() const
return beginOffset; return beginOffset;
} }
void UnwindBuilderWin::startInfo() {} void UnwindBuilderWin::startInfo(Arch arch)
{
LUAU_ASSERT(arch == X64);
}
void UnwindBuilderWin::startFunction() void UnwindBuilderWin::startFunction()
{ {
@ -50,45 +53,6 @@ void UnwindBuilderWin::startFunction()
// rax has register index 0, which in Windows unwind info means that frame register is not used // rax has register index 0, which in Windows unwind info means that frame register is not used
frameReg = X64::rax; frameReg = X64::rax;
frameRegOffset = 0; frameRegOffset = 0;
// Return address was pushed by calling the function
stackOffset = 8;
}
void UnwindBuilderWin::spill(int espOffset, X64::RegisterX64 reg)
{
prologSize += 5; // REX.W mov [rsp + imm8], reg
}
void UnwindBuilderWin::save(X64::RegisterX64 reg)
{
prologSize += 2; // REX.W push reg
stackOffset += 8;
unwindCodes.push_back({prologSize, UWOP_PUSH_NONVOL, reg.index});
}
void UnwindBuilderWin::allocStack(int size)
{
LUAU_ASSERT(size >= 8 && size <= 128 && size % 8 == 0);
prologSize += 4; // REX.W sub rsp, imm8
stackOffset += size;
unwindCodes.push_back({prologSize, UWOP_ALLOC_SMALL, uint8_t((size - 8) / 8)});
}
void UnwindBuilderWin::setupFrameReg(X64::RegisterX64 reg, int espOffset)
{
LUAU_ASSERT(espOffset < 256 && espOffset % 16 == 0);
frameReg = reg;
frameRegOffset = uint8_t(espOffset / 16);
if (espOffset != 0)
prologSize += 5; // REX.W lea rbp, [rsp + imm8]
else
prologSize += 3; // REX.W mov rbp, rsp
unwindCodes.push_back({prologSize, UWOP_SET_FPREG, frameRegOffset});
} }
void UnwindBuilderWin::finishFunction(uint32_t beginOffset, uint32_t endOffset) void UnwindBuilderWin::finishFunction(uint32_t beginOffset, uint32_t endOffset)
@ -99,8 +63,6 @@ void UnwindBuilderWin::finishFunction(uint32_t beginOffset, uint32_t endOffset)
// Windows unwind code count is stored in uint8_t, so we can't have more // Windows unwind code count is stored in uint8_t, so we can't have more
LUAU_ASSERT(unwindCodes.size() < 256); LUAU_ASSERT(unwindCodes.size() < 256);
LUAU_ASSERT(stackOffset % 16 == 0 && "stack has to be aligned to 16 bytes after prologue");
UnwindInfoWin info; UnwindInfoWin info;
info.version = 1; info.version = 1;
info.flags = 0; // No EH info.flags = 0; // No EH
@ -142,6 +104,54 @@ void UnwindBuilderWin::finishFunction(uint32_t beginOffset, uint32_t endOffset)
void UnwindBuilderWin::finishInfo() {} void UnwindBuilderWin::finishInfo() {}
void UnwindBuilderWin::prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list<A64::RegisterA64> regs)
{
LUAU_ASSERT(!"Not implemented");
}
void UnwindBuilderWin::prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> regs)
{
LUAU_ASSERT(stackSize > 0 && stackSize <= 128 && stackSize % 8 == 0);
LUAU_ASSERT(prologueSize < 256);
unsigned int stackOffset = 8; // Return address was pushed by calling the function
unsigned int prologueOffset = 0;
if (setupFrame)
{
// push rbp
stackOffset += 8;
prologueOffset += 2;
unwindCodes.push_back({uint8_t(prologueOffset), UWOP_PUSH_NONVOL, X64::rbp.index});
// mov rbp, rsp
prologueOffset += 3;
frameReg = X64::rbp;
frameRegOffset = 0;
unwindCodes.push_back({uint8_t(prologueOffset), UWOP_SET_FPREG, frameRegOffset});
}
// push reg
for (X64::RegisterX64 reg : regs)
{
LUAU_ASSERT(reg.size == X64::SizeX64::qword);
stackOffset += 8;
prologueOffset += 2;
unwindCodes.push_back({uint8_t(prologueOffset), UWOP_PUSH_NONVOL, reg.index});
}
// sub rsp, stackSize
stackOffset += stackSize;
prologueOffset += 4;
unwindCodes.push_back({uint8_t(prologueOffset), UWOP_ALLOC_SMALL, uint8_t((stackSize - 8) / 8)});
LUAU_ASSERT(stackOffset % 16 == 0);
LUAU_ASSERT(prologueOffset == prologueSize);
this->prologSize = prologueSize;
}
size_t UnwindBuilderWin::getSize() const size_t UnwindBuilderWin::getSize() const
{ {
return sizeof(UnwindFunctionWin) * unwindFunctions.size() + size_t(rawDataPos - rawData); return sizeof(UnwindFunctionWin) * unwindFunctions.size() + size_t(rawDataPos - rawData);

View file

@ -1701,8 +1701,6 @@ void BytecodeBuilder::dumpConstant(std::string& result, int k) const
formatAppend(result, "'%s'", func.dumpname.c_str()); formatAppend(result, "'%s'", func.dumpname.c_str());
break; break;
} }
default:
LUAU_UNREACHABLE();
} }
} }

View file

@ -913,7 +913,9 @@ reentry:
// slow-path: not a function call // slow-path: not a function call
if (LUAU_UNLIKELY(!ttisfunction(ra))) if (LUAU_UNLIKELY(!ttisfunction(ra)))
{ {
VM_PROTECT(luaV_tryfuncTM(L, ra)); VM_PROTECT_PC(); // luaV_tryfuncTM may fail
luaV_tryfuncTM(L, ra);
argtop++; // __call adds an extra self argtop++; // __call adds an extra self
} }

View file

@ -135,20 +135,9 @@ TEST_CASE("WindowsUnwindCodesX64")
UnwindBuilderWin unwind; UnwindBuilderWin unwind;
unwind.startInfo(); unwind.startInfo(UnwindBuilder::X64);
unwind.startFunction(); unwind.startFunction();
unwind.spill(16, rdx); unwind.prologueX64(/* prologueSize= */ 23, /* stackSize= */ 72, /* setupFrame= */ true, {rdi, rsi, rbx, r12, r13, r14, r15});
unwind.spill(8, rcx);
unwind.save(rdi);
unwind.save(rsi);
unwind.save(rbx);
unwind.save(rbp);
unwind.save(r12);
unwind.save(r13);
unwind.save(r14);
unwind.save(r15);
unwind.allocStack(72);
unwind.setupFrameReg(rbp, 48);
unwind.finishFunction(0x11223344, 0x55443322); unwind.finishFunction(0x11223344, 0x55443322);
unwind.finishInfo(); unwind.finishInfo();
@ -156,8 +145,8 @@ TEST_CASE("WindowsUnwindCodesX64")
data.resize(unwind.getSize()); data.resize(unwind.getSize());
unwind.finalize(data.data(), 0, nullptr, 0); unwind.finalize(data.data(), 0, nullptr, 0);
std::vector<uint8_t> expected{0x44, 0x33, 0x22, 0x11, 0x22, 0x33, 0x44, 0x55, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x23, 0x0a, 0x35, 0x23, 0x33, 0x1e, std::vector<uint8_t> expected{0x44, 0x33, 0x22, 0x11, 0x22, 0x33, 0x44, 0x55, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x17, 0x0a, 0x05, 0x17, 0x82, 0x13,
0x82, 0x1a, 0xf0, 0x18, 0xe0, 0x16, 0xd0, 0x14, 0xc0, 0x12, 0x50, 0x10, 0x30, 0x0e, 0x60, 0x0c, 0x70}; 0xf0, 0x11, 0xe0, 0x0f, 0xd0, 0x0d, 0xc0, 0x0b, 0x30, 0x09, 0x60, 0x07, 0x70, 0x05, 0x03, 0x02, 0x50};
REQUIRE(data.size() == expected.size()); REQUIRE(data.size() == expected.size());
CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0);
@ -170,18 +159,9 @@ TEST_CASE("Dwarf2UnwindCodesX64")
UnwindBuilderDwarf2 unwind; UnwindBuilderDwarf2 unwind;
unwind.startInfo(); unwind.startInfo(UnwindBuilder::X64);
unwind.startFunction(); unwind.startFunction();
unwind.save(rdi); unwind.prologueX64(/* prologueSize= */ 23, /* stackSize= */ 72, /* setupFrame= */ true, {rdi, rsi, rbx, r12, r13, r14, r15});
unwind.save(rsi);
unwind.save(rbx);
unwind.save(rbp);
unwind.save(r12);
unwind.save(r13);
unwind.save(r14);
unwind.save(r15);
unwind.allocStack(72);
unwind.setupFrameReg(rbp, 48);
unwind.finishFunction(0, 0); unwind.finishFunction(0, 0);
unwind.finishInfo(); unwind.finishInfo();
@ -189,11 +169,36 @@ TEST_CASE("Dwarf2UnwindCodesX64")
data.resize(unwind.getSize()); data.resize(unwind.getSize());
unwind.finalize(data.data(), 0, nullptr, 0); unwind.finalize(data.data(), 0, nullptr, 0);
std::vector<uint8_t> expected{0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x10, 0x0c, 0x07, 0x08, 0x05, 0x10, 0x01, std::vector<uint8_t> expected{0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x10, 0x0c, 0x07, 0x08, 0x90, 0x01, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x0e, 0x10, 0x85, 0x02, 0x02, 0x02, 0x0e, 0x18, 0x84, 0x03, 0x02, 0x02, 0x0e, 0x20, 0x83, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x0e, 0x10, 0x86, 0x02, 0x02, 0x03, 0x02, 0x02, 0x0e, 0x18, 0x85, 0x03, 0x02, 0x02, 0x0e,
0x04, 0x02, 0x02, 0x0e, 0x28, 0x86, 0x05, 0x02, 0x02, 0x0e, 0x30, 0x8c, 0x06, 0x02, 0x02, 0x0e, 0x38, 0x8d, 0x07, 0x02, 0x02, 0x0e, 0x40, 0x20, 0x84, 0x04, 0x02, 0x02, 0x0e, 0x28, 0x83, 0x05, 0x02, 0x02, 0x0e, 0x30, 0x8c, 0x06, 0x02, 0x02, 0x0e, 0x38, 0x8d, 0x07, 0x02, 0x02,
0x8e, 0x08, 0x02, 0x02, 0x0e, 0x48, 0x8f, 0x09, 0x02, 0x04, 0x0e, 0x90, 0x01, 0x02, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00}; 0x0e, 0x40, 0x8e, 0x08, 0x02, 0x02, 0x0e, 0x48, 0x8f, 0x09, 0x02, 0x04, 0x0e, 0x90, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00};
REQUIRE(data.size() == expected.size());
CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0);
}
TEST_CASE("Dwarf2UnwindCodesA64")
{
using namespace A64;
UnwindBuilderDwarf2 unwind;
unwind.startInfo(UnwindBuilder::A64);
unwind.startFunction();
unwind.prologueA64(/* prologueSize= */ 28, /* stackSize= */ 64, {x29, x30, x19, x20, x21, x22, x23, x24});
unwind.finishFunction(0, 32);
unwind.finishInfo();
std::vector<char> data;
data.resize(unwind.getSize());
unwind.finalize(data.data(), 0, nullptr, 0);
std::vector<uint8_t> expected{0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x1e, 0x0c, 0x1f, 0x00, 0x2c, 0x00, 0x00,
0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x04,
0x0e, 0x40, 0x02, 0x18, 0x9d, 0x08, 0x9e, 0x07, 0x93, 0x06, 0x94, 0x05, 0x95, 0x04, 0x96, 0x03, 0x97, 0x02, 0x98, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00};
REQUIRE(data.size() == expected.size()); REQUIRE(data.size() == expected.size());
CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0);
@ -247,7 +252,7 @@ TEST_CASE("GeneratedCodeExecutionX64")
CHECK(result == 210); CHECK(result == 210);
} }
void throwing(int64_t arg) static void throwing(int64_t arg)
{ {
CHECK(arg == 25); CHECK(arg == 25);
@ -266,27 +271,25 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64")
std::unique_ptr<UnwindBuilder> unwind = std::make_unique<UnwindBuilderDwarf2>(); std::unique_ptr<UnwindBuilder> unwind = std::make_unique<UnwindBuilderDwarf2>();
#endif #endif
unwind->startInfo(); unwind->startInfo(UnwindBuilder::X64);
Label functionBegin = build.setLabel(); Label functionBegin = build.setLabel();
unwind->startFunction(); unwind->startFunction();
// Prologue // Prologue
build.push(rNonVol1);
unwind->save(rNonVol1);
build.push(rNonVol2);
unwind->save(rNonVol2);
build.push(rbp); build.push(rbp);
unwind->save(rbp); build.mov(rbp, rsp);
build.push(rNonVol1);
build.push(rNonVol2);
int stackSize = 32; int stackSize = 32;
int localsSize = 16; int localsSize = 16;
build.sub(rsp, stackSize + localsSize); build.sub(rsp, stackSize + localsSize);
unwind->allocStack(stackSize + localsSize);
build.lea(rbp, addr[rsp + stackSize]); uint32_t prologueSize = build.setLabel().location;
unwind->setupFrameReg(rbp, stackSize);
unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {rNonVol1, rNonVol2});
// Body // Body
build.mov(rNonVol1, rArg1); build.mov(rNonVol1, rArg1);
@ -297,10 +300,10 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64")
build.call(rNonVol2); build.call(rNonVol2);
// Epilogue // Epilogue
build.lea(rsp, addr[rbp + localsSize]); build.add(rsp, stackSize + localsSize);
build.pop(rbp);
build.pop(rNonVol2); build.pop(rNonVol2);
build.pop(rNonVol1); build.pop(rNonVol1);
build.pop(rbp);
build.ret(); build.ret();
unwind->finishFunction(build.getLabelOffset(functionBegin), ~0u); unwind->finishFunction(build.getLabelOffset(functionBegin), ~0u);
@ -349,7 +352,7 @@ TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64")
std::unique_ptr<UnwindBuilder> unwind = std::make_unique<UnwindBuilderDwarf2>(); std::unique_ptr<UnwindBuilder> unwind = std::make_unique<UnwindBuilderDwarf2>();
#endif #endif
unwind->startInfo(); unwind->startInfo(UnwindBuilder::X64);
Label start1; Label start1;
Label start2; Label start2;
@ -360,21 +363,19 @@ TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64")
unwind->startFunction(); unwind->startFunction();
// Prologue // Prologue
build.push(rNonVol1);
unwind->save(rNonVol1);
build.push(rNonVol2);
unwind->save(rNonVol2);
build.push(rbp); build.push(rbp);
unwind->save(rbp); build.mov(rbp, rsp);
build.push(rNonVol1);
build.push(rNonVol2);
int stackSize = 32; int stackSize = 32;
int localsSize = 16; int localsSize = 16;
build.sub(rsp, stackSize + localsSize); build.sub(rsp, stackSize + localsSize);
unwind->allocStack(stackSize + localsSize);
build.lea(rbp, addr[rsp + stackSize]); uint32_t prologueSize = build.setLabel().location - start1.location;
unwind->setupFrameReg(rbp, stackSize);
unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {rNonVol1, rNonVol2});
// Body // Body
build.mov(rNonVol1, rArg1); build.mov(rNonVol1, rArg1);
@ -385,41 +386,35 @@ TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64")
build.call(rNonVol2); build.call(rNonVol2);
// Epilogue // Epilogue
build.lea(rsp, addr[rbp + localsSize]); build.add(rsp, stackSize + localsSize);
build.pop(rbp);
build.pop(rNonVol2); build.pop(rNonVol2);
build.pop(rNonVol1); build.pop(rNonVol1);
build.pop(rbp);
build.ret(); build.ret();
Label end1 = build.setLabel(); Label end1 = build.setLabel();
unwind->finishFunction(build.getLabelOffset(start1), build.getLabelOffset(end1)); unwind->finishFunction(build.getLabelOffset(start1), build.getLabelOffset(end1));
} }
// Second function with different layout // Second function with different layout and no frame
{ {
build.setLabel(start2); build.setLabel(start2);
unwind->startFunction(); unwind->startFunction();
// Prologue // Prologue
build.push(rNonVol1); build.push(rNonVol1);
unwind->save(rNonVol1);
build.push(rNonVol2); build.push(rNonVol2);
unwind->save(rNonVol2);
build.push(rNonVol3); build.push(rNonVol3);
unwind->save(rNonVol3);
build.push(rNonVol4); build.push(rNonVol4);
unwind->save(rNonVol4);
build.push(rbp);
unwind->save(rbp);
int stackSize = 32; int stackSize = 32;
int localsSize = 32; int localsSize = 24;
build.sub(rsp, stackSize + localsSize); build.sub(rsp, stackSize + localsSize);
unwind->allocStack(stackSize + localsSize);
build.lea(rbp, addr[rsp + stackSize]); uint32_t prologueSize = build.setLabel().location - start2.location;
unwind->setupFrameReg(rbp, stackSize);
unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ false, {rNonVol1, rNonVol2, rNonVol3, rNonVol4});
// Body // Body
build.mov(rNonVol3, rArg1); build.mov(rNonVol3, rArg1);
@ -430,8 +425,7 @@ TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64")
build.call(rNonVol4); build.call(rNonVol4);
// Epilogue // Epilogue
build.lea(rsp, addr[rbp + localsSize]); build.add(rsp, stackSize + localsSize);
build.pop(rbp);
build.pop(rNonVol4); build.pop(rNonVol4);
build.pop(rNonVol3); build.pop(rNonVol3);
build.pop(rNonVol2); build.pop(rNonVol2);
@ -495,37 +489,29 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64")
std::unique_ptr<UnwindBuilder> unwind = std::make_unique<UnwindBuilderDwarf2>(); std::unique_ptr<UnwindBuilder> unwind = std::make_unique<UnwindBuilderDwarf2>();
#endif #endif
unwind->startInfo(); unwind->startInfo(UnwindBuilder::X64);
Label functionBegin = build.setLabel(); Label functionBegin = build.setLabel();
unwind->startFunction(); unwind->startFunction();
// Prologue (some of these registers don't have to be saved, but we want to have a big prologue) // Prologue (some of these registers don't have to be saved, but we want to have a big prologue)
build.push(r10);
unwind->save(r10);
build.push(r11);
unwind->save(r11);
build.push(r12);
unwind->save(r12);
build.push(r13);
unwind->save(r13);
build.push(r14);
unwind->save(r14);
build.push(r15);
unwind->save(r15);
build.push(rbp); build.push(rbp);
unwind->save(rbp); build.mov(rbp, rsp);
build.push(r10);
build.push(r11);
build.push(r12);
build.push(r13);
build.push(r14);
build.push(r15);
int stackSize = 64; int stackSize = 64;
int localsSize = 16; int localsSize = 16;
build.sub(rsp, stackSize + localsSize); build.sub(rsp, stackSize + localsSize);
unwind->allocStack(stackSize + localsSize);
build.lea(rbp, addr[rsp + stackSize]); uint32_t prologueSize = build.setLabel().location;
unwind->setupFrameReg(rbp, stackSize);
size_t prologueSize = build.setLabel().location; unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {r10, r11, r12, r13, r14, r15});
// Body // Body
build.mov(rax, rArg1); build.mov(rax, rArg1);
@ -535,14 +521,14 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64")
Label returnOffset = build.setLabel(); Label returnOffset = build.setLabel();
// Epilogue // Epilogue
build.lea(rsp, addr[rbp + localsSize]); build.add(rsp, stackSize + localsSize);
build.pop(rbp);
build.pop(r15); build.pop(r15);
build.pop(r14); build.pop(r14);
build.pop(r13); build.pop(r13);
build.pop(r12); build.pop(r12);
build.pop(r11); build.pop(r11);
build.pop(r10); build.pop(r10);
build.pop(rbp);
build.ret(); build.ret();
unwind->finishFunction(build.getLabelOffset(functionBegin), ~0u); unwind->finishFunction(build.getLabelOffset(functionBegin), ~0u);
@ -650,6 +636,78 @@ TEST_CASE("GeneratedCodeExecutionA64")
CHECK(result == 42); CHECK(result == 42);
} }
static void throwing(int64_t arg)
{
CHECK(arg == 25);
throw std::runtime_error("testing");
}
TEST_CASE("GeneratedCodeExecutionWithThrowA64")
{
using namespace A64;
AssemblyBuilderA64 build(/* logText= */ false);
std::unique_ptr<UnwindBuilder> unwind = std::make_unique<UnwindBuilderDwarf2>();
unwind->startInfo(UnwindBuilder::A64);
build.sub(sp, sp, 32);
build.stp(x29, x30, mem(sp));
build.str(x28, mem(sp, 16));
build.mov(x29, sp);
Label prologueEnd = build.setLabel();
build.add(x0, x0, 15);
build.blr(x1);
build.ldr(x28, mem(sp, 16));
build.ldp(x29, x30, mem(sp));
build.add(sp, sp, 32);
build.ret();
Label functionEnd = build.setLabel();
unwind->startFunction();
unwind->prologueA64(build.getLabelOffset(prologueEnd), 32, {x29, x30, x28});
unwind->finishFunction(0, build.getLabelOffset(functionEnd));
build.finalize();
unwind->finishInfo();
size_t blockSize = 1024 * 1024;
size_t maxTotalSize = 1024 * 1024;
CodeAllocator allocator(blockSize, maxTotalSize);
allocator.context = unwind.get();
allocator.createBlockUnwindInfo = createBlockUnwindInfo;
allocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo;
uint8_t* nativeData;
size_t sizeNativeData;
uint8_t* nativeEntry;
REQUIRE(allocator.allocate(build.data.data(), build.data.size(), reinterpret_cast<uint8_t*>(build.code.data()), build.code.size() * 4, nativeData,
sizeNativeData, nativeEntry));
REQUIRE(nativeEntry);
using FunctionType = int64_t(int64_t, void (*)(int64_t));
FunctionType* f = (FunctionType*)nativeEntry;
// To simplify debugging, CHECK_THROWS_WITH_AS is not used here
try
{
f(10, throwing);
}
catch (const std::runtime_error& error)
{
CHECK(strcmp(error.what(), "testing") == 0);
}
}
#endif #endif
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -532,6 +532,30 @@ bb_0:
)"); )");
} }
TEST_CASE_FIXTURE(IrBuilderFixture, "ReplacementPreservesUses")
{
IrOp block = build.block(IrBlockKind::Internal);
build.beginBlock(block);
IrOp unk = build.inst(IrCmd::LOAD_INT, build.vmReg(0));
build.inst(IrCmd::STORE_INT, build.vmReg(8), build.inst(IrCmd::BITXOR_UINT, unk, build.constInt(~0u)));
build.inst(IrCmd::RETURN, build.constUint(0));
updateUseCounts(build.function);
constantFold();
CHECK("\n" + toString(build.function, /* includeUseInfo */ true) == R"(
bb_0: ; useCount: 0
%0 = LOAD_INT R0 ; useCount: 1, lastUse: %0
%1 = BITNOT_UINT %0 ; useCount: 1, lastUse: %0
STORE_INT R8, %1 ; %2
RETURN 0u ; %3
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "NumericNan") TEST_CASE_FIXTURE(IrBuilderFixture, "NumericNan")
{ {
IrOp block = build.block(IrBlockKind::Internal); IrOp block = build.block(IrBlockKind::Internal);

View file

@ -470,8 +470,6 @@ TEST_SUITE_END();
struct NormalizeFixture : Fixture struct NormalizeFixture : Fixture
{ {
ScopedFastFlag sff2{"LuauNegatedClassTypes", true};
TypeArena arena; TypeArena arena;
InternalErrorReporter iceHandler; InternalErrorReporter iceHandler;
UnifierSharedState unifierState{&iceHandler}; UnifierSharedState unifierState{&iceHandler};
@ -632,11 +630,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "union_function_and_top_function")
TEST_CASE_FIXTURE(NormalizeFixture, "negated_function_is_anything_except_a_function") TEST_CASE_FIXTURE(NormalizeFixture, "negated_function_is_anything_except_a_function")
{ {
ScopedFastFlag sffs[] = {
{"LuauNegatedTableTypes", true},
{"LuauNegatedClassTypes", true},
};
CHECK("(boolean | class | number | string | table | thread)?" == toString(normal(R"( CHECK("(boolean | class | number | string | table | thread)?" == toString(normal(R"(
Not<fun> Not<fun>
)"))); )")));
@ -649,11 +642,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "specific_functions_cannot_be_negated")
TEST_CASE_FIXTURE(NormalizeFixture, "bare_negated_boolean") TEST_CASE_FIXTURE(NormalizeFixture, "bare_negated_boolean")
{ {
ScopedFastFlag sffs[] = {
{"LuauNegatedTableTypes", true},
{"LuauNegatedClassTypes", true},
};
// TODO: We don't yet have a way to say number | string | thread | nil | Class | Table | Function // TODO: We don't yet have a way to say number | string | thread | nil | Class | Table | Function
CHECK("(class | function | number | string | table | thread)?" == toString(normal(R"( CHECK("(class | function | number | string | table | thread)?" == toString(normal(R"(
Not<boolean> Not<boolean>
@ -723,8 +711,6 @@ export type t0 = (((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,}))&(((any)&({_:l0.t
TEST_CASE_FIXTURE(NormalizeFixture, "unions_of_classes") TEST_CASE_FIXTURE(NormalizeFixture, "unions_of_classes")
{ {
ScopedFastFlag sff{"LuauNegatedClassTypes", true};
createSomeClasses(&frontend); createSomeClasses(&frontend);
CHECK("Parent | Unrelated" == toString(normal("Parent | Unrelated"))); CHECK("Parent | Unrelated" == toString(normal("Parent | Unrelated")));
CHECK("Parent" == toString(normal("Parent | Child"))); CHECK("Parent" == toString(normal("Parent | Child")));
@ -733,8 +719,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "unions_of_classes")
TEST_CASE_FIXTURE(NormalizeFixture, "intersections_of_classes") TEST_CASE_FIXTURE(NormalizeFixture, "intersections_of_classes")
{ {
ScopedFastFlag sff{"LuauNegatedClassTypes", true};
createSomeClasses(&frontend); createSomeClasses(&frontend);
CHECK("Child" == toString(normal("Parent & Child"))); CHECK("Child" == toString(normal("Parent & Child")));
CHECK("never" == toString(normal("Child & Unrelated"))); CHECK("never" == toString(normal("Child & Unrelated")));
@ -742,8 +726,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "intersections_of_classes")
TEST_CASE_FIXTURE(NormalizeFixture, "narrow_union_of_classes_with_intersection") TEST_CASE_FIXTURE(NormalizeFixture, "narrow_union_of_classes_with_intersection")
{ {
ScopedFastFlag sff{"LuauNegatedClassTypes", true};
createSomeClasses(&frontend); createSomeClasses(&frontend);
CHECK("Child" == toString(normal("(Child | Unrelated) & Child"))); CHECK("Child" == toString(normal("(Child | Unrelated) & Child")));
} }
@ -764,11 +746,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "crazy_metatable")
TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_classes") TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_classes")
{ {
ScopedFastFlag sffs[] = {
{"LuauNegatedTableTypes", true},
{"LuauNegatedClassTypes", true},
};
createSomeClasses(&frontend); createSomeClasses(&frontend);
CHECK("(Parent & ~Child) | Unrelated" == toString(normal("(Parent & Not<Child>) | Unrelated"))); CHECK("(Parent & ~Child) | Unrelated" == toString(normal("(Parent & Not<Child>) | Unrelated")));
CHECK("((class & ~Child) | boolean | function | number | string | table | thread)?" == toString(normal("Not<Child>"))); CHECK("((class & ~Child) | boolean | function | number | string | table | thread)?" == toString(normal("Not<Child>")));
@ -781,24 +758,18 @@ TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_classes")
TEST_CASE_FIXTURE(NormalizeFixture, "classes_and_unknown") TEST_CASE_FIXTURE(NormalizeFixture, "classes_and_unknown")
{ {
ScopedFastFlag sff{"LuauNegatedClassTypes", true};
createSomeClasses(&frontend); createSomeClasses(&frontend);
CHECK("Parent" == toString(normal("Parent & unknown"))); CHECK("Parent" == toString(normal("Parent & unknown")));
} }
TEST_CASE_FIXTURE(NormalizeFixture, "classes_and_never") TEST_CASE_FIXTURE(NormalizeFixture, "classes_and_never")
{ {
ScopedFastFlag sff{"LuauNegatedClassTypes", true};
createSomeClasses(&frontend); createSomeClasses(&frontend);
CHECK("never" == toString(normal("Parent & never"))); CHECK("never" == toString(normal("Parent & never")));
} }
TEST_CASE_FIXTURE(NormalizeFixture, "top_table_type") TEST_CASE_FIXTURE(NormalizeFixture, "top_table_type")
{ {
ScopedFastFlag sff{"LuauNegatedTableTypes", true};
CHECK("table" == toString(normal("{} | tbl"))); CHECK("table" == toString(normal("{} | tbl")));
CHECK("{| |}" == toString(normal("{} & tbl"))); CHECK("{| |}" == toString(normal("{} & tbl")));
CHECK("never" == toString(normal("number & tbl"))); CHECK("never" == toString(normal("number & tbl")));
@ -806,8 +777,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "top_table_type")
TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_tables") TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_tables")
{ {
ScopedFastFlag sff{"LuauNegatedTableTypes", true};
CHECK(nullptr == toNormalizedType("Not<{}>")); CHECK(nullptr == toNormalizedType("Not<{}>"));
CHECK("(boolean | class | function | number | string | thread)?" == toString(normal("Not<tbl>"))); CHECK("(boolean | class | function | number | string | thread)?" == toString(normal("Not<tbl>")));
CHECK("table" == toString(normal("Not<Not<tbl>>"))); CHECK("table" == toString(normal("Not<Not<tbl>>")));

View file

@ -112,14 +112,6 @@ TEST_CASE_FIXTURE(Fixture, "can_haz_annotations")
REQUIRE(block != nullptr); REQUIRE(block != nullptr);
} }
TEST_CASE_FIXTURE(Fixture, "local_cannot_have_annotation_with_extensions_disabled")
{
Luau::ParseOptions options;
options.allowTypeAnnotations = false;
CHECK_THROWS_AS(parse("local foo: string = \"Hello Types!\"", options), std::exception);
}
TEST_CASE_FIXTURE(Fixture, "local_with_annotation") TEST_CASE_FIXTURE(Fixture, "local_with_annotation")
{ {
AstStatBlock* block = parse(R"( AstStatBlock* block = parse(R"(
@ -150,14 +142,6 @@ TEST_CASE_FIXTURE(Fixture, "type_names_can_contain_dots")
REQUIRE(block != nullptr); REQUIRE(block != nullptr);
} }
TEST_CASE_FIXTURE(Fixture, "functions_cannot_have_return_annotations_if_extensions_are_disabled")
{
Luau::ParseOptions options;
options.allowTypeAnnotations = false;
CHECK_THROWS_AS(parse("function foo(): number return 55 end", options), std::exception);
}
TEST_CASE_FIXTURE(Fixture, "functions_can_have_return_annotations") TEST_CASE_FIXTURE(Fixture, "functions_can_have_return_annotations")
{ {
AstStatBlock* block = parse(R"( AstStatBlock* block = parse(R"(
@ -395,14 +379,6 @@ TEST_CASE_FIXTURE(Fixture, "return_type_is_an_intersection_type_if_led_with_one_
CHECK(returnAnnotation->types.data[1]->as<AstTypeFunction>()); CHECK(returnAnnotation->types.data[1]->as<AstTypeFunction>());
} }
TEST_CASE_FIXTURE(Fixture, "illegal_type_alias_if_extensions_are_disabled")
{
Luau::ParseOptions options;
options.allowTypeAnnotations = false;
CHECK_THROWS_AS(parse("type A = number", options), std::exception);
}
TEST_CASE_FIXTURE(Fixture, "type_alias_to_a_typeof") TEST_CASE_FIXTURE(Fixture, "type_alias_to_a_typeof")
{ {
AstStatBlock* block = parse(R"( AstStatBlock* block = parse(R"(
@ -2837,8 +2813,6 @@ TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_no_comma_after_last_t
TEST_CASE_FIXTURE(Fixture, "missing_default_type_pack_argument_after_variadic_type_parameter") TEST_CASE_FIXTURE(Fixture, "missing_default_type_pack_argument_after_variadic_type_parameter")
{ {
ScopedFastFlag sff{"LuauParserErrorsOnMissingDefaultTypePackArgument", true};
ParseResult result = tryParse(R"( ParseResult result = tryParse(R"(
type Foo<T... = > = nil type Foo<T... = > = nil
)"); )");

View file

@ -108,6 +108,9 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2")
end end
)"); )");
if (FFlag::DebugLuauDeferredConstraintResolution)
LUAU_REQUIRE_ERROR_COUNT(2, result);
else
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("*error-type*", toString(requireType("a"))); CHECK_EQ("*error-type*", toString(requireType("a")));

View file

@ -169,13 +169,26 @@ TEST_CASE_FIXTURE(Fixture, "list_only_alternative_overloads_that_match_argument_
LUAU_REQUIRE_ERROR_COUNT(2, result); LUAU_REQUIRE_ERROR_COUNT(2, result);
if (FFlag::DebugLuauDeferredConstraintResolution)
{
GenericError* g = get<GenericError>(result.errors[0]);
REQUIRE(g);
CHECK(g->message == "None of the overloads for function that accept 1 arguments are compatible.");
}
else
{
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]); TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm); REQUIRE(tm);
CHECK_EQ(builtinTypes->numberType, tm->wantedType); CHECK_EQ(builtinTypes->numberType, tm->wantedType);
CHECK_EQ(builtinTypes->stringType, tm->givenType); CHECK_EQ(builtinTypes->stringType, tm->givenType);
}
ExtraInformation* ei = get<ExtraInformation>(result.errors[1]); ExtraInformation* ei = get<ExtraInformation>(result.errors[1]);
REQUIRE(ei); REQUIRE(ei);
if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK("Available overloads: (number) -> number; and (number) -> string" == ei->message);
else
CHECK_EQ("Other overloads are also not viable: (number) -> string", ei->message); CHECK_EQ("Other overloads are also not viable: (number) -> string", ei->message);
} }

View file

@ -2,6 +2,7 @@
#include "Luau/AstQuery.h" #include "Luau/AstQuery.h"
#include "Luau/BuiltinDefinitions.h" #include "Luau/BuiltinDefinitions.h"
#include "Luau/Frontend.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include "Luau/Type.h" #include "Luau/Type.h"
@ -31,6 +32,53 @@ TEST_CASE_FIXTURE(Fixture, "for_loop")
CHECK_EQ(*builtinTypes->numberType, *requireType("q")); CHECK_EQ(*builtinTypes->numberType, *requireType("q"));
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_no_table_passed")
{
ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true};
CheckResult result = check(R"(
type Iterable = typeof(setmetatable(
{},
{}::{
__iter: (self: Iterable) -> (any, number) -> (number, string)
}
))
local t: Iterable
for a, b in t do end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
GenericError* ge = get<GenericError>(result.errors[0]);
REQUIRE(ge);
CHECK_EQ("__iter metamethod must return (next[, table[, state]])", ge->message);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_regression_issue_69967")
{
ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true};
CheckResult result = check(R"(
type Iterable = typeof(setmetatable(
{},
{}::{
__iter: (self: Iterable) -> () -> (number, string)
}
))
local t: Iterable
for a, b in t do end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
GenericError* ge = get<GenericError>(result.errors[0]);
REQUIRE(ge);
CHECK_EQ("__iter metamethod must return (next[, table[, state]])", ge->message);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop") TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(

View file

@ -26,9 +26,17 @@ TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_not_defi
someTable.Function1() -- Argument count mismatch someTable.Function1() -- Argument count mismatch
)"); )");
if (FFlag::DebugLuauDeferredConstraintResolution)
{
LUAU_REQUIRE_ERROR_COUNT(2, result);
CHECK(toString(result.errors[0]) == "No overload for function accepts 0 arguments.");
CHECK(toString(result.errors[1]) == "Available overloads: <a>(a) -> ()");
}
else
{
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
REQUIRE(get<CountMismatch>(result.errors[0])); REQUIRE(get<CountMismatch>(result.errors[0]));
}
} }
TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2") TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2")
@ -42,9 +50,17 @@ TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_
someTable.Function2() -- Argument count mismatch someTable.Function2() -- Argument count mismatch
)"); )");
if (FFlag::DebugLuauDeferredConstraintResolution)
{
LUAU_REQUIRE_ERROR_COUNT(2, result);
CHECK(toString(result.errors[0]) == "No overload for function accepts 0 arguments.");
CHECK(toString(result.errors[1]) == "Available overloads: <a, b>(a, b) -> ()");
}
else
{
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
REQUIRE(get<CountMismatch>(result.errors[0])); REQUIRE(get<CountMismatch>(result.errors[0]));
}
} }
TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_another_overload_works") TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_another_overload_works")

View file

@ -52,6 +52,43 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete")
CHECK_EQ(expected, decorateWithTypes(code)); CHECK_EQ(expected, decorateWithTypes(code));
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.Array.filter")
{
// This test exercises the fact that we should reduce sealed/unsealed/free tables
// res is a unsealed table with type {((T & ~nil)?) & any}
// Because we do not reduce it fully, we cannot unify it with `Array<T> = { [number] : T}
// TLDR; reduction needs to reduce the indexer on res so it unifies with Array<T>
CheckResult result = check(R"(
--!strict
-- Implements Javascript's `Array.prototype.filter` as defined below
-- https://developer.cmozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/filter
type Array<T> = { [number]: T }
type callbackFn<T> = (element: T, index: number, array: Array<T>) -> boolean
type callbackFnWithThisArg<T, U> = (thisArg: U, element: T, index: number, array: Array<T>) -> boolean
type Object = { [string]: any }
return function<T, U>(t: Array<T>, callback: callbackFn<T> | callbackFnWithThisArg<T, U>, thisArg: U?): Array<T>
local len = #t
local res = {}
if thisArg == nil then
for i = 1, len do
local kValue = t[i]
if kValue ~= nil then
if (callback :: callbackFn<T>)(kValue, i, t) then
res[i] = kValue
end
end
end
else
end
return res
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall_returns_what_f_returns") TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall_returns_what_f_returns")
{ {
const std::string code = R"( const std::string code = R"(

View file

@ -8,7 +8,6 @@
#include "doctest.h" #include "doctest.h"
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauNegatedClassTypes)
using namespace Luau; using namespace Luau;
@ -64,7 +63,7 @@ struct RefinementClassFixture : BuiltinsFixture
TypeArena& arena = frontend.globals.globalTypes; TypeArena& arena = frontend.globals.globalTypes;
NotNull<Scope> scope{frontend.globals.globalScope.get()}; NotNull<Scope> scope{frontend.globals.globalScope.get()};
std::optional<TypeId> rootSuper = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; std::optional<TypeId> rootSuper = std::make_optional(builtinTypes->classType);
unfreeze(arena); unfreeze(arena);
TypeId vec3 = arena.addType(ClassType{"Vector3", {}, rootSuper, std::nullopt, {}, nullptr, "Test"}); TypeId vec3 = arena.addType(ClassType{"Vector3", {}, rootSuper, std::nullopt, {}, nullptr, "Test"});

View file

@ -131,8 +131,16 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons_mismatch")
)"); )");
LUAU_REQUIRE_ERROR_COUNT(2, result); LUAU_REQUIRE_ERROR_COUNT(2, result);
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ("None of the overloads for function that accept 2 arguments are compatible.", toString(result.errors[0]));
CHECK_EQ("Available overloads: (true, string) -> (); and (false, number) -> ()", toString(result.errors[1]));
}
else
{
CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0]));
CHECK_EQ("Other overloads are also not viable: (false, number) -> ()", toString(result.errors[1])); CHECK_EQ("Other overloads are also not viable: (false, number) -> ()", toString(result.errors[1]));
}
} }
TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") TEST_CASE_FIXTURE(Fixture, "enums_using_singletons")

View file

@ -3625,4 +3625,23 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "top_table_type_is_isomorphic_to_empty_sealed
)"); )");
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.Array.includes")
{
CheckResult result = check(R"(
type Array<T> = { [number]: T }
function indexOf<T>(array: Array<T>, searchElement: any, fromIndex: number?): number
return -1
end
return function<T>(array: Array<T>, searchElement: any, fromIndex: number?): boolean
return -1 ~= indexOf(array, searchElement, fromIndex)
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -1,9 +1,5 @@
AnnotationTests.too_many_type_params AnnotationTests.too_many_type_params
AstQuery.last_argument_function_call_type AstQuery.last_argument_function_call_type
AstQuery::getDocumentationSymbolAtPosition.overloaded_class_method
AstQuery::getDocumentationSymbolAtPosition.overloaded_fn
AstQuery::getDocumentationSymbolAtPosition.table_overloaded_function_prop
AutocompleteTest.autocomplete_response_perf1
BuiltinTests.aliased_string_format BuiltinTests.aliased_string_format
BuiltinTests.assert_removes_falsy_types BuiltinTests.assert_removes_falsy_types
BuiltinTests.assert_removes_falsy_types2 BuiltinTests.assert_removes_falsy_types2
@ -54,6 +50,7 @@ ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean
ProvisionalTests.free_options_cannot_be_unified_together ProvisionalTests.free_options_cannot_be_unified_together
ProvisionalTests.generic_type_leak_to_module_interface_variadic ProvisionalTests.generic_type_leak_to_module_interface_variadic
ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns
ProvisionalTests.luau-polyfill.Array.filter
ProvisionalTests.setmetatable_constrains_free_type_into_free_table ProvisionalTests.setmetatable_constrains_free_type_into_free_table
ProvisionalTests.specialization_binds_with_prototypes_too_early ProvisionalTests.specialization_binds_with_prototypes_too_early
ProvisionalTests.table_insert_with_a_singleton_argument ProvisionalTests.table_insert_with_a_singleton_argument
@ -146,7 +143,6 @@ TypeInferClasses.index_instance_property
TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties
TypeInferClasses.warn_when_prop_almost_matches TypeInferClasses.warn_when_prop_almost_matches
TypeInferFunctions.cannot_hoist_interior_defns_into_signature TypeInferFunctions.cannot_hoist_interior_defns_into_signature
TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists
TypeInferFunctions.function_cast_error_uses_correct_language TypeInferFunctions.function_cast_error_uses_correct_language
TypeInferFunctions.function_decl_non_self_sealed_overwrite_2 TypeInferFunctions.function_decl_non_self_sealed_overwrite_2
TypeInferFunctions.function_decl_non_self_unsealed_overwrite TypeInferFunctions.function_decl_non_self_unsealed_overwrite
@ -158,7 +154,6 @@ TypeInferFunctions.infer_that_function_does_not_return_a_table
TypeInferFunctions.luau_subtyping_is_np_hard TypeInferFunctions.luau_subtyping_is_np_hard
TypeInferFunctions.no_lossy_function_type TypeInferFunctions.no_lossy_function_type
TypeInferFunctions.occurs_check_failure_in_function_return_type TypeInferFunctions.occurs_check_failure_in_function_return_type
TypeInferFunctions.record_matching_overload
TypeInferFunctions.report_exiting_without_return_strict TypeInferFunctions.report_exiting_without_return_strict
TypeInferFunctions.return_type_by_overload TypeInferFunctions.return_type_by_overload
TypeInferFunctions.too_few_arguments_variadic TypeInferFunctions.too_few_arguments_variadic
@ -205,6 +200,7 @@ TypePackTests.variadic_packs
TypeSingletons.function_call_with_singletons TypeSingletons.function_call_with_singletons
TypeSingletons.function_call_with_singletons_mismatch TypeSingletons.function_call_with_singletons_mismatch
TypeSingletons.no_widening_from_callsites TypeSingletons.no_widening_from_callsites
TypeSingletons.overloaded_function_call_with_singletons_mismatch
TypeSingletons.return_type_of_f_is_not_widened TypeSingletons.return_type_of_f_is_not_widened
TypeSingletons.table_properties_type_error_escapes TypeSingletons.table_properties_type_error_escapes
TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton