Sync to upstream/release/594 (#1036)

* Fixed `Frontend::markDirty` not working on modules that were not
typechecked yet
* Fixed generic variadic function unification succeeding when it should
have reported an error

New Type Solver:
* Implemented semantic subtyping check for function types

Native Code Generation:
* Improved performance of numerical loops with a constant step
* Simplified IR for `bit32.extract` calls extracting first/last bits
* Improved performance of NaN checks
This commit is contained in:
vegorov-rbx 2023-09-07 17:13:49 -07:00 committed by GitHub
parent bf1fb8f1e4
commit c7c986b996
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
46 changed files with 1628 additions and 1070 deletions

View file

@ -14,8 +14,6 @@ struct GlobalTypes;
struct TypeChecker; struct TypeChecker;
struct TypeArena; struct TypeArena;
void registerBuiltinTypes(GlobalTypes& globals);
void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete = false); void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete = false);
TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types); TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types);
TypeId makeIntersection(TypeArena& arena, std::vector<TypeId>&& types); TypeId makeIntersection(TypeArena& arena, std::vector<TypeId>&& types);

View file

@ -2,12 +2,12 @@
#pragma once #pragma once
#include "Luau/Config.h" #include "Luau/Config.h"
#include "Luau/GlobalTypes.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
#include "Luau/RequireTracer.h" #include "Luau/RequireTracer.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/TypeCheckLimits.h" #include "Luau/TypeCheckLimits.h"
#include "Luau/TypeInfer.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include <mutex> #include <mutex>

View file

@ -0,0 +1,26 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Module.h"
#include "Luau/NotNull.h"
#include "Luau/Scope.h"
#include "Luau/TypeArena.h"
namespace Luau
{
struct BuiltinTypes;
struct GlobalTypes
{
explicit GlobalTypes(NotNull<BuiltinTypes> builtinTypes);
NotNull<BuiltinTypes> builtinTypes; // Global types are based on builtin types
TypeArena globalTypes;
SourceModule globalNames; // names for symbols entered into globalScope
ScopePtr globalScope; // shared by all modules
};
}

View file

@ -19,6 +19,7 @@ class TypeIds;
class Normalizer; class Normalizer;
struct NormalizedType; struct NormalizedType;
struct NormalizedClassType; struct NormalizedClassType;
struct NormalizedFunctionType;
struct SubtypingResult struct SubtypingResult
{ {
@ -103,6 +104,7 @@ private:
SubtypingResult isSubtype_(const NormalizedType* subNorm, const NormalizedType* superNorm); SubtypingResult isSubtype_(const NormalizedType* subNorm, const NormalizedType* superNorm);
SubtypingResult isSubtype_(const NormalizedClassType& subClass, const NormalizedClassType& superClass, const TypeIds& superTables); SubtypingResult isSubtype_(const NormalizedClassType& subClass, const NormalizedClassType& superClass, const TypeIds& superTables);
SubtypingResult isSubtype_(const NormalizedFunctionType& subFunction, const NormalizedFunctionType& superFunction);
SubtypingResult isSubtype_(const TypeIds& subTypes, const TypeIds& superTypes); SubtypingResult isSubtype_(const TypeIds& subTypes, const TypeIds& superTypes);
SubtypingResult isSubtype_(const VariadicTypePack* subVariadic, const VariadicTypePack* superVariadic); SubtypingResult isSubtype_(const VariadicTypePack* subVariadic, const VariadicTypePack* superVariadic);

View file

@ -798,12 +798,13 @@ struct BuiltinTypes
TypeId errorRecoveryType() const; TypeId errorRecoveryType() const;
TypePackId errorRecoveryTypePack() const; TypePackId errorRecoveryTypePack() const;
friend TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes);
friend struct GlobalTypes;
private: private:
std::unique_ptr<struct TypeArena> arena; std::unique_ptr<struct TypeArena> arena;
bool debugFreezeArena = false; bool debugFreezeArena = false;
TypeId makeStringMetatable();
public: public:
const TypeId nilType; const TypeId nilType;
const TypeId numberType; const TypeId numberType;

View file

@ -57,17 +57,6 @@ struct HashBoolNamePair
size_t operator()(const std::pair<bool, Name>& pair) const; size_t operator()(const std::pair<bool, Name>& pair) const;
}; };
struct GlobalTypes
{
GlobalTypes(NotNull<BuiltinTypes> builtinTypes);
NotNull<BuiltinTypes> builtinTypes; // Global types are based on builtin types
TypeArena globalTypes;
SourceModule globalNames; // names for symbols entered into globalScope
ScopePtr globalScope; // shared by all modules
};
// All Types are retained via Environment::types. All TypeIds // All Types are retained via Environment::types. All TypeIds
// within a program are borrowed pointers into this set. // within a program are borrowed pointers into this set.
struct TypeChecker struct TypeChecker

View file

@ -13,8 +13,6 @@
#include <utility> #include <utility>
LUAU_FASTFLAG(DebugLuauReadWriteProperties) LUAU_FASTFLAG(DebugLuauReadWriteProperties)
LUAU_FASTFLAGVARIABLE(LuauAnonymousAutofilled1, false);
LUAU_FASTFLAGVARIABLE(LuauAutocompleteLastTypecheck, false)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteDoEnd, false) LUAU_FASTFLAGVARIABLE(LuauAutocompleteDoEnd, false)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteStringLiteralBounds, false); LUAU_FASTFLAGVARIABLE(LuauAutocompleteStringLiteralBounds, false);
@ -611,7 +609,6 @@ std::optional<TypeId> getLocalTypeInScopeAt(const Module& module, Position posit
template <typename T> template <typename T>
static std::optional<std::string> tryToStringDetailed(const ScopePtr& scope, T ty, bool functionTypeArguments) static std::optional<std::string> tryToStringDetailed(const ScopePtr& scope, T ty, bool functionTypeArguments)
{ {
LUAU_ASSERT(FFlag::LuauAnonymousAutofilled1);
ToStringOptions opts; ToStringOptions opts;
opts.useLineBreaks = false; opts.useLineBreaks = false;
opts.hideTableKind = true; opts.hideTableKind = true;
@ -630,24 +627,8 @@ static std::optional<Name> tryGetTypeNameInScope(ScopePtr scope, TypeId ty, bool
if (!canSuggestInferredType(scope, ty)) if (!canSuggestInferredType(scope, ty))
return std::nullopt; return std::nullopt;
if (FFlag::LuauAnonymousAutofilled1)
{
return tryToStringDetailed(scope, ty, functionTypeArguments); return tryToStringDetailed(scope, ty, functionTypeArguments);
} }
else
{
ToStringOptions opts;
opts.useLineBreaks = false;
opts.hideTableKind = true;
opts.scope = scope;
ToStringResult name = toStringDetailed(ty, opts);
if (name.error || name.invalid || name.cycle || name.truncated)
return std::nullopt;
return name.name;
}
}
static bool tryAddTypeCorrectSuggestion(AutocompleteEntryMap& result, ScopePtr scope, AstType* topType, TypeId inferredType, Position position) static bool tryAddTypeCorrectSuggestion(AutocompleteEntryMap& result, ScopePtr scope, AstType* topType, TypeId inferredType, Position position)
{ {
@ -1417,7 +1398,6 @@ static AutocompleteResult autocompleteWhileLoopKeywords(std::vector<AstNode*> an
static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& funcTy) static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& funcTy)
{ {
LUAU_ASSERT(FFlag::LuauAnonymousAutofilled1);
std::string result = "function("; std::string result = "function(";
auto [args, tail] = Luau::flatten(funcTy.argTypes); auto [args, tail] = Luau::flatten(funcTy.argTypes);
@ -1483,7 +1463,6 @@ static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& func
static std::optional<AutocompleteEntry> makeAnonymousAutofilled(const ModulePtr& module, Position position, const AstNode* node, const std::vector<AstNode*>& ancestry) static std::optional<AutocompleteEntry> makeAnonymousAutofilled(const ModulePtr& module, Position position, const AstNode* node, const std::vector<AstNode*>& ancestry)
{ {
LUAU_ASSERT(FFlag::LuauAnonymousAutofilled1);
const AstExprCall* call = node->as<AstExprCall>(); const AstExprCall* call = node->as<AstExprCall>();
if (!call && ancestry.size() > 1) if (!call && ancestry.size() > 1)
call = ancestry[ancestry.size() - 2]->as<AstExprCall>(); call = ancestry[ancestry.size() - 2]->as<AstExprCall>();
@ -1800,19 +1779,12 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
return {}; return {};
if (node->asExpr()) if (node->asExpr())
{
if (FFlag::LuauAnonymousAutofilled1)
{ {
AutocompleteResult ret = autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); AutocompleteResult ret = autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position);
if (std::optional<AutocompleteEntry> generated = makeAnonymousAutofilled(module, position, node, ancestry)) if (std::optional<AutocompleteEntry> generated = makeAnonymousAutofilled(module, position, node, ancestry))
ret.entryMap[kGeneratedAnonymousFunctionEntryName] = std::move(*generated); ret.entryMap[kGeneratedAnonymousFunctionEntryName] = std::move(*generated);
return ret; return ret;
} }
else
{
return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position);
}
}
else if (node->asStat()) else if (node->asStat())
return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement};
@ -1821,15 +1793,6 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback) AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback)
{ {
if (!FFlag::LuauAutocompleteLastTypecheck)
{
// FIXME: We can improve performance here by parsing without checking.
// The old type graph is probably fine. (famous last words!)
FrontendOptions opts;
opts.forAutocomplete = true;
frontend.check(moduleName, opts);
}
const SourceModule* sourceModule = frontend.getSourceModule(moduleName); const SourceModule* sourceModule = frontend.getSourceModule(moduleName);
if (!sourceModule) if (!sourceModule)
return {}; return {};

View file

@ -201,18 +201,6 @@ void assignPropDocumentationSymbols(TableType::Props& props, const std::string&
} }
} }
void registerBuiltinTypes(GlobalTypes& globals)
{
globals.globalScope->addBuiltinTypeBinding("any", TypeFun{{}, globals.builtinTypes->anyType});
globals.globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, globals.builtinTypes->nilType});
globals.globalScope->addBuiltinTypeBinding("number", TypeFun{{}, globals.builtinTypes->numberType});
globals.globalScope->addBuiltinTypeBinding("string", TypeFun{{}, globals.builtinTypes->stringType});
globals.globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, globals.builtinTypes->booleanType});
globals.globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, globals.builtinTypes->threadType});
globals.globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, globals.builtinTypes->unknownType});
globals.globalScope->addBuiltinTypeBinding("never", TypeFun{{}, globals.builtinTypes->neverType});
}
void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete) void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete)
{ {
LUAU_ASSERT(!globals.globalTypes.types.isFrozen()); LUAU_ASSERT(!globals.globalTypes.types.isFrozen());
@ -310,6 +298,520 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire);
} }
static std::vector<TypeId> parseFormatString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size)
{
const char* options = "cdiouxXeEfgGqs*";
std::vector<TypeId> result;
for (size_t i = 0; i < size; ++i)
{
if (data[i] == '%')
{
i++;
if (i < size && data[i] == '%')
continue;
// we just ignore all characters (including flags/precision) up until first alphabetic character
while (i < size && !(data[i] > 0 && (isalpha(data[i]) || data[i] == '*')))
i++;
if (i == size)
break;
if (data[i] == 'q' || data[i] == 's')
result.push_back(builtinTypes->stringType);
else if (data[i] == '*')
result.push_back(builtinTypes->unknownType);
else if (strchr(options, data[i]))
result.push_back(builtinTypes->numberType);
else
result.push_back(builtinTypes->errorRecoveryType(builtinTypes->anyType));
}
}
return result;
}
std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* fmt = nullptr;
if (auto index = expr.func->as<AstExprIndexName>(); index && expr.self)
{
if (auto group = index->expr->as<AstExprGroup>())
fmt = group->expr->as<AstExprConstantString>();
else
fmt = index->expr->as<AstExprConstantString>();
}
if (!expr.self && expr.args.size > 0)
fmt = expr.args.data[0]->as<AstExprConstantString>();
if (!fmt)
return std::nullopt;
std::vector<TypeId> expected = parseFormatString(typechecker.builtinTypes, fmt->value.data, fmt->value.size);
const auto& [params, tail] = flatten(paramPack);
size_t paramOffset = 1;
size_t dataOffset = expr.self ? 0 : 1;
// unify the prefix one argument at a time
for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i)
{
Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location;
typechecker.unify(params[i + paramOffset], expected[i], scope, location);
}
// if we know the argument count or if we have too many arguments for sure, we can issue an error
size_t numActualParams = params.size();
size_t numExpectedParams = expected.size() + 1; // + 1 for the format string
if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams))
typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}});
return WithPredicate<TypePackId>{arena.addTypePack({typechecker.stringType})};
}
static bool dcrMagicFunctionFormat(MagicFunctionCallContext context)
{
TypeArena* arena = context.solver->arena;
AstExprConstantString* fmt = nullptr;
if (auto index = context.callSite->func->as<AstExprIndexName>(); index && context.callSite->self)
{
if (auto group = index->expr->as<AstExprGroup>())
fmt = group->expr->as<AstExprConstantString>();
else
fmt = index->expr->as<AstExprConstantString>();
}
if (!context.callSite->self && context.callSite->args.size > 0)
fmt = context.callSite->args.data[0]->as<AstExprConstantString>();
if (!fmt)
return false;
std::vector<TypeId> expected = parseFormatString(context.solver->builtinTypes, fmt->value.data, fmt->value.size);
const auto& [params, tail] = flatten(context.arguments);
size_t paramOffset = 1;
// unify the prefix one argument at a time
for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i)
{
context.solver->unify(context.solver->rootScope, context.callSite->location, params[i + paramOffset], expected[i]);
}
// if we know the argument count or if we have too many arguments for sure, we can issue an error
size_t numActualParams = params.size();
size_t numExpectedParams = expected.size() + 1; // + 1 for the format string
if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams))
context.solver->reportError(TypeError{context.callSite->location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}});
TypePackId resultPack = arena->addTypePack({context.solver->builtinTypes->stringType});
asMutable(context.result)->ty.emplace<BoundTypePack>(resultPack);
return true;
}
static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size)
{
std::vector<TypeId> result;
int depth = 0;
bool parsingSet = false;
for (size_t i = 0; i < size; ++i)
{
if (data[i] == '%')
{
++i;
if (!parsingSet && i < size && data[i] == 'b')
i += 2;
}
else if (!parsingSet && data[i] == '[')
{
parsingSet = true;
if (i + 1 < size && data[i + 1] == ']')
i += 1;
}
else if (parsingSet && data[i] == ']')
{
parsingSet = false;
}
else if (data[i] == '(')
{
if (parsingSet)
continue;
if (i + 1 < size && data[i + 1] == ')')
{
i++;
result.push_back(builtinTypes->optionalNumberType);
continue;
}
++depth;
result.push_back(builtinTypes->optionalStringType);
}
else if (data[i] == ')')
{
if (parsingSet)
continue;
--depth;
if (depth < 0)
break;
}
}
if (depth != 0 || parsingSet)
return std::vector<TypeId>();
if (result.empty())
result.push_back(builtinTypes->optionalStringType);
return result;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() != 2)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t index = expr.self ? 0 : 1;
if (expr.args.size > index)
pattern = expr.args.data[index]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
std::vector<TypeId> returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location);
const TypePackId emptyPack = arena.addTypePack({});
const TypePackId returnList = arena.addTypePack(returnTypes);
const TypeId iteratorType = arena.addType(FunctionType{emptyPack, returnList});
return WithPredicate<TypePackId>{arena.addTypePack({iteratorType})};
}
static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context)
{
const auto& [params, tail] = flatten(context.arguments);
if (params.size() != 2)
return false;
TypeArena* arena = context.solver->arena;
AstExprConstantString* pattern = nullptr;
size_t index = context.callSite->self ? 0 : 1;
if (context.callSite->args.size > index)
pattern = context.callSite->args.data[index]->as<AstExprConstantString>();
if (!pattern)
return false;
std::vector<TypeId> returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return false;
context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType);
const TypePackId emptyPack = arena->addTypePack({});
const TypePackId returnList = arena->addTypePack(returnTypes);
const TypeId iteratorType = arena->addType(FunctionType{emptyPack, returnList});
const TypePackId resTypePack = arena->addTypePack({iteratorType});
asMutable(context.result)->ty.emplace<BoundTypePack>(resTypePack);
return true;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() < 2 || params.size() > 3)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = expr.self ? 0 : 1;
if (expr.args.size > patternIndex)
pattern = expr.args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
std::vector<TypeId> returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location);
const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}});
size_t initIndex = expr.self ? 1 : 2;
if (params.size() == 3 && expr.args.size > initIndex)
typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location);
const TypePackId returnList = arena.addTypePack(returnTypes);
return WithPredicate<TypePackId>{returnList};
}
static bool dcrMagicFunctionMatch(MagicFunctionCallContext context)
{
const auto& [params, tail] = flatten(context.arguments);
if (params.size() < 2 || params.size() > 3)
return false;
TypeArena* arena = context.solver->arena;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = context.callSite->self ? 0 : 1;
if (context.callSite->args.size > patternIndex)
pattern = context.callSite->args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return false;
std::vector<TypeId> returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return false;
context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType);
const TypeId optionalNumber = arena->addType(UnionType{{context.solver->builtinTypes->nilType, context.solver->builtinTypes->numberType}});
size_t initIndex = context.callSite->self ? 1 : 2;
if (params.size() == 3 && context.callSite->args.size > initIndex)
context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber);
const TypePackId returnList = arena->addTypePack(returnTypes);
asMutable(context.result)->ty.emplace<BoundTypePack>(returnList);
return true;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() < 2 || params.size() > 4)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = expr.self ? 0 : 1;
if (expr.args.size > patternIndex)
pattern = expr.args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
bool plain = false;
size_t plainIndex = expr.self ? 2 : 3;
if (expr.args.size > plainIndex)
{
AstExprConstantBool* p = expr.args.data[plainIndex]->as<AstExprConstantBool>();
plain = p && p->value;
}
std::vector<TypeId> returnTypes;
if (!plain)
{
returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
}
typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location);
const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}});
const TypeId optionalBoolean = arena.addType(UnionType{{typechecker.nilType, typechecker.booleanType}});
size_t initIndex = expr.self ? 1 : 2;
if (params.size() >= 3 && expr.args.size > initIndex)
typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location);
if (params.size() == 4 && expr.args.size > plainIndex)
typechecker.unify(params[3], optionalBoolean, scope, expr.args.data[plainIndex]->location);
returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber});
const TypePackId returnList = arena.addTypePack(returnTypes);
return WithPredicate<TypePackId>{returnList};
}
static bool dcrMagicFunctionFind(MagicFunctionCallContext context)
{
const auto& [params, tail] = flatten(context.arguments);
if (params.size() < 2 || params.size() > 4)
return false;
TypeArena* arena = context.solver->arena;
NotNull<BuiltinTypes> builtinTypes = context.solver->builtinTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = context.callSite->self ? 0 : 1;
if (context.callSite->args.size > patternIndex)
pattern = context.callSite->args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return false;
bool plain = false;
size_t plainIndex = context.callSite->self ? 2 : 3;
if (context.callSite->args.size > plainIndex)
{
AstExprConstantBool* p = context.callSite->args.data[plainIndex]->as<AstExprConstantBool>();
plain = p && p->value;
}
std::vector<TypeId> returnTypes;
if (!plain)
{
returnTypes = parsePatternString(builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return false;
}
context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], builtinTypes->stringType);
const TypeId optionalNumber = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->numberType}});
const TypeId optionalBoolean = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->booleanType}});
size_t initIndex = context.callSite->self ? 1 : 2;
if (params.size() >= 3 && context.callSite->args.size > initIndex)
context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber);
if (params.size() == 4 && context.callSite->args.size > plainIndex)
context.solver->unify(context.solver->rootScope, context.callSite->location, params[3], optionalBoolean);
returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber});
const TypePackId returnList = arena->addTypePack(returnTypes);
asMutable(context.result)->ty.emplace<BoundTypePack>(returnList);
return true;
}
TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
{
NotNull<TypeArena> arena{builtinTypes->arena.get()};
const TypeId nilType = builtinTypes->nilType;
const TypeId numberType = builtinTypes->numberType;
const TypeId booleanType = builtinTypes->booleanType;
const TypeId stringType = builtinTypes->stringType;
const TypeId anyType = builtinTypes->anyType;
const TypeId optionalNumber = arena->addType(UnionType{{nilType, numberType}});
const TypeId optionalString = arena->addType(UnionType{{nilType, stringType}});
const TypeId optionalBoolean = arena->addType(UnionType{{nilType, booleanType}});
const TypePackId oneStringPack = arena->addTypePack({stringType});
const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true});
FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack};
formatFTV.magicFunction = &magicFunctionFormat;
const TypeId formatFn = arena->addType(formatFTV);
attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat);
const TypePackId emptyPack = arena->addTypePack({});
const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}});
const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}});
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType});
const TypeId replArgType =
arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)),
makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}});
const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType});
const TypeId gmatchFunc =
makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})});
attachMagicFunction(gmatchFunc, magicFunctionGmatch);
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
const TypeId matchFunc = arena->addType(
FunctionType{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})});
attachMagicFunction(matchFunc, magicFunctionMatch);
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})});
attachMagicFunction(findFunc, magicFunctionFind);
attachDcrMagicFunction(findFunc, dcrMagicFunctionFind);
TableType::Props stringLib = {
{"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}},
{"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}},
{"find", {findFunc}},
{"format", {formatFn}}, // FIXME
{"gmatch", {gmatchFunc}},
{"gsub", {gsubFunc}},
{"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"lower", {stringToStringType}},
{"match", {matchFunc}},
{"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}},
{"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}},
{"upper", {stringToStringType}},
{"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {},
{arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}},
{"pack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType}, anyTypePack}),
oneStringPack,
})}},
{"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"unpack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}),
anyTypePack,
})}},
};
assignPropDocumentationSymbols(stringLib, "@luau/global/string");
TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed});
if (TableType* ttv = getMutable<TableType>(tableType))
ttv->name = "typeof(string)";
return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
}
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect( static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate) TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{ {

View file

@ -36,6 +36,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false)
LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false) LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false)
LUAU_FASTFLAGVARIABLE(LuauTypecheckLimitControls, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckLimitControls, false)
LUAU_FASTFLAGVARIABLE(CorrectEarlyReturnInMarkDirty, false)
namespace Luau namespace Luau
{ {
@ -928,7 +929,6 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item)
{ {
// The autocomplete typecheck is always in strict mode with DM awareness // The autocomplete typecheck is always in strict mode with DM awareness
// to provide better type information for IDE features // to provide better type information for IDE features
TypeCheckLimits typeCheckLimits;
if (autocompleteTimeLimit != 0.0) if (autocompleteTimeLimit != 0.0)
typeCheckLimits.finishTime = TimeTrace::getClock() + autocompleteTimeLimit; typeCheckLimits.finishTime = TimeTrace::getClock() + autocompleteTimeLimit;
@ -1148,9 +1148,17 @@ bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const
* It would be nice for this function to be O(1) * It would be nice for this function to be O(1)
*/ */
void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty) void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty)
{
if (FFlag::CorrectEarlyReturnInMarkDirty)
{
if (sourceNodes.count(name) == 0)
return;
}
else
{ {
if (!moduleResolver.getModule(name) && !moduleResolverForAutocomplete.getModule(name)) if (!moduleResolver.getModule(name) && !moduleResolverForAutocomplete.getModule(name))
return; return;
}
std::unordered_map<ModuleName, std::vector<ModuleName>> reverseDeps; std::unordered_map<ModuleName, std::vector<ModuleName>> reverseDeps;
for (const auto& module : sourceNodes) for (const auto& module : sourceNodes)

View file

@ -0,0 +1,34 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/GlobalTypes.h"
LUAU_FASTFLAG(LuauInitializeStringMetatableInGlobalTypes)
namespace Luau
{
GlobalTypes::GlobalTypes(NotNull<BuiltinTypes> builtinTypes)
: builtinTypes(builtinTypes)
{
globalScope = std::make_shared<Scope>(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}));
globalScope->addBuiltinTypeBinding("any", TypeFun{{}, builtinTypes->anyType});
globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, builtinTypes->nilType});
globalScope->addBuiltinTypeBinding("number", TypeFun{{}, builtinTypes->numberType});
globalScope->addBuiltinTypeBinding("string", TypeFun{{}, builtinTypes->stringType});
globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, builtinTypes->booleanType});
globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, builtinTypes->threadType});
globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, builtinTypes->unknownType});
globalScope->addBuiltinTypeBinding("never", TypeFun{{}, builtinTypes->neverType});
if (FFlag::LuauInitializeStringMetatableInGlobalTypes)
{
unfreeze(*builtinTypes->arena);
TypeId stringMetatableTy = makeStringMetatable(builtinTypes);
asMutable(builtinTypes->stringType)->ty.emplace<PrimitiveType>(PrimitiveType::String, stringMetatableTy);
persist(stringMetatableTy);
freeze(*builtinTypes->arena);
}
}
}

