Merge branch 'upstream' into merge

This commit is contained in:
Arseny Kapoulkine 2022-05-05 16:53:50 -07:00
commit 08ac2176c5
50 changed files with 2658 additions and 727 deletions

View file

@ -145,7 +145,6 @@ struct Frontend
*/ */
std::pair<SourceModule, LintResult> lintFragment(std::string_view source, std::optional<LintOptions> enabledLintWarnings = {}); std::pair<SourceModule, LintResult> lintFragment(std::string_view source, std::optional<LintOptions> enabledLintWarnings = {});
CheckResult check(const SourceModule& module); // OLD. TODO KILL
LintResult lint(const SourceModule& module, std::optional<LintOptions> enabledLintWarnings = {}); LintResult lint(const SourceModule& module, std::optional<LintOptions> enabledLintWarnings = {});
bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; bool isDirty(const ModuleName& name, bool forAutocomplete = false) const;

View file

@ -1,9 +1,15 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include <unordered_set>
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/TypeVar.h" #include "Luau/RecursionCounter.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
LUAU_FASTFLAG(LuauUseVisitRecursionLimit)
LUAU_FASTINT(LuauVisitRecursionLimit)
namespace Luau namespace Luau
{ {
@ -219,24 +225,321 @@ void visit(TypePackId tp, F& f, Set& seen)
} // namespace visit_detail } // namespace visit_detail
template<typename S>
struct GenericTypeVarVisitor
{
using Set = S;
Set seen;
int recursionCounter = 0;
GenericTypeVarVisitor() = default;
explicit GenericTypeVarVisitor(Set seen)
: seen(std::move(seen))
{
}
virtual void cycle(TypeId) {}
virtual void cycle(TypePackId) {}
virtual bool visit(TypeId ty)
{
return true;
}
virtual bool visit(TypeId ty, const BoundTypeVar& btv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const FreeTypeVar& ftv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const GenericTypeVar& gtv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const ErrorTypeVar& etv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const ConstrainedTypeVar& ctv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const PrimitiveTypeVar& ptv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const FunctionTypeVar& ftv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const TableTypeVar& ttv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const MetatableTypeVar& mtv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const ClassTypeVar& ctv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const AnyTypeVar& atv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const UnionTypeVar& utv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const IntersectionTypeVar& itv)
{
return visit(ty);
}
virtual bool visit(TypePackId tp)
{
return true;
}
virtual bool visit(TypePackId tp, const BoundTypePack& btp)
{
return visit(tp);
}
virtual bool visit(TypePackId tp, const FreeTypePack& ftp)
{
return visit(tp);
}
virtual bool visit(TypePackId tp, const GenericTypePack& gtp)
{
return visit(tp);
}
virtual bool visit(TypePackId tp, const Unifiable::Error& etp)
{
return visit(tp);
}
virtual bool visit(TypePackId tp, const TypePack& pack)
{
return visit(tp);
}
virtual bool visit(TypePackId tp, const VariadicTypePack& vtp)
{
return visit(tp);
}
void traverse(TypeId ty)
{
RecursionLimiter limiter{&recursionCounter, FInt::LuauVisitRecursionLimit, "TypeVarVisitor"};
if (visit_detail::hasSeen(seen, ty))
{
cycle(ty);
return;
}
if (auto btv = get<BoundTypeVar>(ty))
{
if (visit(ty, *btv))
traverse(btv->boundTo);
}
else if (auto ftv = get<FreeTypeVar>(ty))
visit(ty, *ftv);
else if (auto gtv = get<GenericTypeVar>(ty))
visit(ty, *gtv);
else if (auto etv = get<ErrorTypeVar>(ty))
visit(ty, *etv);
else if (auto ctv = get<ConstrainedTypeVar>(ty))
{
if (visit(ty, *ctv))
{
for (TypeId part : ctv->parts)
traverse(part);
}
}
else if (auto ptv = get<PrimitiveTypeVar>(ty))
visit(ty, *ptv);
else if (auto ftv = get<FunctionTypeVar>(ty))
{
if (visit(ty, *ftv))
{
traverse(ftv->argTypes);
traverse(ftv->retType);
}
}
else if (auto ttv = get<TableTypeVar>(ty))
{
// Some visitors want to see bound tables, that's why we traverse the original type
if (visit(ty, *ttv))
{
if (ttv->boundTo)
{
traverse(*ttv->boundTo);
}
else
{
for (auto& [_name, prop] : ttv->props)
traverse(prop.type);
if (ttv->indexer)
{
traverse(ttv->indexer->indexType);
traverse(ttv->indexer->indexResultType);
}
}
}
}
else if (auto mtv = get<MetatableTypeVar>(ty))
{
if (visit(ty, *mtv))
{
traverse(mtv->table);
traverse(mtv->metatable);
}
}
else if (auto ctv = get<ClassTypeVar>(ty))
{
if (visit(ty, *ctv))
{
for (const auto& [name, prop] : ctv->props)
traverse(prop.type);
if (ctv->parent)
traverse(*ctv->parent);
if (ctv->metatable)
traverse(*ctv->metatable);
}
}
else if (auto atv = get<AnyTypeVar>(ty))
visit(ty, *atv);
else if (auto utv = get<UnionTypeVar>(ty))
{
if (visit(ty, *utv))
{
for (TypeId optTy : utv->options)
traverse(optTy);
}
}
else if (auto itv = get<IntersectionTypeVar>(ty))
{
if (visit(ty, *itv))
{
for (TypeId partTy : itv->parts)
traverse(partTy);
}
}
visit_detail::unsee(seen, ty);
}
void traverse(TypePackId tp)
{
if (visit_detail::hasSeen(seen, tp))
{
cycle(tp);
return;
}
if (auto btv = get<BoundTypePack>(tp))
{
if (visit(tp, *btv))
traverse(btv->boundTo);
}
else if (auto ftv = get<Unifiable::Free>(tp))
visit(tp, *ftv);
else if (auto gtv = get<Unifiable::Generic>(tp))
visit(tp, *gtv);
else if (auto etv = get<Unifiable::Error>(tp))
visit(tp, *etv);
else if (auto pack = get<TypePack>(tp))
{
visit(tp, *pack);
for (TypeId ty : pack->head)
traverse(ty);
if (pack->tail)
traverse(*pack->tail);
}
else if (auto pack = get<VariadicTypePack>(tp))
{
visit(tp, *pack);
traverse(pack->ty);
}
else
LUAU_ASSERT(!"GenericTypeVarVisitor::traverse(TypePackId) is not exhaustive!");
visit_detail::unsee(seen, tp);
}
};
/** Visit each type under a given type. Skips over cycles and keeps recursion depth under control.
*
* The same type may be visited multiple times if there are multiple distinct paths to it. If this is undesirable, use
* TypeVarOnceVisitor.
*/
struct TypeVarVisitor : GenericTypeVarVisitor<std::unordered_set<void*>>
{
};
/// Visit each type under a given type. Each type will only be checked once even if there are multiple paths to it.
struct TypeVarOnceVisitor : GenericTypeVarVisitor<DenseHashSet<void*>>
{
TypeVarOnceVisitor()
: GenericTypeVarVisitor{DenseHashSet<void*>{nullptr}}
{
}
};
// Clip with FFlagLuauUseVisitRecursionLimit
template<typename TID, typename F> template<typename TID, typename F>
void visitTypeVar(TID ty, F& f, std::unordered_set<void*>& seen) void DEPRECATED_visitTypeVar(TID ty, F& f, std::unordered_set<void*>& seen)
{ {
visit_detail::visit(ty, f, seen); visit_detail::visit(ty, f, seen);
} }
// Delete and inline when clipping FFlagLuauUseVisitRecursionLimit
template<typename TID, typename F> template<typename TID, typename F>
void visitTypeVar(TID ty, F& f) void DEPRECATED_visitTypeVar(TID ty, F& f)
{ {
if (FFlag::LuauUseVisitRecursionLimit)
f.traverse(ty);
else
{
std::unordered_set<void*> seen; std::unordered_set<void*> seen;
visit_detail::visit(ty, f, seen); visit_detail::visit(ty, f, seen);
}
} }
// Delete and inline when clipping FFlagLuauUseVisitRecursionLimit
template<typename TID, typename F> template<typename TID, typename F>
void visitTypeVarOnce(TID ty, F& f, DenseHashSet<void*>& seen) void DEPRECATED_visitTypeVarOnce(TID ty, F& f, DenseHashSet<void*>& seen)
{ {
if (FFlag::LuauUseVisitRecursionLimit)
f.traverse(ty);
else
{
seen.clear(); seen.clear();
visit_detail::visit(ty, f, seen); visit_detail::visit(ty, f, seen);
}
} }
} // namespace Luau } // namespace Luau

View file

@ -14,7 +14,6 @@
#include <utility> #include <utility>
LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false);
LUAU_FASTFLAGVARIABLE(LuauAutocompleteSingletonTypes, false);
LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteClassSecurityLevel, false); LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteClassSecurityLevel, false);
LUAU_FASTFLAG(LuauSelfCallAutocompleteFix) LUAU_FASTFLAG(LuauSelfCallAutocompleteFix)
@ -1341,8 +1340,6 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul
scope = scope->parent; scope = scope->parent;
} }
if (FFlag::LuauAutocompleteSingletonTypes)
{
TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType);
TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().trueType); TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().trueType);
TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().falseType); TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().falseType);
@ -1359,21 +1356,6 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul
if (auto ty = findExpectedTypeAt(module, node, position)) if (auto ty = findExpectedTypeAt(module, node, position))
autocompleteStringSingleton(*ty, true, result); autocompleteStringSingleton(*ty, true, result);
} }
else
{
TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType);
TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.booleanType);
TypeCorrectKind correctForFunction =
functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None;
result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false};
result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean};
result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean};
result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil};
result["not"] = {AutocompleteEntryKind::Keyword};
result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction};
}
}
} }
static AutocompleteEntryMap autocompleteExpression(const SourceModule& sourceModule, const Module& module, const TypeChecker& typeChecker, static AutocompleteEntryMap autocompleteExpression(const SourceModule& sourceModule, const Module& module, const TypeChecker& typeChecker,
@ -1680,11 +1662,8 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
{ {
AutocompleteEntryMap result; AutocompleteEntryMap result;
if (FFlag::LuauAutocompleteSingletonTypes)
{
if (auto it = module->astExpectedTypes.find(node->asExpr())) if (auto it = module->astExpectedTypes.find(node->asExpr()))
autocompleteStringSingleton(*it, false, result); autocompleteStringSingleton(*it, false, result);
}
if (finder.ancestry.size() >= 2) if (finder.ancestry.size() >= 2)
{ {
@ -1693,8 +1672,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
if (auto it = module->astTypes.find(idxExpr->expr)) if (auto it = module->astTypes.find(idxExpr->expr))
autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry, result); autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry, result);
} }
else if (auto binExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as<AstExprBinary>(); else if (auto binExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as<AstExprBinary>())
binExpr && FFlag::LuauAutocompleteSingletonTypes)
{ {
if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe)
{ {

View file

@ -18,7 +18,6 @@
LUAU_FASTINT(LuauTypeInferIterationLimit) LUAU_FASTINT(LuauTypeInferIterationLimit)
LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauCyclicModuleTypeSurface)
LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAG(LuauInferInNoCheckMode)
LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false)
LUAU_FASTFLAGVARIABLE(LuauSeparateTypechecks, false) LUAU_FASTFLAGVARIABLE(LuauSeparateTypechecks, false)
@ -433,7 +432,6 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
{ {
// The autocomplete typecheck is always in strict mode with DM awareness // The autocomplete typecheck is always in strict mode with DM awareness
// to provide better type information for IDE features // to provide better type information for IDE features
if (FFlag::LuauCyclicModuleTypeSurface)
typeCheckerForAutocomplete.requireCycles = requireCycles; typeCheckerForAutocomplete.requireCycles = requireCycles;
if (autocompleteTimeLimit != 0.0) if (autocompleteTimeLimit != 0.0)
@ -483,7 +481,6 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
continue; continue;
} }
if (FFlag::LuauCyclicModuleTypeSurface)
typeChecker.requireCycles = requireCycles; typeChecker.requireCycles = requireCycles;
ModulePtr module = typeChecker.check(sourceModule, mode, environmentScope); ModulePtr module = typeChecker.check(sourceModule, mode, environmentScope);
@ -493,7 +490,6 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
// to provide better typen information for IDE features. // to provide better typen information for IDE features.
if (!FFlag::LuauSeparateTypechecks && frontendOptions.typecheckTwice_DEPRECATED) if (!FFlag::LuauSeparateTypechecks && frontendOptions.typecheckTwice_DEPRECATED)
{ {
if (FFlag::LuauCyclicModuleTypeSurface)
typeCheckerForAutocomplete.requireCycles = requireCycles; typeCheckerForAutocomplete.requireCycles = requireCycles;
ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict); ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict);
@ -706,30 +702,6 @@ std::pair<SourceModule, LintResult> Frontend::lintFragment(std::string_view sour
return {std::move(sourceModule), classifyLints(warnings, config)}; return {std::move(sourceModule), classifyLints(warnings, config)};
} }
CheckResult Frontend::check(const SourceModule& module)
{
LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend");
LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str());
const Config& config = configResolver->getConfig(module.name);
Mode mode = module.mode.value_or(config.mode);
double timestamp = getTimestamp();
ModulePtr checkedModule = typeChecker.check(module, mode);
stats.timeCheck += getTimestamp() - timestamp;
stats.filesStrict += mode == Mode::Strict;
stats.filesNonstrict += mode == Mode::Nonstrict;
if (checkedModule == nullptr)
throw std::runtime_error("Frontend::check produced a nullptr module for module " + module.name);
moduleResolver.modules[module.name] = checkedModule;
return CheckResult{checkedModule->errors};
}
LintResult Frontend::lint(const SourceModule& module, std::optional<Luau::LintOptions> enabledLintWarnings) LintResult Frontend::lint(const SourceModule& module, std::optional<Luau::LintOptions> enabledLintWarnings)
{ {
LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend");

View file

@ -304,37 +304,23 @@ static bool areNormal(TypePackId tp, const std::unordered_set<void*>& seen, Inte
++iterationLimit; \ ++iterationLimit; \
} while (false) } while (false)
struct Normalize struct Normalize final : TypeVarVisitor
{ {
using TypeVarVisitor::Set;
Normalize(TypeArena& arena, InternalErrorReporter& ice)
: arena(arena)
, ice(ice)
{
}
TypeArena& arena; TypeArena& arena;
InternalErrorReporter& ice; InternalErrorReporter& ice;
// Debug data. Types being normalized are invalidated but trying to see what's going on is painful.
// To actually see the original type, read it by using the pointer of the type being normalized.
// e.g. in lldb, `e dump(originalTys[ty])`.
SeenTypes originalTys;
SeenTypePacks originalTps;
int iterationLimit = 0; int iterationLimit = 0;
bool limitExceeded = false; bool limitExceeded = false;
template<typename T> // TODO: Clip with FFlag::LuauUseVisitRecursionLimit
bool operator()(TypePackId, const T&)
{
return true;
}
template<typename TID>
void cycle(TID)
{
}
bool operator()(TypeId ty, const FreeTypeVar&)
{
LUAU_ASSERT(!ty->normal);
return false;
}
bool operator()(TypeId ty, const BoundTypeVar& btv, std::unordered_set<void*>& seen) bool operator()(TypeId ty, const BoundTypeVar& btv, std::unordered_set<void*>& seen)
{ {
// A type could be considered normal when it is in the stack, but we will eventually find out it is not normal as normalization progresses. // A type could be considered normal when it is in the stack, but we will eventually find out it is not normal as normalization progresses.
@ -349,27 +335,22 @@ struct Normalize
return !ty->normal; return !ty->normal;
} }
bool operator()(TypeId ty, const PrimitiveTypeVar&) bool operator()(TypeId ty, const FreeTypeVar& ftv)
{ {
LUAU_ASSERT(ty->normal); return visit(ty, ftv);
return false;
} }
bool operator()(TypeId ty, const PrimitiveTypeVar& ptv)
bool operator()(TypeId ty, const GenericTypeVar&)
{ {
if (!ty->normal) return visit(ty, ptv);
asMutable(ty)->normal = true;
return false;
} }
bool operator()(TypeId ty, const GenericTypeVar& gtv)
bool operator()(TypeId ty, const ErrorTypeVar&)
{ {
if (!ty->normal) return visit(ty, gtv);
asMutable(ty)->normal = true; }
return false; bool operator()(TypeId ty, const ErrorTypeVar& etv)
{
return visit(ty, etv);
} }
bool operator()(TypeId ty, const ConstrainedTypeVar& ctvRef, std::unordered_set<void*>& seen) bool operator()(TypeId ty, const ConstrainedTypeVar& ctvRef, std::unordered_set<void*>& seen)
{ {
CHECK_ITERATION_LIMIT(false); CHECK_ITERATION_LIMIT(false);
@ -470,17 +451,12 @@ struct Normalize
bool operator()(TypeId ty, const ClassTypeVar& ctv) bool operator()(TypeId ty, const ClassTypeVar& ctv)
{ {
if (!ty->normal) return visit(ty, ctv);
asMutable(ty)->normal = true;
return false;
} }
bool operator()(TypeId ty, const AnyTypeVar& atv)
bool operator()(TypeId ty, const AnyTypeVar&)
{ {
LUAU_ASSERT(ty->normal); return visit(ty, atv);
return false;
} }
bool operator()(TypeId ty, const UnionTypeVar& utvRef, std::unordered_set<void*>& seen) bool operator()(TypeId ty, const UnionTypeVar& utvRef, std::unordered_set<void*>& seen)
{ {
CHECK_ITERATION_LIMIT(false); CHECK_ITERATION_LIMIT(false);
@ -570,8 +546,257 @@ struct Normalize
return false; return false;
} }
bool operator()(TypeId ty, const LazyTypeVar&) // TODO: Clip with FFlag::LuauUseVisitRecursionLimit
template<typename T>
bool operator()(TypePackId, const T&)
{ {
return true;
}
// TODO: Clip with FFlag::LuauUseVisitRecursionLimit
template<typename TID>
void cycle(TID)
{
}
bool visit(TypeId ty, const FreeTypeVar&) override
{
LUAU_ASSERT(!ty->normal);
return false;
}
bool visit(TypeId ty, const BoundTypeVar& btv) override
{
// A type could be considered normal when it is in the stack, but we will eventually find out it is not normal as normalization progresses.
// So we need to avoid eagerly saying that this bound type is normal if the thing it is bound to is in the stack.
if (seen.find(asMutable(btv.boundTo)) != seen.end())
return false;
// It should never be the case that this TypeVar is normal, but is bound to a non-normal type, except in nontrivial cases.
LUAU_ASSERT(!ty->normal || ty->normal == btv.boundTo->normal);
asMutable(ty)->normal = btv.boundTo->normal;
return !ty->normal;
}
bool visit(TypeId ty, const PrimitiveTypeVar&) override
{
LUAU_ASSERT(ty->normal);
return false;
}
bool visit(TypeId ty, const GenericTypeVar&) override
{
if (!ty->normal)
asMutable(ty)->normal = true;
return false;
}
bool visit(TypeId ty, const ErrorTypeVar&) override
{
if (!ty->normal)
asMutable(ty)->normal = true;
return false;
}
bool visit(TypeId ty, const ConstrainedTypeVar& ctvRef) override
{
CHECK_ITERATION_LIMIT(false);
ConstrainedTypeVar* ctv = const_cast<ConstrainedTypeVar*>(&ctvRef);
std::vector<TypeId> parts = std::move(ctv->parts);
// We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar
for (TypeId part : parts)
traverse(part);
std::vector<TypeId> newParts = normalizeUnion(parts);
const bool normal = areNormal(newParts, seen, ice);
if (newParts.size() == 1)
*asMutable(ty) = BoundTypeVar{newParts[0]};
else
*asMutable(ty) = UnionTypeVar{std::move(newParts)};
asMutable(ty)->normal = normal;
return false;
}
bool visit(TypeId ty, const FunctionTypeVar& ftv) override
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
traverse(ftv.argTypes);
traverse(ftv.retType);
asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retType, seen, ice);
return false;
}
bool visit(TypeId ty, const TableTypeVar& ttv) override
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
bool normal = true;
auto checkNormal = [&](TypeId t) {
// if t is on the stack, it is possible that this type is normal.
// If t is not normal and it is not on the stack, this type is definitely not normal.
if (!t->normal && seen.find(asMutable(t)) == seen.end())
normal = false;
};
if (ttv.boundTo)
{
traverse(*ttv.boundTo);
asMutable(ty)->normal = (*ttv.boundTo)->normal;
return false;
}
for (const auto& [_name, prop] : ttv.props)
{
traverse(prop.type);
checkNormal(prop.type);
}
if (ttv.indexer)
{
traverse(ttv.indexer->indexType);
checkNormal(ttv.indexer->indexType);
traverse(ttv.indexer->indexResultType);
checkNormal(ttv.indexer->indexResultType);
}
asMutable(ty)->normal = normal;
return false;
}
bool visit(TypeId ty, const MetatableTypeVar& mtv) override
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
traverse(mtv.table);
traverse(mtv.metatable);
asMutable(ty)->normal = mtv.table->normal && mtv.metatable->normal;
return false;
}
bool visit(TypeId ty, const ClassTypeVar& ctv) override
{
if (!ty->normal)
asMutable(ty)->normal = true;
return false;
}
bool visit(TypeId ty, const AnyTypeVar&) override
{
LUAU_ASSERT(ty->normal);
return false;
}
bool visit(TypeId ty, const UnionTypeVar& utvRef) override
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
UnionTypeVar* utv = &const_cast<UnionTypeVar&>(utvRef);
std::vector<TypeId> options = std::move(utv->options);
// We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar
for (TypeId option : options)
traverse(option);
std::vector<TypeId> newOptions = normalizeUnion(options);
const bool normal = areNormal(newOptions, seen, ice);
LUAU_ASSERT(!newOptions.empty());
if (newOptions.size() == 1)
*asMutable(ty) = BoundTypeVar{newOptions[0]};
else
utv->options = std::move(newOptions);
asMutable(ty)->normal = normal;
return false;
}
bool visit(TypeId ty, const IntersectionTypeVar& itvRef) override
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
IntersectionTypeVar* itv = &const_cast<IntersectionTypeVar&>(itvRef);
std::vector<TypeId> oldParts = std::move(itv->parts);
for (TypeId part : oldParts)
traverse(part);
std::vector<TypeId> tables;
for (TypeId part : oldParts)
{
part = follow(part);
if (get<TableTypeVar>(part))
tables.push_back(part);
else
{
Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD
combineIntoIntersection(replacer, itv, part);
}
}
// Don't allocate a new table if there's just one in the intersection.
if (tables.size() == 1)
itv->parts.push_back(tables[0]);
else if (!tables.empty())
{
const TableTypeVar* first = get<TableTypeVar>(tables[0]);
LUAU_ASSERT(first);
TypeId newTable = arena.addType(TableTypeVar{first->state, first->level});
TableTypeVar* ttv = getMutable<TableTypeVar>(newTable);
for (TypeId part : tables)
{
// Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need
// to be rewritten to point at 'newTable' in the clone.
Replacer replacer{&arena, part, newTable};
combineIntoTable(replacer, ttv, part);
}
itv->parts.push_back(newTable);
}
asMutable(ty)->normal = areNormal(itv->parts, seen, ice);
if (itv->parts.size() == 1)
{
TypeId part = itv->parts[0];
*asMutable(ty) = BoundTypeVar{part};
}
return false; return false;
} }
@ -778,9 +1003,9 @@ std::pair<TypeId, bool> normalize(TypeId ty, TypeArena& arena, InternalErrorRepo
if (FFlag::DebugLuauCopyBeforeNormalizing) if (FFlag::DebugLuauCopyBeforeNormalizing)
(void)clone(ty, arena, state); (void)clone(ty, arena, state);
Normalize n{arena, ice, std::move(state.seenTypes), std::move(state.seenTypePacks)}; Normalize n{arena, ice};
std::unordered_set<void*> seen; std::unordered_set<void*> seen;
visitTypeVar(ty, n, seen); DEPRECATED_visitTypeVar(ty, n, seen);
return {ty, !n.limitExceeded}; return {ty, !n.limitExceeded};
} }
@ -803,9 +1028,9 @@ std::pair<TypePackId, bool> normalize(TypePackId tp, TypeArena& arena, InternalE
if (FFlag::DebugLuauCopyBeforeNormalizing) if (FFlag::DebugLuauCopyBeforeNormalizing)
(void)clone(tp, arena, state); (void)clone(tp, arena, state);
Normalize n{arena, ice, std::move(state.seenTypes), std::move(state.seenTypePacks)}; Normalize n{arena, ice};
std::unordered_set<void*> seen; std::unordered_set<void*> seen;
visitTypeVar(tp, n, seen); DEPRECATED_visitTypeVar(tp, n, seen);
return {tp, !n.limitExceeded}; return {tp, !n.limitExceeded};
} }