View file

@ -664,9 +664,8 @@ SubtypingResult Subtyping::isSubtype_(const NormalizedType* subNorm, const Norma
result.andAlso(isSubtype_(subNorm->tables, superNorm->tables)); result.andAlso(isSubtype_(subNorm->tables, superNorm->tables));
// isSubtype_(subNorm->tables, superNorm->strings); // isSubtype_(subNorm->tables, superNorm->strings);
// isSubtype_(subNorm->tables, superNorm->classes); // isSubtype_(subNorm->tables, superNorm->classes);
// isSubtype_(subNorm->functions, superNorm->functions); result.andAlso(isSubtype_(subNorm->functions, superNorm->functions));
// isSubtype_(subNorm->tyvars, superNorm->tyvars); // isSubtype_(subNorm->tyvars, superNorm->tyvars);
return result; return result;
} }
@ -703,6 +702,16 @@ SubtypingResult Subtyping::isSubtype_(const NormalizedClassType& subClass, const
return {true}; return {true};
} }
SubtypingResult Subtyping::isSubtype_(const NormalizedFunctionType& subFunction, const NormalizedFunctionType& superFunction)
{
if (subFunction.isNever())
return {true};
else if (superFunction.isTop)
return {true};
else
return isSubtype_(subFunction.parts, superFunction.parts);
}
SubtypingResult Subtyping::isSubtype_(const TypeIds& subTypes, const TypeIds& superTypes) SubtypingResult Subtyping::isSubtype_(const TypeIds& subTypes, const TypeIds& superTypes)
{ {
std::vector<SubtypingResult> results; std::vector<SubtypingResult> results;

View file

@ -9,6 +9,8 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
namespace Luau namespace Luau
{ {
@ -52,7 +54,7 @@ bool StateDot::canDuplicatePrimitive(TypeId ty)
if (get<BoundType>(ty)) if (get<BoundType>(ty))
return false; return false;
return get<PrimitiveType>(ty) || get<AnyType>(ty); return get<PrimitiveType>(ty) || get<AnyType>(ty) || get<UnknownType>(ty) || get<NeverType>(ty);
} }
void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName)
@ -76,6 +78,10 @@ void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName)
formatAppend(result, "n%d [label=\"%s\"];\n", index, toString(ty).c_str()); formatAppend(result, "n%d [label=\"%s\"];\n", index, toString(ty).c_str());
else if (get<AnyType>(ty)) else if (get<AnyType>(ty))
formatAppend(result, "n%d [label=\"any\"];\n", index); formatAppend(result, "n%d [label=\"any\"];\n", index);
else if (get<UnknownType>(ty))
formatAppend(result, "n%d [label=\"unknown\"];\n", index);
else if (get<NeverType>(ty))
formatAppend(result, "n%d [label=\"never\"];\n", index);
} }
else else
{ {
@ -139,142 +145,185 @@ void StateDot::visitChildren(TypeId ty, int index)
startNode(index); startNode(index);
startNodeLabel(); startNodeLabel();
if (const BoundType* btv = get<BoundType>(ty)) auto go = [&](auto&& t)
{
using T = std::decay_t<decltype(t)>;
if constexpr (std::is_same_v<T, BoundType>)
{ {
formatAppend(result, "BoundType %d", index); formatAppend(result, "BoundType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
visitChild(btv->boundTo, index); visitChild(t.boundTo, index);
} }
else if (const FunctionType* ftv = get<FunctionType>(ty)) else if constexpr (std::is_same_v<T, BlockedType>)
{
formatAppend(result, "BlockedType %d", index);
finishNodeLabel(ty);
finishNode();
}
else if constexpr (std::is_same_v<T, FunctionType>)
{ {
formatAppend(result, "FunctionType %d", index); formatAppend(result, "FunctionType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
visitChild(ftv->argTypes, index, "arg"); visitChild(t.argTypes, index, "arg");
visitChild(ftv->retTypes, index, "ret"); visitChild(t.retTypes, index, "ret");
} }
else if (const TableType* ttv = get<TableType>(ty)) else if constexpr (std::is_same_v<T, TableType>)
{ {
if (ttv->name) if (t.name)
formatAppend(result, "TableType %s", ttv->name->c_str()); formatAppend(result, "TableType %s", t.name->c_str());
else if (ttv->syntheticName) else if (t.syntheticName)
formatAppend(result, "TableType %s", ttv->syntheticName->c_str()); formatAppend(result, "TableType %s", t.syntheticName->c_str());
else else
formatAppend(result, "TableType %d", index); formatAppend(result, "TableType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
if (ttv->boundTo) if (t.boundTo)
return visitChild(*ttv->boundTo, index, "boundTo"); return visitChild(*t.boundTo, index, "boundTo");
for (const auto& [name, prop] : ttv->props) for (const auto& [name, prop] : t.props)
visitChild(prop.type(), index, name.c_str()); visitChild(prop.type(), index, name.c_str());
if (ttv->indexer) if (t.indexer)
{ {
visitChild(ttv->indexer->indexType, index, "[index]"); visitChild(t.indexer->indexType, index, "[index]");
visitChild(ttv->indexer->indexResultType, index, "[value]"); visitChild(t.indexer->indexResultType, index, "[value]");
} }
for (TypeId itp : ttv->instantiatedTypeParams) for (TypeId itp : t.instantiatedTypeParams)
visitChild(itp, index, "typeParam"); visitChild(itp, index, "typeParam");
for (TypePackId itp : ttv->instantiatedTypePackParams) for (TypePackId itp : t.instantiatedTypePackParams)
visitChild(itp, index, "typePackParam"); visitChild(itp, index, "typePackParam");
} }
else if (const MetatableType* mtv = get<MetatableType>(ty)) else if constexpr (std::is_same_v<T, MetatableType>)
{ {
formatAppend(result, "MetatableType %d", index); formatAppend(result, "MetatableType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
visitChild(mtv->table, index, "table"); visitChild(t.table, index, "table");
visitChild(mtv->metatable, index, "metatable"); visitChild(t.metatable, index, "metatable");
} }
else if (const UnionType* utv = get<UnionType>(ty)) else if constexpr (std::is_same_v<T, UnionType>)
{ {
formatAppend(result, "UnionType %d", index); formatAppend(result, "UnionType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
for (TypeId opt : utv->options) for (TypeId opt : t.options)
visitChild(opt, index); visitChild(opt, index);
} }
else if (const IntersectionType* itv = get<IntersectionType>(ty)) else if constexpr (std::is_same_v<T, IntersectionType>)
{ {
formatAppend(result, "IntersectionType %d", index); formatAppend(result, "IntersectionType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
for (TypeId part : itv->parts) for (TypeId part : t.parts)
visitChild(part, index); visitChild(part, index);
} }
else if (const GenericType* gtv = get<GenericType>(ty)) else if constexpr (std::is_same_v<T, LazyType>)
{ {
if (gtv->explicitName) formatAppend(result, "LazyType %d", index);
formatAppend(result, "GenericType %s", gtv->name.c_str()); finishNodeLabel(ty);
finishNode();
}
else if constexpr (std::is_same_v<T, PendingExpansionType>)
{
formatAppend(result, "PendingExpansionType %d", index);
finishNodeLabel(ty);
finishNode();
}
else if constexpr (std::is_same_v<T, GenericType>)
{
if (t.explicitName)
formatAppend(result, "GenericType %s", t.name.c_str());
else else
formatAppend(result, "GenericType %d", index); formatAppend(result, "GenericType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
} }
else if (const FreeType* ftv = get<FreeType>(ty)) else if constexpr (std::is_same_v<T, FreeType>)
{ {
formatAppend(result, "FreeType %d", index); formatAppend(result, "FreeType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
if (FFlag::DebugLuauDeferredConstraintResolution)
{
if (!get<NeverType>(t.lowerBound))
visitChild(t.lowerBound, index, "[lowerBound]");
if (!get<UnknownType>(t.upperBound))
visitChild(t.upperBound, index, "[upperBound]");
} }
else if (get<AnyType>(ty)) }
else if constexpr (std::is_same_v<T, AnyType>)
{ {
formatAppend(result, "AnyType %d", index); formatAppend(result, "AnyType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
} }
else if (get<PrimitiveType>(ty)) else if constexpr (std::is_same_v<T, UnknownType>)
{
formatAppend(result, "UnknownType %d", index);
finishNodeLabel(ty);
finishNode();
}
else if constexpr (std::is_same_v<T, NeverType>)
{
formatAppend(result, "NeverType %d", index);
finishNodeLabel(ty);
finishNode();
}
else if constexpr (std::is_same_v<T, PrimitiveType>)
{ {
formatAppend(result, "PrimitiveType %s", toString(ty).c_str()); formatAppend(result, "PrimitiveType %s", toString(ty).c_str());
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
} }
else if (get<ErrorType>(ty)) else if constexpr (std::is_same_v<T, ErrorType>)
{ {
formatAppend(result, "ErrorType %d", index); formatAppend(result, "ErrorType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
} }
else if (const ClassType* ctv = get<ClassType>(ty)) else if constexpr (std::is_same_v<T, ClassType>)
{ {
formatAppend(result, "ClassType %s", ctv->name.c_str()); formatAppend(result, "ClassType %s", t.name.c_str());
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
for (const auto& [name, prop] : ctv->props) for (const auto& [name, prop] : t.props)
visitChild(prop.type(), index, name.c_str()); visitChild(prop.type(), index, name.c_str());
if (ctv->parent) if (t.parent)
visitChild(*ctv->parent, index, "[parent]"); visitChild(*t.parent, index, "[parent]");
if (ctv->metatable) if (t.metatable)
visitChild(*ctv->metatable, index, "[metatable]"); visitChild(*t.metatable, index, "[metatable]");
if (ctv->indexer) if (t.indexer)
{ {
visitChild(ctv->indexer->indexType, index, "[index]"); visitChild(t.indexer->indexType, index, "[index]");
visitChild(ctv->indexer->indexResultType, index, "[value]"); visitChild(t.indexer->indexResultType, index, "[value]");
} }
} }
else if (const SingletonType* stv = get<SingletonType>(ty)) else if constexpr (std::is_same_v<T, SingletonType>)
{ {
std::string res; std::string res;
if (const StringSingleton* ss = get<StringSingleton>(stv)) if (const StringSingleton* ss = get<StringSingleton>(&t))
{ {
// Don't put in quotes anywhere. If it's outside of the call to escape, // Don't put in quotes anywhere. If it's outside of the call to escape,
// then it's invalid syntax. If it's inside, then escaping is super noisy. // then it's invalid syntax. If it's inside, then escaping is super noisy.
res = "string: " + escape(ss->value); res = "string: " + escape(ss->value);
} }
else if (const BooleanSingleton* bs = get<BooleanSingleton>(stv)) else if (const BooleanSingleton* bs = get<BooleanSingleton>(&t))
{ {
res = "boolean: "; res = "boolean: ";
res += bs->value ? "true" : "false"; res += bs->value ? "true" : "false";
@ -286,12 +335,25 @@ void StateDot::visitChildren(TypeId ty, int index)
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
} }
else else if constexpr (std::is_same_v<T, NegationType>)
{ {
LUAU_ASSERT(!"unknown type kind"); formatAppend(result, "NegationType %d", index);
finishNodeLabel(ty);
finishNode();
visitChild(t.ty, index, "[negated]");
}
else if constexpr (std::is_same_v<T, TypeFamilyInstanceType>)
{
formatAppend(result, "TypeFamilyInstanceType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
} }
else
static_assert(always_false_v<T>, "unknown type kind");
};
visit(go, ty->ty);
} }
void StateDot::visitChildren(TypePackId tp, int index) void StateDot::visitChildren(TypePackId tp, int index)

View file

@ -27,26 +27,11 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAG(LuauNormalizeBlockedTypes) LUAU_FASTFLAG(LuauNormalizeBlockedTypes)
LUAU_FASTFLAG(DebugLuauReadWriteProperties) LUAU_FASTFLAG(DebugLuauReadWriteProperties)
LUAU_FASTFLAGVARIABLE(LuauInitializeStringMetatableInGlobalTypes, false)
namespace Luau namespace Luau
{ {
std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
static bool dcrMagicFunctionFormat(MagicFunctionCallContext context);
static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context);
static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
static bool dcrMagicFunctionMatch(MagicFunctionCallContext context);
static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
static bool dcrMagicFunctionFind(MagicFunctionCallContext context);
// LUAU_NOINLINE prevents unwrapLazy from being inlined into advance below; advance is important to keep inlineable // LUAU_NOINLINE prevents unwrapLazy from being inlined into advance below; advance is important to keep inlineable
static LUAU_NOINLINE TypeId unwrapLazy(LazyType* ltv) static LUAU_NOINLINE TypeId unwrapLazy(LazyType* ltv)
{ {
@ -933,6 +918,8 @@ TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initi
std::initializer_list<TypePackId> genericPacks, std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames, std::initializer_list<TypePackId> genericPacks, std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes); std::initializer_list<TypeId> retTypes);
TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes); // BuiltinDefinitions.cpp
BuiltinTypes::BuiltinTypes() BuiltinTypes::BuiltinTypes()
: arena(new TypeArena) : arena(new TypeArena)
, debugFreezeArena(FFlag::DebugLuauFreezeArena) , debugFreezeArena(FFlag::DebugLuauFreezeArena)
@ -961,9 +948,12 @@ BuiltinTypes::BuiltinTypes()
, uninhabitableTypePack(arena->addTypePack(TypePackVar{TypePack{{neverType}, neverTypePack}, /*persistent*/ true})) , uninhabitableTypePack(arena->addTypePack(TypePackVar{TypePack{{neverType}, neverTypePack}, /*persistent*/ true}))
, errorTypePack(arena->addTypePack(TypePackVar{Unifiable::Error{}, /*persistent*/ true})) , errorTypePack(arena->addTypePack(TypePackVar{Unifiable::Error{}, /*persistent*/ true}))
{ {
TypeId stringMetatable = makeStringMetatable(); if (!FFlag::LuauInitializeStringMetatableInGlobalTypes)
{
TypeId stringMetatable = makeStringMetatable(NotNull{this});
asMutable(stringType)->ty = PrimitiveType{PrimitiveType::String, stringMetatable}; asMutable(stringType)->ty = PrimitiveType{PrimitiveType::String, stringMetatable};
persist(stringMetatable); persist(stringMetatable);
}
freeze(*arena); freeze(*arena);
} }
@ -980,82 +970,6 @@ BuiltinTypes::~BuiltinTypes()
FFlag::DebugLuauFreezeArena.value = prevFlag; FFlag::DebugLuauFreezeArena.value = prevFlag;
} }
TypeId BuiltinTypes::makeStringMetatable()
{
const TypeId optionalNumber = arena->addType(UnionType{{nilType, numberType}});
const TypeId optionalString = arena->addType(UnionType{{nilType, stringType}});
const TypeId optionalBoolean = arena->addType(UnionType{{nilType, booleanType}});
const TypePackId oneStringPack = arena->addTypePack({stringType});
const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true});
FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack};
formatFTV.magicFunction = &magicFunctionFormat;
const TypeId formatFn = arena->addType(formatFTV);
attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat);
const TypePackId emptyPack = arena->addTypePack({});
const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}});
const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}});
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType});
const TypeId replArgType =
arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)),
makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}});
const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType});
const TypeId gmatchFunc =
makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})});
attachMagicFunction(gmatchFunc, magicFunctionGmatch);
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
const TypeId matchFunc = arena->addType(
FunctionType{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})});
attachMagicFunction(matchFunc, magicFunctionMatch);
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})});
attachMagicFunction(findFunc, magicFunctionFind);
attachDcrMagicFunction(findFunc, dcrMagicFunctionFind);
TableType::Props stringLib = {
{"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}},
{"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}},
{"find", {findFunc}},
{"format", {formatFn}}, // FIXME
{"gmatch", {gmatchFunc}},
{"gsub", {gsubFunc}},
{"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"lower", {stringToStringType}},
{"match", {matchFunc}},
{"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}},
{"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}},
{"upper", {stringToStringType}},
{"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {},
{arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}},
{"pack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType}, anyTypePack}),
oneStringPack,
})}},
{"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"unpack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}),
anyTypePack,
})}},
};
assignPropDocumentationSymbols(stringLib, "@luau/global/string");
TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed});
if (TableType* ttv = getMutable<TableType>(tableType))
ttv->name = "typeof(string)";
return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
}
TypeId BuiltinTypes::errorRecoveryType() const TypeId BuiltinTypes::errorRecoveryType() const
{ {
return errorType; return errorType;
@ -1261,436 +1175,6 @@ IntersectionTypeIterator end(const IntersectionType* itv)
return IntersectionTypeIterator{}; return IntersectionTypeIterator{};
} }
static std::vector<TypeId> parseFormatString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size)
{
const char* options = "cdiouxXeEfgGqs*";
std::vector<TypeId> result;
for (size_t i = 0; i < size; ++i)
{
if (data[i] == '%')
{
i++;
if (i < size && data[i] == '%')
continue;
// we just ignore all characters (including flags/precision) up until first alphabetic character
while (i < size && !(data[i] > 0 && (isalpha(data[i]) || data[i] == '*')))
i++;
if (i == size)
break;
if (data[i] == 'q' || data[i] == 's')
result.push_back(builtinTypes->stringType);
else if (data[i] == '*')
result.push_back(builtinTypes->unknownType);
else if (strchr(options, data[i]))
result.push_back(builtinTypes->numberType);
else
result.push_back(builtinTypes->errorRecoveryType(builtinTypes->anyType));
}
}
return result;
}
std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* fmt = nullptr;
if (auto index = expr.func->as<AstExprIndexName>(); index && expr.self)
{
if (auto group = index->expr->as<AstExprGroup>())
fmt = group->expr->as<AstExprConstantString>();
else
fmt = index->expr->as<AstExprConstantString>();
}
if (!expr.self && expr.args.size > 0)
fmt = expr.args.data[0]->as<AstExprConstantString>();
if (!fmt)
return std::nullopt;
std::vector<TypeId> expected = parseFormatString(typechecker.builtinTypes, fmt->value.data, fmt->value.size);
const auto& [params, tail] = flatten(paramPack);
size_t paramOffset = 1;
size_t dataOffset = expr.self ? 0 : 1;
// unify the prefix one argument at a time
for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i)
{
Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location;
typechecker.unify(params[i + paramOffset], expected[i], scope, location);
}
// if we know the argument count or if we have too many arguments for sure, we can issue an error
size_t numActualParams = params.size();
size_t numExpectedParams = expected.size() + 1; // + 1 for the format string
if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams))
typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}});
return WithPredicate<TypePackId>{arena.addTypePack({typechecker.stringType})};
}
static bool dcrMagicFunctionFormat(MagicFunctionCallContext context)
{
TypeArena* arena = context.solver->arena;
AstExprConstantString* fmt = nullptr;
if (auto index = context.callSite->func->as<AstExprIndexName>(); index && context.callSite->self)
{
if (auto group = index->expr->as<AstExprGroup>())
fmt = group->expr->as<AstExprConstantString>();
else
fmt = index->expr->as<AstExprConstantString>();
}
if (!context.callSite->self && context.callSite->args.size > 0)
fmt = context.callSite->args.data[0]->as<AstExprConstantString>();
if (!fmt)
return false;
std::vector<TypeId> expected = parseFormatString(context.solver->builtinTypes, fmt->value.data, fmt->value.size);
const auto& [params, tail] = flatten(context.arguments);
size_t paramOffset = 1;
// unify the prefix one argument at a time
for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i)
{
context.solver->unify(context.solver->rootScope, context.callSite->location, params[i + paramOffset], expected[i]);
}
// if we know the argument count or if we have too many arguments for sure, we can issue an error
size_t numActualParams = params.size();
size_t numExpectedParams = expected.size() + 1; // + 1 for the format string
if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams))
context.solver->reportError(TypeError{context.callSite->location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}});
TypePackId resultPack = arena->addTypePack({context.solver->builtinTypes->stringType});
asMutable(context.result)->ty.emplace<BoundTypePack>(resultPack);
return true;
}
static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size)
{
std::vector<TypeId> result;
int depth = 0;
bool parsingSet = false;
for (size_t i = 0; i < size; ++i)
{
if (data[i] == '%')
{
++i;
if (!parsingSet && i < size && data[i] == 'b')
i += 2;
}
else if (!parsingSet && data[i] == '[')
{
parsingSet = true;
if (i + 1 < size && data[i + 1] == ']')
i += 1;
}
else if (parsingSet && data[i] == ']')
{
parsingSet = false;
}
else if (data[i] == '(')
{
if (parsingSet)
continue;
if (i + 1 < size && data[i + 1] == ')')
{
i++;
result.push_back(builtinTypes->optionalNumberType);
continue;
}
++depth;
result.push_back(builtinTypes->optionalStringType);
}
else if (data[i] == ')')
{
if (parsingSet)
continue;
--depth;
if (depth < 0)
break;
}
}
if (depth != 0 || parsingSet)
return std::vector<TypeId>();
if (result.empty())
result.push_back(builtinTypes->optionalStringType);
return result;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() != 2)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t index = expr.self ? 0 : 1;
if (expr.args.size > index)
pattern = expr.args.data[index]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
std::vector<TypeId> returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location);
const TypePackId emptyPack = arena.addTypePack({});
const TypePackId returnList = arena.addTypePack(returnTypes);
const TypeId iteratorType = arena.addType(FunctionType{emptyPack, returnList});
return WithPredicate<TypePackId>{arena.addTypePack({iteratorType})};
}
static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context)
{
const auto& [params, tail] = flatten(context.arguments);
if (params.size() != 2)
return false;
TypeArena* arena = context.solver->arena;
AstExprConstantString* pattern = nullptr;
size_t index = context.callSite->self ? 0 : 1;
if (context.callSite->args.size > index)
pattern = context.callSite->args.data[index]->as<AstExprConstantString>();
if (!pattern)
return false;
std::vector<TypeId> returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return false;
context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType);
const TypePackId emptyPack = arena->addTypePack({});
const TypePackId returnList = arena->addTypePack(returnTypes);
const TypeId iteratorType = arena->addType(FunctionType{emptyPack, returnList});
const TypePackId resTypePack = arena->addTypePack({iteratorType});
asMutable(context.result)->ty.emplace<BoundTypePack>(resTypePack);
return true;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() < 2 || params.size() > 3)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = expr.self ? 0 : 1;
if (expr.args.size > patternIndex)
pattern = expr.args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
std::vector<TypeId> returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location);
const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}});
size_t initIndex = expr.self ? 1 : 2;
if (params.size() == 3 && expr.args.size > initIndex)
typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location);
const TypePackId returnList = arena.addTypePack(returnTypes);
return WithPredicate<TypePackId>{returnList};
}
static bool dcrMagicFunctionMatch(MagicFunctionCallContext context)
{
const auto& [params, tail] = flatten(context.arguments);
if (params.size() < 2 || params.size() > 3)
return false;
TypeArena* arena = context.solver->arena;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = context.callSite->self ? 0 : 1;
if (context.callSite->args.size > patternIndex)
pattern = context.callSite->args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return false;
std::vector<TypeId> returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return false;
context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType);
const TypeId optionalNumber = arena->addType(UnionType{{context.solver->builtinTypes->nilType, context.solver->builtinTypes->numberType}});
size_t initIndex = context.callSite->self ? 1 : 2;
if (params.size() == 3 && context.callSite->args.size > initIndex)
context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber);
const TypePackId returnList = arena->addTypePack(returnTypes);
asMutable(context.result)->ty.emplace<BoundTypePack>(returnList);
return true;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() < 2 || params.size() > 4)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = expr.self ? 0 : 1;
if (expr.args.size > patternIndex)
pattern = expr.args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
bool plain = false;
size_t plainIndex = expr.self ? 2 : 3;
if (expr.args.size > plainIndex)
{
AstExprConstantBool* p = expr.args.data[plainIndex]->as<AstExprConstantBool>();
plain = p && p->value;
}
std::vector<TypeId> returnTypes;
if (!plain)
{
returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
}
typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location);
const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}});
const TypeId optionalBoolean = arena.addType(UnionType{{typechecker.nilType, typechecker.booleanType}});
size_t initIndex = expr.self ? 1 : 2;
if (params.size() >= 3 && expr.args.size > initIndex)
typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location);
if (params.size() == 4 && expr.args.size > plainIndex)
typechecker.unify(params[3], optionalBoolean, scope, expr.args.data[plainIndex]->location);
returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber});
const TypePackId returnList = arena.addTypePack(returnTypes);
return WithPredicate<TypePackId>{returnList};
}
static bool dcrMagicFunctionFind(MagicFunctionCallContext context)
{
const auto& [params, tail] = flatten(context.arguments);
if (params.size() < 2 || params.size() > 4)
return false;
TypeArena* arena = context.solver->arena;
NotNull<BuiltinTypes> builtinTypes = context.solver->builtinTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = context.callSite->self ? 0 : 1;
if (context.callSite->args.size > patternIndex)
pattern = context.callSite->args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return false;
bool plain = false;
size_t plainIndex = context.callSite->self ? 2 : 3;
if (context.callSite->args.size > plainIndex)
{
AstExprConstantBool* p = context.callSite->args.data[plainIndex]->as<AstExprConstantBool>();
plain = p && p->value;
}
std::vector<TypeId> returnTypes;
if (!plain)
{
returnTypes = parsePatternString(builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return false;
}
context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], builtinTypes->stringType);
const TypeId optionalNumber = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->numberType}});
const TypeId optionalBoolean = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->booleanType}});
size_t initIndex = context.callSite->self ? 1 : 2;
if (params.size() >= 3 && context.callSite->args.size > initIndex)
context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber);
if (params.size() == 4 && context.callSite->args.size > plainIndex)
context.solver->unify(context.solver->rootScope, context.callSite->location, params[3], optionalBoolean);
returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber});
const TypePackId returnList = arena->addTypePack(returnTypes);
asMutable(context.result)->ty.emplace<BoundTypePack>(returnList);
return true;
}
TypeId freshType(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, Scope* scope) TypeId freshType(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, Scope* scope)
{ {
return arena->addType(FreeType{scope, builtinTypes->neverType, builtinTypes->unknownType}); return arena->addType(FreeType{scope, builtinTypes->neverType, builtinTypes->unknownType});

View file

@ -38,6 +38,7 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false)
LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure) LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure)
LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false)
LUAU_FASTFLAGVARIABLE(LuauVariadicOverloadFix, false)
LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false)
LUAU_FASTFLAG(LuauParseDeclareClassIndexer) LUAU_FASTFLAG(LuauParseDeclareClassIndexer)
LUAU_FASTFLAG(LuauFloorDivision); LUAU_FASTFLAG(LuauFloorDivision);
@ -210,21 +211,6 @@ size_t HashBoolNamePair::operator()(const std::pair<bool, Name>& pair) const
return std::hash<bool>()(pair.first) ^ std::hash<Name>()(pair.second); return std::hash<bool>()(pair.first) ^ std::hash<Name>()(pair.second);
} }
GlobalTypes::GlobalTypes(NotNull<BuiltinTypes> builtinTypes)
: builtinTypes(builtinTypes)
{
globalScope = std::make_shared<Scope>(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}));
globalScope->addBuiltinTypeBinding("any", TypeFun{{}, builtinTypes->anyType});
globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, builtinTypes->nilType});
globalScope->addBuiltinTypeBinding("number", TypeFun{{}, builtinTypes->numberType});
globalScope->addBuiltinTypeBinding("string", TypeFun{{}, builtinTypes->stringType});
globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, builtinTypes->booleanType});
globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, builtinTypes->threadType});
globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, builtinTypes->unknownType});
globalScope->addBuiltinTypeBinding("never", TypeFun{{}, builtinTypes->neverType});
}
TypeChecker::TypeChecker(const ScopePtr& globalScope, ModuleResolver* resolver, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter* iceHandler) TypeChecker::TypeChecker(const ScopePtr& globalScope, ModuleResolver* resolver, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter* iceHandler)
: globalScope(globalScope) : globalScope(globalScope)
, resolver(resolver) , resolver(resolver)
@ -4038,6 +4024,12 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam
if (argIndex < argLocations.size()) if (argIndex < argLocations.size())
location = argLocations[argIndex]; location = argLocations[argIndex];
if (FFlag::LuauVariadicOverloadFix)
{
state.location = location;
state.tryUnify(*argIter, vtp->ty);
}
else
unify(*argIter, vtp->ty, scope, location); unify(*argIter, vtp->ty, scope, location);
++argIter; ++argIter;
++argIndex; ++argIndex;

View file

@ -25,7 +25,6 @@ LUAU_FASTFLAGVARIABLE(LuauOccursIsntAlwaysFailure, false)
LUAU_FASTFLAG(LuauNormalizeBlockedTypes) LUAU_FASTFLAG(LuauNormalizeBlockedTypes)
LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls) LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAGVARIABLE(LuauTableUnifyRecursionLimit, false)
namespace Luau namespace Luau
{ {
@ -2259,8 +2258,6 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection,
TableType* newSubTable = log.getMutable<TableType>(subTyNew); TableType* newSubTable = log.getMutable<TableType>(subTyNew);
if (superTable != newSuperTable || subTable != newSubTable) if (superTable != newSuperTable || subTable != newSubTable)
{
if (FFlag::LuauTableUnifyRecursionLimit)
{ {
if (errors.empty()) if (errors.empty())
{ {
@ -2270,14 +2267,6 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection,
return; return;
} }
else
{
if (errors.empty())
return tryUnifyTables(subTy, superTy, isIntersection);
else
return;
}
}
} }
for (const auto& [name, prop] : subTable->props) for (const auto& [name, prop] : subTable->props)
@ -2350,8 +2339,6 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection,
TableType* newSubTable = log.getMutable<TableType>(subTyNew); TableType* newSubTable = log.getMutable<TableType>(subTyNew);
if (superTable != newSuperTable || subTable != newSubTable) if (superTable != newSuperTable || subTable != newSubTable)
{
if (FFlag::LuauTableUnifyRecursionLimit)
{ {
if (errors.empty()) if (errors.empty())
{ {
@ -2361,14 +2348,6 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection,
return; return;
} }
else
{
if (errors.empty())
return tryUnifyTables(subTy, superTy, isIntersection);
else
return;
}
}
} }
// Unify indexers // Unify indexers

View file

@ -7,7 +7,6 @@
#include <limits.h> #include <limits.h>
LUAU_FASTFLAGVARIABLE(LuauFloorDivision, false) LUAU_FASTFLAGVARIABLE(LuauFloorDivision, false)
LUAU_FASTFLAGVARIABLE(LuauLexerConsumeFast, false)
LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false) LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false)
namespace Luau namespace Luau
@ -460,19 +459,8 @@ Position Lexer::position() const
LUAU_FORCEINLINE LUAU_FORCEINLINE
void Lexer::consume() void Lexer::consume()
{ {
if (isNewline(buffer[offset])) // consume() assumes current character is known to not be a newline; use consumeAny if this is not guaranteed
{
// TODO: When the flag is removed, remove the outer condition
if (FFlag::LuauLexerConsumeFast)
{
LUAU_ASSERT(!isNewline(buffer[offset])); LUAU_ASSERT(!isNewline(buffer[offset]));
}
else
{
line++;
lineOffset = offset + 1;
}
}
offset++; offset++;
} }