View file

@ -9,7 +9,7 @@ LUAU_FASTFLAG(LuauTypecheckOptPass)
namespace Luau namespace Luau
{ {
struct Quantifier struct Quantifier final : TypeVarOnceVisitor
{ {
TypeLevel level; TypeLevel level;
std::vector<TypeId> generics; std::vector<TypeId> generics;
@ -17,26 +17,17 @@ struct Quantifier
bool seenGenericType = false; bool seenGenericType = false;
bool seenMutableType = false; bool seenMutableType = false;
Quantifier(TypeLevel level) explicit Quantifier(TypeLevel level)
: level(level) : level(level)
{ {
} }
void cycle(TypeId) {} void cycle(TypeId) override {}
void cycle(TypePackId) {} void cycle(TypePackId) override {}
bool operator()(TypeId ty, const FreeTypeVar& ftv) bool operator()(TypeId ty, const FreeTypeVar& ftv)
{ {
if (FFlag::LuauTypecheckOptPass) return visit(ty, ftv);
seenMutableType = true;
if (!level.subsumes(ftv.level))
return false;
*asMutable(ty) = GenericTypeVar{level};
generics.push_back(ty);
return false;
} }
template<typename T> template<typename T>
@ -56,8 +47,33 @@ struct Quantifier
return true; return true;
} }
bool operator()(TypeId ty, const TableTypeVar&) bool operator()(TypeId ty, const TableTypeVar& ttv)
{ {
return visit(ty, ttv);
}
bool operator()(TypePackId tp, const FreeTypePack& ftp)
{
return visit(tp, ftp);
}
bool visit(TypeId ty, const FreeTypeVar& ftv) override
{
if (FFlag::LuauTypecheckOptPass)
seenMutableType = true;
if (!level.subsumes(ftv.level))
return false;
*asMutable(ty) = GenericTypeVar{level};
generics.push_back(ty);
return false;
}
bool visit(TypeId ty, const TableTypeVar&) override
{
LUAU_ASSERT(getMutable<TableTypeVar>(ty));
TableTypeVar& ttv = *getMutable<TableTypeVar>(ty); TableTypeVar& ttv = *getMutable<TableTypeVar>(ty);
if (FFlag::LuauTypecheckOptPass) if (FFlag::LuauTypecheckOptPass)
@ -93,7 +109,7 @@ struct Quantifier
return true; return true;
} }
bool operator()(TypePackId tp, const FreeTypePack& ftp) bool visit(TypePackId tp, const FreeTypePack& ftp) override
{ {
if (FFlag::LuauTypecheckOptPass) if (FFlag::LuauTypecheckOptPass)
seenMutableType = true; seenMutableType = true;
@ -111,7 +127,7 @@ void quantify(TypeId ty, TypeLevel level)
{ {
Quantifier q{level}; Quantifier q{level};
DenseHashSet<void*> seen{nullptr}; DenseHashSet<void*> seen{nullptr};
visitTypeVarOnce(ty, q, seen); DEPRECATED_visitTypeVarOnce(ty, q, seen);
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(ty); FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(ty);
LUAU_ASSERT(ftv); LUAU_ASSERT(ftv);

View file

@ -26,7 +26,7 @@ namespace Luau
namespace namespace
{ {
struct FindCyclicTypes struct FindCyclicTypes final : TypeVarVisitor
{ {
FindCyclicTypes() = default; FindCyclicTypes() = default;
FindCyclicTypes(const FindCyclicTypes&) = delete; FindCyclicTypes(const FindCyclicTypes&) = delete;
@ -38,20 +38,22 @@ struct FindCyclicTypes
std::set<TypeId> cycles; std::set<TypeId> cycles;
std::set<TypePackId> cycleTPs; std::set<TypePackId> cycleTPs;
void cycle(TypeId ty) void cycle(TypeId ty) override
{ {
cycles.insert(ty); cycles.insert(ty);
} }
void cycle(TypePackId tp) void cycle(TypePackId tp) override
{ {
cycleTPs.insert(tp); cycleTPs.insert(tp);
} }
// TODO: Clip all the operator()s when we clip FFlagLuauUseVisitRecursionLimit
template<typename T> template<typename T>
bool operator()(TypeId ty, const T&) bool operator()(TypeId ty, const T&)
{ {
return visited.insert(ty).second; return visit(ty);
} }
bool operator()(TypeId ty, const TableTypeVar& ttv) = delete; bool operator()(TypeId ty, const TableTypeVar& ttv) = delete;
@ -64,10 +66,10 @@ struct FindCyclicTypes
if (ttv.name || ttv.syntheticName) if (ttv.name || ttv.syntheticName)
{ {
for (TypeId itp : ttv.instantiatedTypeParams) for (TypeId itp : ttv.instantiatedTypeParams)
visitTypeVar(itp, *this, seen); DEPRECATED_visitTypeVar(itp, *this, seen);
for (TypePackId itp : ttv.instantiatedTypePackParams) for (TypePackId itp : ttv.instantiatedTypePackParams)
visitTypeVar(itp, *this, seen); DEPRECATED_visitTypeVar(itp, *this, seen);
return exhaustive; return exhaustive;
} }
@ -82,9 +84,43 @@ struct FindCyclicTypes
template<typename T> template<typename T>
bool operator()(TypePackId tp, const T&) bool operator()(TypePackId tp, const T&)
{
return visit(tp);
}
bool visit(TypeId ty) override
{
return visited.insert(ty).second;
}
bool visit(TypePackId tp) override
{ {
return visitedPacks.insert(tp).second; return visitedPacks.insert(tp).second;
} }
bool visit(TypeId ty, const TableTypeVar& ttv) override
{
if (!visited.insert(ty).second)
return false;
if (ttv.name || ttv.syntheticName)
{
for (TypeId itp : ttv.instantiatedTypeParams)
traverse(itp);
for (TypePackId itp : ttv.instantiatedTypePackParams)
traverse(itp);
return exhaustive;
}
return true;
}
bool visit(TypeId ty, const ClassTypeVar&) override
{
return false;
}
}; };
template<typename TID> template<typename TID>
@ -92,7 +128,7 @@ void findCyclicTypes(std::set<TypeId>& cycles, std::set<TypePackId>& cycleTPs, T
{ {
FindCyclicTypes fct; FindCyclicTypes fct;
fct.exhaustive = exhaustive; fct.exhaustive = exhaustive;
visitTypeVar(ty, fct); DEPRECATED_visitTypeVar(ty, fct);
cycles = std::move(fct.cycles); cycles = std::move(fct.cycles);
cycleTPs = std::move(fct.cycleTPs); cycleTPs = std::move(fct.cycleTPs);

View file

@ -7,7 +7,6 @@
#include <algorithm> #include <algorithm>
#include <stdexcept> #include <stdexcept>
LUAU_FASTFLAGVARIABLE(LuauTxnLogPreserveOwner, false)
LUAU_FASTFLAGVARIABLE(LuauJustOneCallFrameForHaveSeen, false) LUAU_FASTFLAGVARIABLE(LuauJustOneCallFrameForHaveSeen, false)
namespace Luau namespace Luau
@ -81,8 +80,6 @@ void TxnLog::concat(TxnLog rhs)
void TxnLog::commit() void TxnLog::commit()
{ {
if (FFlag::LuauTxnLogPreserveOwner)
{
for (auto& [ty, rep] : typeVarChanges) for (auto& [ty, rep] : typeVarChanges)
{ {
TypeArena* owningArena = ty->owningArena; TypeArena* owningArena = ty->owningArena;
@ -98,15 +95,6 @@ void TxnLog::commit()
*mpv = rep.get()->pending; *mpv = rep.get()->pending;
mpv->owningArena = owningArena; mpv->owningArena = owningArena;
} }
}
else
{
for (auto& [ty, rep] : typeVarChanges)
*asMutable(ty) = rep.get()->pending;
for (auto& [tp, rep] : typePackChanges)
*asMutable(tp) = rep.get()->pending;
}
clear(); clear();
} }

View file

@ -26,11 +26,11 @@ LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 165)
LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000)
LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000)
LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300)
LUAU_FASTFLAGVARIABLE(LuauUseVisitRecursionLimit, false)
LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500)
LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAG(LuauSeparateTypechecks) LUAU_FASTFLAG(LuauSeparateTypechecks)
LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits)
LUAU_FASTFLAG(LuauAutocompleteSingletonTypes)
LUAU_FASTFLAGVARIABLE(LuauCyclicModuleTypeSurface, false)
LUAU_FASTFLAGVARIABLE(LuauDoNotRelyOnNextBinding, false) LUAU_FASTFLAGVARIABLE(LuauDoNotRelyOnNextBinding, false)
LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false)
LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false.
@ -40,6 +40,7 @@ LUAU_FASTFLAGVARIABLE(LuauInferStatFunction, false)
LUAU_FASTFLAGVARIABLE(LuauInstantiateFollows, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateFollows, false)
LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false)
LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false)
LUAU_FASTFLAGVARIABLE(LuauReduceUnionRecursion, false)
LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false)
LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify4, false) LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify4, false)
LUAU_FASTFLAGVARIABLE(LuauTypecheckOptPass, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckOptPass, false)
@ -57,6 +58,7 @@ LUAU_FASTFLAGVARIABLE(LuauTableUseCounterInstead, false)
LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false)
LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false);
LUAU_FASTFLAG(LuauLosslessClone) LUAU_FASTFLAG(LuauLosslessClone)
LUAU_FASTFLAGVARIABLE(LuauTypecheckIter, false);
namespace Luau namespace Luau
{ {
@ -1159,6 +1161,47 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin)
iterTy = follow(instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location)); iterTy = follow(instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location));
} }
if (FFlag::LuauTypecheckIter)
{
if (std::optional<TypeId> iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location))
{
// if __iter metamethod is present, it will be called and the results are going to be called as if they are functions
// TODO: this needs to typecheck all returned values by __iter as if they were for loop arguments
// the structure of the function makes it difficult to do this especially since we don't have actual expressions, only types
for (TypeId var : varTypes)
unify(anyType, var, forin.location);
return check(loopScope, *forin.body);
}
else if (const TableTypeVar* iterTable = get<TableTypeVar>(iterTy))
{
// TODO: note that this doesn't cleanly handle iteration over mixed tables and tables without an indexer
// this behavior is more or less consistent with what we do for pairs(), but really both are pretty wrong and need revisiting
if (iterTable->indexer)
{
if (varTypes.size() > 0)
unify(iterTable->indexer->indexType, varTypes[0], forin.location);
if (varTypes.size() > 1)
unify(iterTable->indexer->indexResultType, varTypes[1], forin.location);
for (size_t i = 2; i < varTypes.size(); ++i)
unify(nilType, varTypes[i], forin.location);
}
else
{
TypeId varTy = errorRecoveryType(loopScope);
for (TypeId var : varTypes)
unify(varTy, var, forin.location);
reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"});
}
return check(loopScope, *forin.body);
}
}
const FunctionTypeVar* iterFunc = get<FunctionTypeVar>(iterTy); const FunctionTypeVar* iterFunc = get<FunctionTypeVar>(iterTy);
if (!iterFunc) if (!iterFunc)
{ {
@ -2025,6 +2068,19 @@ std::vector<TypeId> TypeChecker::reduceUnion(const std::vector<TypeId>& types)
return {t}; return {t};
if (const UnionTypeVar* utv = get<UnionTypeVar>(t)) if (const UnionTypeVar* utv = get<UnionTypeVar>(t))
{
if (FFlag::LuauReduceUnionRecursion)
{
for (TypeId ty : utv)
{
if (get<ErrorTypeVar>(ty) || get<AnyTypeVar>(ty))
return {ty};
if (result.end() == std::find(result.begin(), result.end(), ty))
result.push_back(ty);
}
}
else
{ {
std::vector<TypeId> r = reduceUnion(utv->options); std::vector<TypeId> r = reduceUnion(utv->options);
for (TypeId ty : r) for (TypeId ty : r)
@ -2037,6 +2093,7 @@ std::vector<TypeId> TypeChecker::reduceUnion(const std::vector<TypeId>& types)
result.push_back(ty); result.push_back(ty);
} }
} }
}
else if (std::find(result.begin(), result.end(), t) == result.end()) else if (std::find(result.begin(), result.end(), t) == result.end())
result.push_back(t); result.push_back(t);
} }
@ -4372,18 +4429,13 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module
} }
// Types of requires that transitively refer to current module have to be replaced with 'any' // Types of requires that transitively refer to current module have to be replaced with 'any'
std::string humanReadableName; std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name);
if (FFlag::LuauCyclicModuleTypeSurface)
{
humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name);
for (const auto& [location, path] : requireCycles) for (const auto& [location, path] : requireCycles)
{ {
if (!path.empty() && path.front() == humanReadableName) if (!path.empty() && path.front() == humanReadableName)
return anyType; return anyType;
} }
}
ModulePtr module = resolver->getModule(moduleInfo.name); ModulePtr module = resolver->getModule(moduleInfo.name);
if (!module) if (!module)
@ -4392,32 +4444,14 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module
// either the file does not exist or there's a cycle. If there's a cycle // either the file does not exist or there's a cycle. If there's a cycle
// we will already have reported the error. // we will already have reported the error.
if (!resolver->moduleExists(moduleInfo.name) && !moduleInfo.optional) if (!resolver->moduleExists(moduleInfo.name) && !moduleInfo.optional)
{
if (FFlag::LuauCyclicModuleTypeSurface)
{
reportError(TypeError{location, UnknownRequire{humanReadableName}}); reportError(TypeError{location, UnknownRequire{humanReadableName}});
}
else
{
std::string reportedModulePath = resolver->getHumanReadableModuleName(moduleInfo.name);
reportError(TypeError{location, UnknownRequire{reportedModulePath}});
}
}
return errorRecoveryType(scope); return errorRecoveryType(scope);
} }
if (module->type != SourceCode::Module) if (module->type != SourceCode::Module)
{
if (FFlag::LuauCyclicModuleTypeSurface)
{ {
reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."});
}
else
{
std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name);
reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."});
}
return errorRecoveryType(scope); return errorRecoveryType(scope);
} }
@ -4428,16 +4462,8 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module
std::optional<TypeId> moduleType = first(modulePack); std::optional<TypeId> moduleType = first(modulePack);
if (!moduleType) if (!moduleType)
{
if (FFlag::LuauCyclicModuleTypeSurface)
{ {
reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."});
}
else
{
std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name);
reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."});
}
return errorRecoveryType(scope); return errorRecoveryType(scope);
} }
@ -4947,10 +4973,7 @@ TypeId TypeChecker::freshType(TypeLevel level)
TypeId TypeChecker::singletonType(bool value) TypeId TypeChecker::singletonType(bool value)
{ {
if (FFlag::LuauAutocompleteSingletonTypes)
return value ? getSingletonTypes().trueType : getSingletonTypes().falseType; return value ? getSingletonTypes().trueType : getSingletonTypes().falseType;
return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BooleanSingleton{value})));
} }
TypeId TypeChecker::singletonType(std::string value) TypeId TypeChecker::singletonType(std::string value)

View file