View file

@ -66,6 +66,8 @@ struct IrBuilder
bool inTerminatedBlock = false; bool inTerminatedBlock = false;
bool interruptRequested = false;
bool activeFastcallFallback = false; bool activeFastcallFallback = false;
IrOp fastcallFallbackReturn; IrOp fastcallFallbackReturn;
int fastcallSkipTarget = -1; int fastcallSkipTarget = -1;
@ -76,6 +78,8 @@ struct IrBuilder
std::vector<uint32_t> instIndexToBlock; // Block index at the bytecode instruction std::vector<uint32_t> instIndexToBlock; // Block index at the bytecode instruction
std::vector<IrOp> loopStepStack;
// Similar to BytecodeBuilder, duplicate constants are removed used the same method // Similar to BytecodeBuilder, duplicate constants are removed used the same method
struct ConstantKey struct ConstantKey
{ {

View file

@ -199,24 +199,12 @@ enum class IrCmd : uint8_t
// D: block (if false) // D: block (if false)
JUMP_EQ_TAG, JUMP_EQ_TAG,
// Jump if two int numbers are equal // Perform a conditional jump based on the result of integer comparison
// A, B: int
// C: block (if true)
// D: block (if false)
JUMP_EQ_INT,
// Jump if A < B
// A, B: int
// C: block (if true)
// D: block (if false)
JUMP_LT_INT,
// Jump if unsigned(A) >= unsigned(B)
// A, B: int // A, B: int
// C: condition // C: condition
// D: block (if true) // D: block (if true)
// E: block (if false) // E: block (if false)
JUMP_GE_UINT, JUMP_CMP_INT,
// Jump if pointers are equal // Jump if pointers are equal
// A, B: pointer (*) // A, B: pointer (*)

View file

@ -94,9 +94,7 @@ inline bool isBlockTerminator(IrCmd cmd)
case IrCmd::JUMP_IF_TRUTHY: case IrCmd::JUMP_IF_TRUTHY:
case IrCmd::JUMP_IF_FALSY: case IrCmd::JUMP_IF_FALSY:
case IrCmd::JUMP_EQ_TAG: case IrCmd::JUMP_EQ_TAG:
case IrCmd::JUMP_EQ_INT: case IrCmd::JUMP_CMP_INT:
case IrCmd::JUMP_LT_INT:
case IrCmd::JUMP_GE_UINT:
case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_EQ_POINTER:
case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_CMP_NUM:
case IrCmd::JUMP_SLOT_MATCH: case IrCmd::JUMP_SLOT_MATCH:

View file

@ -12,6 +12,8 @@
#include "lgc.h" #include "lgc.h"
#include "lstate.h" #include "lstate.h"
#include <utility>
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
@ -22,10 +24,15 @@ namespace X64
void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label) void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label)
{ {
// Refresher on comi/ucomi EFLAGS: // Refresher on comi/ucomi EFLAGS:
// all zero: greater
// CF only: less // CF only: less
// ZF only: equal // ZF only: equal
// PF+CF+ZF: unordered (NaN) // PF+CF+ZF: unordered (NaN)
// To avoid the lack of conditional jumps that check for "greater" conditions in IEEE 754 compliant way, we use "less" forms to emulate these
if (cond == IrCondition::Greater || cond == IrCondition::GreaterEqual || cond == IrCondition::NotGreater || cond == IrCondition::NotGreaterEqual)
std::swap(lhs, rhs);
if (rhs.cat == CategoryX64::reg) if (rhs.cat == CategoryX64::reg)
{ {
build.vucomisd(rhs, lhs); build.vucomisd(rhs, lhs);
@ -41,18 +48,22 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs,
switch (cond) switch (cond)
{ {
case IrCondition::NotLessEqual: case IrCondition::NotLessEqual:
case IrCondition::NotGreaterEqual:
// (b < a) is the same as !(a <= b). jnae checks CF=1 which means < or NaN // (b < a) is the same as !(a <= b). jnae checks CF=1 which means < or NaN
build.jcc(ConditionX64::NotAboveEqual, label); build.jcc(ConditionX64::NotAboveEqual, label);
break; break;
case IrCondition::LessEqual: case IrCondition::LessEqual:
case IrCondition::GreaterEqual:
// (b >= a) is the same as (a <= b). jae checks CF=0 which means >= and not NaN // (b >= a) is the same as (a <= b). jae checks CF=0 which means >= and not NaN
build.jcc(ConditionX64::AboveEqual, label); build.jcc(ConditionX64::AboveEqual, label);
break; break;
case IrCondition::NotLess: case IrCondition::NotLess:
case IrCondition::NotGreater:
// (b <= a) is the same as !(a < b). jna checks CF=1 or ZF=1 which means <= or NaN // (b <= a) is the same as !(a < b). jna checks CF=1 or ZF=1 which means <= or NaN
build.jcc(ConditionX64::NotAbove, label); build.jcc(ConditionX64::NotAbove, label);
break; break;
case IrCondition::Less: case IrCondition::Less:
case IrCondition::Greater:
// (b > a) is the same as (a < b). ja checks CF=0 and ZF=0 which means > and not NaN // (b > a) is the same as (a < b). ja checks CF=0 and ZF=0 which means > and not NaN
build.jcc(ConditionX64::Above, label); build.jcc(ConditionX64::Above, label);
break; break;
@ -66,6 +77,44 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs,
} }
} }
ConditionX64 getConditionInt(IrCondition cond)
{
switch (cond)
{
case IrCondition::Equal:
return ConditionX64::Equal;
case IrCondition::NotEqual:
return ConditionX64::NotEqual;
case IrCondition::Less:
return ConditionX64::Less;
case IrCondition::NotLess:
return ConditionX64::NotLess;
case IrCondition::LessEqual:
return ConditionX64::LessEqual;
case IrCondition::NotLessEqual:
return ConditionX64::NotLessEqual;
case IrCondition::Greater:
return ConditionX64::Greater;
case IrCondition::NotGreater:
return ConditionX64::NotGreater;
case IrCondition::GreaterEqual:
return ConditionX64::GreaterEqual;
case IrCondition::NotGreaterEqual:
return ConditionX64::NotGreaterEqual;
case IrCondition::UnsignedLess:
return ConditionX64::Below;
case IrCondition::UnsignedLessEqual:
return ConditionX64::BelowEqual;
case IrCondition::UnsignedGreater:
return ConditionX64::Above;
case IrCondition::UnsignedGreaterEqual:
return ConditionX64::AboveEqual;
default:
LUAU_ASSERT(!"Unsupported condition");
return ConditionX64::Zero;
}
}
void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos) void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos)
{ {
LUAU_ASSERT(tmp != node); LUAU_ASSERT(tmp != node);

View file

@ -195,6 +195,8 @@ inline void jumpIfTruthy(AssemblyBuilderX64& build, int ri, Label& target, Label
void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label); void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label);
ConditionX64 getConditionInt(IrCondition cond);
void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos); void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos);
void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, Label& label); void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, Label& label);

View file

@ -149,6 +149,12 @@ void IrBuilder::buildFunctionIr(Proto* proto)
// We skip dead bytecode instructions when they appear after block was already terminated // We skip dead bytecode instructions when they appear after block was already terminated
if (!inTerminatedBlock) if (!inTerminatedBlock)
{ {
if (interruptRequested)
{
interruptRequested = false;
inst(IrCmd::INTERRUPT, constUint(i));
}
translateInst(op, pc, i); translateInst(op, pc, i);
if (fastcallSkipTarget != -1) if (fastcallSkipTarget != -1)

View file

@ -157,12 +157,8 @@ const char* getCmdName(IrCmd cmd)
return "JUMP_IF_FALSY"; return "JUMP_IF_FALSY";
case IrCmd::JUMP_EQ_TAG: case IrCmd::JUMP_EQ_TAG:
return "JUMP_EQ_TAG"; return "JUMP_EQ_TAG";
case IrCmd::JUMP_EQ_INT: case IrCmd::JUMP_CMP_INT:
return "JUMP_EQ_INT"; return "JUMP_CMP_INT";
case IrCmd::JUMP_LT_INT:
return "JUMP_LT_INT";
case IrCmd::JUMP_GE_UINT:
return "JUMP_GE_UINT";
case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_EQ_POINTER:
return "JUMP_EQ_POINTER"; return "JUMP_EQ_POINTER";
case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_CMP_NUM:

View file

@ -58,6 +58,58 @@ inline ConditionA64 getConditionFP(IrCondition cond)
} }
} }
inline ConditionA64 getConditionInt(IrCondition cond)
{
switch (cond)
{
case IrCondition::Equal:
return ConditionA64::Equal;
case IrCondition::NotEqual:
return ConditionA64::NotEqual;
case IrCondition::Less:
return ConditionA64::Minus;
case IrCondition::NotLess:
return ConditionA64::Plus;
case IrCondition::LessEqual:
return ConditionA64::LessEqual;
case IrCondition::NotLessEqual:
return ConditionA64::Greater;
case IrCondition::Greater:
return ConditionA64::Greater;
case IrCondition::NotGreater:
return ConditionA64::LessEqual;
case IrCondition::GreaterEqual:
return ConditionA64::GreaterEqual;
case IrCondition::NotGreaterEqual:
return ConditionA64::Less;
case IrCondition::UnsignedLess:
return ConditionA64::CarryClear;
case IrCondition::UnsignedLessEqual:
return ConditionA64::UnsignedLessEqual;
case IrCondition::UnsignedGreater:
return ConditionA64::UnsignedGreater;
case IrCondition::UnsignedGreaterEqual:
return ConditionA64::CarrySet;
default:
LUAU_ASSERT(!"Unexpected condition code");
return ConditionA64::Always;
}
}
static void emitAddOffset(AssemblyBuilderA64& build, RegisterA64 dst, RegisterA64 src, size_t offset) static void emitAddOffset(AssemblyBuilderA64& build, RegisterA64 dst, RegisterA64 src, size_t offset)
{ {
LUAU_ASSERT(dst != src); LUAU_ASSERT(dst != src);
@ -714,31 +766,25 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
} }
break; break;
} }
case IrCmd::JUMP_EQ_INT: case IrCmd::JUMP_CMP_INT:
if (intOp(inst.b) == 0)
{ {
build.cbz(regOp(inst.a), labelOp(inst.c)); IrCondition cond = conditionOp(inst.c);
if (cond == IrCondition::Equal && intOp(inst.b) == 0)
{
build.cbz(regOp(inst.a), labelOp(inst.d));
}
else if (cond == IrCondition::NotEqual && intOp(inst.b) == 0)
{
build.cbnz(regOp(inst.a), labelOp(inst.d));
} }
else else
{ {
LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate);
build.cmp(regOp(inst.a), uint16_t(intOp(inst.b))); build.cmp(regOp(inst.a), uint16_t(intOp(inst.b)));
build.b(ConditionA64::Equal, labelOp(inst.c)); build.b(getConditionInt(cond), labelOp(inst.d));
} }
jumpOrFallthrough(blockOp(inst.d), next); jumpOrFallthrough(blockOp(inst.e), next);
break;
case IrCmd::JUMP_LT_INT:
LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate);
build.cmp(regOp(inst.a), uint16_t(intOp(inst.b)));
build.b(ConditionA64::Less, labelOp(inst.c));
jumpOrFallthrough(blockOp(inst.d), next);
break;
case IrCmd::JUMP_GE_UINT:
{
LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate);
build.cmp(regOp(inst.a), uint16_t(unsigned(intOp(inst.b))));
build.b(ConditionA64::CarrySet, labelOp(inst.c));
jumpOrFallthrough(blockOp(inst.d), next);
break; break;
} }
case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_EQ_POINTER:

View file

@ -655,42 +655,36 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
} }
break; break;
} }
case IrCmd::JUMP_EQ_INT: case IrCmd::JUMP_CMP_INT:
if (intOp(inst.b) == 0)
{ {
IrCondition cond = conditionOp(inst.c);
if ((cond == IrCondition::Equal || cond == IrCondition::NotEqual) && intOp(inst.b) == 0)
{
bool invert = cond == IrCondition::NotEqual;
build.test(regOp(inst.a), regOp(inst.a)); build.test(regOp(inst.a), regOp(inst.a));
if (isFallthroughBlock(blockOp(inst.c), next)) if (isFallthroughBlock(blockOp(inst.d), next))
{ {
build.jcc(ConditionX64::NotZero, labelOp(inst.d)); build.jcc(invert ? ConditionX64::Zero : ConditionX64::NotZero, labelOp(inst.e));
jumpOrFallthrough(blockOp(inst.c), next); jumpOrFallthrough(blockOp(inst.d), next);
} }
else else
{ {
build.jcc(ConditionX64::Zero, labelOp(inst.c)); build.jcc(invert ? ConditionX64::NotZero : ConditionX64::Zero, labelOp(inst.d));
jumpOrFallthrough(blockOp(inst.d), next); jumpOrFallthrough(blockOp(inst.e), next);
} }
} }
else else
{ {
build.cmp(regOp(inst.a), intOp(inst.b)); build.cmp(regOp(inst.a), intOp(inst.b));
build.jcc(ConditionX64::Equal, labelOp(inst.c)); build.jcc(getConditionInt(cond), labelOp(inst.d));
jumpOrFallthrough(blockOp(inst.d), next); jumpOrFallthrough(blockOp(inst.e), next);
} }
break; break;
case IrCmd::JUMP_LT_INT: }
build.cmp(regOp(inst.a), intOp(inst.b));
build.jcc(ConditionX64::Less, labelOp(inst.c));
jumpOrFallthrough(blockOp(inst.d), next);
break;
case IrCmd::JUMP_GE_UINT:
build.cmp(regOp(inst.a), unsigned(intOp(inst.b)));
build.jcc(ConditionX64::AboveEqual, labelOp(inst.c));
jumpOrFallthrough(blockOp(inst.d), next);
break;
case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_EQ_POINTER:
build.cmp(regOp(inst.a), regOp(inst.b)); build.cmp(regOp(inst.a), regOp(inst.b));
@ -703,7 +697,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
ScopedRegX64 tmp{regs, SizeX64::xmmword}; ScopedRegX64 tmp{regs, SizeX64::xmmword};
// TODO: jumpOnNumberCmp should work on IrCondition directly
jumpOnNumberCmp(build, tmp.reg, memRegDoubleOp(inst.a), memRegDoubleOp(inst.b), cond, labelOp(inst.d)); jumpOnNumberCmp(build, tmp.reg, memRegDoubleOp(inst.a), memRegDoubleOp(inst.b), cond, labelOp(inst.d));
jumpOrFallthrough(blockOp(inst.e), next); jumpOrFallthrough(blockOp(inst.e), next);
break; break;

View file

@ -411,7 +411,7 @@ static BuiltinImplResult translateBuiltinBit32BinaryOp(
IrOp falsey = build.block(IrBlockKind::Internal); IrOp falsey = build.block(IrBlockKind::Internal);
IrOp truthy = build.block(IrBlockKind::Internal); IrOp truthy = build.block(IrBlockKind::Internal);
IrOp exit = build.block(IrBlockKind::Internal); IrOp exit = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_EQ_INT, res, build.constInt(0), falsey, truthy); build.inst(IrCmd::JUMP_CMP_INT, res, build.constInt(0), build.cond(IrCondition::Equal), falsey, truthy);
build.beginBlock(falsey); build.beginBlock(falsey);
build.inst(IrCmd::STORE_INT, build.vmReg(ra), build.constInt(0)); build.inst(IrCmd::STORE_INT, build.vmReg(ra), build.constInt(0));
@ -484,7 +484,7 @@ static BuiltinImplResult translateBuiltinBit32Shift(
if (!knownGoodShift) if (!knownGoodShift)
{ {
IrOp block = build.block(IrBlockKind::Internal); IrOp block = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_GE_UINT, vbi, build.constInt(32), fallback, block); build.inst(IrCmd::JUMP_CMP_INT, vbi, build.constInt(32), build.cond(IrCondition::UnsignedGreaterEqual), fallback, block);
build.beginBlock(block); build.beginBlock(block);
} }
@ -549,36 +549,56 @@ static BuiltinImplResult translateBuiltinBit32Extract(
IrOp vb = builtinLoadDouble(build, args); IrOp vb = builtinLoadDouble(build, args);
IrOp n = build.inst(IrCmd::NUM_TO_UINT, va); IrOp n = build.inst(IrCmd::NUM_TO_UINT, va);
IrOp f = build.inst(IrCmd::NUM_TO_INT, vb);
IrOp value; IrOp value;
if (nparams == 2) if (nparams == 2)
{ {
IrOp block = build.block(IrBlockKind::Internal); if (vb.kind == IrOpKind::Constant)
build.inst(IrCmd::JUMP_GE_UINT, f, build.constInt(32), fallback, block); {
build.beginBlock(block); int f = int(build.function.doubleOp(vb));
// TODO: this can be optimized using a bit-select instruction (bt on x86) if (unsigned(f) >= 32)
IrOp shift = build.inst(IrCmd::BITRSHIFT_UINT, n, f); build.inst(IrCmd::JUMP, fallback);
value = build.inst(IrCmd::BITAND_UINT, shift, build.constInt(1));
// TODO: this pair can be optimized using a bit-select instruction (bt on x86)
if (f)
value = build.inst(IrCmd::BITRSHIFT_UINT, n, build.constInt(f));
if ((f + 1) < 32)
value = build.inst(IrCmd::BITAND_UINT, value, build.constInt(1));
} }
else else
{ {
IrOp f = build.inst(IrCmd::NUM_TO_INT, vb);
IrOp block = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_CMP_INT, f, build.constInt(32), build.cond(IrCondition::UnsignedGreaterEqual), fallback, block);
build.beginBlock(block);
// TODO: this pair can be optimized using a bit-select instruction (bt on x86)
IrOp shift = build.inst(IrCmd::BITRSHIFT_UINT, n, f);
value = build.inst(IrCmd::BITAND_UINT, shift, build.constInt(1));
}
}
else
{
IrOp f = build.inst(IrCmd::NUM_TO_INT, vb);
builtinCheckDouble(build, build.vmReg(args.index + 1), pcpos); builtinCheckDouble(build, build.vmReg(args.index + 1), pcpos);
IrOp vc = builtinLoadDouble(build, build.vmReg(args.index + 1)); IrOp vc = builtinLoadDouble(build, build.vmReg(args.index + 1));
IrOp w = build.inst(IrCmd::NUM_TO_INT, vc); IrOp w = build.inst(IrCmd::NUM_TO_INT, vc);
IrOp block1 = build.block(IrBlockKind::Internal); IrOp block1 = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_LT_INT, f, build.constInt(0), fallback, block1); build.inst(IrCmd::JUMP_CMP_INT, f, build.constInt(0), build.cond(IrCondition::Less), fallback, block1);
build.beginBlock(block1); build.beginBlock(block1);
IrOp block2 = build.block(IrBlockKind::Internal); IrOp block2 = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_LT_INT, w, build.constInt(1), fallback, block2); build.inst(IrCmd::JUMP_CMP_INT, w, build.constInt(1), build.cond(IrCondition::Less), fallback, block2);
build.beginBlock(block2); build.beginBlock(block2);
IrOp block3 = build.block(IrBlockKind::Internal); IrOp block3 = build.block(IrBlockKind::Internal);
IrOp fw = build.inst(IrCmd::ADD_INT, f, w); IrOp fw = build.inst(IrCmd::ADD_INT, f, w);
build.inst(IrCmd::JUMP_LT_INT, fw, build.constInt(33), block3, fallback); build.inst(IrCmd::JUMP_CMP_INT, fw, build.constInt(33), build.cond(IrCondition::Less), block3, fallback);
build.beginBlock(block3); build.beginBlock(block3);
IrOp shift = build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1))); IrOp shift = build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1)));
@ -615,10 +635,15 @@ static BuiltinImplResult translateBuiltinBit32ExtractK(
uint32_t m = ~(0xfffffffeu << w1); uint32_t m = ~(0xfffffffeu << w1);
IrOp nf = build.inst(IrCmd::BITRSHIFT_UINT, n, build.constInt(f)); IrOp result = n;
IrOp and_ = build.inst(IrCmd::BITAND_UINT, nf, build.constInt(m));
IrOp value = build.inst(IrCmd::UINT_TO_NUM, and_); if (f)
result = build.inst(IrCmd::BITRSHIFT_UINT, result, build.constInt(f));
if ((f + w1 + 1) < 32)
result = build.inst(IrCmd::BITAND_UINT, result, build.constInt(m));
IrOp value = build.inst(IrCmd::UINT_TO_NUM, result);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value);
if (ra != arg) if (ra != arg)
@ -673,7 +698,7 @@ static BuiltinImplResult translateBuiltinBit32Replace(
if (nparams == 3) if (nparams == 3)
{ {
IrOp block = build.block(IrBlockKind::Internal); IrOp block = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_GE_UINT, f, build.constInt(32), fallback, block); build.inst(IrCmd::JUMP_CMP_INT, f, build.constInt(32), build.cond(IrCondition::UnsignedGreaterEqual), fallback, block);
build.beginBlock(block); build.beginBlock(block);
// TODO: this can be optimized using a bit-select instruction (btr on x86) // TODO: this can be optimized using a bit-select instruction (btr on x86)
@ -694,16 +719,16 @@ static BuiltinImplResult translateBuiltinBit32Replace(
IrOp w = build.inst(IrCmd::NUM_TO_INT, vd); IrOp w = build.inst(IrCmd::NUM_TO_INT, vd);
IrOp block1 = build.block(IrBlockKind::Internal); IrOp block1 = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_LT_INT, f, build.constInt(0), fallback, block1); build.inst(IrCmd::JUMP_CMP_INT, f, build.constInt(0), build.cond(IrCondition::Less), fallback, block1);
build.beginBlock(block1); build.beginBlock(block1);
IrOp block2 = build.block(IrBlockKind::Internal); IrOp block2 = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_LT_INT, w, build.constInt(1), fallback, block2); build.inst(IrCmd::JUMP_CMP_INT, w, build.constInt(1), build.cond(IrCondition::Less), fallback, block2);
build.beginBlock(block2); build.beginBlock(block2);
IrOp block3 = build.block(IrBlockKind::Internal); IrOp block3 = build.block(IrBlockKind::Internal);
IrOp fw = build.inst(IrCmd::ADD_INT, f, w); IrOp fw = build.inst(IrCmd::ADD_INT, f, w);
build.inst(IrCmd::JUMP_LT_INT, fw, build.constInt(33), block3, fallback); build.inst(IrCmd::JUMP_CMP_INT, fw, build.constInt(33), build.cond(IrCondition::Less), block3, fallback);
build.beginBlock(block3); build.beginBlock(block3);
IrOp shift1 = build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1))); IrOp shift1 = build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1)));

View file