@ -22,7 +22,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation);
LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauErrorRecoveryType);
LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false)
LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false) LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false)
LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter, false) LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter2, false)
LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false)
LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional)
LUAU_FASTFLAG(LuauTypecheckOptPass) LUAU_FASTFLAG(LuauTypecheckOptPass)
@ -30,7 +30,7 @@ LUAU_FASTFLAG(LuauTypecheckOptPass)
namespace Luau namespace Luau
{ {
struct PromoteTypeLevels struct PromoteTypeLevels final : TypeVarOnceVisitor
{ {
TxnLog& log; TxnLog& log;
const TypeArena* typeArena = nullptr; const TypeArena* typeArena = nullptr;
@ -53,13 +53,34 @@ struct PromoteTypeLevels
} }
} }
// TODO cycle and operator() need to be clipped when FFlagLuauUseVisitRecursionLimit is clipped
template<typename TID> template<typename TID>
void cycle(TID) void cycle(TID)
{ {
} }
template<typename TID, typename T> template<typename TID, typename T>
bool operator()(TID ty, const T&) bool operator()(TID ty, const T&)
{
return visit(ty);
}
bool operator()(TypeId ty, const FreeTypeVar& ftv)
{
return visit(ty, ftv);
}
bool operator()(TypeId ty, const FunctionTypeVar& ftv)
{
return visit(ty, ftv);
}
bool operator()(TypeId ty, const TableTypeVar& ttv)
{
return visit(ty, ttv);
}
bool operator()(TypePackId tp, const FreeTypePack& ftp)
{
return visit(tp, ftp);
}
bool visit(TypeId ty) override
{ {
// Type levels of types from other modules are already global, so we don't need to promote anything inside // Type levels of types from other modules are already global, so we don't need to promote anything inside
if (ty->owningArena != typeArena) if (ty->owningArena != typeArena)
@ -68,7 +89,16 @@ struct PromoteTypeLevels
return true; return true;
} }
bool operator()(TypeId ty, const FreeTypeVar&) bool visit(TypePackId tp) override
{
// Type levels of types from other modules are already global, so we don't need to promote anything inside
if (tp->owningArena != typeArena)
return false;
return true;
}
bool visit(TypeId ty, const FreeTypeVar&) override
{ {
// Surprise, it's actually a BoundTypeVar that hasn't been committed yet. // Surprise, it's actually a BoundTypeVar that hasn't been committed yet.
// Calling getMutable on this will trigger an assertion. // Calling getMutable on this will trigger an assertion.
@ -79,7 +109,7 @@ struct PromoteTypeLevels
return true; return true;
} }
bool operator()(TypeId ty, const FunctionTypeVar&) bool visit(TypeId ty, const FunctionTypeVar&) override
{ {
// Type levels of types from other modules are already global, so we don't need to promote anything inside // Type levels of types from other modules are already global, so we don't need to promote anything inside
if (ty->owningArena != typeArena) if (ty->owningArena != typeArena)
@ -89,7 +119,7 @@ struct PromoteTypeLevels
return true; return true;
} }
bool operator()(TypeId ty, const TableTypeVar& ttv) bool visit(TypeId ty, const TableTypeVar& ttv) override
{ {
// Type levels of types from other modules are already global, so we don't need to promote anything inside // Type levels of types from other modules are already global, so we don't need to promote anything inside
if (ty->owningArena != typeArena) if (ty->owningArena != typeArena)
@ -102,7 +132,7 @@ struct PromoteTypeLevels
return true; return true;
} }
bool operator()(TypePackId tp, const FreeTypePack&) bool visit(TypePackId tp, const FreeTypePack&) override
{ {
// Surprise, it's actually a BoundTypePack that hasn't been committed yet. // Surprise, it's actually a BoundTypePack that hasn't been committed yet.
// Calling getMutable on this will trigger an assertion. // Calling getMutable on this will trigger an assertion.
@ -122,7 +152,7 @@ static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel
PromoteTypeLevels ptl{log, typeArena, minLevel}; PromoteTypeLevels ptl{log, typeArena, minLevel};
DenseHashSet<void*> seen{nullptr}; DenseHashSet<void*> seen{nullptr};
visitTypeVarOnce(ty, ptl, seen); DEPRECATED_visitTypeVarOnce(ty, ptl, seen);
} }
void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp)
@ -133,10 +163,10 @@ void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLev
PromoteTypeLevels ptl{log, typeArena, minLevel}; PromoteTypeLevels ptl{log, typeArena, minLevel};
DenseHashSet<void*> seen{nullptr}; DenseHashSet<void*> seen{nullptr};
visitTypeVarOnce(tp, ptl, seen); DEPRECATED_visitTypeVarOnce(tp, ptl, seen);
} }
struct SkipCacheForType struct SkipCacheForType final : TypeVarOnceVisitor
{ {
SkipCacheForType(const DenseHashMap<TypeId, bool>& skipCacheForType, const TypeArena* typeArena) SkipCacheForType(const DenseHashMap<TypeId, bool>& skipCacheForType, const TypeArena* typeArena)
: skipCacheForType(skipCacheForType) : skipCacheForType(skipCacheForType)
@ -144,28 +174,68 @@ struct SkipCacheForType
{ {
} }
void cycle(TypeId) {} // TODO cycle() and operator() can be clipped with FFlagLuauUseVisitRecursionLimit
void cycle(TypePackId) {} void cycle(TypeId) override {}
void cycle(TypePackId) override {}
bool operator()(TypeId ty, const FreeTypeVar& ftv) bool operator()(TypeId ty, const FreeTypeVar& ftv)
{ {
result = true; return visit(ty, ftv);
return false;
} }
bool operator()(TypeId ty, const BoundTypeVar& btv) bool operator()(TypeId ty, const BoundTypeVar& btv)
{ {
result = true; return visit(ty, btv);
return false; }
bool operator()(TypeId ty, const GenericTypeVar& gtv)
{
return visit(ty, gtv);
}
bool operator()(TypeId ty, const TableTypeVar& ttv)
{
return visit(ty, ttv);
}
bool operator()(TypePackId tp, const FreeTypePack& ftp)
{
return visit(tp, ftp);
}
bool operator()(TypePackId tp, const BoundTypePack& ftp)
{
return visit(tp, ftp);
}
bool operator()(TypePackId tp, const GenericTypePack& ftp)
{
return visit(tp, ftp);
}
template<typename T>
bool operator()(TypeId ty, const T& t)
{
return visit(ty);
}
template<typename T>
bool operator()(TypePackId tp, const T&)
{
return visit(tp);
} }
bool operator()(TypeId ty, const GenericTypeVar& btv) bool visit(TypeId, const FreeTypeVar&) override
{ {
result = true; result = true;
return false; return false;
} }
bool operator()(TypeId ty, const TableTypeVar&) bool visit(TypeId, const BoundTypeVar&) override
{
result = true;
return false;
}
bool visit(TypeId, const GenericTypeVar&) override
{
result = true;
return false;
}
bool visit(TypeId ty, const TableTypeVar&) override
{ {
// Types from other modules don't contain mutable elements and are ok to cache // Types from other modules don't contain mutable elements and are ok to cache
if (ty->owningArena != typeArena) if (ty->owningArena != typeArena)
@ -188,8 +258,7 @@ struct SkipCacheForType
return true; return true;
} }
template<typename T> bool visit(TypeId ty) override
bool operator()(TypeId ty, const T& t)
{ {
// Types from other modules don't contain mutable elements and are ok to cache // Types from other modules don't contain mutable elements and are ok to cache
if (ty->owningArena != typeArena) if (ty->owningArena != typeArena)
@ -206,8 +275,7 @@ struct SkipCacheForType
return true; return true;
} }
template<typename T> bool visit(TypePackId tp) override
bool operator()(TypePackId tp, const T&)
{ {
// Types from other modules don't contain mutable elements and are ok to cache // Types from other modules don't contain mutable elements and are ok to cache
if (tp->owningArena != typeArena) if (tp->owningArena != typeArena)
@ -216,19 +284,19 @@ struct SkipCacheForType
return true; return true;
} }
bool operator()(TypePackId tp, const FreeTypePack& ftp) bool visit(TypePackId tp, const FreeTypePack&) override
{ {
result = true; result = true;
return false; return false;
} }
bool operator()(TypePackId tp, const BoundTypePack& ftp) bool visit(TypePackId tp, const BoundTypePack&) override
{ {
result = true; result = true;
return false; return false;
} }
bool operator()(TypePackId tp, const GenericTypePack& ftp) bool visit(TypePackId tp, const GenericTypePack&) override
{ {
result = true; result = true;
return false; return false;
@ -578,7 +646,7 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId
failed = true; failed = true;
} }
if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter) if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter2)
{ {
} }
else else
@ -593,7 +661,7 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId
} }
// even if A | B <: T fails, we want to bind some options of T with A | B iff A | B was a subtype of that option. // even if A | B <: T fails, we want to bind some options of T with A | B iff A | B was a subtype of that option.
if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter) if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter2)
{ {
auto tryBind = [this, subTy](TypeId superOption) { auto tryBind = [this, subTy](TypeId superOption) {
superOption = log.follow(superOption); superOption = log.follow(superOption);
@ -603,6 +671,14 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId
if (!log.is<FreeTypeVar>(superOption) && (!ttv || ttv->state != TableState::Free)) if (!log.is<FreeTypeVar>(superOption) && (!ttv || ttv->state != TableState::Free))
return; return;
// If superOption is already present in subTy, do nothing. Nothing new has been learned, but the subtype
// test is successful.
if (auto subUnion = get<UnionTypeVar>(subTy))
{
if (end(subUnion) != std::find(begin(subUnion), end(subUnion), superOption))
return;
}
// Since we have already checked if S <: T, checking it again will not queue up the type for replacement. // Since we have already checked if S <: T, checking it again will not queue up the type for replacement.
// So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set. // So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set.
if (log.haveSeen(subTy, superOption)) if (log.haveSeen(subTy, superOption))
@ -822,7 +898,7 @@ bool Unifier::canCacheResult(TypeId subTy, TypeId superTy)
auto skipCacheFor = [this](TypeId ty) { auto skipCacheFor = [this](TypeId ty) {
SkipCacheForType visitor{sharedState.skipCacheForType, types}; SkipCacheForType visitor{sharedState.skipCacheForType, types};
visitTypeVarOnce(ty, visitor, sharedState.seenAny); DEPRECATED_visitTypeVarOnce(ty, visitor, sharedState.seenAny);
sharedState.skipCacheForType[ty] = visitor.result; sharedState.skipCacheForType[ty] = visitor.result;

View file

@ -313,7 +313,7 @@ template<typename T>
struct AstArray struct AstArray
{ {
T* data; T* data;
std::size_t size; size_t size;
const T* begin() const const T* begin() const
{ {

View file

@ -10,7 +10,6 @@
// See docs/SyntaxChanges.md for an explanation. // See docs/SyntaxChanges.md for an explanation.
LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000)
LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
LUAU_FASTFLAGVARIABLE(LuauParseRecoverUnexpectedPack, false)
LUAU_FASTFLAGVARIABLE(LuauParseLocationIgnoreCommentSkipInCapture, false) LUAU_FASTFLAGVARIABLE(LuauParseLocationIgnoreCommentSkipInCapture, false)
namespace Luau namespace Luau
@ -1430,7 +1429,7 @@ AstType* Parser::parseTypeAnnotation(TempVector<AstType*>& parts, const Location
parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type);
isIntersection = true; isIntersection = true;
} }
else if (FFlag::LuauParseRecoverUnexpectedPack && c == Lexeme::Dot3) else if (c == Lexeme::Dot3)
{ {
report(lexer.current().location, "Unexpected '...' after type annotation"); report(lexer.current().location, "Unexpected '...' after type annotation");
nextLexeme(); nextLexeme();
@ -1551,7 +1550,7 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack)
prefix = name.name; prefix = name.name;
name = parseIndexName("field name", pointPosition); name = parseIndexName("field name", pointPosition);
} }
else if (FFlag::LuauParseRecoverUnexpectedPack && lexer.current().type == Lexeme::Dot3) else if (lexer.current().type == Lexeme::Dot3)
{ {
report(lexer.current().location, "Unexpected '...' after type name; type pack is not allowed in this context"); report(lexer.current().location, "Unexpected '...' after type name; type pack is not allowed in this context");
nextLexeme(); nextLexeme();

View file

@ -353,6 +353,11 @@ enum LuauOpcode
// AUX: constant index // AUX: constant index
LOP_FASTCALL2K, LOP_FASTCALL2K,
// FORGPREP: prepare loop variables for a generic for loop, jump to the loop backedge unconditionally
// A: target register; generic for loops assume a register layout [generator, state, index, variables...]
// D: jump offset (-32768..32767)
LOP_FORGPREP,
// Enum entry for number of opcodes, not a valid opcode by itself! // Enum entry for number of opcodes, not a valid opcode by itself!
LOP__COUNT LOP__COUNT
}; };

View file

@ -96,6 +96,7 @@ inline bool isJumpD(LuauOpcode op)
case LOP_JUMPIFNOTLT: case LOP_JUMPIFNOTLT:
case LOP_FORNPREP: case LOP_FORNPREP:
case LOP_FORNLOOP: case LOP_FORNLOOP:
case LOP_FORGPREP:
case LOP_FORGLOOP: case LOP_FORGLOOP:
case LOP_FORGPREP_INEXT: case LOP_FORGPREP_INEXT:
case LOP_FORGLOOP_INEXT: case LOP_FORGLOOP_INEXT:
@ -1269,6 +1270,11 @@ void BytecodeBuilder::validate() const
VJUMP(LUAU_INSN_D(insn)); VJUMP(LUAU_INSN_D(insn));
break; break;
case LOP_FORGPREP:
VREG(LUAU_INSN_A(insn) + 2 + 1); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables
VJUMP(LUAU_INSN_D(insn));
break;
case LOP_FORGLOOP: case LOP_FORGLOOP:
VREG( VREG(
LUAU_INSN_A(insn) + 2 + insns[i + 1]); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables LUAU_INSN_A(insn) + 2 + insns[i + 1]); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables
@ -1622,6 +1628,10 @@ const uint32_t* BytecodeBuilder::dumpInstruction(const uint32_t* code, std::stri
formatAppend(result, "FORNLOOP R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); formatAppend(result, "FORNLOOP R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn));
break; break;
case LOP_FORGPREP:
formatAppend(result, "FORGPREP R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn));
break;
case LOP_FORGLOOP: case LOP_FORGLOOP:
formatAppend(result, "FORGLOOP R%d %+d %d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn), *code++); formatAppend(result, "FORGLOOP R%d %+d %d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn), *code++);
break; break;

View file

@ -17,9 +17,19 @@
#include <math.h> #include <math.h>
#include <limits.h> #include <limits.h>
LUAU_FASTFLAGVARIABLE(LuauCompileSupportInlining, false)
LUAU_FASTFLAGVARIABLE(LuauCompileIter, false)
LUAU_FASTFLAGVARIABLE(LuauCompileIterNoReserve, false)
LUAU_FASTFLAGVARIABLE(LuauCompileIterNoPairs, false)
LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThreshold, 25)
LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThresholdMaxBoost, 300)
LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25)
LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300)
LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5)
namespace Luau namespace Luau
{ {
@ -147,6 +157,52 @@ struct Compiler
} }
} }
AstExprFunction* getFunctionExpr(AstExpr* node)
{
if (AstExprLocal* le = node->as<AstExprLocal>())
{
Variable* lv = variables.find(le->local);
if (!lv || lv->written || !lv->init)
return nullptr;
return getFunctionExpr(lv->init);
}
else if (AstExprGroup* ge = node->as<AstExprGroup>())
return getFunctionExpr(ge->expr);
else
return node->as<AstExprFunction>();
}
bool canInlineFunctionBody(AstStat* stat)
{
struct CanInlineVisitor : AstVisitor
{
bool result = true;
bool visit(AstExpr* node) override
{
// nested functions may capture function arguments, and our upval handling doesn't handle elided variables (constant)
// TODO: we could remove this case if we changed function compilation to create temporary locals for constant upvalues
// TODO: additionally we would need to change upvalue handling in compileExprFunction to handle upvalue->local migration
result = result && !node->is<AstExprFunction>();
return result;
}
bool visit(AstStat* node) override
{
// loops may need to be unrolled which can result in cost amplification
result = result && !node->is<AstStatFor>();
return result;
}
};
CanInlineVisitor canInline;
stat->visit(&canInline);
return canInline.result;
}
uint32_t compileFunction(AstExprFunction* func) uint32_t compileFunction(AstExprFunction* func)
{ {
LUAU_TIMETRACE_SCOPE("Compiler::compileFunction", "Compiler"); LUAU_TIMETRACE_SCOPE("Compiler::compileFunction", "Compiler");
@ -214,13 +270,21 @@ struct Compiler
bytecode.endFunction(uint8_t(stackSize), uint8_t(upvals.size())); bytecode.endFunction(uint8_t(stackSize), uint8_t(upvals.size()));
stackSize = 0;
Function& f = functions[func]; Function& f = functions[func];
f.id = fid; f.id = fid;
f.upvals = upvals; f.upvals = upvals;
// record information for inlining
if (FFlag::LuauCompileSupportInlining && options.optimizationLevel >= 2 && !func->vararg && canInlineFunctionBody(func->body) &&
!getfenvUsed && !setfenvUsed)
{
f.canInline = true;
f.stackSize = stackSize;
f.costModel = modelCost(func->body, func->args.data, func->args.size);
}
upvals.clear(); // note: instead of std::move above, we copy & clear to preserve capacity for future pushes upvals.clear(); // note: instead of std::move above, we copy & clear to preserve capacity for future pushes
stackSize = 0;
return fid; return fid;
} }
@ -390,12 +454,183 @@ struct Compiler
} }
} }
bool tryCompileInlinedCall(AstExprCall* expr, AstExprFunction* func, uint8_t target, uint8_t targetCount, bool multRet, int thresholdBase,
int thresholdMaxBoost, int depthLimit)
{
Function* fi = functions.find(func);
LUAU_ASSERT(fi);
// make sure we have enough register space
if (regTop > 128 || fi->stackSize > 32)
{
bytecode.addDebugRemark("inlining failed: high register pressure");
return false;
}
// we should ideally aggregate the costs during recursive inlining, but for now simply limit the depth
if (int(inlineFrames.size()) >= depthLimit)
{
bytecode.addDebugRemark("inlining failed: too many inlined frames");
return false;
}
// compiling recursive inlining is difficult because we share constant/variable state but need to bind variables to different registers
for (InlineFrame& frame : inlineFrames)
if (frame.func == func)
{
bytecode.addDebugRemark("inlining failed: can't inline recursive calls");
return false;
}
// TODO: we can compile multret functions if all returns of the function are multret as well
if (multRet)
{
bytecode.addDebugRemark("inlining failed: can't convert fixed returns to multret");
return false;
}
// TODO: we can compile functions with mismatching arity at call site but it's more annoying
if (func->args.size != expr->args.size)
{
bytecode.addDebugRemark("inlining failed: argument count mismatch (expected %d, got %d)", int(func->args.size), int(expr->args.size));
return false;
}
// we use a dynamic cost threshold that's based on the fixed limit boosted by the cost advantage we gain due to inlining
bool varc[8] = {};
for (size_t i = 0; i < expr->args.size && i < 8; ++i)
varc[i] = isConstant(expr->args.data[i]);
int inlinedCost = computeCost(fi->costModel, varc, std::min(int(expr->args.size), 8));
int baselineCost = computeCost(fi->costModel, nullptr, 0) + 3;
int inlineProfit = (inlinedCost == 0) ? thresholdMaxBoost : std::min(thresholdMaxBoost, 100 * baselineCost / inlinedCost);
int threshold = thresholdBase * inlineProfit / 100;
if (inlinedCost > threshold)
{
bytecode.addDebugRemark("inlining failed: too expensive (cost %d, profit %.2fx)", inlinedCost, double(inlineProfit) / 100);
return false;
}
bytecode.addDebugRemark(
"inlining succeeded (cost %d, profit %.2fx, depth %d)", inlinedCost, double(inlineProfit) / 100, int(inlineFrames.size()));
compileInlinedCall(expr, func, target, targetCount);
return true;
}
void compileInlinedCall(AstExprCall* expr, AstExprFunction* func, uint8_t target, uint8_t targetCount)
{
RegScope rs(this);
size_t oldLocals = localStack.size();
// note that we push the frame early; this is needed to block recursive inline attempts
inlineFrames.push_back({func, target, targetCount});
// evaluate all arguments; note that we don't emit code for constant arguments (relying on constant folding)
for (size_t i = 0; i < func->args.size; ++i)
{
AstLocal* var = func->args.data[i];
AstExpr* arg = expr->args.data[i];
if (Variable* vv = variables.find(var); vv && vv->written)
{
// if the argument is mutated, we need to allocate a fresh register even if it's a constant
uint8_t reg = allocReg(arg, 1);
compileExprTemp(arg, reg);
pushLocal(var, reg);
}
else if (const Constant* cv = constants.find(arg); cv && cv->type != Constant::Type_Unknown)
{
// since the argument is not mutated, we can simply fold the value into the expressions that need it
locstants[var] = *cv;
}
else
{
AstExprLocal* le = arg->as<AstExprLocal>();
Variable* lv = le ? variables.find(le->local) : nullptr;
// if the argument is a local that isn't mutated, we will simply reuse the existing register
if (isExprLocalReg(arg) && (!lv || !lv->written))
{
uint8_t reg = getLocal(le->local);
pushLocal(var, reg);
}
else
{
uint8_t reg = allocReg(arg, 1);
compileExprTemp(arg, reg);
pushLocal(var, reg);
}
}
}
// fold constant values updated above into expressions in the function body
foldConstants(constants, variables, locstants, func->body);
bool usedFallthrough = false;
for (size_t i = 0; i < func->body->body.size; ++i)
{
AstStat* stat = func->body->body.data[i];
if (AstStatReturn* ret = stat->as<AstStatReturn>())
{
// Optimization: use fallthrough when compiling return at the end of the function to avoid an extra JUMP
compileInlineReturn(ret, /* fallthrough= */ true);
// TODO: This doesn't work when return is part of control flow; ideally we would track the state somehow and generalize this
usedFallthrough = true;
break;
}
else
compileStat(stat);
}
// for the fallthrough path we need to ensure we clear out target registers
if (!usedFallthrough && !allPathsEndWithReturn(func->body))
{
for (size_t i = 0; i < targetCount; ++i)
bytecode.emitABC(LOP_LOADNIL, uint8_t(target + i), 0, 0);
}
popLocals(oldLocals);
size_t returnLabel = bytecode.emitLabel();
patchJumps(expr, inlineFrames.back().returnJumps, returnLabel);
inlineFrames.pop_back();
// clean up constant state for future inlining attempts
for (size_t i = 0; i < func->args.size; ++i)
if (Constant* var = locstants.find(func->args.data[i]))
var->type = Constant::Type_Unknown;
foldConstants(constants, variables, locstants, func->body);
}
void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false) void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false)
{ {
LUAU_ASSERT(!targetTop || unsigned(target + targetCount) == regTop); LUAU_ASSERT(!targetTop || unsigned(target + targetCount) == regTop);
setDebugLine(expr); // normally compileExpr sets up line info, but compileExprCall can be called directly setDebugLine(expr); // normally compileExpr sets up line info, but compileExprCall can be called directly
// try inlining the function
if (options.optimizationLevel >= 2 && !expr->self)
{
AstExprFunction* func = getFunctionExpr(expr->func);
Function* fi = func ? functions.find(func) : nullptr;
if (fi && fi->canInline &&
tryCompileInlinedCall(expr, func, target, targetCount, multRet, FInt::LuauCompileInlineThreshold,
FInt::LuauCompileInlineThresholdMaxBoost, FInt::LuauCompileInlineDepth))
return;
if (fi && !fi->canInline)
bytecode.addDebugRemark("inlining failed: complex constructs in function body");
}
RegScope rs(this); RegScope rs(this);
unsigned int regCount = std::max(unsigned(1 + expr->self + expr->args.size), unsigned(targetCount)); unsigned int regCount = std::max(unsigned(1 + expr->self + expr->args.size), unsigned(targetCount));
@ -760,7 +995,7 @@ struct Compiler
{ {
const Constant* c = constants.find(node); const Constant* c = constants.find(node);
if (!c) if (!c || c->type == Constant::Type_Unknown)
return -1; return -1;
int cid = -1; int cid = -1;
@ -1395,27 +1630,29 @@ struct Compiler
{ {
RegScope rs(this); RegScope rs(this);
// note: cv may be invalidated by compileExpr* so we stop using it before calling compile recursively
const Constant* cv = constants.find(expr->index); const Constant* cv = constants.find(expr->index);
if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 && if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 &&
double(int(cv->valueNumber)) == cv->valueNumber) double(int(cv->valueNumber)) == cv->valueNumber)
{ {
uint8_t rt = compileExprAuto(expr->expr, rs);
uint8_t i = uint8_t(int(cv->valueNumber) - 1); uint8_t i = uint8_t(int(cv->valueNumber) - 1);
uint8_t rt = compileExprAuto(expr->expr, rs);
setDebugLine(expr->index); setDebugLine(expr->index);
bytecode.emitABC(LOP_GETTABLEN, target, rt, i); bytecode.emitABC(LOP_GETTABLEN, target, rt, i);
} }
else if (cv && cv->type == Constant::Type_String) else if (cv && cv->type == Constant::Type_String)
{ {
uint8_t rt = compileExprAuto(expr->expr, rs);
BytecodeBuilder::StringRef iname = sref(cv->getString()); BytecodeBuilder::StringRef iname = sref(cv->getString());
int32_t cid = bytecode.addConstantString(iname); int32_t cid = bytecode.addConstantString(iname);
if (cid < 0) if (cid < 0)
CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile");
uint8_t rt = compileExprAuto(expr->expr, rs);
setDebugLine(expr->index); setDebugLine(expr->index);
bytecode.emitABC(LOP_GETTABLEKS, target, rt, uint8_t(BytecodeBuilder::getStringHash(iname))); bytecode.emitABC(LOP_GETTABLEKS, target, rt, uint8_t(BytecodeBuilder::getStringHash(iname)));
@ -1561,8 +1798,9 @@ struct Compiler
} }
else if (AstExprLocal* expr = node->as<AstExprLocal>()) else if (AstExprLocal* expr = node->as<AstExprLocal>())
{ {
if (expr->upvalue) if (FFlag::LuauCompileSupportInlining ? !isExprLocalReg(expr) : expr->upvalue)
{ {
LUAU_ASSERT(expr->upvalue);
uint8_t uid = getUpval(expr->local); uint8_t uid = getUpval(expr->local);
bytecode.emitABC(LOP_GETUPVAL, target, uid, 0); bytecode.emitABC(LOP_GETUPVAL, target, uid, 0);
@ -1650,12 +1888,12 @@ struct Compiler
// initializes target..target+targetCount-1 range using expressions from the list // initializes target..target+targetCount-1 range using expressions from the list
// if list has fewer expressions, and last expression is a call, we assume the call returns the rest of the values // if list has fewer expressions, and last expression is a call, we assume the call returns the rest of the values
// if list has fewer expressions, and last expression isn't a call, we fill the rest with nil // if list has fewer expressions, and last expression isn't a call, we fill the rest with nil
// assumes target register range can be clobbered and is at the top of the register space // assumes target register range can be clobbered and is at the top of the register space if targetTop = true
void compileExprListTop(const AstArray<AstExpr*>& list, uint8_t target, uint8_t targetCount) void compileExprListTemp(const AstArray<AstExpr*>& list, uint8_t target, uint8_t targetCount, bool targetTop)
{ {
// we assume that target range is at the top of the register space and can be clobbered // we assume that target range is at the top of the register space and can be clobbered
// this is what allows us to compile the last call expression - if it's a call - using targetTop=true // this is what allows us to compile the last call expression - if it's a call - using targetTop=true
LUAU_ASSERT(unsigned(target + targetCount) == regTop); LUAU_ASSERT(!targetTop || unsigned(target + targetCount) == regTop);
if (list.size == targetCount) if (list.size == targetCount)
{ {
@ -1683,7 +1921,7 @@ struct Compiler
if (AstExprCall* expr = last->as<AstExprCall>()) if (AstExprCall* expr = last->as<AstExprCall>())
{ {
compileExprCall(expr, uint8_t(target + list.size - 1), uint8_t(targetCount - (list.size - 1)), /* targetTop= */ true); compileExprCall(expr, uint8_t(target + list.size - 1), uint8_t(targetCount - (list.size - 1)), targetTop);
} }
else if (AstExprVarargs* expr = last->as<AstExprVarargs>()) else if (AstExprVarargs* expr = last->as<AstExprVarargs>())
{ {
@ -1765,8 +2003,10 @@ struct Compiler
if (AstExprLocal* expr = node->as<AstExprLocal>()) if (AstExprLocal* expr = node->as<AstExprLocal>())
{ {
if (expr->upvalue) if (FFlag::LuauCompileSupportInlining ? !isExprLocalReg(expr) : expr->upvalue)
{ {
LUAU_ASSERT(expr->upvalue);
LValue result = {LValue::Kind_Upvalue}; LValue result = {LValue::Kind_Upvalue};
result.upval = getUpval(expr->local); result.upval = getUpval(expr->local);
result.location = node->location; result.location = node->location;
@ -1873,7 +2113,7 @@ struct Compiler
bool isExprLocalReg(AstExpr* expr) bool isExprLocalReg(AstExpr* expr)
{ {
AstExprLocal* le = expr->as<AstExprLocal>(); AstExprLocal* le = expr->as<AstExprLocal>();
if (!le || le->upvalue) if (!le || (!FFlag::LuauCompileSupportInlining && le->upvalue))
return false; return false;
Local* l = locals.find(le->local); Local* l = locals.find(le->local);
@ -2080,6 +2320,23 @@ struct Compiler
loops.pop_back(); loops.pop_back();
} }
void compileInlineReturn(AstStatReturn* stat, bool fallthrough)
{
setDebugLine(stat); // normally compileStat sets up line info, but compileInlineReturn can be called directly
InlineFrame frame = inlineFrames.back();
compileExprListTemp(stat->list, frame.target, frame.targetCount, /* targetTop= */ false);
if (!fallthrough)
{
size_t jumpLabel = bytecode.emitLabel();
bytecode.emitAD(LOP_JUMP, 0, 0);
inlineFrames.back().returnJumps.push_back(jumpLabel);
}
}
void compileStatReturn(AstStatReturn* stat) void compileStatReturn(AstStatReturn* stat)
{ {
RegScope rs(this); RegScope rs(this);
@ -2138,7 +2395,7 @@ struct Compiler
// note: allocReg in this case allocates into parent block register - note that we don't have RegScope here // note: allocReg in this case allocates into parent block register - note that we don't have RegScope here
uint8_t vars = allocReg(stat, unsigned(stat->vars.size)); uint8_t vars = allocReg(stat, unsigned(stat->vars.size));
compileExprListTop(stat->values, vars, uint8_t(stat->vars.size)); compileExprListTemp(stat->values, vars, uint8_t(stat->vars.size), /* targetTop= */ true);
for (size_t i = 0; i < stat->vars.size; ++i) for (size_t i = 0; i < stat->vars.size; ++i)
pushLocal(stat->vars.data[i], uint8_t(vars + i)); pushLocal(stat->vars.data[i], uint8_t(vars + i));
@ -2168,6 +2425,7 @@ struct Compiler
bool visit(AstExpr* node) override bool visit(AstExpr* node) override
{ {
// functions may capture loop variable, and our upval handling doesn't handle elided variables (constant) // functions may capture loop variable, and our upval handling doesn't handle elided variables (constant)
// TODO: we could remove this case if we changed function compilation to create temporary locals for constant upvalues
result = result && !node->is<AstExprFunction>(); result = result && !node->is<AstExprFunction>();
return result; return result;
} }
@ -2251,6 +2509,11 @@ struct Compiler
compileStat(stat->body); compileStat(stat->body);
} }
// clean up fold state in case we need to recompile - normally we compile the loop body once, but due to inlining we may need to do it again
locstants[var].type = Constant::Type_Unknown;
foldConstants(constants, variables, locstants, stat);
return true; return true;
} }
@ -2336,12 +2599,17 @@ struct Compiler
uint8_t regs = allocReg(stat, 3); uint8_t regs = allocReg(stat, 3);
// this puts initial values of (generator, state, index) into the loop registers // this puts initial values of (generator, state, index) into the loop registers
compileExprListTop(stat->values, regs, 3); compileExprListTemp(stat->values, regs, 3, /* targetTop= */ true);
// we don't need this because the extra stack space is just for calling the function with a loop protocol which is similar to calling
// metamethods - it should fit into the extra stack reservation
if (!FFlag::LuauCompileIterNoReserve)
{
// for the general case, we will execute a CALL for every iteration that needs to evaluate "variables... = generator(state, index)" // for the general case, we will execute a CALL for every iteration that needs to evaluate "variables... = generator(state, index)"
// this requires at least extra 3 stack slots after index // this requires at least extra 3 stack slots after index
// note that these stack slots overlap with the variables so we only need to reserve them to make sure stack frame is large enough // note that these stack slots overlap with the variables so we only need to reserve them to make sure stack frame is large enough
reserveReg(stat, 3); reserveReg(stat, 3);
}
// note that we reserve at least 2 variables; this allows our fast path to assume that we need 2 variables instead of 1 or 2 // note that we reserve at least 2 variables; this allows our fast path to assume that we need 2 variables instead of 1 or 2
uint8_t vars = allocReg(stat, std::max(unsigned(stat->vars.size), 2u)); uint8_t vars = allocReg(stat, std::max(unsigned(stat->vars.size), 2u));
@ -2350,7 +2618,7 @@ struct Compiler
// Optimization: when we iterate through pairs/ipairs, we generate special bytecode that optimizes the traversal using internal iteration // Optimization: when we iterate through pairs/ipairs, we generate special bytecode that optimizes the traversal using internal iteration
// index These instructions dynamically check if generator is equal to next/inext and bail out They assume that the generator produces 2 // index These instructions dynamically check if generator is equal to next/inext and bail out They assume that the generator produces 2
// variables, which is why we allocate at least 2 above (see vars assignment) // variables, which is why we allocate at least 2 above (see vars assignment)
LuauOpcode skipOp = LOP_JUMP; LuauOpcode skipOp = FFlag::LuauCompileIter ? LOP_FORGPREP : LOP_JUMP;
LuauOpcode loopOp = LOP_FORGLOOP; LuauOpcode loopOp = LOP_FORGLOOP;
if (options.optimizationLevel >= 1 && stat->vars.size <= 2) if (options.optimizationLevel >= 1 && stat->vars.size <= 2)
@ -2367,7 +2635,7 @@ struct Compiler
else if (builtin.isGlobal("pairs")) // for .. in pairs(t) else if (builtin.isGlobal("pairs")) // for .. in pairs(t)
{ {
skipOp = LOP_FORGPREP_NEXT; skipOp = LOP_FORGPREP_NEXT;
loopOp = LOP_FORGLOOP_NEXT; loopOp = FFlag::LuauCompileIterNoPairs ? LOP_FORGLOOP : LOP_FORGLOOP_NEXT;
} }
} }
else if (stat->values.size == 2) else if (stat->values.size == 2)
@ -2377,7 +2645,7 @@ struct Compiler
if (builtin.isGlobal("next")) // for .. in next,t if (builtin.isGlobal("next")) // for .. in next,t
{ {
skipOp = LOP_FORGPREP_NEXT; skipOp = LOP_FORGPREP_NEXT;
loopOp = LOP_FORGLOOP_NEXT; loopOp = FFlag::LuauCompileIterNoPairs ? LOP_FORGLOOP : LOP_FORGLOOP_NEXT;
} }
} }
} }
@ -2514,10 +2782,10 @@ struct Compiler
// compute values into temporaries // compute values into temporaries
uint8_t regs = allocReg(stat, unsigned(stat->vars.size)); uint8_t regs = allocReg(stat, unsigned(stat->vars.size));
compileExprListTop(stat->values, regs, uint8_t(stat->vars.size)); compileExprListTemp(stat->values, regs, uint8_t(stat->vars.size), /* targetTop= */ true);
// assign variables that have associated values; note that if we have fewer values than variables, we'll assign nil because compileExprListTop // assign variables that have associated values; note that if we have fewer values than variables, we'll assign nil because
// will generate nils // compileExprListTemp will generate nils
for (size_t i = 0; i < stat->vars.size; ++i) for (size_t i = 0; i < stat->vars.size; ++i)
{ {
setDebugLine(stat->vars.data[i]); setDebugLine(stat->vars.data[i]);
@ -2675,6 +2943,9 @@ struct Compiler
} }
else if (AstStatReturn* stat = node->as<AstStatReturn>()) else if (AstStatReturn* stat = node->as<AstStatReturn>())
{ {
if (options.optimizationLevel >= 2 && !inlineFrames.empty())
compileInlineReturn(stat, /* fallthrough= */ false);
else
compileStatReturn(stat); compileStatReturn(stat);
} }
else if (AstStatExpr* stat = node->as<AstStatExpr>()) else if (AstStatExpr* stat = node->as<AstStatExpr>())
@ -3069,6 +3340,10 @@ struct Compiler
{ {
uint32_t id; uint32_t id;
std::vector<AstLocal*> upvals; std::vector<AstLocal*> upvals;
uint64_t costModel = 0;
unsigned int stackSize = 0;
bool canInline = false;
}; };
struct Local struct Local
@ -3098,6 +3373,16 @@ struct Compiler
AstExpr* untilCondition; AstExpr* untilCondition;
}; };
struct InlineFrame
{
AstExprFunction* func;
uint8_t target;
uint8_t targetCount;
std::vector<size_t> returnJumps;
};
BytecodeBuilder& bytecode; BytecodeBuilder& bytecode;
CompileOptions options; CompileOptions options;
@ -3120,6 +3405,7 @@ struct Compiler
std::vector<AstLocal*> upvals; std::vector<AstLocal*> upvals;
std::vector<LoopJump> loopJumps; std::vector<LoopJump> loopJumps;
std::vector<Loop> loops; std::vector<Loop> loops;
std::vector<InlineFrame> inlineFrames;
}; };
void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options) void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options)

View file

@ -3,6 +3,8 @@
#include <math.h> #include <math.h>
LUAU_FASTFLAG(LuauCompileSupportInlining)
namespace Luau namespace Luau
{ {
namespace Compile namespace Compile
@ -314,12 +316,35 @@ struct ConstantVisitor : AstVisitor
LUAU_ASSERT(!"Unknown expression type"); LUAU_ASSERT(!"Unknown expression type");
} }
if (result.type != Constant::Type_Unknown) recordConstant(constants, node, result);
constants[node] = result;
return result; return result;
} }
template<typename T>
void recordConstant(DenseHashMap<T, Constant>& map, T key, const Constant& value)
{
if (value.type != Constant::Type_Unknown)
map[key] = value;
else if (!FFlag::LuauCompileSupportInlining)
;
else if (Constant* old = map.find(key))
old->type = Constant::Type_Unknown;
}
void recordValue(AstLocal* local, const Constant& value)
{
// note: we rely on trackValues to have been run before us
Variable* v = variables.find(local);
LUAU_ASSERT(v);
if (!v->written)
{
v->constant = (value.type != Constant::Type_Unknown);
recordConstant(locals, local, value);
}
}
bool visit(AstExpr* node) override bool visit(AstExpr* node) override
{ {
// note: we short-circuit the visitor traversal through any expression trees by returning false // note: we short-circuit the visitor traversal through any expression trees by returning false
@ -336,18 +361,7 @@ struct ConstantVisitor : AstVisitor
{ {
Constant arg = analyze(node->values.data[i]); Constant arg = analyze(node->values.data[i]);
if (arg.type != Constant::Type_Unknown) recordValue(node->vars.data[i], arg);
{
// note: we rely on trackValues to have been run before us
Variable* v = variables.find(node->vars.data[i]);
LUAU_ASSERT(v);
if (!v->written)
{
locals[node->vars.data[i]] = arg;
v->constant = true;
}
}
} }
if (node->vars.size > node->values.size) if (node->vars.size > node->values.size)
@ -361,15 +375,8 @@ struct ConstantVisitor : AstVisitor
{ {
for (size_t i = node->values.size; i < node->vars.size; ++i) for (size_t i = node->values.size; i < node->vars.size; ++i)
{ {
// note: we rely on trackValues to have been run before us Constant nil = {Constant::Type_Nil};
Variable* v = variables.find(node->vars.data[i]); recordValue(node->vars.data[i], nil);
LUAU_ASSERT(v);
if (!v->written)
{
locals[node->vars.data[i]].type = Constant::Type_Nil;
v->constant = true;
}
} }
} }
} }

View file

@ -264,6 +264,7 @@ if(TARGET Luau.UnitTest)
tests/TypePack.test.cpp tests/TypePack.test.cpp
tests/TypeVar.test.cpp tests/TypeVar.test.cpp
tests/Variant.test.cpp tests/Variant.test.cpp
tests/VisitTypeVar.test.cpp
tests/main.cpp) tests/main.cpp)
endif() endif()