@ -12,6 +12,8 @@
#include "lstate.h" #include "lstate.h"
#include "ltm.h" #include "ltm.h"
LUAU_FASTFLAGVARIABLE(LuauImproveForN, false)
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
@ -170,7 +172,7 @@ void translateInstJumpIfEq(IrBuilder& build, const Instruction* pc, int pcpos, b
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
IrOp result = build.inst(IrCmd::CMP_ANY, build.vmReg(ra), build.vmReg(rb), build.cond(IrCondition::Equal)); IrOp result = build.inst(IrCmd::CMP_ANY, build.vmReg(ra), build.vmReg(rb), build.cond(IrCondition::Equal));
build.inst(IrCmd::JUMP_EQ_INT, result, build.constInt(0), not_ ? target : next, not_ ? next : target); build.inst(IrCmd::JUMP_CMP_INT, result, build.constInt(0), build.cond(IrCondition::Equal), not_ ? target : next, not_ ? next : target);
build.beginBlock(next); build.beginBlock(next);
} }
@ -218,7 +220,7 @@ void translateInstJumpIfCond(IrBuilder& build, const Instruction* pc, int pcpos,
} }
IrOp result = build.inst(IrCmd::CMP_ANY, build.vmReg(ra), build.vmReg(rb), build.cond(cond)); IrOp result = build.inst(IrCmd::CMP_ANY, build.vmReg(ra), build.vmReg(rb), build.cond(cond));
build.inst(IrCmd::JUMP_EQ_INT, result, build.constInt(0), reverse ? target : next, reverse ? next : target); build.inst(IrCmd::JUMP_CMP_INT, result, build.constInt(0), build.cond(IrCondition::Equal), reverse ? target : next, reverse ? next : target);
build.beginBlock(next); build.beginBlock(next);
} }
@ -262,7 +264,7 @@ void translateInstJumpxEqB(IrBuilder& build, const Instruction* pc, int pcpos)
build.beginBlock(checkValue); build.beginBlock(checkValue);
IrOp va = build.inst(IrCmd::LOAD_INT, build.vmReg(ra)); IrOp va = build.inst(IrCmd::LOAD_INT, build.vmReg(ra));
build.inst(IrCmd::JUMP_EQ_INT, va, build.constInt(aux & 0x1), not_ ? next : target, not_ ? target : next); build.inst(IrCmd::JUMP_CMP_INT, va, build.constInt(aux & 0x1), build.cond(IrCondition::Equal), not_ ? next : target, not_ ? target : next);
// Fallthrough in original bytecode is implicit, so we start next internal block here // Fallthrough in original bytecode is implicit, so we start next internal block here
if (build.isInternalBlock(next)) if (build.isInternalBlock(next))
@ -607,6 +609,27 @@ IrOp translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool
return fallback; return fallback;
} }
// numeric for loop always ends with the computation of step that targets ra+1
// any conditionals would result in a split basic block, so we can recover the step constants by pattern matching the IR we generated for LOADN/K
static IrOp getLoopStepK(IrBuilder& build, int ra)
{
IrBlock& active = build.function.blocks[build.activeBlockIdx];
if (active.start + 2 < build.function.instructions.size())
{
IrInst& sv = build.function.instructions[build.function.instructions.size() - 2];
IrInst& st = build.function.instructions[build.function.instructions.size() - 1];
// We currently expect to match IR generated from LOADN/LOADK so we match a particular sequence of opcodes
// In the future this can be extended to cover opposite STORE order as well as STORE_SPLIT_TVALUE
if (sv.cmd == IrCmd::STORE_DOUBLE && sv.a.kind == IrOpKind::VmReg && sv.a.index == ra + 1 && sv.b.kind == IrOpKind::Constant &&
st.cmd == IrCmd::STORE_TAG && st.a.kind == IrOpKind::VmReg && st.a.index == ra + 1 && build.function.tagOp(st.b) == LUA_TNUMBER)
return sv.b;
}
return build.undef();
}
void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos) void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
@ -614,6 +637,62 @@ void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos)
IrOp loopStart = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc)))); IrOp loopStart = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc))));
IrOp loopExit = build.blockAtInst(getJumpTarget(*pc, pcpos)); IrOp loopExit = build.blockAtInst(getJumpTarget(*pc, pcpos));
if (FFlag::LuauImproveForN)
{
IrOp stepK = getLoopStepK(build, ra);
build.loopStepStack.push_back(stepK);
// When loop parameters are not numbers, VM tries to perform type coercion from string and raises an exception if that fails
// Performing that fallback in native code increases code size and complicates CFG, obscuring the values when they are constant
// To avoid that overhead for an extremely rare case (that doesn't even typecheck), we exit to VM to handle it
IrOp tagLimit = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 0));
build.inst(IrCmd::CHECK_TAG, tagLimit, build.constTag(LUA_TNUMBER), build.vmExit(pcpos));
IrOp tagIdx = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2));
build.inst(IrCmd::CHECK_TAG, tagIdx, build.constTag(LUA_TNUMBER), build.vmExit(pcpos));
IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0));
IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2));
if (stepK.kind == IrOpKind::Undef)
{
IrOp tagStep = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1));
build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), build.vmExit(pcpos));
IrOp direct = build.block(IrBlockKind::Internal);
IrOp reverse = build.block(IrBlockKind::Internal);
IrOp zero = build.constDouble(0.0);
IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1));
// step > 0
// note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64
build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::Greater), direct, reverse);
// Condition to start the loop: step > 0 ? idx <= limit : limit <= idx
// We invert the condition so that loopStart is the fallthrough (false) label
// step > 0 is false, check limit <= idx
build.beginBlock(reverse);
build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::NotLessEqual), loopExit, loopStart);
// step > 0 is true, check idx <= limit
build.beginBlock(direct);
build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::NotLessEqual), loopExit, loopStart);
}
else
{
double stepN = build.function.doubleOp(stepK);
// Condition to start the loop: step > 0 ? idx <= limit : limit <= idx
// We invert the condition so that loopStart is the fallthrough (false) label
if (stepN > 0)
build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::NotLessEqual), loopExit, loopStart);
else
build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::NotLessEqual), loopExit, loopStart);
}
}
else
{
IrOp direct = build.block(IrBlockKind::Internal); IrOp direct = build.block(IrBlockKind::Internal);
IrOp reverse = build.block(IrBlockKind::Internal); IrOp reverse = build.block(IrBlockKind::Internal);
@ -644,10 +723,17 @@ void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos)
// step <= 0 is true, check limit <= idx // step <= 0 is true, check limit <= idx
build.beginBlock(reverse); build.beginBlock(reverse);
build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopStart, loopExit); build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopStart, loopExit);
}
// Fallthrough in original bytecode is implicit, so we start next internal block here // Fallthrough in original bytecode is implicit, so we start next internal block here
if (build.isInternalBlock(loopStart)) if (build.isInternalBlock(loopStart))
build.beginBlock(loopStart); build.beginBlock(loopStart);
// VM places interrupt in FORNLOOP, but that creates a likely spill point for short loops that use loop index as INTERRUPT always spills
// We place the interrupt at the beginning of the loop body instead; VM uses FORNLOOP because it doesn't want to waste an extra instruction.
// Because loop block may not have been started yet (as it's started when lowering the first instruction!), we need to defer INTERRUPT placement.
if (FFlag::LuauImproveForN)
build.interruptRequested = true;
} }
void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos) void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos)
@ -657,6 +743,52 @@ void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos)
IrOp loopRepeat = build.blockAtInst(getJumpTarget(*pc, pcpos)); IrOp loopRepeat = build.blockAtInst(getJumpTarget(*pc, pcpos));
IrOp loopExit = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc)))); IrOp loopExit = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc))));
if (FFlag::LuauImproveForN)
{
LUAU_ASSERT(!build.loopStepStack.empty());
IrOp stepK = build.loopStepStack.back();
build.loopStepStack.pop_back();
IrOp zero = build.constDouble(0.0);
IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0));
IrOp step = stepK.kind == IrOpKind::Undef ? build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)) : stepK;
IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2));
idx = build.inst(IrCmd::ADD_NUM, idx, step);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra + 2), idx);
if (stepK.kind == IrOpKind::Undef)
{
IrOp direct = build.block(IrBlockKind::Internal);
IrOp reverse = build.block(IrBlockKind::Internal);
// step > 0
// note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64
build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::Greater), direct, reverse);
// Condition to continue the loop: step > 0 ? idx <= limit : limit <= idx
// step > 0 is false, check limit <= idx
build.beginBlock(reverse);
build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit);
// step > 0 is true, check idx <= limit
build.beginBlock(direct);
build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit);
}
else
{
double stepN = build.function.doubleOp(stepK);
// Condition to continue the loop: step > 0 ? idx <= limit : limit <= idx
if (stepN > 0)
build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit);
else
build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit);
}
}
else
{
build.inst(IrCmd::INTERRUPT, build.constUint(pcpos)); build.inst(IrCmd::INTERRUPT, build.constUint(pcpos));
IrOp zero = build.constDouble(0.0); IrOp zero = build.constDouble(0.0);
@ -680,6 +812,7 @@ void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos)
// step <= 0 is true, check limit <= idx // step <= 0 is true, check limit <= idx
build.beginBlock(reverse); build.beginBlock(reverse);
build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit);
}
// Fallthrough in original bytecode is implicit, so we start next internal block here // Fallthrough in original bytecode is implicit, so we start next internal block here
if (build.isInternalBlock(loopExit)) if (build.isInternalBlock(loopExit))

View file

@ -72,9 +72,7 @@ IrValueKind getCmdValueKind(IrCmd cmd)
case IrCmd::JUMP_IF_TRUTHY: case IrCmd::JUMP_IF_TRUTHY:
case IrCmd::JUMP_IF_FALSY: case IrCmd::JUMP_IF_FALSY:
case IrCmd::JUMP_EQ_TAG: case IrCmd::JUMP_EQ_TAG:
case IrCmd::JUMP_EQ_INT: case IrCmd::JUMP_CMP_INT:
case IrCmd::JUMP_LT_INT:
case IrCmd::JUMP_GE_UINT:
case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_EQ_POINTER:
case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_CMP_NUM:
case IrCmd::JUMP_SLOT_MATCH: case IrCmd::JUMP_SLOT_MATCH:
@ -422,6 +420,45 @@ bool compare(double a, double b, IrCondition cond)
return false; return false;
} }
bool compare(int a, int b, IrCondition cond)
{
switch (cond)
{
case IrCondition::Equal:
return a == b;
case IrCondition::NotEqual:
return a != b;
case IrCondition::Less:
return a < b;
case IrCondition::NotLess:
return !(a < b);
case IrCondition::LessEqual:
return a <= b;
case IrCondition::NotLessEqual:
return !(a <= b);
case IrCondition::Greater:
return a > b;
case IrCondition::NotGreater:
return !(a > b);
case IrCondition::GreaterEqual:
return a >= b;
case IrCondition::NotGreaterEqual:
return !(a >= b);
case IrCondition::UnsignedLess:
return unsigned(a) < unsigned(b);
case IrCondition::UnsignedLessEqual:
return unsigned(a) <= unsigned(b);
case IrCondition::UnsignedGreater:
return unsigned(a) > unsigned(b);
case IrCondition::UnsignedGreaterEqual:
return unsigned(a) >= unsigned(b);
default:
LUAU_ASSERT(!"Unsupported condition");
}
return false;
}
void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint32_t index) void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint32_t index)
{ {
IrInst& inst = function.instructions[index]; IrInst& inst = function.instructions[index];
@ -540,31 +577,13 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3
replace(function, block, index, {IrCmd::JUMP, inst.d}); replace(function, block, index, {IrCmd::JUMP, inst.d});
} }
break; break;
case IrCmd::JUMP_EQ_INT: case IrCmd::JUMP_CMP_INT:
if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant)
{ {
if (function.intOp(inst.a) == function.intOp(inst.b)) if (compare(function.intOp(inst.a), function.intOp(inst.b), conditionOp(inst.c)))
replace(function, block, index, {IrCmd::JUMP, inst.c});
else
replace(function, block, index, {IrCmd::JUMP, inst.d}); replace(function, block, index, {IrCmd::JUMP, inst.d});
}
break;
case IrCmd::JUMP_LT_INT:
if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant)
{
if (function.intOp(inst.a) < function.intOp(inst.b))
replace(function, block, index, {IrCmd::JUMP, inst.c});
else else
replace(function, block, index, {IrCmd::JUMP, inst.d}); replace(function, block, index, {IrCmd::JUMP, inst.e});
}
break;
case IrCmd::JUMP_GE_UINT:
if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant)
{
if (unsigned(function.intOp(inst.a)) >= unsigned(function.intOp(inst.b)))
replace(function, block, index, {IrCmd::JUMP, inst.c});
else
replace(function, block, index, {IrCmd::JUMP, inst.d});
} }
break; break;
case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_CMP_NUM:

View file

@ -17,6 +17,7 @@ LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64)
LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false)
LUAU_FASTFLAGVARIABLE(LuauReuseHashSlots2, false) LUAU_FASTFLAGVARIABLE(LuauReuseHashSlots2, false)
LUAU_FASTFLAGVARIABLE(LuauKeepVmapLinear, false) LUAU_FASTFLAGVARIABLE(LuauKeepVmapLinear, false)
LUAU_FASTFLAGVARIABLE(LuauMergeTagLoads, false)
namespace Luau namespace Luau
{ {
@ -502,9 +503,16 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
{ {
case IrCmd::LOAD_TAG: case IrCmd::LOAD_TAG:
if (uint8_t tag = state.tryGetTag(inst.a); tag != 0xff) if (uint8_t tag = state.tryGetTag(inst.a); tag != 0xff)
{
substitute(function, inst, build.constTag(tag)); substitute(function, inst, build.constTag(tag));
}
else if (inst.a.kind == IrOpKind::VmReg) else if (inst.a.kind == IrOpKind::VmReg)
{
if (FFlag::LuauMergeTagLoads)
state.substituteOrRecordVmRegLoad(inst);
else
state.createRegLink(index, inst.a); state.createRegLink(index, inst.a);
}
break; break;
case IrCmd::LOAD_POINTER: case IrCmd::LOAD_POINTER:
if (inst.a.kind == IrOpKind::VmReg) if (inst.a.kind == IrOpKind::VmReg)
@ -716,44 +724,20 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
else else
replace(function, block, index, {IrCmd::JUMP, inst.d}); replace(function, block, index, {IrCmd::JUMP, inst.d});
} }
else if (FFlag::LuauMergeTagLoads && inst.a == inst.b)
{
replace(function, block, index, {IrCmd::JUMP, inst.c});
}
break; break;
} }
case IrCmd::JUMP_EQ_INT: case IrCmd::JUMP_CMP_INT:
{ {
std::optional<int> valueA = function.asIntOp(inst.a.kind == IrOpKind::Constant ? inst.a : state.tryGetValue(inst.a)); std::optional<int> valueA = function.asIntOp(inst.a.kind == IrOpKind::Constant ? inst.a : state.tryGetValue(inst.a));
std::optional<int> valueB = function.asIntOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)); std::optional<int> valueB = function.asIntOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b));
if (valueA && valueB) if (valueA && valueB)
{ {
if (*valueA == *valueB) if (compare(*valueA, *valueB, conditionOp(inst.c)))
replace(function, block, index, {IrCmd::JUMP, inst.c});
else
replace(function, block, index, {IrCmd::JUMP, inst.d});
}
break;
}
case IrCmd::JUMP_LT_INT:
{
std::optional<int> valueA = function.asIntOp(inst.a.kind == IrOpKind::Constant ? inst.a : state.tryGetValue(inst.a));
std::optional<int> valueB = function.asIntOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b));
if (valueA && valueB)
{
if (*valueA < *valueB)
replace(function, block, index, {IrCmd::JUMP, inst.c});
else
replace(function, block, index, {IrCmd::JUMP, inst.d});
}
break;
}
case IrCmd::JUMP_GE_UINT:
{
std::optional<unsigned> valueA = function.asUintOp(inst.a.kind == IrOpKind::Constant ? inst.a : state.tryGetValue(inst.a));
std::optional<unsigned> valueB = function.asUintOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b));
if (valueA && valueB)
{
if (*valueA >= *valueB)
replace(function, block, index, {IrCmd::JUMP, inst.c}); replace(function, block, index, {IrCmd::JUMP, inst.c});
else else
replace(function, block, index, {IrCmd::JUMP, inst.d}); replace(function, block, index, {IrCmd::JUMP, inst.d});

View file

@ -167,6 +167,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/Error.h Analysis/include/Luau/Error.h
Analysis/include/Luau/FileResolver.h Analysis/include/Luau/FileResolver.h
Analysis/include/Luau/Frontend.h Analysis/include/Luau/Frontend.h
Analysis/include/Luau/GlobalTypes.h
Analysis/include/Luau/InsertionOrderedMap.h Analysis/include/Luau/InsertionOrderedMap.h
Analysis/include/Luau/Instantiation.h Analysis/include/Luau/Instantiation.h
Analysis/include/Luau/IostreamHelpers.h Analysis/include/Luau/IostreamHelpers.h
@ -226,6 +227,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/EmbeddedBuiltinDefinitions.cpp Analysis/src/EmbeddedBuiltinDefinitions.cpp
Analysis/src/Error.cpp Analysis/src/Error.cpp
Analysis/src/Frontend.cpp Analysis/src/Frontend.cpp
Analysis/src/GlobalTypes.cpp
Analysis/src/Instantiation.cpp Analysis/src/Instantiation.cpp
Analysis/src/IostreamHelpers.cpp Analysis/src/IostreamHelpers.cpp
Analysis/src/JsonEmitter.cpp Analysis/src/JsonEmitter.cpp
@ -365,6 +367,8 @@ if(TARGET Luau.UnitTest)
tests/AstQueryDsl.cpp tests/AstQueryDsl.cpp
tests/AstQueryDsl.h tests/AstQueryDsl.h
tests/AstVisitor.test.cpp tests/AstVisitor.test.cpp
tests/RegisterCallbacks.h
tests/RegisterCallbacks.cpp
tests/Autocomplete.test.cpp tests/Autocomplete.test.cpp
tests/BuiltinDefinitions.test.cpp tests/BuiltinDefinitions.test.cpp
tests/ClassFixture.cpp tests/ClassFixture.cpp
@ -447,6 +451,8 @@ endif()
if(TARGET Luau.Conformance) if(TARGET Luau.Conformance)
# Luau.Conformance Sources # Luau.Conformance Sources
target_sources(Luau.Conformance PRIVATE target_sources(Luau.Conformance PRIVATE
tests/RegisterCallbacks.h
tests/RegisterCallbacks.cpp
tests/Conformance.test.cpp tests/Conformance.test.cpp
tests/main.cpp) tests/main.cpp)
endif() endif()
@ -464,6 +470,8 @@ if(TARGET Luau.CLI.Test)
CLI/Profiler.cpp CLI/Profiler.cpp
CLI/Repl.cpp CLI/Repl.cpp
tests/RegisterCallbacks.h
tests/RegisterCallbacks.cpp
tests/Repl.test.cpp tests/Repl.test.cpp
tests/main.cpp) tests/main.cpp)
endif() endif()

View file

@ -135,6 +135,8 @@
// Does VM support native execution via ExecutionCallbacks? We mostly assume it does but keep the define to make it easy to quantify the cost. // Does VM support native execution via ExecutionCallbacks? We mostly assume it does but keep the define to make it easy to quantify the cost.
#define VM_HAS_NATIVE 1 #define VM_HAS_NATIVE 1
void (*lua_iter_call_telemetry)(lua_State* L, int gtt, int stt, int itt) = NULL;
LUAU_NOINLINE void luau_callhook(lua_State* L, lua_Hook hook, void* userdata) LUAU_NOINLINE void luau_callhook(lua_State* L, lua_Hook hook, void* userdata)
{ {
ptrdiff_t base = savestack(L, L->base); ptrdiff_t base = savestack(L, L->base);
@ -2289,6 +2291,10 @@ reentry:
{ {
// table or userdata with __call, will be called during FORGLOOP // table or userdata with __call, will be called during FORGLOOP
// TODO: we might be able to stop supporting this depending on whether it's used in practice // TODO: we might be able to stop supporting this depending on whether it's used in practice
void (*telemetrycb)(lua_State* L, int gtt, int stt, int itt) = lua_iter_call_telemetry;
if (telemetrycb)
telemetrycb(L, ttype(ra), ttype(ra + 1), ttype(ra + 2));
} }
else if (ttistable(ra)) else if (ttistable(ra))
{ {

View file

@ -15,7 +15,6 @@
LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2)
LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel)
LUAU_FASTFLAG(LuauAutocompleteLastTypecheck)
using namespace Luau; using namespace Luau;
@ -33,37 +32,28 @@ struct ACFixtureImpl : BaseType
} }
AutocompleteResult autocomplete(unsigned row, unsigned column) AutocompleteResult autocomplete(unsigned row, unsigned column)
{
if (FFlag::LuauAutocompleteLastTypecheck)
{ {
FrontendOptions opts; FrontendOptions opts;
opts.forAutocomplete = true; opts.forAutocomplete = true;
this->frontend.check("MainModule", opts); this->frontend.check("MainModule", opts);
}
return Luau::autocomplete(this->frontend, "MainModule", Position{row, column}, nullCallback); return Luau::autocomplete(this->frontend, "MainModule", Position{row, column}, nullCallback);
} }
AutocompleteResult autocomplete(char marker, StringCompletionCallback callback = nullCallback) AutocompleteResult autocomplete(char marker, StringCompletionCallback callback = nullCallback)
{
if (FFlag::LuauAutocompleteLastTypecheck)
{ {
FrontendOptions opts; FrontendOptions opts;
opts.forAutocomplete = true; opts.forAutocomplete = true;
this->frontend.check("MainModule", opts); this->frontend.check("MainModule", opts);
}
return Luau::autocomplete(this->frontend, "MainModule", getPosition(marker), callback); return Luau::autocomplete(this->frontend, "MainModule", getPosition(marker), callback);
} }
AutocompleteResult autocomplete(const ModuleName& name, Position pos, StringCompletionCallback callback = nullCallback) AutocompleteResult autocomplete(const ModuleName& name, Position pos, StringCompletionCallback callback = nullCallback)
{
if (FFlag::LuauAutocompleteLastTypecheck)
{ {
FrontendOptions opts; FrontendOptions opts;
opts.forAutocomplete = true; opts.forAutocomplete = true;
this->frontend.check(name, opts); this->frontend.check(name, opts);
}
return Luau::autocomplete(this->frontend, name, pos, callback); return Luau::autocomplete(this->frontend, name, pos, callback);
} }
@ -3699,8 +3689,6 @@ TEST_CASE_FIXTURE(ACFixture, "string_completion_outside_quotes")
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_empty") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_empty")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local function foo(a: () -> ()) local function foo(a: () -> ())
a() a()
@ -3722,8 +3710,6 @@ foo(@1)
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local function foo(a: (number, string) -> ()) local function foo(a: (number, string) -> ())
a() a()
@ -3745,8 +3731,6 @@ foo(@1)
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args_single_return") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args_single_return")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local function foo(a: (number, string) -> (string)) local function foo(a: (number, string) -> (string))
a() a()
@ -3768,8 +3752,6 @@ foo(@1)
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args_multi_return") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args_multi_return")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local function foo(a: (number, string) -> (string, number)) local function foo(a: (number, string) -> (string, number))
a() a()
@ -3791,8 +3773,6 @@ foo(@1)
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled__noargs_multi_return") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled__noargs_multi_return")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local function foo(a: () -> (string, number)) local function foo(a: () -> (string, number))
a() a()
@ -3814,8 +3794,6 @@ foo(@1)
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled__varargs_multi_return") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled__varargs_multi_return")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local function foo(a: (...number) -> (string, number)) local function foo(a: (...number) -> (string, number))
a() a()
@ -3837,8 +3815,6 @@ foo(@1)
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_multi_return") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_multi_return")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local function foo(a: (string, ...number) -> (string, number)) local function foo(a: (string, ...number) -> (string, number))
a() a()
@ -3860,8 +3836,6 @@ foo(@1)
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_varargs_return") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_varargs_return")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local function foo(a: (string, ...number) -> ...number) local function foo(a: (string, ...number) -> ...number)
a() a()
@ -3883,8 +3857,6 @@ foo(@1)
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_multi_varargs_return") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_multi_varargs_return")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local function foo(a: (string, ...number) -> (boolean, ...number)) local function foo(a: (string, ...number) -> (boolean, ...number))
a() a()
@ -3906,8 +3878,6 @@ foo(@1)
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_named_args") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_named_args")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local function foo(a: (foo: number, bar: string) -> (string, number)) local function foo(a: (foo: number, bar: string) -> (string, number))
a() a()
@ -3929,8 +3899,6 @@ foo(@1)
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_partially_args") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_partially_args")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local function foo(a: (number, bar: string) -> (string, number)) local function foo(a: (number, bar: string) -> (string, number))
a() a()
@ -3952,8 +3920,6 @@ foo(@1)
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_partially_args_last") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_partially_args_last")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local function foo(a: (foo: number, string) -> (string, number)) local function foo(a: (foo: number, string) -> (string, number))
a() a()
@ -3975,8 +3941,6 @@ foo(@1)
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_args") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_args")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local t = { a = 1, b = 2 } local t = { a = 1, b = 2 }
@ -4000,8 +3964,6 @@ foo(@1)
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_table_literal_args") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_table_literal_args")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local function foo(a: (tbl: { x: number, y: number }) -> number) return a({x=2, y = 3}) end local function foo(a: (tbl: { x: number, y: number }) -> number) return a({x=2, y = 3}) end
foo(@1) foo(@1)
@ -4020,8 +3982,6 @@ foo(@1)
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_returns") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_returns")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local t = { a = 1, b = 2 } local t = { a = 1, b = 2 }
@ -4045,8 +4005,6 @@ foo(@1)
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_table_literal_args") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_table_literal_args")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local function foo(a: () -> { x: number, y: number }) return {x=2, y = 3} end local function foo(a: () -> { x: number, y: number }) return {x=2, y = 3} end
foo(@1) foo(@1)
@ -4065,8 +4023,6 @@ foo(@1)
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_vararg") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_vararg")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local t = { a = 1, b = 2 } local t = { a = 1, b = 2 }
@ -4090,8 +4046,6 @@ foo(@1)
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_generic_type_pack_vararg") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_generic_type_pack_vararg")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local function foo<A>(a: (...A) -> number, ...: A) local function foo<A>(a: (...A) -> number, ...: A)
return a(...) return a(...)
@ -4113,8 +4067,6 @@ foo(@1)
TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_generic_on_argument_type_pack_vararg") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_generic_on_argument_type_pack_vararg")
{ {
ScopedFastFlag flag{"LuauAnonymousAutofilled1", true};
check(R"( check(R"(
local function foo(a: <T...>(...: T...) -> number) local function foo(a: <T...>(...: T...) -> number)
return a(4, 5, 6) return a(4, 5, 6)

View file

@ -282,6 +282,8 @@ TEST_CASE("Assert")
TEST_CASE("Basic") TEST_CASE("Basic")
{ {
ScopedFastFlag sffs{"LuauFloorDivision", true}; ScopedFastFlag sffs{"LuauFloorDivision", true};
ScopedFastFlag sfff{"LuauImproveForN", true};
runConformance("basic.lua"); runConformance("basic.lua");
} }

View file

@ -9,6 +9,7 @@
#include "Luau/Parser.h" #include "Luau/Parser.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/TypeAttach.h" #include "Luau/TypeAttach.h"
#include "Luau/TypeInfer.h"
#include "Luau/Transpiler.h" #include "Luau/Transpiler.h"
#include "doctest.h" #include "doctest.h"
@ -144,8 +145,6 @@ Fixture::Fixture(bool freeze, bool prepareAutocomplete)
configResolver.defaultConfig.enabledLint.warningMask = ~0ull; configResolver.defaultConfig.enabledLint.warningMask = ~0ull;
configResolver.defaultConfig.parseOptions.captureComments = true; configResolver.defaultConfig.parseOptions.captureComments = true;
registerBuiltinTypes(frontend.globals);
Luau::freeze(frontend.globals.globalTypes); Luau::freeze(frontend.globals.globalTypes);
Luau::freeze(frontend.globalsForAutocomplete.globalTypes); Luau::freeze(frontend.globalsForAutocomplete.globalTypes);

View file

@ -1222,4 +1222,28 @@ TEST_CASE_FIXTURE(FrontendFixture, "parse_only")
CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0]));
} }
TEST_CASE_FIXTURE(FrontendFixture, "markdirty_early_return")
{
ScopedFastFlag fflag("CorrectEarlyReturnInMarkDirty", true);
constexpr char moduleName[] = "game/Gui/Modules/A";
fileResolver.source[moduleName] = R"(
return 1
)";
{
std::vector<ModuleName> markedDirty;
frontend.markDirty(moduleName, &markedDirty);
CHECK(markedDirty.empty());
}
frontend.parse(moduleName);
{
std::vector<ModuleName> markedDirty;
frontend.markDirty(moduleName, &markedDirty);
CHECK(!markedDirty.empty());
}
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -621,11 +621,11 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ControlFlowEq")
}); });
withTwoBlocks([this](IrOp a, IrOp b) { withTwoBlocks([this](IrOp a, IrOp b) {
build.inst(IrCmd::JUMP_EQ_INT, build.constInt(0), build.constInt(0), a, b); build.inst(IrCmd::JUMP_CMP_INT, build.constInt(0), build.constInt(0), build.cond(IrCondition::Equal), a, b);
}); });
withTwoBlocks([this](IrOp a, IrOp b) { withTwoBlocks([this](IrOp a, IrOp b) {
build.inst(IrCmd::JUMP_EQ_INT, build.constInt(0), build.constInt(1), a, b); build.inst(IrCmd::JUMP_CMP_INT, build.constInt(0), build.constInt(1), build.cond(IrCondition::Equal), a, b);
}); });
updateUseCounts(build.function); updateUseCounts(build.function);
@ -1359,7 +1359,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "IntEqRemoval")
build.beginBlock(block); build.beginBlock(block);
build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(5)); build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(5));
IrOp value = build.inst(IrCmd::LOAD_INT, build.vmReg(1)); IrOp value = build.inst(IrCmd::LOAD_INT, build.vmReg(1));
build.inst(IrCmd::JUMP_EQ_INT, value, build.constInt(5), trueBlock, falseBlock); build.inst(IrCmd::JUMP_CMP_INT, value, build.constInt(5), build.cond(IrCondition::Equal), trueBlock, falseBlock);
build.beginBlock(trueBlock); build.beginBlock(trueBlock);
build.inst(IrCmd::RETURN, build.constUint(1)); build.inst(IrCmd::RETURN, build.constUint(1));
@ -1556,7 +1556,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval2")
IrOp repeat = build.block(IrBlockKind::Internal); IrOp repeat = build.block(IrBlockKind::Internal);
build.beginBlock(entry); build.beginBlock(entry);
build.inst(IrCmd::JUMP_EQ_INT, build.constInt(0), build.constInt(1), block, exit1); build.inst(IrCmd::JUMP_CMP_INT, build.constInt(0), build.constInt(1), build.cond(IrCondition::Equal), block, exit1);
build.beginBlock(exit1); build.beginBlock(exit1);
build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0));
@ -2785,4 +2785,37 @@ bb_0:
)"); )");
} }
TEST_CASE_FIXTURE(IrBuilderFixture, "TagSelfEqualityCheckRemoval")
{
ScopedFastFlag luauMergeTagLoads{"LuauMergeTagLoads", true};
IrOp entry = build.block(IrBlockKind::Internal);
IrOp trueBlock = build.block(IrBlockKind::Internal);
IrOp falseBlock = build.block(IrBlockKind::Internal);
build.beginBlock(entry);
IrOp tag1 = build.inst(IrCmd::LOAD_TAG, build.vmReg(0));
IrOp tag2 = build.inst(IrCmd::LOAD_TAG, build.vmReg(0));
build.inst(IrCmd::JUMP_EQ_TAG, tag1, tag2, trueBlock, falseBlock);
build.beginBlock(trueBlock);
build.inst(IrCmd::RETURN, build.constUint(1));
build.beginBlock(falseBlock);
build.inst(IrCmd::RETURN, build.constUint(2));
updateUseCounts(build.function);
constPropInBlockChains(build, true);
CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"(
bb_0:
JUMP bb_1
bb_1:
RETURN 1u
)");
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -0,0 +1,20 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "RegisterCallbacks.h"
namespace Luau
{
std::unordered_set<RegisterCallback>& getRegisterCallbacks()
{
static std::unordered_set<RegisterCallback> cbs;
return cbs;
}
int addTestCallback(RegisterCallback cb)
{
getRegisterCallbacks().insert(cb);
return 0;
}
} // namespace Luau

22
tests/RegisterCallbacks.h Normal file
View file

@ -0,0 +1,22 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <unordered_set>
#include <string>
namespace Luau
{
using RegisterCallback = void (*)();
/// Gets a set of callbacks to run immediately before running tests, intended
/// for registering new tests at runtime.
std::unordered_set<RegisterCallback>& getRegisterCallbacks();
/// Adds a new callback to be ran immediately before running tests.
///
/// @param cb the callback to add.
/// @returns a dummy integer to satisfy a doctest internal contract.
int addTestCallback(RegisterCallback cb);
} // namespace Luau

View file