View file

@ -1270,7 +1270,7 @@ const char* lua_setupvalue(lua_State* L, int funcindex, int n)
L->top--; L->top--;
setobj(L, val, L->top); setobj(L, val, L->top);
luaC_barrier(L, clvalue(fi), L->top); luaC_barrier(L, clvalue(fi), L->top);
luaC_upvalbarrier(L, NULL, val); luaC_upvalbarrier(L, cast_to(UpVal*, NULL), val);
} }
return name; return name;
} }

View file

@ -15,6 +15,8 @@
#include <intrin.h> #include <intrin.h>
#endif #endif
LUAU_FASTFLAGVARIABLE(LuauFixBuiltinsStackLimit, false)
// luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM // luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM
// The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack. // The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack.
// If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path
@ -1003,7 +1005,7 @@ static int luauF_tunpack(lua_State* L, StkId res, TValue* arg0, int nresults, St
else if (nparams == 3 && ttisnumber(args) && ttisnumber(args + 1) && nvalue(args) == 1.0) else if (nparams == 3 && ttisnumber(args) && ttisnumber(args + 1) && nvalue(args) == 1.0)
n = int(nvalue(args + 1)); n = int(nvalue(args + 1));
if (n >= 0 && n <= t->sizearray && cast_int(L->stack_last - res) >= n) if (n >= 0 && n <= t->sizearray && cast_int(L->stack_last - res) >= n && (!FFlag::LuauFixBuiltinsStackLimit || n + nparams <= LUAI_MAXCSTACK))
{ {
TValue* array = t->array; TValue* array = t->array;
for (int i = 0; i < n; ++i) for (int i = 0; i < n; ++i)

View file

@ -120,7 +120,7 @@
#define luaC_upvalbarrier(L, uv, tv) \ #define luaC_upvalbarrier(L, uv, tv) \
{ \ { \
if (iscollectable(tv) && iswhite(gcvalue(tv)) && (!(uv) || ((UpVal*)uv)->v != &((UpVal*)uv)->u.value)) \ if (iscollectable(tv) && iswhite(gcvalue(tv)) && (!(uv) || (uv)->v != &(uv)->u.value)) \
luaC_barrierupval(L, gcvalue(tv)); \ luaC_barrierupval(L, gcvalue(tv)); \
} }

View file

@ -33,8 +33,6 @@
#include <string.h> #include <string.h>
LUAU_FASTFLAGVARIABLE(LuauTableNewBoundary2, false)
// max size of both array and hash part is 2^MAXBITS // max size of both array and hash part is 2^MAXBITS
#define MAXBITS 26 #define MAXBITS 26
#define MAXSIZE (1 << MAXBITS) #define MAXSIZE (1 << MAXBITS)
@ -431,7 +429,6 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize)
static int adjustasize(Table* t, int size, const TValue* ek) static int adjustasize(Table* t, int size, const TValue* ek)
{ {
LUAU_ASSERT(FFlag::LuauTableNewBoundary2);
bool tbound = t->node != dummynode || size < t->sizearray; bool tbound = t->node != dummynode || size < t->sizearray;
int ekindex = ek && ttisnumber(ek) ? arrayindex(nvalue(ek)) : -1; int ekindex = ek && ttisnumber(ek) ? arrayindex(nvalue(ek)) : -1;
/* move the array size up until the boundary is guaranteed to be inside the array part */ /* move the array size up until the boundary is guaranteed to be inside the array part */
@ -443,7 +440,7 @@ static int adjustasize(Table* t, int size, const TValue* ek)
void luaH_resizearray(lua_State* L, Table* t, int nasize) void luaH_resizearray(lua_State* L, Table* t, int nasize)
{ {
int nsize = (t->node == dummynode) ? 0 : sizenode(t); int nsize = (t->node == dummynode) ? 0 : sizenode(t);
int asize = FFlag::LuauTableNewBoundary2 ? adjustasize(t, nasize, NULL) : nasize; int asize = adjustasize(t, nasize, NULL);
resize(L, t, asize, nsize); resize(L, t, asize, nsize);
} }
@ -468,7 +465,6 @@ static void rehash(lua_State* L, Table* t, const TValue* ek)
int na = computesizes(nums, &nasize); int na = computesizes(nums, &nasize);
int nh = totaluse - na; int nh = totaluse - na;
/* enforce the boundary invariant; for performance, only do hash lookups if we must */ /* enforce the boundary invariant; for performance, only do hash lookups if we must */
if (FFlag::LuauTableNewBoundary2)
nasize = adjustasize(t, nasize, ek); nasize = adjustasize(t, nasize, ek);
/* resize the table to new computed sizes */ /* resize the table to new computed sizes */
resize(L, t, nasize, nh); resize(L, t, nasize, nh);
@ -531,7 +527,7 @@ static LuaNode* getfreepos(Table* t)
static TValue* newkey(lua_State* L, Table* t, const TValue* key) static TValue* newkey(lua_State* L, Table* t, const TValue* key)
{ {
/* enforce boundary invariant */ /* enforce boundary invariant */
if (FFlag::LuauTableNewBoundary2 && ttisnumber(key) && nvalue(key) == t->sizearray + 1) if (ttisnumber(key) && nvalue(key) == t->sizearray + 1)
{ {
rehash(L, t, key); /* grow table */ rehash(L, t, key); /* grow table */
@ -713,37 +709,6 @@ TValue* luaH_setstr(lua_State* L, Table* t, TString* key)
} }
} }
static LUAU_NOINLINE int unbound_search(Table* t, unsigned int j)
{
LUAU_ASSERT(!FFlag::LuauTableNewBoundary2);
unsigned int i = j; /* i is zero or a present index */
j++;
/* find `i' and `j' such that i is present and j is not */
while (!ttisnil(luaH_getnum(t, j)))
{
i = j;
j *= 2;
if (j > cast_to(unsigned int, INT_MAX))
{ /* overflow? */
/* table was built with bad purposes: resort to linear search */
i = 1;
while (!ttisnil(luaH_getnum(t, i)))
i++;
return i - 1;
}
}
/* now do a binary search between them */
while (j - i > 1)
{
unsigned int m = (i + j) / 2;
if (ttisnil(luaH_getnum(t, m)))
j = m;
else
i = m;
}
return i;
}
static int updateaboundary(Table* t, int boundary) static int updateaboundary(Table* t, int boundary)
{ {
if (boundary < t->sizearray && ttisnil(&t->array[boundary - 1])) if (boundary < t->sizearray && ttisnil(&t->array[boundary - 1]))
@ -800,17 +765,12 @@ int luaH_getn(Table* t)
maybesetaboundary(t, boundary); maybesetaboundary(t, boundary);
return boundary; return boundary;
} }
else if (FFlag::LuauTableNewBoundary2) else
{ {
/* validate boundary invariant */ /* validate boundary invariant */
LUAU_ASSERT(t->node == dummynode || ttisnil(luaH_getnum(t, j + 1))); LUAU_ASSERT(t->node == dummynode || ttisnil(luaH_getnum(t, j + 1)));
return j; return j;
} }
/* else must find a boundary in hash part */
else if (t->node == dummynode) /* hash part is empty? */
return j; /* that is easy... */
else
return unbound_search(t, j);
} }
Table* luaH_clone(lua_State* L, Table* tt) Table* luaH_clone(lua_State* L, Table* tt)

View file

@ -37,6 +37,8 @@ const char* const luaT_eventname[] = {
"__newindex", "__newindex",
"__mode", "__mode",
"__namecall", "__namecall",
"__call",
"__iter",
"__eq", "__eq",
@ -54,13 +56,13 @@ const char* const luaT_eventname[] = {
"__lt", "__lt",
"__le", "__le",
"__concat", "__concat",
"__call",
"__type", "__type",
}; };
// clang-format on // clang-format on
static_assert(sizeof(luaT_typenames) / sizeof(luaT_typenames[0]) == LUA_T_COUNT, "luaT_typenames size mismatch"); static_assert(sizeof(luaT_typenames) / sizeof(luaT_typenames[0]) == LUA_T_COUNT, "luaT_typenames size mismatch");
static_assert(sizeof(luaT_eventname) / sizeof(luaT_eventname[0]) == TM_N, "luaT_eventname size mismatch"); static_assert(sizeof(luaT_eventname) / sizeof(luaT_eventname[0]) == TM_N, "luaT_eventname size mismatch");
static_assert(TM_EQ < 8, "fasttm optimization stores a bitfield with metamethods in a byte");
void luaT_init(lua_State* L) void luaT_init(lua_State* L)
{ {

View file

@ -16,6 +16,8 @@ typedef enum
TM_NEWINDEX, TM_NEWINDEX,
TM_MODE, TM_MODE,
TM_NAMECALL, TM_NAMECALL,
TM_CALL,
TM_ITER,
TM_EQ, /* last tag method with `fast' access */ TM_EQ, /* last tag method with `fast' access */
@ -33,7 +35,6 @@ typedef enum
TM_LT, TM_LT,
TM_LE, TM_LE,
TM_CONCAT, TM_CONCAT,
TM_CALL,
TM_TYPE, TM_TYPE,
TM_N /* number of elements in the enum */ TM_N /* number of elements in the enum */

View file

@ -16,7 +16,10 @@
#include <string.h> #include <string.h>
LUAU_FASTFLAG(LuauTableNewBoundary2) LUAU_FASTFLAGVARIABLE(LuauIter, false)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauIterCallTelemetry, false)
void (*lua_iter_call_telemetry)(lua_State* L);
// Disable c99-designator to avoid the warning in CGOTO dispatch table // Disable c99-designator to avoid the warning in CGOTO dispatch table
#ifdef __clang__ #ifdef __clang__
@ -110,7 +113,7 @@ LUAU_FASTFLAG(LuauTableNewBoundary2)
VM_DISPATCH_OP(LOP_FORGLOOP_NEXT), VM_DISPATCH_OP(LOP_GETVARARGS), VM_DISPATCH_OP(LOP_DUPCLOSURE), VM_DISPATCH_OP(LOP_PREPVARARGS), \ VM_DISPATCH_OP(LOP_FORGLOOP_NEXT), VM_DISPATCH_OP(LOP_GETVARARGS), VM_DISPATCH_OP(LOP_DUPCLOSURE), VM_DISPATCH_OP(LOP_PREPVARARGS), \
VM_DISPATCH_OP(LOP_LOADKX), VM_DISPATCH_OP(LOP_JUMPX), VM_DISPATCH_OP(LOP_FASTCALL), VM_DISPATCH_OP(LOP_COVERAGE), \ VM_DISPATCH_OP(LOP_LOADKX), VM_DISPATCH_OP(LOP_JUMPX), VM_DISPATCH_OP(LOP_FASTCALL), VM_DISPATCH_OP(LOP_COVERAGE), \
VM_DISPATCH_OP(LOP_CAPTURE), VM_DISPATCH_OP(LOP_JUMPIFEQK), VM_DISPATCH_OP(LOP_JUMPIFNOTEQK), VM_DISPATCH_OP(LOP_FASTCALL1), \ VM_DISPATCH_OP(LOP_CAPTURE), VM_DISPATCH_OP(LOP_JUMPIFEQK), VM_DISPATCH_OP(LOP_JUMPIFNOTEQK), VM_DISPATCH_OP(LOP_FASTCALL1), \
VM_DISPATCH_OP(LOP_FASTCALL2), VM_DISPATCH_OP(LOP_FASTCALL2K), VM_DISPATCH_OP(LOP_FASTCALL2), VM_DISPATCH_OP(LOP_FASTCALL2K), VM_DISPATCH_OP(LOP_FORGPREP),
#if defined(__GNUC__) || defined(__clang__) #if defined(__GNUC__) || defined(__clang__)
#define VM_USE_CGOTO 1 #define VM_USE_CGOTO 1
@ -150,8 +153,20 @@ LUAU_NOINLINE static void luau_prepareFORN(lua_State* L, StkId plimit, StkId pst
LUAU_NOINLINE static bool luau_loopFORG(lua_State* L, int a, int c) LUAU_NOINLINE static bool luau_loopFORG(lua_State* L, int a, int c)
{ {
// note: it's safe to push arguments past top for complicated reasons (see top of the file)
StkId ra = &L->base[a]; StkId ra = &L->base[a];
LUAU_ASSERT(ra + 6 <= L->top); LUAU_ASSERT(ra + 3 <= L->top);
if (DFFlag::LuauIterCallTelemetry)
{
/* TODO: we might be able to stop supporting this depending on whether it's used in practice */
void (*telemetrycb)(lua_State* L) = lua_iter_call_telemetry;
if (telemetrycb && ttistable(ra) && fasttm(L, hvalue(ra)->metatable, TM_CALL))
telemetrycb(L);
if (telemetrycb && ttisuserdata(ra) && fasttm(L, uvalue(ra)->metatable, TM_CALL))
telemetrycb(L);
}
setobjs2s(L, ra + 3 + 2, ra + 2); setobjs2s(L, ra + 3 + 2, ra + 2);
setobjs2s(L, ra + 3 + 1, ra + 1); setobjs2s(L, ra + 3 + 1, ra + 1);
@ -2204,13 +2219,62 @@ static void luau_execute(lua_State* L)
} }
} }
VM_CASE(LOP_FORGPREP)
{
Instruction insn = *pc++;
StkId ra = VM_REG(LUAU_INSN_A(insn));
if (ttisfunction(ra))
{
/* will be called during FORGLOOP */
}
else if (FFlag::LuauIter)
{
Table* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(Table*, NULL);
if (const TValue* fn = fasttm(L, mt, TM_ITER))
{
setobj2s(L, ra + 1, ra);
setobj2s(L, ra, fn);
L->top = ra + 2; /* func + self arg */
LUAU_ASSERT(L->top <= L->stack_last);
VM_PROTECT(luaD_call(L, ra, 3));
L->top = L->ci->top;
}
else if (fasttm(L, mt, TM_CALL))
{
/* table or userdata with __call, will be called during FORGLOOP */
/* TODO: we might be able to stop supporting this depending on whether it's used in practice */
}
else if (ttistable(ra))
{
/* set up registers for builtin iteration */
setobj2s(L, ra + 1, ra);
setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0)));
setnilvalue(ra);
}
else
{
VM_PROTECT(luaG_typeerror(L, ra, "iterate over"));
}
}
pc += LUAU_INSN_D(insn);
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));
VM_NEXT();
}
VM_CASE(LOP_FORGLOOP) VM_CASE(LOP_FORGLOOP)
{ {
VM_INTERRUPT(); VM_INTERRUPT();
Instruction insn = *pc++; Instruction insn = *pc++;
StkId ra = VM_REG(LUAU_INSN_A(insn));
uint32_t aux = *pc; uint32_t aux = *pc;
// note: this is a slow generic path, fast-path is FORGLOOP_INEXT/NEXT if (!FFlag::LuauIter)
{
bool stop; bool stop;
VM_PROTECT(stop = luau_loopFORG(L, LUAU_INSN_A(insn), aux)); VM_PROTECT(stop = luau_loopFORG(L, LUAU_INSN_A(insn), aux));
@ -2220,6 +2284,86 @@ static void luau_execute(lua_State* L)
VM_NEXT(); VM_NEXT();
} }
// fast-path: builtin table iteration
if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2))
{
Table* h = hvalue(ra + 1);
int index = int(reinterpret_cast<uintptr_t>(pvalue(ra + 2)));
int sizearray = h->sizearray;
int sizenode = 1 << h->lsizenode;
// clear extra variables since we might have more than two
if (LUAU_UNLIKELY(aux > 2))
for (int i = 2; i < int(aux); ++i)
setnilvalue(ra + 3 + i);
// first we advance index through the array portion
while (unsigned(index) < unsigned(sizearray))
{
if (!ttisnil(&h->array[index]))
{
setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(index + 1)));
setnvalue(ra + 3, double(index + 1));
setobj2s(L, ra + 4, &h->array[index]);
pc += LUAU_INSN_D(insn);
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));
VM_NEXT();
}
index++;
}
// then we advance index through the hash portion
while (unsigned(index - sizearray) < unsigned(sizenode))
{
LuaNode* n = &h->node[index - sizearray];
if (!ttisnil(gval(n)))
{
setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(index + 1)));
getnodekey(L, ra + 3, n);
setobj2s(L, ra + 4, gval(n));
pc += LUAU_INSN_D(insn);
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));
VM_NEXT();
}
index++;
}
// fallthrough to exit
pc++;
VM_NEXT();
}
else
{
// note: it's safe to push arguments past top for complicated reasons (see top of the file)
setobjs2s(L, ra + 3 + 2, ra + 2);
setobjs2s(L, ra + 3 + 1, ra + 1);
setobjs2s(L, ra + 3, ra);
L->top = ra + 3 + 3; /* func + 2 args (state and index) */
LUAU_ASSERT(L->top <= L->stack_last);
VM_PROTECT(luaD_call(L, ra + 3, aux));
L->top = L->ci->top;
// recompute ra since stack might have been reallocated
ra = VM_REG(LUAU_INSN_A(insn));
// copy first variable back into the iteration index
setobjs2s(L, ra + 2, ra + 3);
// note that we need to increment pc by 1 to exit the loop since we need to skip over aux
pc += ttisnil(ra + 3) ? 1 : LUAU_INSN_D(insn);
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));
VM_NEXT();
}
}
VM_CASE(LOP_FORGPREP_INEXT) VM_CASE(LOP_FORGPREP_INEXT)
{ {
Instruction insn = *pc++; Instruction insn = *pc++;
@ -2228,8 +2372,15 @@ static void luau_execute(lua_State* L)
// fast-path: ipairs/inext // fast-path: ipairs/inext
if (cl->env->safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0) if (cl->env->safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0)
{ {
if (FFlag::LuauIter)
setnilvalue(ra);
setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0))); setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0)));
} }
else if (FFlag::LuauIter && !ttisfunction(ra))
{
VM_PROTECT(luaG_typeerror(L, ra, "iterate over"));
}
pc += LUAU_INSN_D(insn); pc += LUAU_INSN_D(insn);
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));
@ -2268,23 +2419,9 @@ static void luau_execute(lua_State* L)
VM_NEXT(); VM_NEXT();
} }
} }
else if (FFlag::LuauTableNewBoundary2 || (h->lsizenode == 0 && ttisnil(gval(h->node))))
{
// fallthrough to exit
VM_NEXT();
}
else else
{ {
// the table has a hash part; index + 1 may appear in it in which case we need to iterate through the hash portion as well // fallthrough to exit
const TValue* val = luaH_getnum(h, index + 1);
setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(index + 1)));
setnvalue(ra + 3, double(index + 1));
setobj2s(L, ra + 4, val);
// note that nil elements inside the array terminate the traversal
pc += ttisnil(ra + 4) ? 0 : LUAU_INSN_D(insn);
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));
VM_NEXT(); VM_NEXT();
} }
} }
@ -2308,8 +2445,15 @@ static void luau_execute(lua_State* L)
// fast-path: pairs/next // fast-path: pairs/next
if (cl->env->safeenv && ttistable(ra + 1) && ttisnil(ra + 2)) if (cl->env->safeenv && ttistable(ra + 1) && ttisnil(ra + 2))
{ {
if (FFlag::LuauIter)
setnilvalue(ra);
setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0))); setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0)));
} }
else if (FFlag::LuauIter && !ttisfunction(ra))
{
VM_PROTECT(luaG_typeerror(L, ra, "iterate over"));
}
pc += LUAU_INSN_D(insn); pc += LUAU_INSN_D(insn);
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));
@ -2704,7 +2848,7 @@ static void luau_execute(lua_State* L)
{ {
VM_PROTECT_PC(); VM_PROTECT_PC();
int n = f(L, ra, arg, nresults, nullptr, nparams); int n = f(L, ra, arg, nresults, NULL, nparams);
if (n >= 0) if (n >= 0)
{ {

View file

@ -0,0 +1,17 @@
local bench = script and require(script.Parent.bench_support) or require("bench_support")
function test()
local t = {}
for i=1,1000000 do t[i] = i end
local ts0 = os.clock()
local sum = 0
for k,v in t do sum = sum + v end
local ts1 = os.clock()
return ts1-ts0
end
bench.runCode(test, "LargeTableSum: for k,v in {}")

View file

@ -25,7 +25,7 @@ local DisplArea = {}
DisplArea.Width = 300; DisplArea.Width = 300;
DisplArea.Height = 300; DisplArea.Height = 300;
function DrawLine(From, To) local function DrawLine(From, To)
local x1 = From.V[1]; local x1 = From.V[1];
local x2 = To.V[1]; local x2 = To.V[1];
local y1 = From.V[2]; local y1 = From.V[2];
@ -81,7 +81,7 @@ function DrawLine(From, To)
Q.LastPx = NumPix; Q.LastPx = NumPix;
end end
function CalcCross(V0, V1) local function CalcCross(V0, V1)
local Cross = {}; local Cross = {};
Cross[1] = V0[2]*V1[3] - V0[3]*V1[2]; Cross[1] = V0[2]*V1[3] - V0[3]*V1[2];
Cross[2] = V0[3]*V1[1] - V0[1]*V1[3]; Cross[2] = V0[3]*V1[1] - V0[1]*V1[3];
@ -89,7 +89,7 @@ function CalcCross(V0, V1)
return Cross; return Cross;
end end
function CalcNormal(V0, V1, V2) local function CalcNormal(V0, V1, V2)
local A = {}; local B = {}; local A = {}; local B = {};
for i = 1,3 do for i = 1,3 do
A[i] = V0[i] - V1[i]; A[i] = V0[i] - V1[i];
@ -102,14 +102,14 @@ function CalcNormal(V0, V1, V2)
return A; return A;
end end
function CreateP(X,Y,Z) local function CreateP(X,Y,Z)
local result = {} local result = {}
result.V = {X,Y,Z,1}; result.V = {X,Y,Z,1};
return result return result
end end
-- multiplies two matrices -- multiplies two matrices
function MMulti(M1, M2) local function MMulti(M1, M2)
local M = {{},{},{},{}}; local M = {{},{},{},{}};
for i = 1,4 do for i = 1,4 do
for j = 1,4 do for j = 1,4 do
@ -120,7 +120,7 @@ function MMulti(M1, M2)
end end
-- multiplies matrix with vector -- multiplies matrix with vector
function VMulti(M, V) local function VMulti(M, V)
local Vect = {}; local Vect = {};
for i = 1,4 do for i = 1,4 do
Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3] + M[i][4] * V[4]; Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3] + M[i][4] * V[4];
@ -128,7 +128,7 @@ function VMulti(M, V)
return Vect; return Vect;
end end
function VMulti2(M, V) local function VMulti2(M, V)
local Vect = {}; local Vect = {};
for i = 1,3 do for i = 1,3 do
Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3]; Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3];
@ -137,7 +137,7 @@ function VMulti2(M, V)
end end
-- add to matrices -- add to matrices
function MAdd(M1, M2) local function MAdd(M1, M2)
local M = {{},{},{},{}}; local M = {{},{},{},{}};
for i = 1,4 do for i = 1,4 do
for j = 1,4 do for j = 1,4 do
@ -147,7 +147,7 @@ function MAdd(M1, M2)
return M; return M;
end end
function Translate(M, Dx, Dy, Dz) local function Translate(M, Dx, Dy, Dz)
local T = { local T = {
{1,0,0,Dx}, {1,0,0,Dx},
{0,1,0,Dy}, {0,1,0,Dy},
@ -157,7 +157,7 @@ function Translate(M, Dx, Dy, Dz)
return MMulti(T, M); return MMulti(T, M);
end end
function RotateX(M, Phi) local function RotateX(M, Phi)
local a = Phi; local a = Phi;
a = a * math.pi / 180; a = a * math.pi / 180;
local Cos = math.cos(a); local Cos = math.cos(a);
@ -171,7 +171,7 @@ function RotateX(M, Phi)
return MMulti(R, M); return MMulti(R, M);
end end
function RotateY(M, Phi) local function RotateY(M, Phi)
local a = Phi; local a = Phi;
a = a * math.pi / 180; a = a * math.pi / 180;
local Cos = math.cos(a); local Cos = math.cos(a);
@ -185,7 +185,7 @@ function RotateY(M, Phi)
return MMulti(R, M); return MMulti(R, M);
end end
function RotateZ(M, Phi) local function RotateZ(M, Phi)
local a = Phi; local a = Phi;
a = a * math.pi / 180; a = a * math.pi / 180;
local Cos = math.cos(a); local Cos = math.cos(a);
@ -199,7 +199,7 @@ function RotateZ(M, Phi)
return MMulti(R, M); return MMulti(R, M);
end end
function DrawQube() local function DrawQube()
-- calc current normals -- calc current normals
local CurN = {}; local CurN = {};
local i = 5; local i = 5;
@ -245,7 +245,7 @@ function DrawQube()
Q.LastPx = 0; Q.LastPx = 0;
end end
function Loop() local function Loop()
if (Testing.LoopCount > Testing.LoopMax) then return; end if (Testing.LoopCount > Testing.LoopMax) then return; end
local TestingStr = tostring(Testing.LoopCount); local TestingStr = tostring(Testing.LoopCount);
while (#TestingStr < 3) do TestingStr = "0" .. TestingStr; end while (#TestingStr < 3) do TestingStr = "0" .. TestingStr; end
@ -265,7 +265,7 @@ function Loop()
Loop(); Loop();
end end
function Init(CubeSize) local function Init(CubeSize)
-- init/reset vars -- init/reset vars
Origin.V = {150,150,20,1}; Origin.V = {150,150,20,1};
Testing.LoopCount = 0; Testing.LoopCount = 0;

View file

@ -31,7 +31,7 @@ local loops = 15
local nx = 120 local nx = 120
local nz = 120 local nz = 120
function morph(a, f) local function morph(a, f)
local PI2nx = math.pi * 8/nx local PI2nx = math.pi * 8/nx
local sin = math.sin local sin = math.sin
local f30 = -(50 * sin(f*math.pi*2)) local f30 = -(50 * sin(f*math.pi*2))

View file

@ -28,40 +28,40 @@ function test()
local size = 30 local size = 30
function createVector(x,y,z) local function createVector(x,y,z)
return { x,y,z }; return { x,y,z };
end end
function sqrLengthVector(self) local function sqrLengthVector(self)
return self[1] * self[1] + self[2] * self[2] + self[3] * self[3]; return self[1] * self[1] + self[2] * self[2] + self[3] * self[3];
end end
function lengthVector(self) local function lengthVector(self)
return math.sqrt(self[1] * self[1] + self[2] * self[2] + self[3] * self[3]); return math.sqrt(self[1] * self[1] + self[2] * self[2] + self[3] * self[3]);
end end
function addVector(self, v) local function addVector(self, v)
self[1] = self[1] + v[1]; self[1] = self[1] + v[1];
self[2] = self[2] + v[2]; self[2] = self[2] + v[2];
self[3] = self[3] + v[3]; self[3] = self[3] + v[3];
return self; return self;
end end
function subVector(self, v) local function subVector(self, v)
self[1] = self[1] - v[1]; self[1] = self[1] - v[1];
self[2] = self[2] - v[2]; self[2] = self[2] - v[2];
self[3] = self[3] - v[3]; self[3] = self[3] - v[3];
return self; return self;
end end
function scaleVector(self, scale) local function scaleVector(self, scale)
self[1] = self[1] * scale; self[1] = self[1] * scale;
self[2] = self[2] * scale; self[2] = self[2] * scale;
self[3] = self[3] * scale; self[3] = self[3] * scale;
return self; return self;
end end
function normaliseVector(self) local function normaliseVector(self)
local len = math.sqrt(self[1] * self[1] + self[2] * self[2] + self[3] * self[3]); local len = math.sqrt(self[1] * self[1] + self[2] * self[2] + self[3] * self[3]);
self[1] = self[1] / len; self[1] = self[1] / len;
self[2] = self[2] / len; self[2] = self[2] / len;
@ -69,39 +69,39 @@ function normaliseVector(self)
return self; return self;
end end
function add(v1, v2) local function add(v1, v2)
return { v1[1] + v2[1], v1[2] + v2[2], v1[3] + v2[3] }; return { v1[1] + v2[1], v1[2] + v2[2], v1[3] + v2[3] };
end end
function sub(v1, v2) local function sub(v1, v2)
return { v1[1] - v2[1], v1[2] - v2[2], v1[3] - v2[3] }; return { v1[1] - v2[1], v1[2] - v2[2], v1[3] - v2[3] };
end end
function scalev(v1, v2) local function scalev(v1, v2)
return { v1[1] * v2[1], v1[2] * v2[2], v1[3] * v2[3] }; return { v1[1] * v2[1], v1[2] * v2[2], v1[3] * v2[3] };
end end
function dot(v1, v2) local function dot(v1, v2)
return v1[1] * v2[1] + v1[2] * v2[2] + v1[3] * v2[3]; return v1[1] * v2[1] + v1[2] * v2[2] + v1[3] * v2[3];
end end
function scale(v, scale) local function scale(v, scale)
return { v[1] * scale, v[2] * scale, v[3] * scale }; return { v[1] * scale, v[2] * scale, v[3] * scale };
end end
function cross(v1, v2) local function cross(v1, v2)
return { v1[2] * v2[3] - v1[3] * v2[2], return { v1[2] * v2[3] - v1[3] * v2[2],
v1[3] * v2[1] - v1[1] * v2[3], v1[3] * v2[1] - v1[1] * v2[3],
v1[1] * v2[2] - v1[2] * v2[1] }; v1[1] * v2[2] - v1[2] * v2[1] };
end end
function normalise(v) local function normalise(v)
local len = lengthVector(v); local len = lengthVector(v);
return { v[1] / len, v[2] / len, v[3] / len }; return { v[1] / len, v[2] / len, v[3] / len };
end end
function transformMatrix(self, v) local function transformMatrix(self, v)
local vals = self; local vals = self;
local x = vals[1] * v[1] + vals[2] * v[2] + vals[3] * v[3] + vals[4]; local x = vals[1] * v[1] + vals[2] * v[2] + vals[3] * v[3] + vals[4];
local y = vals[5] * v[1] + vals[6] * v[2] + vals[7] * v[3] + vals[8]; local y = vals[5] * v[1] + vals[6] * v[2] + vals[7] * v[3] + vals[8];
@ -109,7 +109,7 @@ function transformMatrix(self, v)
return { x, y, z }; return { x, y, z };
end end
function invertMatrix(self) local function invertMatrix(self)
local temp = {} local temp = {}
local tx = -self[4]; local tx = -self[4];
local ty = -self[8]; local ty = -self[8];
@ -131,7 +131,7 @@ function invertMatrix(self)
end end
-- Triangle intersection using barycentric coord method -- Triangle intersection using barycentric coord method
function Triangle(p1, p2, p3) local function Triangle(p1, p2, p3)
local this = {} local this = {}
local edge1 = sub(p3, p1); local edge1 = sub(p3, p1);
@ -205,7 +205,7 @@ function Triangle(p1, p2, p3)
return this return this
end end
function Scene(a_triangles) local function Scene(a_triangles)
local this = {} local this = {}
this.triangles = a_triangles; this.triangles = a_triangles;
this.lights = {}; this.lights = {};
@ -302,7 +302,7 @@ local zero = { 0,0,0 };
-- this camera code is from notes i made ages ago, it is from *somewhere* -- i cannot remember where -- this camera code is from notes i made ages ago, it is from *somewhere* -- i cannot remember where
-- that somewhere is -- that somewhere is
function Camera(origin, lookat, up) local function Camera(origin, lookat, up)
local this = {} local this = {}
local zaxis = normaliseVector(subVector(lookat, origin)); local zaxis = normaliseVector(subVector(lookat, origin));
@ -357,7 +357,7 @@ function Camera(origin, lookat, up)
return this return this
end end
function raytraceScene() local function raytraceScene()
local startDate = 13154863; local startDate = 13154863;
local numTriangles = 2 * 6; local numTriangles = 2 * 6;
local triangles = {}; -- numTriangles); local triangles = {}; -- numTriangles);
@ -450,7 +450,7 @@ function raytraceScene()
return pixels; return pixels;
end end
function arrayToCanvasCommands(pixels) local function arrayToCanvasCommands(pixels)
local s = {}; local s = {};
table.insert(s, '<!DOCTYPE html><html><head><title>Test</title></head><body><canvas id="renderCanvas" width="' .. size .. 'px" height="' .. size .. 'px"></canvas><scr' .. 'ipt>\nvar pixels = ['); table.insert(s, '<!DOCTYPE html><html><head><title>Test</title></head><body><canvas id="renderCanvas" width="' .. size .. 'px" height="' .. size .. 'px"></canvas><scr' .. 'ipt>\nvar pixels = [');
for y = 0,size-1 do for y = 0,size-1 do
@ -485,7 +485,7 @@ for (var y = 0; y < size; y++) {\n\
return table.concat(s); return table.concat(s);
end end
testOutput = arrayToCanvasCommands(raytraceScene()); local testOutput = arrayToCanvasCommands(raytraceScene());
--local f = io.output("output.html") --local f = io.output("output.html")
--f:write(testOutput) --f:write(testOutput)

View file

@ -1,69 +0,0 @@
--[[
The Great Computer Language Shootout
http://shootout.alioth.debian.org/
contributed by Isaac Gouy
]]
local bench = script and require(script.Parent.bench_support) or require("bench_support")
function test()
function TreeNode(left,right,item)
local this = {}
this.left = left;
this.right = right;
this.item = item;
this.itemCheck = function(self)
if (self.left==nil) then return self.item;
else return self.item + self.left:itemCheck() - self.right:itemCheck(); end
end
return this
end
function bottomUpTree(item,depth)
if (depth>0) then
return TreeNode(
bottomUpTree(2*item-1, depth-1)
,bottomUpTree(2*item, depth-1)
,item
);
else
return TreeNode(nil,nil,item);
end
end
local ret = 0;
for n = 4,7,1 do
local minDepth = 4;
local maxDepth = math.max(minDepth + 2, n);
local stretchDepth = maxDepth + 1;
local check = bottomUpTree(0,stretchDepth):itemCheck();
local longLivedTree = bottomUpTree(0,maxDepth);
for depth = minDepth,maxDepth,2 do
local iterations = 2.0 ^ (maxDepth - depth + minDepth - 1) -- 1 << (maxDepth - depth + minDepth);
check = 0;
for i = 1,iterations do
check = check + bottomUpTree(i,depth):itemCheck();
check = check + bottomUpTree(-i,depth):itemCheck();
end
end
ret = ret + longLivedTree:itemCheck();
end
local expected = -4;
if (ret ~= expected) then
assert(false, "ERROR: bad result: expected " .. expected .. " but got " .. ret);
end
end
bench.runCode(test, "access-binary-trees")

View file

@ -7,18 +7,18 @@ local bench = script and require(script.Parent.bench_support) or require("bench_
function test() function test()
function ack(m,n) local function ack(m,n)
if (m==0) then return n+1; end if (m==0) then return n+1; end
if (n==0) then return ack(m-1,1); end if (n==0) then return ack(m-1,1); end
return ack(m-1, ack(m,n-1) ); return ack(m-1, ack(m,n-1) );
end end
function fib(n) local function fib(n)
if (n < 2) then return 1; end if (n < 2) then return 1; end
return fib(n-2) + fib(n-1); return fib(n-2) + fib(n-1);
end end
function tak(x,y,z) local function tak(x,y,z)
if (y >= x) then return z; end if (y >= x) then return z; end
return tak(tak(x-1,y,z), tak(y-1,z,x), tak(z-1,x,y)); return tak(tak(x-1,y,z), tak(y-1,z,x), tak(z-1,x,y));
end end
@ -27,7 +27,7 @@ local result = 0;
for i = 3,5 do for i = 3,5 do
result = result + ack(3,i); result = result + ack(3,i);
result = result + fib(17.0+i); result = result + fib(17+i);
result = result + tak(3*i+3,2*i+2,i+1); result = result + tak(3*i+3,2*i+2,i+1);
end end

View file

@ -42,7 +42,68 @@ local Rcon = { { 0x00, 0x00, 0x00, 0x00 },
{0x1b, 0x00, 0x00, 0x00}, {0x1b, 0x00, 0x00, 0x00},
{0x36, 0x00, 0x00, 0x00} }; {0x36, 0x00, 0x00, 0x00} };
function Cipher(input, w) -- main Cipher function [§5.1] local function SubBytes(s, Nb) -- apply SBox to state S [§5.1.1]
for r = 0,3 do
for c = 0,Nb-1 do s[r + 1][c + 1] = Sbox[s[r + 1][c + 1] + 1]; end
end
return s;
end
local function ShiftRows(s, Nb) -- shift row r of state S left by r bytes [§5.1.2]
local t = {};
for r = 1,3 do
for c = 0,3 do t[c + 1] = s[r + 1][((c + r) % Nb) + 1] end; -- shift into temp copy
for c = 0,3 do s[r + 1][c + 1] = t[c + 1]; end -- and copy back
end -- note that this will work for Nb=4,5,6, but not 7,8 (always 4 for AES):
return s; -- see fp.gladman.plus.com/cryptography_technology/rijndael/aes.spec.311.pdf
end
local function MixColumns(s, Nb) -- combine bytes of each col of state S [§5.1.3]
for c = 0,3 do
local a = {}; -- 'a' is a copy of the current column from 's'
local b = {}; -- 'b' is a•{02} in GF(2^8)
for i = 0,3 do
a[i + 1] = s[i + 1][c + 1];
if bit32.band(s[i + 1][c + 1], 0x80) ~= 0 then
b[i + 1] = bit32.bxor(bit32.lshift(s[i + 1][c + 1], 1), 0x011b);
else
b[i + 1] = bit32.lshift(s[i + 1][c + 1], 1);
end
end
-- a[n] ^ b[n] is a•{03} in GF(2^8)
s[1][c + 1] = bit32.bxor(b[1], a[2], b[2], a[3], a[4]); -- 2*a0 + 3*a1 + a2 + a3
s[2][c + 1] = bit32.bxor(a[1], b[2], a[3], b[3], a[4]); -- a0 * 2*a1 + 3*a2 + a3
s[3][c + 1] = bit32.bxor(a[1], a[2], b[3], a[4], b[4]); -- a0 + a1 + 2*a2 + 3*a3
s[4][c + 1] = bit32.bxor(a[1], b[1], a[2], a[3], b[4]); -- 3*a0 + a1 + a2 + 2*a3
end
return s;
end
local function SubWord(w) -- apply SBox to 4-byte word w
for i = 0,3 do w[i + 1] = Sbox[w[i + 1] + 1]; end
return w;
end
local function RotWord(w) -- rotate 4-byte word w left by one byte
w[5] = w[1];
for i = 0,3 do w[i + 1] = w[i + 2]; end
return w;
end
local function AddRoundKey(state, w, rnd, Nb) -- xor Round Key into state S [§5.1.4]
for r = 0,3 do
for c = 0,Nb-1 do state[r + 1][c + 1] = bit32.bxor(state[r + 1][c + 1], w[rnd*4+c + 1][r + 1]); end
end
return state;
end
local function Cipher(input, w) -- main Cipher function [§5.1]
local Nb = 4; -- block size (in words): no of columns in state (fixed at 4 for AES) local Nb = 4; -- block size (in words): no of columns in state (fixed at 4 for AES)
local Nr = #w / Nb - 1; -- no of rounds: 10/12/14 for 128/192/256-bit keys local Nr = #w / Nb - 1; -- no of rounds: 10/12/14 for 128/192/256-bit keys
@ -69,56 +130,7 @@ function Cipher(input, w) -- main Cipher function [§5.1]
end end
function SubBytes(s, Nb) -- apply SBox to state S [§5.1.1] local function KeyExpansion(key) -- generate Key Schedule (byte-array Nr+1 x Nb) from Key [§5.2]
for r = 0,3 do
for c = 0,Nb-1 do s[r + 1][c + 1] = Sbox[s[r + 1][c + 1] + 1]; end
end
return s;
end
function ShiftRows(s, Nb) -- shift row r of state S left by r bytes [§5.1.2]
local t = {};
for r = 1,3 do
for c = 0,3 do t[c + 1] = s[r + 1][((c + r) % Nb) + 1] end; -- shift into temp copy
for c = 0,3 do s[r + 1][c + 1] = t[c + 1]; end -- and copy back
end -- note that this will work for Nb=4,5,6, but not 7,8 (always 4 for AES):
return s; -- see fp.gladman.plus.com/cryptography_technology/rijndael/aes.spec.311.pdf
end
function MixColumns(s, Nb) -- combine bytes of each col of state S [§5.1.3]
for c = 0,3 do
local a = {}; -- 'a' is a copy of the current column from 's'
local b = {}; -- 'b' is a•{02} in GF(2^8)
for i = 0,3 do
a[i + 1] = s[i + 1][c + 1];
if bit32.band(s[i + 1][c + 1], 0x80) ~= 0 then
b[i + 1] = bit32.bxor(bit32.lshift(s[i + 1][c + 1], 1), 0x011b);
else
b[i + 1] = bit32.lshift(s[i + 1][c + 1], 1);
end
end
-- a[n] ^ b[n] is a•{03} in GF(2^8)
s[1][c + 1] = bit32.bxor(b[1], a[2], b[2], a[3], a[4]); -- 2*a0 + 3*a1 + a2 + a3
s[2][c + 1] = bit32.bxor(a[1], b[2], a[3], b[3], a[4]); -- a0 * 2*a1 + 3*a2 + a3
s[3][c + 1] = bit32.bxor(a[1], a[2], b[3], a[4], b[4]); -- a0 + a1 + 2*a2 + 3*a3
s[4][c + 1] = bit32.bxor(a[1], b[1], a[2], a[3], b[4]); -- 3*a0 + a1 + a2 + 2*a3
end
return s;
end
function AddRoundKey(state, w, rnd, Nb) -- xor Round Key into state S [§5.1.4]
for r = 0,3 do
for c = 0,Nb-1 do state[r + 1][c + 1] = bit32.bxor(state[r + 1][c + 1], w[rnd*4+c + 1][r + 1]); end
end
return state;
end
function KeyExpansion(key) -- generate Key Schedule (byte-array Nr+1 x Nb) from Key [§5.2]
local Nb = 4; -- block size (in words): no of columns in state (fixed at 4 for AES) local Nb = 4; -- block size (in words): no of columns in state (fixed at 4 for AES)
local Nk = #key / 4 -- key length (in words): 4/6/8 for 128/192/256-bit keys local Nk = #key / 4 -- key length (in words): 4/6/8 for 128/192/256-bit keys
local Nr = Nk + 6; -- no of rounds: 10/12/14 for 128/192/256-bit keys local Nr = Nk + 6; -- no of rounds: 10/12/14 for 128/192/256-bit keys
@ -146,17 +158,17 @@ function KeyExpansion(key) -- generate Key Schedule (byte-array Nr+1 x Nb) from
return w; return w;
end end
function SubWord(w) -- apply SBox to 4-byte word w local function escCtrlChars(str) -- escape control chars which might cause problems handling ciphertext
for i = 0,3 do w[i + 1] = Sbox[w[i + 1] + 1]; end return string.gsub(str, "[\0\t\n\v\f\r\'\"!-]", function(c) return '!' .. string.byte(c, 1) .. '!'; end);
return w;
end end
function RotWord(w) -- rotate 4-byte word w left by one byte local function unescCtrlChars(str) -- unescape potentially problematic control characters
w[5] = w[1]; return string.gsub(str, "!%d%d?%d?!", function(c)
for i = 0,3 do w[i + 1] = w[i + 2]; end local sc = string.sub(c, 2,-2)
return w;
end
return string.char(tonumber(sc));
end);
end
--[[ --[[
* Use AES to encrypt 'plaintext' with 'password' using 'nBits' key, in 'Counter' mode of operation * Use AES to encrypt 'plaintext' with 'password' using 'nBits' key, in 'Counter' mode of operation
@ -166,7 +178,7 @@ end
* - cipherblock = plaintext xor outputblock * - cipherblock = plaintext xor outputblock
]] ]]
function AESEncryptCtr(plaintext, password, nBits) local function AESEncryptCtr(plaintext, password, nBits)
if (not (nBits==128 or nBits==192 or nBits==256)) then return ''; end -- standard allows 128/192/256 bit keys if (not (nBits==128 or nBits==192 or nBits==256)) then return ''; end -- standard allows 128/192/256 bit keys
-- for this example script, generate the key by applying Cipher to 1st 16/24/32 chars of password; -- for this example script, generate the key by applying Cipher to 1st 16/24/32 chars of password;
@ -243,7 +255,7 @@ end
* - cipherblock = plaintext xor outputblock * - cipherblock = plaintext xor outputblock
]] ]]
function AESDecryptCtr(ciphertext, password, nBits) local function AESDecryptCtr(ciphertext, password, nBits)
if (not (nBits==128 or nBits==192 or nBits==256)) then return ''; end -- standard allows 128/192/256 bit keys if (not (nBits==128 or nBits==192 or nBits==256)) then return ''; end -- standard allows 128/192/256 bit keys
local nBytes = nBits/8; -- no bytes in key local nBytes = nBits/8; -- no bytes in key
@ -300,19 +312,7 @@ function AESDecryptCtr(ciphertext, password, nBits)
return table.concat(plaintext) return table.concat(plaintext)
end end
function escCtrlChars(str) -- escape control chars which might cause problems handling ciphertext local function test()
return string.gsub(str, "[\0\t\n\v\f\r\'\"!-]", function(c) return '!' .. string.byte(c, 1) .. '!'; end);
end
function unescCtrlChars(str) -- unescape potentially problematic control characters
return string.gsub(str, "!%d%d?%d?!", function(c)
local sc = string.sub(c, 2,-2)
return string.char(tonumber(sc));
end);
end
function test()
local plainText = "ROMEO: But, soft! what light through yonder window breaks?\n\ local plainText = "ROMEO: But, soft! what light through yonder window breaks?\n\
It is the east, and Juliet is the sun.\n\ It is the east, and Juliet is the sun.\n\