@ -2,7 +2,9 @@
#include "doctest.h" #include "doctest.h"
#include "Fixture.h" #include "Fixture.h"
#include "RegisterCallbacks.h"
#include "Luau/Normalize.h"
#include "Luau/Subtyping.h" #include "Luau/Subtyping.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
@ -344,14 +346,72 @@ struct SubtypeFixture : Fixture
CHECK_MESSAGE(!result.isErrorSuppressing, "Expected " << leftTy << " to error-suppress " << rightTy); \ CHECK_MESSAGE(!result.isErrorSuppressing, "Expected " << leftTy << " to error-suppress " << rightTy); \
} while (0) } while (0)
/// Internal macro for registering a generated test case.
///
/// @param der the name of the derived fixture struct
/// @param reg the name of the registration callback, invoked immediately before
/// tests are ran to register the test
/// @param run the name of the run callback, invoked to actually run the test case
#define TEST_REGISTER(der, reg, run) \
static inline DOCTEST_NOINLINE void run() \
{ \
der fix; \
fix.test(); \
} \
static inline DOCTEST_NOINLINE void reg() \
{ \
/* we have to mark this as `static` to ensure the memory remains alive \
for the entirety of the test process */ \
static std::string name = der().testName; \
doctest::detail::regTest(doctest::detail::TestCase(run, __FILE__, __LINE__, \
doctest_detail_test_suite_ns::getCurrentTestSuite()) /* the test case's name, determined at runtime */ \
* name.c_str() /* getCurrentTestSuite() only works at static initialization \
time due to implementation details. To ensure that test cases \
are grouped where they should be, manually override the suite \
with the test_suite decorator. */ \
* doctest::test_suite("Subtyping")); \
} \
DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_VAR_), addTestCallback(reg));
/// Internal macro for deriving a test case fixture. Roughly analogous to
/// DOCTEST_IMPLEMENT_FIXTURE.
///
/// @param op a function (or macro) to call that compares the subtype to
/// the supertype.
/// @param symbol the symbol to use in stringification
/// @param der the name of the derived fixture struct
/// @param left the subtype expression
/// @param right the supertype expression
#define TEST_DERIVE(op, symbol, der, left, right) \
namespace \
{ \
struct der : SubtypeFixture \
{ \
const TypeId subTy = (left); \
const TypeId superTy = (right); \
const std::string testName = toString(subTy) + " " symbol " " + toString(superTy); \
inline DOCTEST_NOINLINE void test() \
{ \
op(subTy, superTy); \
} \
}; \
TEST_REGISTER(der, DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_)); \
}
/// Generates a test that checks if a type is a subtype of another.
#define TEST_IS_SUBTYPE(left, right) TEST_DERIVE(CHECK_IS_SUBTYPE, "<:", DOCTEST_ANONYMOUS(DOCTEST_ANON_CLASS_), left, right)
/// Generates a test that checks if a type is _not_ a subtype of another.
/// Uses <!: instead of </: to ensure that rotest doesn't explode when it splits
/// on / characters.
#define TEST_IS_NOT_SUBTYPE(left, right) TEST_DERIVE(CHECK_IS_NOT_SUBTYPE, "<!:", DOCTEST_ANONYMOUS(DOCTEST_ANON_CLASS_), left, right)
TEST_SUITE_BEGIN("Subtyping"); TEST_SUITE_BEGIN("Subtyping");
// We would like to write </: to mean "is not a subtype," but rotest does not like that at all, so we instead use <!: // We would like to write </: to mean "is not a subtype," but rotest does not like that at all, so we instead use <!:
TEST_CASE_FIXTURE(SubtypeFixture, "number <: any") TEST_IS_SUBTYPE(builtinTypes->numberType, builtinTypes->anyType);
{ TEST_IS_NOT_SUBTYPE(builtinTypes->numberType, builtinTypes->stringType);
CHECK_IS_SUBTYPE(builtinTypes->numberType, builtinTypes->anyType);
}
TEST_CASE_FIXTURE(SubtypeFixture, "any <!: unknown") TEST_CASE_FIXTURE(SubtypeFixture, "any <!: unknown")
{ {
@ -375,11 +435,6 @@ TEST_CASE_FIXTURE(SubtypeFixture, "number <: number")
CHECK_IS_SUBTYPE(builtinTypes->numberType, builtinTypes->numberType); CHECK_IS_SUBTYPE(builtinTypes->numberType, builtinTypes->numberType);
} }
TEST_CASE_FIXTURE(SubtypeFixture, "number <!: string")
{
CHECK_IS_NOT_SUBTYPE(builtinTypes->numberType, builtinTypes->stringType);
}
TEST_CASE_FIXTURE(SubtypeFixture, "number <: number?") TEST_CASE_FIXTURE(SubtypeFixture, "number <: number?")
{ {
CHECK_IS_SUBTYPE(builtinTypes->numberType, builtinTypes->optionalNumberType); CHECK_IS_SUBTYPE(builtinTypes->numberType, builtinTypes->optionalNumberType);
@ -895,6 +950,16 @@ TEST_CASE_FIXTURE(SubtypeFixture, "string <!: { insaneThingNoScalarHas : () -> (
CHECK_IS_NOT_SUBTYPE(builtinTypes->stringType, tableWithoutScalarProp); CHECK_IS_NOT_SUBTYPE(builtinTypes->stringType, tableWithoutScalarProp);
} }
TEST_CASE_FIXTURE(SubtypeFixture, "~fun & (string) -> number <: (string) -> number")
{
CHECK_IS_SUBTYPE(meet(negate(builtinTypes->functionType), numberToStringType), numberToStringType);
}
TEST_CASE_FIXTURE(SubtypeFixture, "(string) -> number <: ~fun & (string) -> number")
{
CHECK_IS_NOT_SUBTYPE(numberToStringType, meet(negate(builtinTypes->functionType), numberToStringType));
}
/* /*
* <A>(A) -> A <: <X>(X) -> X * <A>(A) -> A <: <X>(X) -> X
* A can be bound to X. * A can be bound to X.

View file

@ -44,25 +44,34 @@ TEST_SUITE_BEGIN("ToDot");
TEST_CASE_FIXTURE(Fixture, "primitive") TEST_CASE_FIXTURE(Fixture, "primitive")
{ {
CheckResult result = check(R"( CHECK_EQ(R"(digraph graphname {
local a: nil n1 [label="nil"];
local b: number })",
local c: any toDot(builtinTypes->nilType));
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_NE("nil", toDot(requireType("a")));
CHECK_EQ(R"(digraph graphname { CHECK_EQ(R"(digraph graphname {
n1 [label="number"]; n1 [label="number"];
})", })",
toDot(requireType("b"))); toDot(builtinTypes->numberType));
CHECK_EQ(R"(digraph graphname { CHECK_EQ(R"(digraph graphname {
n1 [label="any"]; n1 [label="any"];
})", })",
toDot(requireType("c"))); toDot(builtinTypes->anyType));
CHECK_EQ(R"(digraph graphname {
n1 [label="unknown"];
})",
toDot(builtinTypes->unknownType));
CHECK_EQ(R"(digraph graphname {
n1 [label="never"];
})",
toDot(builtinTypes->neverType));
}
TEST_CASE_FIXTURE(Fixture, "no_duplicatePrimitives")
{
ToDotOptions opts; ToDotOptions opts;
opts.showPointers = false; opts.showPointers = false;
opts.duplicatePrimitives = false; opts.duplicatePrimitives = false;
@ -70,12 +79,22 @@ n1 [label="any"];
CHECK_EQ(R"(digraph graphname { CHECK_EQ(R"(digraph graphname {
n1 [label="PrimitiveType number"]; n1 [label="PrimitiveType number"];
})", })",
toDot(requireType("b"), opts)); toDot(builtinTypes->numberType, opts));
CHECK_EQ(R"(digraph graphname { CHECK_EQ(R"(digraph graphname {
n1 [label="AnyType 1"]; n1 [label="AnyType 1"];
})", })",
toDot(requireType("c"), opts)); toDot(builtinTypes->anyType, opts));
CHECK_EQ(R"(digraph graphname {
n1 [label="UnknownType 1"];
})",
toDot(builtinTypes->unknownType, opts));
CHECK_EQ(R"(digraph graphname {
n1 [label="NeverType 1"];
})",
toDot(builtinTypes->neverType, opts));
} }
TEST_CASE_FIXTURE(Fixture, "bound") TEST_CASE_FIXTURE(Fixture, "bound")
@ -283,6 +302,30 @@ n1 [label="FreeType 1"];
toDot(&type, opts)); toDot(&type, opts));
} }
TEST_CASE_FIXTURE(Fixture, "free_with_constraints")
{
ScopedFastFlag sff[] = {
{"DebugLuauDeferredConstraintResolution", true},
};
Type type{TypeVariant{FreeType{nullptr, builtinTypes->numberType, builtinTypes->optionalNumberType}}};
ToDotOptions opts;
opts.showPointers = false;
CHECK_EQ(R"(digraph graphname {
n1 [label="FreeType 1"];
n1 -> n2 [label="[lowerBound]"];
n2 [label="number"];
n1 -> n3 [label="[upperBound]"];
n3 [label="UnionType 3"];
n3 -> n4;
n4 [label="number"];
n3 -> n5;
n5 [label="nil"];
})",
toDot(&type, opts));
}
TEST_CASE_FIXTURE(Fixture, "error") TEST_CASE_FIXTURE(Fixture, "error")
{ {
Type type{TypeVariant{ErrorType{}}}; Type type{TypeVariant{ErrorType{}}};
@ -440,4 +483,19 @@ n5 [label="SingletonType boolean: false"];
toDot(requireType("x"), opts)); toDot(requireType("x"), opts));
} }
TEST_CASE_FIXTURE(Fixture, "negation")
{
TypeArena arena;
TypeId t = arena.addType(NegationType{builtinTypes->stringType});
ToDotOptions opts;
opts.showPointers = false;
CHECK(R"(digraph graphname {
n1 [label="NegationType 1"];
n1 -> n2 [label="[negated]"];
n2 [label="string"];
})" == toDot(t, opts));
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -1,5 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypeFamily.h" #include "Luau/TypeFamily.h"
#include "Luau/TxnLog.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Fixture.h" #include "Fixture.h"

View file

@ -1006,8 +1006,6 @@ end
// We would prefer this unification to be able to complete, but at least it should not crash // We would prefer this unification to be able to complete, but at least it should not crash
TEST_CASE_FIXTURE(BuiltinsFixture, "table_unification_infinite_recursion") TEST_CASE_FIXTURE(BuiltinsFixture, "table_unification_infinite_recursion")
{ {
ScopedFastFlag luauTableUnifyRecursionLimit{"LuauTableUnifyRecursionLimit", true};
#if defined(_NOOPT) || defined(_DEBUG) #if defined(_NOOPT) || defined(_DEBUG)
ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 100}; ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 100};
#endif #endif

View file

@ -1404,4 +1404,32 @@ TEST_CASE_FIXTURE(Fixture, "promote_tail_type_packs")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
} }
/*
* CLI-49876
*
* We had a bug where we would not use the correct TxnLog when evaluating a
* variadic overload. We could therefore get into a state where the TxnLog has
* logged that a generic matches to one type, but the variadic tail has already
* been bound to another type outside of that TxnLog.
*
* This caused type checking to succeed when it should have failed.
*/
TEST_CASE_FIXTURE(BuiltinsFixture, "be_sure_to_use_active_txnlog_when_evaluating_a_variadic_overload")
{
ScopedFastFlag sff{"LuauVariadicOverloadFix", true};
CheckResult result = check(R"(
local function concat<T>(target: {T}, ...: {T} | T): {T}
return (nil :: any) :: {T}
end
local res = concat({"alic"}, 1, 2)
)");
LUAU_REQUIRE_ERRORS(result);
for (const auto& e: result.errors)
CHECK(5 == e.location.begin.line);
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -177,6 +177,33 @@ assert((function() local a = 1 for b=1,9 do a = a * 2 if a == 128 then break els
-- make sure internal index is protected against modification -- make sure internal index is protected against modification
assert((function() local a = 1 for b=9,1,-2 do a = a * 2 b = nil end return a end)() == 32) assert((function() local a = 1 for b=9,1,-2 do a = a * 2 b = nil end return a end)() == 32)
-- make sure that when step is 0, we treat it as backward iteration (and as such, iterate zero times or indefinitely)
-- this is consistent with Lua 5.1; future Lua versions emit an error when step is 0; LuaJIT instead treats 0 as forward iteration
-- we repeat tests twice, with and without constant folding
local zero = tonumber("0")
assert((function() local c = 0 for i=1,10,0 do c += 1 if c > 10 then break end end return c end)() == 0)
assert((function() local c = 0 for i=10,1,0 do c += 1 if c > 10 then break end end return c end)() == 11)
assert((function() local c = 0 for i=1,10,zero do c += 1 if c > 10 then break end end return c end)() == 0)
assert((function() local c = 0 for i=10,1,zero do c += 1 if c > 10 then break end end return c end)() == 11)
-- make sure that when limit is nan, we iterate zero times (this is consistent with Lua 5.1; future Lua versions break this)
-- we repeat tests twice, with and without constant folding
local nan = tonumber("nan")
assert((function() local c = 0 for i=1,0/0 do c += 1 end return c end)() == 0)
assert((function() local c = 0 for i=1,0/0,-1 do c += 1 end return c end)() == 0)
assert((function() local c = 0 for i=1,nan do c += 1 end return c end)() == 0)
assert((function() local c = 0 for i=1,nan,-1 do c += 1 end return c end)() == 0)
-- make sure that when step is nan, we treat it as backward iteration and as such iterate once iff start<=limit
assert((function() local c = 0 for i=1,10,0/0 do c += 1 end return c end)() == 0)
assert((function() local c = 0 for i=10,1,0/0 do c += 1 end return c end)() == 1)
assert((function() local c = 0 for i=1,10,nan do c += 1 end return c end)() == 0)
assert((function() local c = 0 for i=10,1,nan do c += 1 end return c end)() == 1)
-- make sure that when index becomes nan mid-iteration, we correctly exit the loop (this is broken in Lua 5.1; future Lua versions fix this)
assert((function() local c = 0 for i=-math.huge,0,math.huge do c += 1 end return c end)() == 1)
assert((function() local c = 0 for i=math.huge,math.huge,-math.huge do c += 1 end return c end)() == 1)
-- generic for -- generic for
-- ipairs -- ipairs
assert((function() local a = '' for k in ipairs({5, 6, 7}) do a = a .. k end return a end)() == "123") assert((function() local a = '' for k in ipairs({5, 6, 7}) do a = a .. k end return a end)() == "123")
@ -286,6 +313,10 @@ assert((function()
return result return result
end)() == "ArcticDunesCanyonsWaterMountainsHillsLavaflowPlainsMarsh") end)() == "ArcticDunesCanyonsWaterMountainsHillsLavaflowPlainsMarsh")
-- table literals may contain duplicate fields; the language doesn't specify assignment order but we currently assign left to right
assert((function() local t = {data = 4, data = nil, data = 42} return t.data end)() == 42)
assert((function() local t = {data = 4, data = nil, data = 42, data = nil} return t.data end)() == nil)
-- multiple returns -- multiple returns
-- local= -- local=
assert((function() function foo() return 2, 3, 4 end local a, b, c = foo() return ''..a..b..c end)() == "234") assert((function() function foo() return 2, 3, 4 end local a, b, c = foo() return ''..a..b..c end)() == "234")

View file

@ -189,6 +189,26 @@ do -- testing NaN
assert(a[NaN] == nil) assert(a[NaN] == nil)
end end
-- extra NaN tests, hidden in a function
do
function neq(a) return a ~= a end
function eq(a) return a == a end
function lt(a) return a < a end
function le(a) return a <= a end
function gt(a) return a > a end
function ge(a) return a >= a end
local NaN -- to avoid constant folding
NaN = 10e500 - 10e400
assert(neq(NaN))
assert(not eq(NaN))
assert(not lt(NaN))
assert(not le(NaN))
assert(not gt(NaN))
assert(not ge(NaN))
end
-- require "checktable" -- require "checktable"
-- stat(a) -- stat(a)

View file

@ -6,6 +6,8 @@
#define DOCTEST_CONFIG_OPTIONS_PREFIX "" #define DOCTEST_CONFIG_OPTIONS_PREFIX ""
#include "doctest.h" #include "doctest.h"
#include "RegisterCallbacks.h"
#ifdef _WIN32 #ifdef _WIN32
#ifndef WIN32_LEAN_AND_MEAN #ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN
@ -327,6 +329,14 @@ int main(int argc, char** argv)
} }
} }
// These callbacks register unit tests that need runtime support to be
// correctly set up. Running them here means that all command line flags
// have been parsed, fast flags have been set, and we've potentially already
// exited. Once doctest::Context::run is invoked, the test list will be
// picked up from global state.
for (Luau::RegisterCallback cb : Luau::getRegisterCallbacks())
cb();
int result = context.run(); int result = context.run();
if (doctest::parseFlag(argc, argv, "--help") || doctest::parseFlag(argc, argv, "-h")) if (doctest::parseFlag(argc, argv, "--help") || doctest::parseFlag(argc, argv, "-h"))
{ {