View file

@ -31,15 +31,15 @@ function test()
local AG_CONST = 0.6072529350; local AG_CONST = 0.6072529350;
function FIXED(X) local function FIXED(X)
return X * 65536.0; return X * 65536.0;
end end
function FLOAT(X) local function FLOAT(X)
return X / 65536.0; return X / 65536.0;
end end
function DEG2RAD(X) local function DEG2RAD(X)
return 0.017453 * (X); return 0.017453 * (X);
end end
@ -52,7 +52,7 @@ local Angles = {
local Target = 28.027; local Target = 28.027;
function cordicsincos(Target) local function cordicsincos(Target)
local X; local X;
local Y; local Y;
local TargetAngle; local TargetAngle;
@ -85,7 +85,7 @@ end
local total = 0; local total = 0;
function cordic( runs ) local function cordic( runs )
for i = 1,runs do for i = 1,runs do
total = total + cordicsincos(Target); total = total + cordicsincos(Target);
end end

View file

@ -7,7 +7,7 @@ local bench = script and require(script.Parent.bench_support) or require("bench_
function test() function test()
function partial(n) local function partial(n)
local a1, a2, a3, a4, a5, a6, a7, a8, a9 = 0, 0, 0, 0, 0, 0, 0, 0, 0; local a1, a2, a3, a4, a5, a6, a7, a8, a9 = 0, 0, 0, 0, 0, 0, 0, 0, 0;
local twothirds = 2.0/3.0; local twothirds = 2.0/3.0;
local alt = -1.0; local alt = -1.0;

View file

@ -1,72 +0,0 @@
--[[
The Great Computer Language Shootout
http://shootout.alioth.debian.org/
contributed by Ian Osgood
]]
local bench = script and require(script.Parent.bench_support) or require("bench_support")
function test()
function A(i,j)
return 1/((i+j)*(i+j+1)/2+i+1);
end
function Au(u,v)
for i = 0,#u-1 do
local t = 0;
for j = 0,#u-1 do
t = t + A(i,j) * u[j + 1];
end
v[i + 1] = t;
end
end
function Atu(u,v)
for i = 0,#u-1 do
local t = 0;
for j = 0,#u-1 do
t = t + A(j,i) * u[j + 1];
end
v[i + 1] = t;
end
end
function AtAu(u,v,w)
Au(u,w);
Atu(w,v);
end
function spectralnorm(n)
local u, v, w, vv, vBv = {}, {}, {}, 0, 0;
for i = 1,n do
u[i] = 1; v[i] = 0; w[i] = 0;
end
for i = 0,9 do
AtAu(u,v,w);
AtAu(v,u,w);
end
for i = 1,n do
vBv = vBv + u[i]*v[i];
vv = vv + v[i]*v[i];
end
return math.sqrt(vBv/vv);
end
local total = 0;
local i = 6
while i <= 48 do
total = total + spectralnorm(i);
i = i * 2
end
local expected = 5.086694231303284;
if (total ~= expected) then
assert(false, "ERROR: bad result: expected " .. expected .. " but got " .. total)
end
end
bench.runCode(test, "math-spectral-norm")

View file

@ -2760,8 +2760,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons")
TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons")
{ {
ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true};
check(R"( check(R"(
type tag = "cat" | "dog" type tag = "cat" | "dog"
local function f(a: tag) end local function f(a: tag) end
@ -2798,8 +2796,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons")
TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_equality") TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_equality")
{ {
ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true};
check(R"( check(R"(
type tagged = {tag:"cat", fieldx:number} | {tag:"dog", fieldy:number} type tagged = {tag:"cat", fieldx:number} | {tag:"dog", fieldy:number}
local x: tagged = {tag="cat", fieldx=2} local x: tagged = {tag="cat", fieldx=2}
@ -2821,8 +2817,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_equality")
TEST_CASE_FIXTURE(ACFixture, "autocomplete_boolean_singleton") TEST_CASE_FIXTURE(ACFixture, "autocomplete_boolean_singleton")
{ {
ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true};
check(R"( check(R"(
local function f(x: true) end local function f(x: true) end
f(@1) f(@1)
@ -2838,8 +2832,6 @@ f(@1)
TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_escape") TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_escape")
{ {
ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true};
check(R"( check(R"(
type tag = "strange\t\"cat\"" | 'nice\t"dog"' type tag = "strange\t\"cat\"" | 'nice\t"dog"'
local function f(x: tag) end local function f(x: tag) end

View file

@ -261,6 +261,9 @@ RETURN R0 0
TEST_CASE("ForBytecode") TEST_CASE("ForBytecode")
{ {
ScopedFastFlag sff("LuauCompileIter", true);
ScopedFastFlag sff2("LuauCompileIterNoPairs", false);
// basic for loop: variable directly refers to internal iteration index (R2) // basic for loop: variable directly refers to internal iteration index (R2)
CHECK_EQ("\n" + compileFunction0("for i=1,5 do print(i) end"), R"( CHECK_EQ("\n" + compileFunction0("for i=1,5 do print(i) end"), R"(
LOADN R2 1 LOADN R2 1
@ -295,7 +298,7 @@ GETIMPORT R0 2
LOADK R1 K3 LOADK R1 K3
LOADK R2 K4 LOADK R2 K4
CALL R0 2 3 CALL R0 2 3
JUMP +4 FORGPREP R0 +4
GETIMPORT R5 6 GETIMPORT R5 6
MOVE R6 R3 MOVE R6 R3
CALL R5 1 0 CALL R5 1 0
@ -347,6 +350,8 @@ RETURN R0 0
TEST_CASE("ForBytecodeBuiltin") TEST_CASE("ForBytecodeBuiltin")
{ {
ScopedFastFlag sff("LuauCompileIter", true);
// we generally recognize builtins like pairs/ipairs and emit special opcodes // we generally recognize builtins like pairs/ipairs and emit special opcodes
CHECK_EQ("\n" + compileFunction0("for k,v in ipairs({}) do end"), R"( CHECK_EQ("\n" + compileFunction0("for k,v in ipairs({}) do end"), R"(
GETIMPORT R0 1 GETIMPORT R0 1
@ -385,7 +390,7 @@ GETIMPORT R0 3
MOVE R1 R0 MOVE R1 R0
NEWTABLE R2 0 0 NEWTABLE R2 0 0
CALL R1 1 3 CALL R1 1 3
JUMP +0 FORGPREP R1 +0
FORGLOOP R1 -1 2 FORGLOOP R1 -1 2
RETURN R0 0 RETURN R0 0
)"); )");
@ -397,7 +402,7 @@ SETGLOBAL R0 K2
GETGLOBAL R0 K2 GETGLOBAL R0 K2
NEWTABLE R1 0 0 NEWTABLE R1 0 0
CALL R0 1 3 CALL R0 1 3
JUMP +0 FORGPREP R0 +0
FORGLOOP R0 -1 2 FORGLOOP R0 -1 2
RETURN R0 0 RETURN R0 0
)"); )");
@ -407,7 +412,7 @@ RETURN R0 0
GETIMPORT R0 1 GETIMPORT R0 1
NEWTABLE R1 0 0 NEWTABLE R1 0 0
CALL R0 1 3 CALL R0 1 3
JUMP +0 FORGPREP R0 +0
FORGLOOP R0 -1 2 FORGLOOP R0 -1 2
RETURN R0 0 RETURN R0 0
)"); )");
@ -2260,6 +2265,8 @@ TEST_CASE("TypeAliasing")
TEST_CASE("DebugLineInfo") TEST_CASE("DebugLineInfo")
{ {
ScopedFastFlag sff("LuauCompileIterNoPairs", false);
Luau::BytecodeBuilder bcb; Luau::BytecodeBuilder bcb;
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines);
Luau::compileOrThrow(bcb, R"( Luau::compileOrThrow(bcb, R"(
@ -2316,6 +2323,8 @@ return result
TEST_CASE("DebugLineInfoFor") TEST_CASE("DebugLineInfoFor")
{ {
ScopedFastFlag sff("LuauCompileIter", true);
Luau::BytecodeBuilder bcb; Luau::BytecodeBuilder bcb;
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines);
Luau::compileOrThrow(bcb, R"( Luau::compileOrThrow(bcb, R"(
@ -2336,7 +2345,7 @@ end
5: LOADN R0 1 5: LOADN R0 1
7: LOADN R1 2 7: LOADN R1 2
9: LOADN R2 3 9: LOADN R2 3
9: JUMP +4 9: FORGPREP R0 +4
11: GETIMPORT R5 1 11: GETIMPORT R5 1
11: MOVE R6 R3 11: MOVE R6 R3
11: CALL R5 1 0 11: CALL R5 1 0
@ -2541,6 +2550,8 @@ a
TEST_CASE("DebugSource") TEST_CASE("DebugSource")
{ {
ScopedFastFlag sff("LuauCompileIterNoPairs", false);
const char* source = R"( const char* source = R"(
local kSelectedBiomes = { local kSelectedBiomes = {
['Mountains'] = true, ['Mountains'] = true,
@ -2616,6 +2627,8 @@ RETURN R1 1
TEST_CASE("DebugLocals") TEST_CASE("DebugLocals")
{ {
ScopedFastFlag sff("LuauCompileIterNoPairs", false);
const char* source = R"( const char* source = R"(
function foo(e, f) function foo(e, f)
local a = 1 local a = 1
@ -3767,6 +3780,8 @@ RETURN R0 1
TEST_CASE("SharedClosure") TEST_CASE("SharedClosure")
{ {
ScopedFastFlag sff("LuauCompileIterNoPairs", false);
// closures can be shared even if functions refer to upvalues, as long as upvalues are top-level // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level
CHECK_EQ("\n" + compileFunction(R"( CHECK_EQ("\n" + compileFunction(R"(
local val = ... local val = ...
@ -4452,5 +4467,688 @@ RETURN R0 0
)"); )");
} }
TEST_CASE("InlineBasic")
{
ScopedFastFlag sff("LuauCompileSupportInlining", true);
// inline function that returns a constant
CHECK_EQ("\n" + compileFunction(R"(
local function foo()
return 42
end
local x = foo()
return x
)",
1, 2),
R"(
DUPCLOSURE R0 K0
LOADN R1 42
RETURN R1 1
)");
// inline function that returns the argument
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a)
return a
end
local x = foo(42)
return x
)",
1, 2),
R"(
DUPCLOSURE R0 K0
LOADN R1 42
RETURN R1 1
)");
// inline function that returns one of the two arguments
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a, b, c)
if a then
return b
else
return c
end
end
local x = foo(true, math.random(), 5)
return x
)",
1, 2),
R"(
DUPCLOSURE R0 K0
GETIMPORT R2 3
CALL R2 0 1
MOVE R1 R2
RETURN R1 1
RETURN R1 1
)");
// inline function that returns one of the two arguments
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a, b, c)
if a then
return b
else
return c
end
end
local x = foo(true, 5, math.random())
return x
)",
1, 2),
R"(
DUPCLOSURE R0 K0
GETIMPORT R2 3
CALL R2 0 1
LOADN R1 5
RETURN R1 1
RETURN R1 1
)");
}
TEST_CASE("InlineMutate")
{
ScopedFastFlag sff("LuauCompileSupportInlining", true);
// if the argument is mutated, it gets a register even if the value is constant
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a)
a = a or 5
return a
end
local x = foo(42)
return x
)",
1, 2),
R"(
DUPCLOSURE R0 K0
LOADN R2 42
ORK R2 R2 K1
MOVE R1 R2
RETURN R1 1
)");
// if the argument is a local, it can be used directly
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a)
return a
end
local x = ...
local y = foo(x)
return y
)",
1, 2),
R"(
DUPCLOSURE R0 K0
GETVARARGS R1 1
MOVE R2 R1
RETURN R2 1
)");
// ... but if it's mutated, we move it in case it is mutated through a capture during the inlined function
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a)
return a
end
local x = ...
x = nil
local y = foo(x)
return y
)",
1, 2),
R"(
DUPCLOSURE R0 K0
GETVARARGS R1 1
LOADNIL R1
MOVE R3 R1
MOVE R2 R3
RETURN R2 1
)");
// we also don't inline functions if they have been assigned to
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a)
return a
end
foo = foo
local x = foo(42)
return x
)",
1, 2),
R"(
DUPCLOSURE R0 K0
MOVE R0 R0
MOVE R1 R0
LOADN R2 42
CALL R1 1 1
RETURN R1 1
)");
}
TEST_CASE("InlineUpval")
{
ScopedFastFlag sff("LuauCompileSupportInlining", true);
// if the argument is an upvalue, we naturally need to copy it to a local
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a)
return a
end
local b = ...
function bar()
local x = foo(b)
return x
end
)",
1, 2),
R"(
GETUPVAL R1 0
MOVE R0 R1
RETURN R0 1
)");
// if the function uses an upvalue it's more complicated, because the lexical upvalue may become a local
CHECK_EQ("\n" + compileFunction(R"(
local b = ...
local function foo(a)
return a + b
end
local x = foo(42)
return x
)",
1, 2),
R"(
GETVARARGS R0 1
DUPCLOSURE R1 K0
CAPTURE VAL R0
LOADN R3 42
ADD R2 R3 R0
RETURN R2 1
)");
// sometimes the lexical upvalue is deep enough that it's still an upvalue though
CHECK_EQ("\n" + compileFunction(R"(
local b = ...
function bar()
local function foo(a)
return a + b
end
local x = foo(42)
return x
end
)",
1, 2),
R"(
DUPCLOSURE R0 K0
CAPTURE UPVAL U0
LOADN R2 42
GETUPVAL R3 0
ADD R1 R2 R3
RETURN R1 1
)");
}
TEST_CASE("InlineFallthrough")
{
ScopedFastFlag sff("LuauCompileSupportInlining", true);
// if the function doesn't return, we still fill the results with nil
CHECK_EQ("\n" + compileFunction(R"(
local function foo()
end
local a, b = foo()
return a, b
)",
1, 2),
R"(
DUPCLOSURE R0 K0
LOADNIL R1
LOADNIL R2
MOVE R3 R1
MOVE R4 R2
RETURN R3 2
)");
// this happens even if the function returns conditionally
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a)
if a then return 42 end
end
local a, b = foo(false)
return a, b
)",
1, 2),
R"(
DUPCLOSURE R0 K0
LOADNIL R1
LOADNIL R2
MOVE R3 R1
MOVE R4 R2
RETURN R3 2
)");
// note though that we can't inline a function like this in multret context
// this is because we don't have a SETTOP instruction
CHECK_EQ("\n" + compileFunction(R"(
local function foo()
end
return foo()
)",
1, 2),
R"(
DUPCLOSURE R0 K0
MOVE R1 R0
CALL R1 0 -1
RETURN R1 -1
)");
}
TEST_CASE("InlineCapture")
{
ScopedFastFlag sff("LuauCompileSupportInlining", true);
// can't inline function with nested functions that capture locals because they might be constants
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a)
local function bar()
return a
end
return bar()
end
)",
1, 2),
R"(
NEWCLOSURE R1 P0
CAPTURE VAL R0
MOVE R2 R1
CALL R2 0 -1
RETURN R2 -1
)");
}
TEST_CASE("InlineArgMismatch")
{
ScopedFastFlag sff("LuauCompileSupportInlining", true);
// when inlining a function, we must respect all the usual rules
// caller might not have enough arguments
// TODO: we don't inline this atm
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a)
return a
end
local x = foo()
return x
)",
1, 2),
R"(
DUPCLOSURE R0 K0
MOVE R1 R0
CALL R1 0 1
RETURN R1 1
)");
// caller might be using multret for arguments
// TODO: we don't inline this atm
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a, b)
return a + b
end
local x = foo(math.modf(1.5))
return x
)",
1, 2),
R"(
DUPCLOSURE R0 K0
MOVE R1 R0
LOADK R3 K1
FASTCALL1 20 R3 +2
GETIMPORT R2 4
CALL R2 1 -1
CALL R1 -1 1
RETURN R1 1
)");
// caller might have too many arguments, but we still need to compute them for side effects
// TODO: we don't inline this atm
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a)
return a
end
local x = foo(42, print())
return x
)",
1, 2),
R"(
DUPCLOSURE R0 K0
MOVE R1 R0
LOADN R2 42
GETIMPORT R3 2
CALL R3 0 -1
CALL R1 -1 1
RETURN R1 1
)");
}
TEST_CASE("InlineMultiple")
{
ScopedFastFlag sff("LuauCompileSupportInlining", true);
// we call this with a different set of variable/constant args
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a, b)
return a + b
end
local x, y = ...
local a = foo(x, 1)
local b = foo(1, x)
local c = foo(1, 2)
local d = foo(x, y)
return a, b, c, d
)",
1, 2),
R"(
DUPCLOSURE R0 K0
GETVARARGS R1 2
ADDK R3 R1 K1
LOADN R5 1
ADD R4 R5 R1
LOADN R5 3
ADD R6 R1 R2
MOVE R7 R3
MOVE R8 R4
MOVE R9 R5
MOVE R10 R6
RETURN R7 4
)");
}
TEST_CASE("InlineChain")
{
ScopedFastFlag sff("LuauCompileSupportInlining", true);
// inline a chain of functions
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a, b)
return a + b
end
local function bar(x)
return foo(x, 1) * foo(x, -1)
end
local function baz()
return (bar(42))
end
return (baz())
)",
3, 2),
R"(
DUPCLOSURE R0 K0
DUPCLOSURE R1 K1
DUPCLOSURE R2 K2
LOADN R4 43
LOADN R5 41
MUL R3 R4 R5
RETURN R3 1
)");
}
TEST_CASE("InlineThresholds")
{
ScopedFastFlag sff("LuauCompileSupportInlining", true);
ScopedFastInt sfis[] = {
{"LuauCompileInlineThreshold", 25},
{"LuauCompileInlineThresholdMaxBoost", 300},
{"LuauCompileInlineDepth", 2},
};
// this function has enormous register pressure (50 regs) so we choose not to inline it
CHECK_EQ("\n" + compileFunction(R"(
local function foo()
return {{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}
end
return (foo())
)",
1, 2),
R"(
DUPCLOSURE R0 K0
MOVE R1 R0
CALL R1 0 1
RETURN R1 1
)");
// this function has less register pressure but a large cost
CHECK_EQ("\n" + compileFunction(R"(
local function foo()
return {},{},{},{},{}
end
return (foo())
)",
1, 2),
R"(
DUPCLOSURE R0 K0
MOVE R1 R0
CALL R1 0 1
RETURN R1 1
)");
// this chain of function is of length 3 but our limit in this test is 2, so we call foo twice
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a, b)
return a + b
end
local function bar(x)
return foo(x, 1) * foo(x, -1)
end
local function baz()
return (bar(42))
end
return (baz())
)",
3, 2),
R"(
DUPCLOSURE R0 K0
DUPCLOSURE R1 K1
DUPCLOSURE R2 K2
MOVE R4 R0
LOADN R5 42
LOADN R6 1
CALL R4 2 1
MOVE R5 R0
LOADN R6 42
LOADN R7 -1
CALL R5 2 1
MUL R3 R4 R5
RETURN R3 1
)");
}
TEST_CASE("InlineIIFE")
{
ScopedFastFlag sff("LuauCompileSupportInlining", true);
// IIFE with arguments
CHECK_EQ("\n" + compileFunction(R"(
function choose(a, b, c)
return ((function(a, b, c) if a then return b else return c end end)(a, b, c))
end
)",
1, 2),
R"(
JUMPIFNOT R0 +2
MOVE R3 R1
RETURN R3 1
MOVE R3 R2
RETURN R3 1
RETURN R3 1
)");
// IIFE with upvalues
CHECK_EQ("\n" + compileFunction(R"(
function choose(a, b, c)
return ((function() if a then return b else return c end end)())
end
)",
1, 2),
R"(
JUMPIFNOT R0 +2
MOVE R3 R1
RETURN R3 1
MOVE R3 R2
RETURN R3 1
RETURN R3 1
)");
}
TEST_CASE("InlineRecurseArguments")
{
ScopedFastFlag sff("LuauCompileSupportInlining", true);
// we can't inline a function if it's used to compute its own arguments
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a, b)
end
foo(foo(foo,foo(foo,foo))[foo])
)",
1, 2),
R"(
DUPCLOSURE R0 K0
MOVE R1 R0
MOVE R4 R0
MOVE R5 R0
MOVE R6 R0
CALL R4 2 1
LOADNIL R3
GETTABLE R2 R3 R0
CALL R1 1 0
RETURN R0 0
)");
}
TEST_CASE("InlineFastCallK")
{
ScopedFastFlag sff("LuauCompileSupportInlining", true);
CHECK_EQ("\n" + compileFunction(R"(
local function set(l0)
rawset({}, l0)
end
set(false)
set({})
)",
1, 2),
R"(
DUPCLOSURE R0 K0
NEWTABLE R2 0 0
FASTCALL2K 49 R2 K1 +4
LOADK R3 K1
GETIMPORT R1 3
CALL R1 2 0
NEWTABLE R1 0 0
NEWTABLE R3 0 0
FASTCALL2 49 R3 R1 +4
MOVE R4 R1
GETIMPORT R2 3
CALL R2 2 0
RETURN R0 0
)");
}
TEST_CASE("InlineExprIndexK")
{
ScopedFastFlag sff("LuauCompileSupportInlining", true);
CHECK_EQ("\n" + compileFunction(R"(
local _ = function(l0)
local _ = nil
while _(_)[_] do
end
end
local _ = _(0)[""]
if _ then
do
for l0=0,8 do
end
end
elseif _ then
_ = nil
do
for l0=0,8 do
return true
end
end
end
)",
1, 2),
R"(
DUPCLOSURE R0 K0
LOADNIL R4
LOADNIL R5
CALL R4 1 1
LOADNIL R5
GETTABLE R3 R4 R5
JUMPIFNOT R3 +1
JUMPBACK -7
LOADNIL R2
GETTABLEKS R1 R2 K1
JUMPIFNOT R1 +1
RETURN R0 0
JUMPIFNOT R1 +19
LOADNIL R1
LOADB R2 1
RETURN R2 1
LOADB R2 1
RETURN R2 1
LOADB R2 1
RETURN R2 1
LOADB R2 1
RETURN R2 1
LOADB R2 1
RETURN R2 1
LOADB R2 1
RETURN R2 1
LOADB R2 1
RETURN R2 1
LOADB R2 1
RETURN R2 1
LOADB R2 1
RETURN R2 1
RETURN R0 0
)");
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -241,6 +241,8 @@ TEST_CASE("Math")
TEST_CASE("Table") TEST_CASE("Table")
{ {
ScopedFastFlag sff("LuauFixBuiltinsStackLimit", true);
runConformance("nextvar.lua"); runConformance("nextvar.lua");
} }
@ -1099,4 +1101,14 @@ TEST_CASE("UserdataApi")
CHECK(dtorhits == 42); CHECK(dtorhits == 42);
} }
TEST_CASE("Iter")
{
ScopedFastFlag sffs[] = {
{ "LuauCompileIter", true },
{ "LuauIter", true },
};
runConformance("iter.lua");
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -386,8 +386,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "cycle_error_paths")
TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface") TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface")
{ {
ScopedFastFlag luauCyclicModuleTypeSurface{"LuauCyclicModuleTypeSurface", true};
fileResolver.source["game/A"] = R"( fileResolver.source["game/A"] = R"(
return {hello = 2} return {hello = 2}
)"; )";
@ -410,8 +408,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface")
TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface_longer") TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface_longer")
{ {
ScopedFastFlag luauCyclicModuleTypeSurface{"LuauCyclicModuleTypeSurface", true};
fileResolver.source["game/A"] = R"( fileResolver.source["game/A"] = R"(
return {mod_a = 2} return {mod_a = 2}
)"; )";

View file

@ -2041,8 +2041,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors")
TEST_CASE_FIXTURE(Fixture, "parse_type_pack_errors") TEST_CASE_FIXTURE(Fixture, "parse_type_pack_errors")
{ {
ScopedFastFlag luauParseRecoverUnexpectedPack{"LuauParseRecoverUnexpectedPack", true};
matchParseError("type Y<T...> = {a: T..., b: number}", "Unexpected '...' after type name; type pack is not allowed in this context", matchParseError("type Y<T...> = {a: T..., b: number}", "Unexpected '...' after type name; type pack is not allowed in this context",
Location{{0, 20}, {0, 23}}); Location{{0, 20}, {0, 23}});
matchParseError("type Y<T...> = {a: (number | string)...", "Unexpected '...' after type annotation", Location{{0, 36}, {0, 39}}); matchParseError("type Y<T...> = {a: (number | string)...", "Unexpected '...' after type annotation", Location{{0, 36}, {0, 39}});
@ -2618,8 +2616,6 @@ type Y<T..., U = T...> = (T...) -> U...
TEST_CASE_FIXTURE(Fixture, "recover_unexpected_type_pack") TEST_CASE_FIXTURE(Fixture, "recover_unexpected_type_pack")
{ {
ScopedFastFlag luauParseRecoverUnexpectedPack{"LuauParseRecoverUnexpectedPack", true};
ParseResult result = tryParse(R"( ParseResult result = tryParse(R"(
type X<T...> = { a: T..., b: number } type X<T...> = { a: T..., b: number }
type Y<T> = { a: T..., b: number } type Y<T> = { a: T..., b: number }

View file

@ -35,9 +35,9 @@ bool hasError(const CheckResult& result, T* = nullptr)
return it != result.errors.end(); return it != result.errors.end();
} }
TEST_SUITE_BEGIN("RuntimeLimitTests"); TEST_SUITE_BEGIN("RuntimeLimits");
TEST_CASE_FIXTURE(LimitFixture, "bail_early_on_typescript_port_of_Result_type" * doctest::timeout(1.0)) TEST_CASE_FIXTURE(LimitFixture, "typescript_port_of_Result_type")
{ {
constexpr const char* src = R"LUA( constexpr const char* src = R"LUA(
--!strict --!strict

View file

@ -488,4 +488,71 @@ TEST_CASE_FIXTURE(Fixture, "fuzz_fail_missing_instantitation_follow")
)"); )");
} }
TEST_CASE_FIXTURE(Fixture, "loop_iter_basic")
{
ScopedFastFlag sff{"LuauTypecheckIter", true};
CheckResult result = check(R"(
local t: {string} = {}
local key
for k: number in t do
end
for k: number, v: string in t do
end
for k, v in t do
key = k
end
)");
LUAU_REQUIRE_ERROR_COUNT(0, result);
CHECK_EQ(*typeChecker.numberType, *requireType("key"));
}
TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil")
{
ScopedFastFlag sff{"LuauTypecheckIter", true};
CheckResult result = check(R"(
local t: {string} = {}
local extra
for k, v, e in t do
extra = e
end
)");
LUAU_REQUIRE_ERROR_COUNT(0, result);
CHECK_EQ(*typeChecker.nilType, *requireType("extra"));
}
TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer")
{
ScopedFastFlag sff{"LuauTypecheckIter", true};
CheckResult result = check(R"(
local t = {}
for k, v in t do
end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
GenericError* ge = get<GenericError>(result.errors[0]);
REQUIRE(ge);
CHECK_EQ("Cannot iterate over a table without indexer", ge->message);
}
TEST_CASE_FIXTURE(Fixture, "loop_iter_iter_metamethod")
{
ScopedFastFlag sff{"LuauTypecheckIter", true};
CheckResult result = check(R"(
local t = {}
setmetatable(t, { __iter = function(o) return next, o.children end })
for k: number, v: string in t do
end
)");
LUAU_REQUIRE_ERROR_COUNT(0, result);
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -5,7 +5,6 @@
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/VisitTypeVar.h"
#include "Fixture.h" #include "Fixture.h"

View file

@ -2331,7 +2331,7 @@ TEST_CASE_FIXTURE(Fixture, "confusing_indexing")
TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table") TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table")
{ {
ScopedFastFlag sff{"LuauDifferentOrderOfUnificationDoesntMatter", true}; ScopedFastFlag sff{"LuauDifferentOrderOfUnificationDoesntMatter2", true};
CheckResult result = check(R"( CheckResult result = check(R"(
local a: {x: number, y: number, [any]: any} | {y: number} local a: {x: number, y: number, [any]: any} | {y: number}
@ -2351,7 +2351,7 @@ TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a
TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table_2") TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table_2")
{ {
ScopedFastFlag sff{"LuauDifferentOrderOfUnificationDoesntMatter", true}; ScopedFastFlag sff{"LuauDifferentOrderOfUnificationDoesntMatter2", true};
CheckResult result = check(R"( CheckResult result = check(R"(
local a: {y: number} | {x: number, y: number, [any]: any} local a: {y: number} | {x: number, y: number, [any]: any}

View file

@ -1034,4 +1034,45 @@ TEST_CASE_FIXTURE(Fixture, "follow_on_new_types_in_substitution")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
} }
/**
* The problem we had here was that the type of q in B.h was initially inferring to {} | {prop: free} before we bound
* that second table to the enclosing union.
*/
TEST_CASE_FIXTURE(Fixture, "do_not_bind_a_free_table_to_a_union_containing_that_table")
{
ScopedFastFlag flag[] = {
{"LuauStatFunctionSimplify4", true},
{"LuauLowerBoundsCalculation", true},
{"LuauDifferentOrderOfUnificationDoesntMatter2", true},
};
CheckResult result = check(R"(
--!strict
local A = {}
function A:f()
local t = {}
for key, value in pairs(self) do
t[key] = value
end
return t
end
local B = A:f()
function B.g(t)
assert(type(t) == "table")
assert(t.prop ~= nil)
end
function B.h(q)
q = q or {}
return q or {}
end
)");
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -242,8 +242,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification")
TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner")
{ {
ScopedFastFlag luauTxnLogPreserveOwner{"LuauTxnLogPreserveOwner", true};
TypeId a = arena.addType(TypeVar{FreeTypeVar{TypeLevel{}}}); TypeId a = arena.addType(TypeVar{FreeTypeVar{TypeLevel{}}});
TypeId b = typeChecker.numberType; TypeId b = typeChecker.numberType;
@ -255,8 +253,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner")
TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_pack_owner") TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_pack_owner")
{ {
ScopedFastFlag luauTxnLogPreserveOwner{"LuauTxnLogPreserveOwner", true};
TypePackId a = arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}); TypePackId a = arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}});
TypePackId b = typeChecker.anyTypePack; TypePackId b = typeChecker.anyTypePack;

View file

@ -313,23 +313,33 @@ TEST_CASE("tagging_props")
CHECK(Luau::hasTag(prop, "foo")); CHECK(Luau::hasTag(prop, "foo"));
} }
struct VisitCountTracker struct VisitCountTracker final : TypeVarOnceVisitor
{ {
std::unordered_map<TypeId, unsigned> tyVisits; std::unordered_map<TypeId, unsigned> tyVisits;
std::unordered_map<TypePackId, unsigned> tpVisits; std::unordered_map<TypePackId, unsigned> tpVisits;
void cycle(TypeId) {} void cycle(TypeId) override {}
void cycle(TypePackId) {} void cycle(TypePackId) override {}
template<typename T> template<typename T>
bool operator()(TypeId ty, const T& t) bool operator()(TypeId ty, const T& t)
{
return visit(ty);
}
template<typename T>
bool operator()(TypePackId tp, const T&)
{
return visit(tp);
}
bool visit(TypeId ty) override
{ {
tyVisits[ty]++; tyVisits[ty]++;
return true; return true;
} }
template<typename T> bool visit(TypePackId tp) override
bool operator()(TypePackId tp, const T&)
{ {
tpVisits[tp]++; tpVisits[tp]++;
return true; return true;
@ -348,7 +358,7 @@ local b: (T, T, T) -> T
VisitCountTracker tester; VisitCountTracker tester;
DenseHashSet<void*> seen{nullptr}; DenseHashSet<void*> seen{nullptr};
visitTypeVarOnce(bType, tester, seen); DEPRECATED_visitTypeVarOnce(bType, tester, seen);
for (auto [_, count] : tester.tyVisits) for (auto [_, count] : tester.tyVisits)
CHECK_EQ(count, 1); CHECK_EQ(count, 1);

View file

@ -0,0 +1,48 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Fixture.h"
#include "Luau/RecursionCounter.h"
#include "doctest.h"
using namespace Luau;
LUAU_FASTFLAG(LuauUseVisitRecursionLimit)
LUAU_FASTINT(LuauVisitRecursionLimit)
struct VisitTypeVarFixture : Fixture
{
ScopedFastFlag flag1 = {"LuauUseVisitRecursionLimit", true};
ScopedFastFlag flag2 = {"LuauRecursionLimitException", true};
};
TEST_SUITE_BEGIN("VisitTypeVar");
TEST_CASE_FIXTURE(VisitTypeVarFixture, "throw_when_limit_is_exceeded")
{
ScopedFastInt sfi{"LuauVisitRecursionLimit", 3};
CheckResult result = check(R"(
local t : {a: {b: {c: {d: {e: boolean}}}}}
)");
TypeId tType = requireType("t");
CHECK_THROWS_AS(toString(tType), RecursionLimitException);
}
TEST_CASE_FIXTURE(VisitTypeVarFixture, "dont_throw_when_limit_is_high_enough")
{
ScopedFastInt sfi{"LuauVisitRecursionLimit", 8};
CheckResult result = check(R"(
local t : {a: {b: {c: {d: {e: boolean}}}}}
)");
TypeId tType = requireType("t");
(void)toString(tType);
}
TEST_SUITE_END();

196
tests/conformance/iter.lua Normal file
View file

@ -0,0 +1,196 @@
-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes
print('testing iteration')
-- basic for loop tests
do
local a
for a,b in pairs{} do error("not here") end
for i=1,0 do error("not here") end
for i=0,1,-1 do error("not here") end
a = nil; for i=1,1 do assert(not a); a=1 end; assert(a)
a = nil; for i=1,1,-1 do assert(not a); a=1 end; assert(a)
a = 0; for i=0, 1, 0.1 do a=a+1 end; assert(a==11)
end
-- precision tests for for loops
do
local a
--a = 0; for i=1, 0, -0.01 do a=a+1 end; assert(a==101)
a = 0; for i=0, 0.999999999, 0.1 do a=a+1 end; assert(a==10)
a = 0; for i=1, 1, 1 do a=a+1 end; assert(a==1)
a = 0; for i=1e10, 1e10, -1 do a=a+1 end; assert(a==1)
a = 0; for i=1, 0.99999, 1 do a=a+1 end; assert(a==0)
a = 0; for i=99999, 1e5, -1 do a=a+1 end; assert(a==0)
a = 0; for i=1, 0.99999, -1 do a=a+1 end; assert(a==1)
end
-- for loops do string->number coercion
do
local a = 0; for i="10","1","-2" do a=a+1 end; assert(a==5)
end
-- generic for with function iterators
do
local function f (n, p)
local t = {}; for i=1,p do t[i] = i*10 end
return function (_,n)
if n > 0 then
n = n-1
return n, unpack(t)
end
end, nil, n
end
local x = 0
for n,a,b,c,d in f(5,3) do
x = x+1
assert(a == 10 and b == 20 and c == 30 and d == nil)
end
assert(x == 5)
end
-- generic for with __call (tables)
do
local f = {}
setmetatable(f, { __call = function(_, _, n) if n > 0 then return n - 1 end end })
local x = 0
for n in f, nil, 5 do
x += n
end
assert(x == 10)
end
-- generic for with __call (userdata)
do
local f = newproxy(true)
getmetatable(f).__call = function(_, _, n) if n > 0 then return n - 1 end end
local x = 0
for n in f, nil, 5 do
x += n
end
assert(x == 10)
end
-- generic for with pairs
do
local x = 0
for k, v in pairs({a = 1, b = 2, c = 3}) do
x += v
end
assert(x == 6)
end
-- generic for with pairs with holes
do
local x = 0
for k, v in pairs({1, 2, 3, nil, 5}) do
x += v
end
assert(x == 11)
end
-- generic for with ipairs
do
local x = 0
for k, v in ipairs({1, 2, 3, nil, 5}) do
x += v
end
assert(x == 6)
end
-- generic for with __iter (tables)
do
local f = {}
setmetatable(f, { __iter = function(x)
assert(f == x)
return next, {1, 2, 3, 4}
end })
local x = 0
for n in f do
x += n
end
assert(x == 10)
end
-- generic for with __iter (userdata)
do
local f = newproxy(true)
getmetatable(f).__iter = function(x)
assert(f == x)
return next, {1, 2, 3, 4}
end
local x = 0
for n in f do
x += n
end
assert(x == 10)
end
-- generic for with tables (dictionary)
do
local x = 0
for k, v in {a = 1, b = 2, c = 3} do
print(k, v)
x += v
end
assert(x == 6)
end
-- generic for with tables (arrays)
do
local x = ''
for k, v in {1, 2, 3, nil, 5} do
x ..= tostring(v)
end
assert(x == "1235")
end
-- generic for with tables (mixed)
do
local x = 0
for k, v in {1, 2, 3, nil, 5, a = 1, b = 2, c = 3} do
x += v
end
assert(x == 17)
end
-- generic for over a non-iterable object
do
local ok, err = pcall(function() for x in 42 do end end)
assert(not ok and err:match("attempt to iterate"))
end
-- generic for over an iterable object that doesn't return a function
do
local obj = {}
setmetatable(obj, { __iter = function() end })
local ok, err = pcall(function() for x in obj do end end)
assert(not ok and err:match("attempt to call a nil value"))
end
-- it's okay to iterate through a table with a single variable
do
local x = 0
for k in {1, 2, 3, 4, 5} do
x += k
end
assert(x == 15)
end
-- all extra variables should be set to nil during builtin traversal
do
local x = 0
for k,v,a,b,c,d,e in {1, 2, 3, 4, 5} do
x += k
assert(a == nil and b == nil and c == nil and d == nil and e == nil)
end
assert(x == 15)
end
return"OK"

View file

@ -368,48 +368,6 @@ assert(next(a,nil) == 1000 and next(a,1000) == nil)
assert(next({}) == nil) assert(next({}) == nil)
assert(next({}, nil) == nil) assert(next({}, nil) == nil)
for a,b in pairs{} do error("not here") end
for i=1,0 do error("not here") end
for i=0,1,-1 do error("not here") end
a = nil; for i=1,1 do assert(not a); a=1 end; assert(a)
a = nil; for i=1,1,-1 do assert(not a); a=1 end; assert(a)
a = 0; for i=0, 1, 0.1 do a=a+1 end; assert(a==11)
-- precision problems
--a = 0; for i=1, 0, -0.01 do a=a+1 end; assert(a==101)
a = 0; for i=0, 0.999999999, 0.1 do a=a+1 end; assert(a==10)
a = 0; for i=1, 1, 1 do a=a+1 end; assert(a==1)
a = 0; for i=1e10, 1e10, -1 do a=a+1 end; assert(a==1)
a = 0; for i=1, 0.99999, 1 do a=a+1 end; assert(a==0)
a = 0; for i=99999, 1e5, -1 do a=a+1 end; assert(a==0)
a = 0; for i=1, 0.99999, -1 do a=a+1 end; assert(a==1)
-- conversion
a = 0; for i="10","1","-2" do a=a+1 end; assert(a==5)
collectgarbage()
-- testing generic 'for'
local function f (n, p)
local t = {}; for i=1,p do t[i] = i*10 end
return function (_,n)
if n > 0 then
n = n-1
return n, unpack(t)
end
end, nil, n
end
local x = 0
for n,a,b,c,d in f(5,3) do
x = x+1
assert(a == 10 and b == 20 and c == 30 and d == nil)
end
assert(x == 5)
-- testing table.create and table.find -- testing table.create and table.find
do do
local t = table.create(5) local t = table.create(5)
@ -596,4 +554,17 @@ do
assert(#t2 == 6) assert(#t2 == 6)
end end
-- test table.unpack fastcall for rejecting large unpacks
do
local ok, res = pcall(function()
local a = table.create(7999, 0)
local b = table.create(8000, 0)
local at = { table.unpack(a) }
local bt = { table.unpack(b) }
end)
assert(not ok)
end
return"OK" return"OK"

View file

@ -97,7 +97,7 @@ class LuauVariantSyntheticChildrenProvider:
if self.current_type: if self.current_type:
storage = self.valobj.GetChildMemberWithName("storage") storage = self.valobj.GetChildMemberWithName("storage")
self.stored_value = storage.Cast(self.current_type.GetPointerType()).Dereference() self.stored_value = storage.Cast(self.current_type)
else: else:
self.stored_value = None self.stored_value = None
else: else: