Sync to upstream/release/655

* General
- Fix the benchmark require wrapper function to work in Lua
- Fix memory leak in the new Luau C API test

* New Solver
- Luau: type functions should be able to signal whether or not irreducibility is due to an error
- Do not generate extra expansion constraint for uninvoked user-defined type functions
- Print in a user-defined type function should be reported as an error
instead of logging to stdout
- Many e-graphs bugfixes and performance improvements
- Many general bugfixes and improvements to the new solver as a whole
- Fixed issue with Luau used-defined type functions not having all environments initialized
- Infer types of globals under new type solver

* Fragment Autocomplete
- Miscellaneous fixes to make interop with the old solver better

* Runtime
- Support disabling specific Luau built-in functions from being
fast-called or constant-evaluated
- Added constant folding for vector arithmetic
- Added constant propagation and type inference for Vector3 globals

----------------------------------------------------------
9 contributors:

Co-authored-by: Aaron Weiss <aaronweiss@roblox.com>
Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
Co-authored-by: Daniel Angel <danielangel@roblox.com>
Co-authored-by: Jonathan Kelaty <jkelaty@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Varun Saini <vsaini@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
This commit is contained in:
Vighnesh 2024-12-13 10:57:30 -08:00
parent 230ab81326
commit 906a00d498
171 changed files with 2783 additions and 904 deletions

View file

@ -53,7 +53,7 @@ LUAU_EQSAT_NODE_SET(Intersection);
LUAU_EQSAT_NODE_ARRAY(Negation, 1); LUAU_EQSAT_NODE_ARRAY(Negation, 1);
LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(TTypeFun, const TypeFunction*); LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(TTypeFun, std::shared_ptr<const TypeFunctionInstanceType>);
LUAU_EQSAT_UNIT(TNoRefine); LUAU_EQSAT_UNIT(TNoRefine);
LUAU_EQSAT_UNIT(Invalid); LUAU_EQSAT_UNIT(Invalid);
@ -218,6 +218,7 @@ struct Simplifier
void simplifyUnion(Id id); void simplifyUnion(Id id);
void uninhabitedIntersection(Id id); void uninhabitedIntersection(Id id);
void intersectWithNegatedClass(Id id); void intersectWithNegatedClass(Id id);
void intersectWithNegatedAtom(Id id);
void intersectWithNoRefine(Id id); void intersectWithNoRefine(Id id);
void cyclicIntersectionOfUnion(Id id); void cyclicIntersectionOfUnion(Id id);
void cyclicUnionOfIntersection(Id id); void cyclicUnionOfIntersection(Id id);
@ -228,6 +229,7 @@ struct Simplifier
void unneededTableModification(Id id); void unneededTableModification(Id id);
void builtinTypeFunctions(Id id); void builtinTypeFunctions(Id id);
void iffyTypeFunctions(Id id); void iffyTypeFunctions(Id id);
void strictMetamethods(Id id);
}; };
template<typename Tag> template<typename Tag>

View file

@ -15,6 +15,7 @@ struct TypeCheckLimits;
void checkNonStrict( void checkNonStrict(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> ice, NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState, NotNull<UnifierSharedState> unifierState,

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/EqSatSimplification.h"
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/Set.h" #include "Luau/Set.h"
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
@ -21,8 +22,22 @@ struct Scope;
using ModulePtr = std::shared_ptr<Module>; using ModulePtr = std::shared_ptr<Module>;
bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice); bool isSubtype(
bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice); TypeId subTy,
TypeId superTy,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
InternalErrorReporter& ice
);
bool isSubtype(
TypePackId subPack,
TypePackId superPack,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
InternalErrorReporter& ice
);
class TypeIds class TypeIds
{ {

View file

@ -2,12 +2,13 @@
#pragma once #pragma once
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/InsertionOrderedMap.h" #include "Luau/EqSatSimplification.h"
#include "Luau/NotNull.h"
#include "Luau/TypeFwd.h"
#include "Luau/Location.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/InsertionOrderedMap.h"
#include "Luau/Location.h"
#include "Luau/NotNull.h"
#include "Luau/Subtyping.h" #include "Luau/Subtyping.h"
#include "Luau/TypeFwd.h"
namespace Luau namespace Luau
{ {
@ -34,6 +35,7 @@ struct OverloadResolver
OverloadResolver( OverloadResolver(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope, NotNull<Scope> scope,
@ -44,6 +46,7 @@ struct OverloadResolver
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena; NotNull<TypeArena> arena;
NotNull<Simplifier> simplifier;
NotNull<Normalizer> normalizer; NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime; NotNull<TypeFunctionRuntime> typeFunctionRuntime;
NotNull<Scope> scope; NotNull<Scope> scope;
@ -110,6 +113,7 @@ struct SolveResult
SolveResult solveFunctionCall( SolveResult solveFunctionCall(
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter, NotNull<InternalErrorReporter> iceReporter,

View file

@ -85,6 +85,10 @@ struct Scope
void inheritAssignments(const ScopePtr& childScope); void inheritAssignments(const ScopePtr& childScope);
void inheritRefinements(const ScopePtr& childScope); void inheritRefinements(const ScopePtr& childScope);
// Track globals that should emit warnings during type checking.
DenseHashSet<std::string> globalsToWarn{""};
bool shouldWarnGlobal(std::string name) const;
// For mutually recursive type aliases, it's important that // For mutually recursive type aliases, it's important that
// they use the same types for the same names. // they use the same types for the same names.
// For instance, in `type Tree<T> { data: T, children: Forest<T> } type Forest<T> = {Tree<T>}` // For instance, in `type Tree<T> { data: T, children: Forest<T> } type Forest<T> = {Tree<T>}`

View file

@ -1,13 +1,14 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/DenseHash.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Set.h" #include "Luau/Set.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFunction.h"
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
#include "Luau/TypePairHash.h" #include "Luau/TypePairHash.h"
#include "Luau/TypePath.h" #include "Luau/TypePath.h"
#include "Luau/TypeFunction.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/DenseHash.h"
#include <vector> #include <vector>
#include <optional> #include <optional>
@ -134,6 +135,7 @@ struct Subtyping
{ {
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena; NotNull<TypeArena> arena;
NotNull<Simplifier> simplifier;
NotNull<Normalizer> normalizer; NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime; NotNull<TypeFunctionRuntime> typeFunctionRuntime;
NotNull<InternalErrorReporter> iceReporter; NotNull<InternalErrorReporter> iceReporter;
@ -155,6 +157,7 @@ struct Subtyping
Subtyping( Subtyping(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> typeArena, NotNull<TypeArena> typeArena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter NotNull<InternalErrorReporter> iceReporter

View file

@ -608,7 +608,8 @@ struct UserDefinedFunctionData
// References to AST elements are owned by the Module allocator which also stores this type // References to AST elements are owned by the Module allocator which also stores this type
AstStatTypeFunction* definition = nullptr; AstStatTypeFunction* definition = nullptr;
DenseHashMap<Name, AstStatTypeFunction*> environment{""}; DenseHashMap<Name, std::pair<AstStatTypeFunction*, size_t>> environment{""};
DenseHashMap<Name, AstStatTypeFunction*> environment_DEPRECATED{""};
}; };
/** /**

View file

@ -2,15 +2,16 @@
#pragma once #pragma once
#include "Luau/Error.h"
#include "Luau/NotNull.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/TypeUtils.h" #include "Luau/EqSatSimplification.h"
#include "Luau/Error.h"
#include "Luau/Normalize.h"
#include "Luau/NotNull.h"
#include "Luau/Subtyping.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
#include "Luau/TypeOrPack.h" #include "Luau/TypeOrPack.h"
#include "Luau/Normalize.h" #include "Luau/TypeUtils.h"
#include "Luau/Subtyping.h"
namespace Luau namespace Luau
{ {
@ -60,6 +61,7 @@ struct Reasonings
void check( void check(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<UnifierSharedState> sharedState, NotNull<UnifierSharedState> sharedState,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
@ -71,6 +73,7 @@ void check(
struct TypeChecker2 struct TypeChecker2
{ {
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
NotNull<Simplifier> simplifier;
NotNull<TypeFunctionRuntime> typeFunctionRuntime; NotNull<TypeFunctionRuntime> typeFunctionRuntime;
DcrLogger* logger; DcrLogger* logger;
const NotNull<TypeCheckLimits> limits; const NotNull<TypeCheckLimits> limits;
@ -90,6 +93,7 @@ struct TypeChecker2
TypeChecker2( TypeChecker2(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<UnifierSharedState> unifierState, NotNull<UnifierSharedState> unifierState,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
@ -213,6 +217,9 @@ private:
std::vector<TypeError>& errors std::vector<TypeError>& errors
); );
// Avoid duplicate warnings being emitted for the same global variable.
DenseHashSet<std::string> warnedGlobals{""};
void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) const; void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) const;
bool isErrorSuppressing(Location loc, TypeId ty); bool isErrorSuppressing(Location loc, TypeId ty);
bool isErrorSuppressing(Location loc1, TypeId ty1, Location loc2, TypeId ty2); bool isErrorSuppressing(Location loc1, TypeId ty1, Location loc2, TypeId ty2);

View file

@ -2,6 +2,7 @@
#pragma once #pragma once
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/TypeCheckLimits.h" #include "Luau/TypeCheckLimits.h"
@ -41,9 +42,15 @@ struct TypeFunctionRuntime
StateRef state; StateRef state;
// Set of functions which have their environment table initialized
DenseHashSet<AstStatTypeFunction*> initialized{nullptr};
// Evaluation of type functions should only be performed in the absence of parse errors in the source module // Evaluation of type functions should only be performed in the absence of parse errors in the source module
bool allowEvaluation = true; bool allowEvaluation = true;
// Output created by 'print' function
std::vector<std::string> messages;
private: private:
void prepareState(); void prepareState();
}; };
@ -53,6 +60,7 @@ struct TypeFunctionContext
NotNull<TypeArena> arena; NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtins; NotNull<BuiltinTypes> builtins;
NotNull<Scope> scope; NotNull<Scope> scope;
NotNull<Simplifier> simplifier;
NotNull<Normalizer> normalizer; NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime; NotNull<TypeFunctionRuntime> typeFunctionRuntime;
NotNull<InternalErrorReporter> ice; NotNull<InternalErrorReporter> ice;
@ -71,6 +79,7 @@ struct TypeFunctionContext
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtins, NotNull<BuiltinTypes> builtins,
NotNull<Scope> scope, NotNull<Scope> scope,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> ice, NotNull<InternalErrorReporter> ice,
@ -79,6 +88,7 @@ struct TypeFunctionContext
: arena(arena) : arena(arena)
, builtins(builtins) , builtins(builtins)
, scope(scope) , scope(scope)
, simplifier(simplifier)
, normalizer(normalizer) , normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime) , typeFunctionRuntime(typeFunctionRuntime)
, ice(ice) , ice(ice)
@ -91,19 +101,31 @@ struct TypeFunctionContext
NotNull<Constraint> pushConstraint(ConstraintV&& c) const; NotNull<Constraint> pushConstraint(ConstraintV&& c) const;
}; };
enum class Reduction
{
// The type function is either known to be reducible or the determination is blocked.
MaybeOk,
// The type function is known to be irreducible, but maybe not be erroneous, e.g. when it's over generics or free types.
Irreducible,
// The type function is known to be irreducible, and is definitely erroneous.
Erroneous,
};
/// Represents a reduction result, which may have successfully reduced the type, /// Represents a reduction result, which may have successfully reduced the type,
/// may have concretely failed to reduce the type, or may simply be stuck /// may have concretely failed to reduce the type, or may simply be stuck
/// without more information. /// without more information.
template<typename Ty> template<typename Ty>
struct TypeFunctionReductionResult struct TypeFunctionReductionResult
{ {
/// The result of the reduction, if any. If this is nullopt, the type function /// The result of the reduction, if any. If this is nullopt, the type function
/// could not be reduced. /// could not be reduced.
std::optional<Ty> result; std::optional<Ty> result;
/// Whether the result is uninhabited: whether we know, unambiguously and /// Indicates the status of this reduction: is `Reduction::Irreducible` if
/// permanently, whether this type function reduction results in an /// the this result indicates the type function is irreducible, and
/// uninhabitable type. This will trigger an error to be reported. /// `Reduction::Erroneous` if this result indicates the type function is
bool uninhabited; /// erroneous. `Reduction::MaybeOk` otherwise.
Reduction reductionStatus;
/// Any types that need to be progressed or mutated before the reduction may /// Any types that need to be progressed or mutated before the reduction may
/// proceed. /// proceed.
std::vector<TypeId> blockedTypes; std::vector<TypeId> blockedTypes;
@ -112,6 +134,8 @@ struct TypeFunctionReductionResult
std::vector<TypePackId> blockedPacks; std::vector<TypePackId> blockedPacks;
/// A runtime error message from user-defined type functions /// A runtime error message from user-defined type functions
std::optional<std::string> error; std::optional<std::string> error;
/// Messages printed out from user-defined type functions
std::vector<std::string> messages;
}; };
template<typename T> template<typename T>
@ -145,6 +169,7 @@ struct TypePackFunction
struct FunctionGraphReductionResult struct FunctionGraphReductionResult
{ {
ErrorVec errors; ErrorVec errors;
ErrorVec messages;
DenseHashSet<TypeId> blockedTypes{nullptr}; DenseHashSet<TypeId> blockedTypes{nullptr};
DenseHashSet<TypePackId> blockedPacks{nullptr}; DenseHashSet<TypePackId> blockedPacks{nullptr};
DenseHashSet<TypeId> reducedTypes{nullptr}; DenseHashSet<TypeId> reducedTypes{nullptr};

View file

@ -150,6 +150,7 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull<Scope> scope, T
{ {
InternalErrorReporter iceReporter; InternalErrorReporter iceReporter;
UnifierSharedState unifierState(&iceReporter); UnifierSharedState unifierState(&iceReporter);
SimplifierPtr simplifier = newSimplifier(NotNull{typeArena}, builtinTypes);
Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}};
if (FFlag::LuauSolverV2) if (FFlag::LuauSolverV2)
@ -162,7 +163,9 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull<Scope> scope, T
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit;
Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}}; Subtyping subtyping{
builtinTypes, NotNull{typeArena}, NotNull{simplifier.get()}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}
};
return subtyping.isSubtype(subTy, superTy, scope).isSubtype; return subtyping.isSubtype(subTy, superTy, scope).isSubtype;
} }

View file

@ -33,10 +33,14 @@ LUAU_FASTFLAG(DebugLuauLogSolverToJson)
LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_FASTFLAG(DebugLuauEqSatSimplification) LUAU_FASTFLAG(DebugLuauEqSatSimplification)
LUAU_FASTFLAG(LuauTypestateBuiltins2) LUAU_FASTFLAG(LuauTypestateBuiltins2)
LUAU_FASTFLAG(LuauUserTypeFunUpdateAllEnvs)
LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues) LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunExportedAndLocal) LUAU_FASTFLAGVARIABLE(LuauUserTypeFunExportedAndLocal)
LUAU_FASTFLAGVARIABLE(LuauNewSolverPopulateTableLocations) LUAU_FASTFLAGVARIABLE(LuauNewSolverPopulateTableLocations)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunNoExtraConstraint)
LUAU_FASTFLAGVARIABLE(InferGlobalTypes)
namespace Luau namespace Luau
{ {
@ -791,9 +795,10 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc
} }
// Fill it with all visible type functions // Fill it with all visible type functions
if (mainTypeFun) if (FFlag::LuauUserTypeFunUpdateAllEnvs && mainTypeFun)
{ {
UserDefinedFunctionData& userFuncData = mainTypeFun->userFuncData; UserDefinedFunctionData& userFuncData = mainTypeFun->userFuncData;
size_t level = 0;
for (Scope* curr = scope.get(); curr; curr = curr->parent.get()) for (Scope* curr = scope.get(); curr; curr = curr->parent.get())
{ {
@ -803,7 +808,7 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc
continue; continue;
if (auto ty = get<TypeFunctionInstanceType>(tf.type); ty && ty->userFuncData.definition) if (auto ty = get<TypeFunctionInstanceType>(tf.type); ty && ty->userFuncData.definition)
userFuncData.environment[name] = ty->userFuncData.definition; userFuncData.environment[name] = std::make_pair(ty->userFuncData.definition, level);
} }
for (auto& [name, tf] : curr->exportedTypeBindings) for (auto& [name, tf] : curr->exportedTypeBindings)
@ -812,7 +817,34 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc
continue; continue;
if (auto ty = get<TypeFunctionInstanceType>(tf.type); ty && ty->userFuncData.definition) if (auto ty = get<TypeFunctionInstanceType>(tf.type); ty && ty->userFuncData.definition)
userFuncData.environment[name] = ty->userFuncData.definition; userFuncData.environment[name] = std::make_pair(ty->userFuncData.definition, level);
}
level++;
}
}
else if (mainTypeFun)
{
UserDefinedFunctionData& userFuncData = mainTypeFun->userFuncData;
for (Scope* curr = scope.get(); curr; curr = curr->parent.get())
{
for (auto& [name, tf] : curr->privateTypeBindings)
{
if (userFuncData.environment_DEPRECATED.find(name))
continue;
if (auto ty = get<TypeFunctionInstanceType>(tf.type); ty && ty->userFuncData.definition)
userFuncData.environment_DEPRECATED[name] = ty->userFuncData.definition;
}
for (auto& [name, tf] : curr->exportedTypeBindings)
{
if (userFuncData.environment_DEPRECATED.find(name))
continue;
if (auto ty = get<TypeFunctionInstanceType>(tf.type); ty && ty->userFuncData.definition)
userFuncData.environment_DEPRECATED[name] = ty->userFuncData.definition;
} }
} }
} }
@ -1543,18 +1575,22 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeAlias*
ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeFunction* function) ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeFunction* function)
{ {
// If a type function with the same name was already defined, we skip over if (!FFlag::LuauUserTypeFunNoExtraConstraint)
auto bindingIt = scope->privateTypeBindings.find(function->name.value);
if (bindingIt == scope->privateTypeBindings.end())
return ControlFlow::None;
TypeFun typeFunction = bindingIt->second;
// Adding typeAliasExpansionConstraint on user-defined type function for the constraint solver
if (auto typeFunctionTy = get<TypeFunctionInstanceType>(follow(typeFunction.type)))
{ {
TypeId expansionTy = arena->addType(PendingExpansionType{{}, function->name, typeFunctionTy->typeArguments, typeFunctionTy->packArguments}); // If a type function with the same name was already defined, we skip over
addConstraint(scope, function->location, TypeAliasExpansionConstraint{/* target */ expansionTy}); auto bindingIt = scope->privateTypeBindings.find(function->name.value);
if (bindingIt == scope->privateTypeBindings.end())
return ControlFlow::None;
TypeFun typeFunction = bindingIt->second;
// Adding typeAliasExpansionConstraint on user-defined type function for the constraint solver
if (auto typeFunctionTy = get<TypeFunctionInstanceType>(follow(typeFunction.type)))
{
TypeId expansionTy =
arena->addType(PendingExpansionType{{}, function->name, typeFunctionTy->typeArguments, typeFunctionTy->packArguments});
addConstraint(scope, function->location, TypeAliasExpansionConstraint{/* target */ expansionTy});
}
} }
return ControlFlow::None; return ControlFlow::None;
@ -2747,6 +2783,14 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* glob
DefId def = dfg->getDef(global); DefId def = dfg->getDef(global);
rootScope->lvalueTypes[def] = rhsType; rootScope->lvalueTypes[def] = rhsType;
if (FFlag::InferGlobalTypes)
{
// Sketchy: We're specifically looking for BlockedTypes that were
// initially created by ConstraintGenerator::prepopulateGlobalScope.
if (auto bt = get<BlockedType>(follow(*annotatedTy)); bt && !bt->getOwner())
emplaceType<BoundType>(asMutable(*annotatedTy), rhsType);
}
addConstraint(scope, global->location, SubtypeConstraint{rhsType, *annotatedTy}); addConstraint(scope, global->location, SubtypeConstraint{rhsType, *annotatedTy});
} }
} }
@ -3139,9 +3183,9 @@ TypeId ConstraintGenerator::resolveReferenceType(
if (alias.has_value()) if (alias.has_value())
{ {
// If the alias is not generic, we don't need to set up a blocked // If the alias is not generic, we don't need to set up a blocked type and an instantiation constraint
// type and an instantiation constraint. if (alias.has_value() && alias->typeParams.empty() && alias->typePackParams.empty() &&
if (alias.has_value() && alias->typeParams.empty() && alias->typePackParams.empty()) (!FFlag::LuauUserTypeFunNoExtraConstraint || !ref->hasParameterList))
{ {
result = alias->type; result = alias->type;
} }
@ -3651,6 +3695,26 @@ struct GlobalPrepopulator : AstVisitor
return true; return true;
} }
bool visit(AstStatAssign* assign) override
{
if (FFlag::InferGlobalTypes)
{
for (const Luau::AstExpr* expr : assign->vars)
{
if (const AstExprGlobal* g = expr->as<AstExprGlobal>())
{
if (!globalScope->lookup(g->name))
globalScope->globalsToWarn.insert(g->name.value);
TypeId bt = arena->addType(BlockedType{});
globalScope->bindings[g->name] = Binding{bt, g->location};
}
}
}
return true;
}
bool visit(AstStatFunction* function) override bool visit(AstStatFunction* function) override
{ {
if (AstExprGlobal* g = function->name->as<AstExprGlobal>()) if (AstExprGlobal* g = function->name->as<AstExprGlobal>())

View file

@ -35,6 +35,7 @@ LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack)
LUAU_FASTFLAGVARIABLE(DebugLuauEqSatSimplification) LUAU_FASTFLAGVARIABLE(DebugLuauEqSatSimplification)
LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations) LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations)
LUAU_FASTFLAGVARIABLE(LuauAllowNilAssignmentToIndexer) LUAU_FASTFLAGVARIABLE(LuauAllowNilAssignmentToIndexer)
LUAU_FASTFLAG(LuauUserTypeFunNoExtraConstraint)
namespace Luau namespace Luau
{ {
@ -939,12 +940,14 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
if (auto typeFn = get<TypeFunctionInstanceType>(follow(tf->type))) if (auto typeFn = get<TypeFunctionInstanceType>(follow(tf->type)))
pushConstraint(NotNull(constraint->scope.get()), constraint->location, ReduceConstraint{tf->type}); pushConstraint(NotNull(constraint->scope.get()), constraint->location, ReduceConstraint{tf->type});
// If there are no parameters to the type function we can just use the type if (!FFlag::LuauUserTypeFunNoExtraConstraint)
// directly.
if (tf->typeParams.empty() && tf->typePackParams.empty())
{ {
bindResult(tf->type); // If there are no parameters to the type function we can just use the type directly
return true; if (tf->typeParams.empty() && tf->typePackParams.empty())
{
bindResult(tf->type);
return true;
}
} }
// Due to how pending expansion types and TypeFun's are created // Due to how pending expansion types and TypeFun's are created
@ -959,6 +962,16 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
return true; return true;
} }
if (FFlag::LuauUserTypeFunNoExtraConstraint)
{
// If there are no parameters to the type function we can just use the type directly
if (tf->typeParams.empty() && tf->typePackParams.empty())
{
bindResult(tf->type);
return true;
}
}
auto [typeArguments, packArguments] = saturateArguments(arena, builtinTypes, *tf, petv->typeArguments, petv->packArguments); auto [typeArguments, packArguments] = saturateArguments(arena, builtinTypes, *tf, petv->typeArguments, petv->packArguments);
bool sameTypes = std::equal( bool sameTypes = std::equal(
@ -1263,6 +1276,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
OverloadResolver resolver{ OverloadResolver resolver{
builtinTypes, builtinTypes,
NotNull{arena}, NotNull{arena},
simplifier,
normalizer, normalizer,
typeFunctionRuntime, typeFunctionRuntime,
constraint->scope, constraint->scope,
@ -2102,6 +2116,11 @@ bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNull<const Cons
if (force || reductionFinished) if (force || reductionFinished)
{ {
for (auto& message : result.messages)
{
reportError(std::move(message));
}
// if we're completely dispatching this constraint, we want to record any uninhabited type functions to unblock. // if we're completely dispatching this constraint, we want to record any uninhabited type functions to unblock.
for (auto error : result.errors) for (auto error : result.errors)
{ {

View file

@ -143,6 +143,62 @@ static bool isTerminal(const EType& node)
node.get<TNever>() || node.get<TNoRefine>(); node.get<TNever>() || node.get<TNoRefine>();
} }
static bool areTerminalAndDefinitelyDisjoint(const EType& lhs, const EType& rhs)
{
// If either node is non-terminal, then we early exit: we're not going to
// do a state space search for whether something like:
// (A | B | C | D) & (E | F | G | H)
// ... is a disjoint intersection.
if (!isTerminal(lhs) || !isTerminal(rhs))
return false;
// Special case some types that aren't strict, disjoint subsets.
if (lhs.get<TTopClass>() || lhs.get<TClass>())
return !(rhs.get<TTopClass>() || rhs.get<TClass>());
// Handling strings / booleans: these are the types for which we
// expect something like:
//
// "foo" & ~"bar"
//
// ... to simplify to "foo".
if (lhs.get<TString>())
return !(rhs.get<TString>() || rhs.get<SString>());
if (lhs.get<TBoolean>())
return !(rhs.get<TBoolean>() || rhs.get<SBoolean>());
if (auto lhsSString = lhs.get<SString>())
{
auto rhsSString = rhs.get<SString>();
if (!rhsSString)
return !rhs.get<TString>();
return lhsSString->value() != rhsSString->value();
}
if (auto lhsSBoolean = lhs.get<SBoolean>())
{
auto rhsSBoolean = rhs.get<SBoolean>();
if (!rhsSBoolean)
return !rhs.get<TBoolean>();
return lhsSBoolean->value() != rhsSBoolean->value();
}
// At this point:
// - We know both nodes are terminal
// - We know that the LHS is not any boolean, string, or class
// At this point, we have two classes of checks left:
// - Whether the two enodes are exactly the same set (now that the static
// sets have been covered).
// - Whether one of the enodes is a large semantic set such as TAny,
// TUnknown, or TError.
return !(
lhs.index() == rhs.index() ||
lhs.get<TUnknown>() || rhs.get<TUnknown>() || lhs.get<TAny>() || rhs.get<TAny>() || lhs.get<TNoRefine>() || rhs.get<TNoRefine>() ||
lhs.get<TError>() || rhs.get<TError>() || lhs.get<TOpaque>() || rhs.get<TOpaque>()
);
}
static bool isTerminal(const EGraph& egraph, Id eclass) static bool isTerminal(const EGraph& egraph, Id eclass)
{ {
const auto& nodes = egraph[eclass].nodes; const auto& nodes = egraph[eclass].nodes;
@ -336,10 +392,20 @@ Id toId(
LUAU_ASSERT(tfun->packArguments.empty()); LUAU_ASSERT(tfun->packArguments.empty());
std::vector<Id> parts; std::vector<Id> parts;
parts.reserve(tfun->typeArguments.size());
for (TypeId part : tfun->typeArguments) for (TypeId part : tfun->typeArguments)
parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part)); parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part));
return cache(egraph.add(TTypeFun{tfun->function.get(), std::move(parts)})); // This looks sily, but we're making a copy of the specific
// `TypeFunctionInstanceType` outside of the provided arena so that
// we can access the members without fear of the specific TFIT being
// overwritten with a bound type.
return cache(egraph.add(TTypeFun{
std::make_shared<const TypeFunctionInstanceType>(
tfun->function, tfun->typeArguments, tfun->packArguments, tfun->userFuncName, tfun->userFuncData
),
std::move(parts)
}));
} }
else if (get<NoRefineType>(ty)) else if (get<NoRefineType>(ty))
return egraph.add(TNoRefine{}); return egraph.add(TNoRefine{});
@ -574,18 +640,24 @@ TypeId flattenTableNode(
// If a TTable is its own basis, it must be the case that some other // If a TTable is its own basis, it must be the case that some other
// node on this eclass is a TImportedTable. Let's use that. // node on this eclass is a TImportedTable. Let's use that.
bool found = false;
for (size_t i = 0; i < eclass.nodes.size(); ++i) for (size_t i = 0; i < eclass.nodes.size(); ++i)
{ {
if (eclass.nodes[i].get<TImportedTable>()) if (eclass.nodes[i].get<TImportedTable>())
{ {
found = true;
index = i; index = i;
break; break;
} }
} }
// If we couldn't find one, we don't know what to do. Use ErrorType. if (!found)
LUAU_ASSERT(0); {
return builtinTypes->errorType; // If we couldn't find one, we don't know what to do. Use ErrorType.
LUAU_ASSERT(0);
return builtinTypes->errorType;
}
} }
const auto& node = eclass.nodes[index]; const auto& node = eclass.nodes[index];
@ -703,7 +775,20 @@ TypeId fromId(
if (parts.empty()) if (parts.empty())
return builtinTypes->neverType; return builtinTypes->neverType;
else if (parts.size() == 1) else if (parts.size() == 1)
return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, parts[0]); {
TypeId placeholder = arena->addType(BlockedType{});
seen[rootId] = placeholder;
auto result = fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, parts[0]);
if (follow(result) == placeholder)
{
emplaceType<GenericType>(asMutable(placeholder), "EGRAPH-SINGLETON-CYCLE");
}
else
{
emplaceType<BoundType>(asMutable(placeholder), result);
}
return result;
}
else else
{ {
TypeId res = arena->addType(BlockedType{}); TypeId res = arena->addType(BlockedType{});
@ -768,7 +853,11 @@ TypeId fromId(
for (Id part : tfun->operands()) for (Id part : tfun->operands())
args.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part)); args.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part));
asMutable(res)->ty.emplace<TypeFunctionInstanceType>(*tfun->value(), std::move(args)); auto oldInstance = tfun->value();
asMutable(res)->ty.emplace<TypeFunctionInstanceType>(
oldInstance->function, std::move(args), std::vector<TypePackId>(), oldInstance->userFuncName, oldInstance->userFuncData
);
newTypeFunctions.push_back(res); newTypeFunctions.push_back(res);
@ -906,7 +995,7 @@ static std::string getNodeName(const StringCache& strings, const EType& node)
else if (node.get<TNever>()) else if (node.get<TNever>())
return "never"; return "never";
else if (auto tfun = node.get<TTypeFun>()) else if (auto tfun = node.get<TTypeFun>())
return "tfun " + tfun->value()->name; return "tfun " + tfun->value()->function->name;
else if (node.get<Negation>()) else if (node.get<Negation>())
return "~"; return "~";
else if (node.get<Invalid>()) else if (node.get<Invalid>())
@ -1552,11 +1641,6 @@ std::optional<EType> intersectOne(EGraph& egraph, Id hereId, const EType* hereNo
thereNode->get<Intersection>() || thereNode->get<Negation>() || hereNode->get<TOpaque>() || thereNode->get<TOpaque>()) thereNode->get<Intersection>() || thereNode->get<Negation>() || hereNode->get<TOpaque>() || thereNode->get<TOpaque>())
return std::nullopt; return std::nullopt;
if (hereNode->get<TAny>())
return *thereNode;
if (thereNode->get<TAny>())
return *hereNode;
if (hereNode->get<TUnknown>()) if (hereNode->get<TUnknown>())
return *thereNode; return *thereNode;
if (thereNode->get<TUnknown>()) if (thereNode->get<TUnknown>())
@ -1839,6 +1923,74 @@ void Simplifier::intersectWithNegatedClass(Id id)
} }
} }
void Simplifier::intersectWithNegatedAtom(Id id)
{
// Let I and ~J be two arbitrary distinct operands of an intersection where
// I and J are terminal but are not type variables. (free, generic, or
// otherwise opaque)
//
// If I and J are equal, then the whole intersection is equivalent to never.
//
// If I and J are inequal, then J & ~I == J
for (const auto [intersection, intersectionIndex] : Query<Intersection>(&egraph, id))
{
const Slice<const Id>& intersectionOperands = intersection->operands();
for (size_t i = 0; i < intersectionOperands.size(); ++i)
{
for (const auto [negation, negationIndex] : Query<Negation>(&egraph, intersectionOperands[i]))
{
for (size_t negationOperandIndex = 0; negationOperandIndex < egraph[negation->operands()[0]].nodes.size(); ++negationOperandIndex)
{
const EType* negationOperand = &egraph[negation->operands()[0]].nodes[negationOperandIndex];
if (!isTerminal(*negationOperand) || negationOperand->get<TOpaque>())
continue;
for (size_t j = 0; j < intersectionOperands.size(); ++j)
{
if (j == i)
continue;
for (size_t jNodeIndex = 0; jNodeIndex < egraph[intersectionOperands[j]].nodes.size(); ++jNodeIndex)
{
const EType* jNode = &egraph[intersectionOperands[j]].nodes[jNodeIndex];
if (!isTerminal(*jNode) || jNode->get<TOpaque>())
continue;
if (*negationOperand == *jNode)
{
// eg "Hello" & ~"Hello"
// or boolean & ~boolean
subst(
id,
egraph.add(TNever{}),
"intersectWithNegatedAtom",
{{id, intersectionIndex}, {intersectionOperands[i], negationIndex}, {intersectionOperands[j], jNodeIndex}}
);
return;
}
else if (areTerminalAndDefinitelyDisjoint(*jNode, *negationOperand))
{
// eg "Hello" & ~"World"
// or boolean & ~string
std::vector<Id> newOperands(intersectionOperands.begin(), intersectionOperands.end());
newOperands.erase(newOperands.begin() + std::vector<Id>::difference_type(i));
subst(
id,
egraph.add(Intersection{newOperands}),
"intersectWithNegatedAtom",
{{id, intersectionIndex}, {intersectionOperands[i], negationIndex}, {intersectionOperands[j], jNodeIndex}}
);
}
}
}
}
}
}
}
}
void Simplifier::intersectWithNoRefine(Id id) void Simplifier::intersectWithNoRefine(Id id)
{ {
for (const auto pair : Query<Intersection>(&egraph, id)) for (const auto pair : Query<Intersection>(&egraph, id))
@ -2160,7 +2312,7 @@ void Simplifier::intersectTableProperty(Id id)
subst( subst(
id, id,
egraph.add(Intersection{std::move(newIntersectionParts)}), mkIntersection(egraph, std::move(newIntersectionParts)),
"intersectTableProperty", "intersectTableProperty",
{{id, intersectionIndex}, {iId, table1Index}, {jId, table2Index}} {{id, intersectionIndex}, {iId, table1Index}, {jId, table2Index}}
); );
@ -2250,7 +2402,7 @@ void Simplifier::builtinTypeFunctions(Id id)
if (args.size() != 2) if (args.size() != 2)
continue; continue;
const std::string& name = tfun->value()->name; const std::string& name = tfun->value()->function->name;
if (name == "add" || name == "sub" || name == "mul" || name == "div" || name == "idiv" || name == "pow" || name == "mod") if (name == "add" || name == "sub" || name == "mul" || name == "div" || name == "idiv" || name == "pow" || name == "mod")
{ {
if (isTag<TNumber>(args[0]) && isTag<TNumber>(args[1])) if (isTag<TNumber>(args[0]) && isTag<TNumber>(args[1]))
@ -2272,15 +2424,43 @@ void Simplifier::iffyTypeFunctions(Id id)
{ {
const Slice<const Id>& args = tfun->operands(); const Slice<const Id>& args = tfun->operands();
const std::string& name = tfun->value()->name; const std::string& name = tfun->value()->function->name;
if (name == "union") if (name == "union")
subst(id, add(Union{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}}); subst(id, add(Union{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}});
else if (name == "intersect" || name == "refine") else if (name == "intersect")
subst(id, add(Intersection{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}}); subst(id, add(Intersection{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}});
} }
} }
// Replace instances of `lt<X, Y>` and `le<X, Y>` when either X or Y is `number`
// or `string` with `boolean`. Lua semantics are that if we see the expression:
//
// x < y
//
// ... we error if `x` and `y` don't have the same type. We know that for
// `string` and `number`, comparisons will always return a boolean. So if either
// of the arguments to `lt<>` are equivalent to `number` or `string`, then the
// type is effectively `boolean`: either the other type is equivalent, in which
// case we eval to `boolean`, or we diverge (raise an error).
void Simplifier::strictMetamethods(Id id)
{
for (const auto [tfun, index] : Query<TTypeFun>(&egraph, id))
{
const Slice<const Id>& args = tfun->operands();
const std::string& name = tfun->value()->function->name;
if (!(name == "lt" || name == "le") || args.size() != 2)
continue;
if (isTag<TNumber>(args[0]) || isTag<TString>(args[0]) || isTag<TNumber>(args[1]) || isTag<TString>(args[1]))
{
subst(id, add(TBoolean{}), __FUNCTION__, {{id, index}});
}
}
}
static void deleteSimplifier(Simplifier* s) static void deleteSimplifier(Simplifier* s)
{ {
delete s; delete s;
@ -2308,6 +2488,7 @@ std::optional<EqSatSimplificationResult> eqSatSimplify(NotNull<Simplifier> simpl
&Simplifier::simplifyUnion, &Simplifier::simplifyUnion,
&Simplifier::uninhabitedIntersection, &Simplifier::uninhabitedIntersection,
&Simplifier::intersectWithNegatedClass, &Simplifier::intersectWithNegatedClass,
&Simplifier::intersectWithNegatedAtom,
&Simplifier::intersectWithNoRefine, &Simplifier::intersectWithNoRefine,
&Simplifier::cyclicIntersectionOfUnion, &Simplifier::cyclicIntersectionOfUnion,
&Simplifier::cyclicUnionOfIntersection, &Simplifier::cyclicUnionOfIntersection,
@ -2318,6 +2499,7 @@ std::optional<EqSatSimplificationResult> eqSatSimplify(NotNull<Simplifier> simpl
&Simplifier::unneededTableModification, &Simplifier::unneededTableModification,
&Simplifier::builtinTypeFunctions, &Simplifier::builtinTypeFunctions,
&Simplifier::iffyTypeFunctions, &Simplifier::iffyTypeFunctions,
&Simplifier::strictMetamethods,
}; };
std::unordered_set<Id> seen; std::unordered_set<Id> seen;

View file

@ -50,6 +50,8 @@ LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false)
LUAU_FASTFLAG(StudioReportLuauAny2) LUAU_FASTFLAG(StudioReportLuauAny2)
LUAU_FASTFLAGVARIABLE(LuauStoreSolverTypeOnModule) LUAU_FASTFLAGVARIABLE(LuauStoreSolverTypeOnModule)
LUAU_FASTFLAGVARIABLE(LuauReferenceAllocatorInNewSolver)
namespace Luau namespace Luau
{ {
@ -1317,6 +1319,11 @@ ModulePtr check(
result->mode = mode; result->mode = mode;
result->internalTypes.owningModule = result.get(); result->internalTypes.owningModule = result.get();
result->interfaceTypes.owningModule = result.get(); result->interfaceTypes.owningModule = result.get();
if (FFlag::LuauReferenceAllocatorInNewSolver)
{
result->allocator = sourceModule.allocator;
result->names = sourceModule.names;
}
iceHandler->moduleName = sourceModule.name; iceHandler->moduleName = sourceModule.name;
@ -1427,6 +1434,7 @@ ModulePtr check(
case Mode::Nonstrict: case Mode::Nonstrict:
Luau::checkNonStrict( Luau::checkNonStrict(
builtinTypes, builtinTypes,
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime}, NotNull{&typeFunctionRuntime},
iceHandler, iceHandler,
NotNull{&unifierState}, NotNull{&unifierState},
@ -1440,7 +1448,14 @@ ModulePtr check(
// fallthrough intentional // fallthrough intentional
case Mode::Strict: case Mode::Strict:
Luau::check( Luau::check(
builtinTypes, NotNull{&typeFunctionRuntime}, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get() builtinTypes,
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
NotNull{&unifierState},
NotNull{&limits},
logger.get(),
sourceModule,
result.get()
); );
break; break;
case Mode::NoCheck: case Mode::NoCheck:

View file

@ -10,12 +10,14 @@
#include "Luau/VisitType.h" #include "Luau/VisitType.h"
LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete) LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete)
LUAU_FASTFLAGVARIABLE(LuauGeneralizationRemoveRecursiveUpperBound)
namespace Luau namespace Luau
{ {
struct MutatingGeneralizer : TypeOnceVisitor struct MutatingGeneralizer : TypeOnceVisitor
{ {
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
NotNull<Scope> scope; NotNull<Scope> scope;
@ -29,6 +31,7 @@ struct MutatingGeneralizer : TypeOnceVisitor
bool avoidSealingTables = false; bool avoidSealingTables = false;
MutatingGeneralizer( MutatingGeneralizer(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Scope> scope, NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> cachedTypes, NotNull<DenseHashSet<TypeId>> cachedTypes,
@ -37,6 +40,7 @@ struct MutatingGeneralizer : TypeOnceVisitor
bool avoidSealingTables bool avoidSealingTables
) )
: TypeOnceVisitor(/* skipBoundTypes */ true) : TypeOnceVisitor(/* skipBoundTypes */ true)
, arena(arena)
, builtinTypes(builtinTypes) , builtinTypes(builtinTypes)
, scope(scope) , scope(scope)
, cachedTypes(cachedTypes) , cachedTypes(cachedTypes)
@ -229,6 +233,53 @@ struct MutatingGeneralizer : TypeOnceVisitor
else else
{ {
TypeId ub = follow(ft->upperBound); TypeId ub = follow(ft->upperBound);
if (FFlag::LuauGeneralizationRemoveRecursiveUpperBound)
{
// If the upper bound is a union type or an intersection type,
// and one of it's members is the free type we're
// generalizing, don't include it in the upper bound. For a
// free type such as:
//
// t1 where t1 = D <: 'a <: (A | B | C | t1)
//
// Naively replacing it with it's upper bound creates:
//
// t1 where t1 = A | B | C | t1
//
// It makes sense to just optimize this and exclude the
// recursive component by semantic subtyping rules.
if (auto itv = get<IntersectionType>(ub))
{
std::vector<TypeId> newIds;
newIds.reserve(itv->parts.size());
for (auto part : itv)
{
if (part != ty)
newIds.push_back(part);
}
if (newIds.size() == 1)
ub = newIds[0];
else if (newIds.size() > 0)
ub = arena->addType(IntersectionType{std::move(newIds)});
}
else if (auto utv = get<UnionType>(ub))
{
std::vector<TypeId> newIds;
newIds.reserve(utv->options.size());
for (auto part : utv)
{
if (part != ty)
newIds.push_back(part);
}
if (newIds.size() == 1)
ub = newIds[0];
else if (newIds.size() > 0)
ub = arena->addType(UnionType{std::move(newIds)});
}
}
if (FreeType* upperFree = getMutable<FreeType>(ub); upperFree && upperFree->lowerBound == ty) if (FreeType* upperFree = getMutable<FreeType>(ub); upperFree && upperFree->lowerBound == ty)
upperFree->lowerBound = builtinTypes->neverType; upperFree->lowerBound = builtinTypes->neverType;
else else
@ -969,7 +1020,7 @@ std::optional<TypeId> generalize(
FreeTypeSearcher fts{scope, cachedTypes}; FreeTypeSearcher fts{scope, cachedTypes};
fts.traverse(ty); fts.traverse(ty);
MutatingGeneralizer gen{builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables}; MutatingGeneralizer gen{arena, builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables};
gen.traverse(ty); gen.traverse(ty);

View file

@ -19,7 +19,6 @@
#include <iostream> #include <iostream>
#include <iterator> #include <iterator>
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunNonstrict)
LUAU_FASTFLAGVARIABLE(LuauCountSelfCallsNonstrict) LUAU_FASTFLAGVARIABLE(LuauCountSelfCallsNonstrict)
namespace Luau namespace Luau
@ -158,6 +157,7 @@ private:
struct NonStrictTypeChecker struct NonStrictTypeChecker
{ {
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
NotNull<Simplifier> simplifier;
NotNull<TypeFunctionRuntime> typeFunctionRuntime; NotNull<TypeFunctionRuntime> typeFunctionRuntime;
const NotNull<InternalErrorReporter> ice; const NotNull<InternalErrorReporter> ice;
NotNull<TypeArena> arena; NotNull<TypeArena> arena;
@ -174,6 +174,7 @@ struct NonStrictTypeChecker
NonStrictTypeChecker( NonStrictTypeChecker(
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
const NotNull<InternalErrorReporter> ice, const NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState, NotNull<UnifierSharedState> unifierState,
@ -182,12 +183,13 @@ struct NonStrictTypeChecker
Module* module Module* module
) )
: builtinTypes(builtinTypes) : builtinTypes(builtinTypes)
, simplifier(simplifier)
, typeFunctionRuntime(typeFunctionRuntime) , typeFunctionRuntime(typeFunctionRuntime)
, ice(ice) , ice(ice)
, arena(arena) , arena(arena)
, module(module) , module(module)
, normalizer{arena, builtinTypes, unifierState, /* cache inhabitance */ true} , normalizer{arena, builtinTypes, unifierState, /* cache inhabitance */ true}
, subtyping{builtinTypes, arena, NotNull(&normalizer), typeFunctionRuntime, ice} , subtyping{builtinTypes, arena, simplifier, NotNull(&normalizer), typeFunctionRuntime, ice}
, dfg(dfg) , dfg(dfg)
, limits(limits) , limits(limits)
{ {
@ -232,13 +234,14 @@ struct NonStrictTypeChecker
if (noTypeFunctionErrors.find(instance)) if (noTypeFunctionErrors.find(instance))
return instance; return instance;
ErrorVec errors = reduceTypeFunctions( ErrorVec errors =
instance, reduceTypeFunctions(
location, instance,
TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, typeFunctionRuntime, ice, limits}, location,
true TypeFunctionContext{arena, builtinTypes, stack.back(), simplifier, NotNull{&normalizer}, typeFunctionRuntime, ice, limits},
) true
.errors; )
.errors;
if (errors.empty()) if (errors.empty())
noTypeFunctionErrors.insert(instance); noTypeFunctionErrors.insert(instance);
@ -424,9 +427,6 @@ struct NonStrictTypeChecker
NonStrictContext visit(AstStatTypeFunction* typeFunc) NonStrictContext visit(AstStatTypeFunction* typeFunc)
{ {
if (!FFlag::LuauUserTypeFunNonstrict)
reportError(GenericError{"This syntax is not supported"}, typeFunc->location);
return {}; return {};
} }
@ -888,6 +888,7 @@ private:
void checkNonStrict( void checkNonStrict(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> ice, NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState, NotNull<UnifierSharedState> unifierState,
@ -899,7 +900,9 @@ void checkNonStrict(
{ {
LUAU_TIMETRACE_SCOPE("checkNonStrict", "Typechecking"); LUAU_TIMETRACE_SCOPE("checkNonStrict", "Typechecking");
NonStrictTypeChecker typeChecker{NotNull{&module->internalTypes}, builtinTypes, typeFunctionRuntime, ice, unifierState, dfg, limits, module}; NonStrictTypeChecker typeChecker{
NotNull{&module->internalTypes}, builtinTypes, simplifier, typeFunctionRuntime, ice, unifierState, dfg, limits, module
};
typeChecker.visit(sourceModule.root); typeChecker.visit(sourceModule.root);
unfreeze(module->interfaceTypes); unfreeze(module->interfaceTypes);
copyErrors(module->errors, module->interfaceTypes, builtinTypes); copyErrors(module->errors, module->interfaceTypes, builtinTypes);

View file

@ -3465,7 +3465,14 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm)
return arena->addType(UnionType{std::move(result)}); return arena->addType(UnionType{std::move(result)});
} }
bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice) bool isSubtype(
TypeId subTy,
TypeId superTy,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
InternalErrorReporter& ice
)
{ {
UnifierSharedState sharedState{&ice}; UnifierSharedState sharedState{&ice};
TypeArena arena; TypeArena arena;
@ -3478,7 +3485,7 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<Built
// Subtyping under DCR is not implemented using unification! // Subtyping under DCR is not implemented using unification!
if (FFlag::LuauSolverV2) if (FFlag::LuauSolverV2)
{ {
Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}}; Subtyping subtyping{builtinTypes, NotNull{&arena}, simplifier, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}};
return subtyping.isSubtype(subTy, superTy, scope).isSubtype; return subtyping.isSubtype(subTy, superTy, scope).isSubtype;
} }
@ -3491,7 +3498,14 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<Built
} }
} }
bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice) bool isSubtype(
TypePackId subPack,
TypePackId superPack,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
InternalErrorReporter& ice
)
{ {
UnifierSharedState sharedState{&ice}; UnifierSharedState sharedState{&ice};
TypeArena arena; TypeArena arena;
@ -3504,7 +3518,7 @@ bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, N
// Subtyping under DCR is not implemented using unification! // Subtyping under DCR is not implemented using unification!
if (FFlag::LuauSolverV2) if (FFlag::LuauSolverV2)
{ {
Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}}; Subtyping subtyping{builtinTypes, NotNull{&arena}, simplifier, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}};
return subtyping.isSubtype(subPack, superPack, scope).isSubtype; return subtyping.isSubtype(subPack, superPack, scope).isSubtype;
} }

View file

@ -16,6 +16,7 @@ namespace Luau
OverloadResolver::OverloadResolver( OverloadResolver::OverloadResolver(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope, NotNull<Scope> scope,
@ -25,12 +26,13 @@ OverloadResolver::OverloadResolver(
) )
: builtinTypes(builtinTypes) : builtinTypes(builtinTypes)
, arena(arena) , arena(arena)
, simplifier(simplifier)
, normalizer(normalizer) , normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime) , typeFunctionRuntime(typeFunctionRuntime)
, scope(scope) , scope(scope)
, ice(reporter) , ice(reporter)
, limits(limits) , limits(limits)
, subtyping({builtinTypes, arena, normalizer, typeFunctionRuntime, ice}) , subtyping({builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, ice})
, callLoc(callLocation) , callLoc(callLocation)
{ {
} }
@ -202,7 +204,7 @@ std::pair<OverloadResolver::Analysis, ErrorVec> OverloadResolver::checkOverload_
) )
{ {
FunctionGraphReductionResult result = reduceTypeFunctions( FunctionGraphReductionResult result = reduceTypeFunctions(
fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, ice, limits}, /*force=*/true fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, simplifier, normalizer, typeFunctionRuntime, ice, limits}, /*force=*/true
); );
if (!result.errors.empty()) if (!result.errors.empty())
return {OverloadIsNonviable, result.errors}; return {OverloadIsNonviable, result.errors};
@ -404,9 +406,10 @@ void OverloadResolver::add(Analysis analysis, TypeId ty, ErrorVec&& errors)
// we wrap calling the overload resolver in a separate function to reduce overall stack pressure in `solveFunctionCall`. // we wrap calling the overload resolver in a separate function to reduce overall stack pressure in `solveFunctionCall`.
// this limits the lifetime of `OverloadResolver`, a large type, to only as long as it is actually needed. // this limits the lifetime of `OverloadResolver`, a large type, to only as long as it is actually needed.
std::optional<TypeId> selectOverload( static std::optional<TypeId> selectOverload(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope, NotNull<Scope> scope,
@ -417,7 +420,7 @@ std::optional<TypeId> selectOverload(
TypePackId argsPack TypePackId argsPack
) )
{ {
auto resolver = std::make_unique<OverloadResolver>(builtinTypes, arena, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location); auto resolver = std::make_unique<OverloadResolver>(builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location);
auto [status, overload] = resolver->selectOverload(fn, argsPack); auto [status, overload] = resolver->selectOverload(fn, argsPack);
if (status == OverloadResolver::Analysis::Ok) if (status == OverloadResolver::Analysis::Ok)
@ -432,6 +435,7 @@ std::optional<TypeId> selectOverload(
SolveResult solveFunctionCall( SolveResult solveFunctionCall(
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter, NotNull<InternalErrorReporter> iceReporter,
@ -443,7 +447,7 @@ SolveResult solveFunctionCall(
) )
{ {
std::optional<TypeId> overloadToUse = std::optional<TypeId> overloadToUse =
selectOverload(builtinTypes, arena, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location, fn, argsPack); selectOverload(builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location, fn, argsPack);
if (!overloadToUse) if (!overloadToUse)
return {SolveResult::NoMatchingOverload}; return {SolveResult::NoMatchingOverload};

View file

@ -211,6 +211,16 @@ void Scope::inheritRefinements(const ScopePtr& childScope)
} }
} }
bool Scope::shouldWarnGlobal(std::string name) const
{
for (const Scope* current = this; current; current = current->parent.get())
{
if (current->globalsToWarn.contains(name))
return true;
}
return false;
}
bool subsumesStrict(Scope* left, Scope* right) bool subsumesStrict(Scope* left, Scope* right)
{ {
while (right) while (right)

View file

@ -396,12 +396,14 @@ TypePackId* SubtypingEnvironment::getMappedPackBounds(TypePackId tp)
Subtyping::Subtyping( Subtyping::Subtyping(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> typeArena, NotNull<TypeArena> typeArena,
NotNull<Simplifier> simplifier,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter NotNull<InternalErrorReporter> iceReporter
) )
: builtinTypes(builtinTypes) : builtinTypes(builtinTypes)
, arena(typeArena) , arena(typeArena)
, simplifier(simplifier)
, normalizer(normalizer) , normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime) , typeFunctionRuntime(typeFunctionRuntime)
, iceReporter(iceReporter) , iceReporter(iceReporter)
@ -1861,7 +1863,7 @@ TypeId Subtyping::makeAggregateType(const Container& container, TypeId orElse)
std::pair<TypeId, ErrorVec> Subtyping::handleTypeFunctionReductionResult(const TypeFunctionInstanceType* functionInstance, NotNull<Scope> scope) std::pair<TypeId, ErrorVec> Subtyping::handleTypeFunctionReductionResult(const TypeFunctionInstanceType* functionInstance, NotNull<Scope> scope)
{ {
TypeFunctionContext context{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, iceReporter, NotNull{&limits}}; TypeFunctionContext context{arena, builtinTypes, scope, simplifier, normalizer, typeFunctionRuntime, iceReporter, NotNull{&limits}};
TypeId function = arena->addType(*functionInstance); TypeId function = arena->addType(*functionInstance);
FunctionGraphReductionResult result = reduceTypeFunctions(function, {}, context, true); FunctionGraphReductionResult result = reduceTypeFunctions(function, {}, context, true);
ErrorVec errors; ErrorVec errors;

View file

@ -386,8 +386,12 @@ public:
} }
AstType* operator()(const NegationType& ntv) AstType* operator()(const NegationType& ntv)
{ {
// FIXME: do the same thing we do with ErrorType AstArray<AstTypeOrPack> params;
throw InternalCompilerError("Cannot convert NegationType into AstNode"); params.size = 1;
params.data = static_cast<AstTypeOrPack*>(allocator->allocate(sizeof(AstType*)));
params.data[0] = AstTypeOrPack{Luau::visit(*this, ntv.ty->ty), nullptr};
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("negate"), std::nullopt, Location(), true, params);
} }
AstType* operator()(const TypeFunctionInstanceType& tfit) AstType* operator()(const TypeFunctionInstanceType& tfit)
{ {

View file

@ -32,6 +32,7 @@
LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_FASTFLAG(InferGlobalTypes)
LUAU_FASTFLAGVARIABLE(LuauTableKeysAreRValues) LUAU_FASTFLAGVARIABLE(LuauTableKeysAreRValues)
namespace Luau namespace Luau
@ -268,6 +269,7 @@ struct InternalTypeFunctionFinder : TypeOnceVisitor
void check( void check(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<UnifierSharedState> unifierState, NotNull<UnifierSharedState> unifierState,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
@ -278,7 +280,7 @@ void check(
{ {
LUAU_TIMETRACE_SCOPE("check", "Typechecking"); LUAU_TIMETRACE_SCOPE("check", "Typechecking");
TypeChecker2 typeChecker{builtinTypes, typeFunctionRuntime, unifierState, limits, logger, &sourceModule, module}; TypeChecker2 typeChecker{builtinTypes, simplifier, typeFunctionRuntime, unifierState, limits, logger, &sourceModule, module};
typeChecker.visit(sourceModule.root); typeChecker.visit(sourceModule.root);
@ -295,6 +297,7 @@ void check(
TypeChecker2::TypeChecker2( TypeChecker2::TypeChecker2(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<UnifierSharedState> unifierState, NotNull<UnifierSharedState> unifierState,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
@ -303,6 +306,7 @@ TypeChecker2::TypeChecker2(
Module* module Module* module
) )
: builtinTypes(builtinTypes) : builtinTypes(builtinTypes)
, simplifier(simplifier)
, typeFunctionRuntime(typeFunctionRuntime) , typeFunctionRuntime(typeFunctionRuntime)
, logger(logger) , logger(logger)
, limits(limits) , limits(limits)
@ -310,7 +314,7 @@ TypeChecker2::TypeChecker2(
, sourceModule(sourceModule) , sourceModule(sourceModule)
, module(module) , module(module)
, normalizer{&module->internalTypes, builtinTypes, unifierState, /* cacheInhabitance */ true} , normalizer{&module->internalTypes, builtinTypes, unifierState, /* cacheInhabitance */ true}
, _subtyping{builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, typeFunctionRuntime, NotNull{unifierState->iceHandler}} , _subtyping{builtinTypes, NotNull{&module->internalTypes}, simplifier, NotNull{&normalizer}, typeFunctionRuntime, NotNull{unifierState->iceHandler}}
, subtyping(&_subtyping) , subtyping(&_subtyping)
{ {
} }
@ -492,7 +496,9 @@ TypeId TypeChecker2::checkForTypeFunctionInhabitance(TypeId instance, Location l
reduceTypeFunctions( reduceTypeFunctions(
instance, instance,
location, location,
TypeFunctionContext{NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, typeFunctionRuntime, ice, limits}, TypeFunctionContext{
NotNull{&module->internalTypes}, builtinTypes, stack.back(), simplifier, NotNull{&normalizer}, typeFunctionRuntime, ice, limits
},
true true
) )
.errors; .errors;
@ -1349,7 +1355,17 @@ void TypeChecker2::visit(AstExprGlobal* expr)
{ {
NotNull<Scope> scope = stack.back(); NotNull<Scope> scope = stack.back();
if (!scope->lookup(expr->name)) if (!scope->lookup(expr->name))
{
reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location); reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location);
}
else if (FFlag::InferGlobalTypes)
{
if (scope->shouldWarnGlobal(expr->name.value) && !warnedGlobals.contains(expr->name.value))
{
reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location);
warnedGlobals.insert(expr->name.value);
}
}
} }
void TypeChecker2::visit(AstExprVarargs* expr) void TypeChecker2::visit(AstExprVarargs* expr)
@ -1448,6 +1464,7 @@ void TypeChecker2::visitCall(AstExprCall* call)
OverloadResolver resolver{ OverloadResolver resolver{
builtinTypes, builtinTypes,
NotNull{&module->internalTypes}, NotNull{&module->internalTypes},
simplifier,
NotNull{&normalizer}, NotNull{&normalizer},
typeFunctionRuntime, typeFunctionRuntime,
NotNull{stack.back()}, NotNull{stack.back()},
@ -3024,7 +3041,7 @@ PropertyType TypeChecker2::hasIndexTypeFromType(
{ {
TypeId indexType = follow(tt->indexer->indexType); TypeId indexType = follow(tt->indexer->indexType);
TypeId givenType = module->internalTypes.addType(SingletonType{StringSingleton{prop}}); TypeId givenType = module->internalTypes.addType(SingletonType{StringSingleton{prop}});
if (isSubtype(givenType, indexType, NotNull{module->getModuleScope().get()}, builtinTypes, *ice)) if (isSubtype(givenType, indexType, NotNull{module->getModuleScope().get()}, builtinTypes, simplifier, *ice))
return {NormalizationResult::True, {tt->indexer->indexResultType}}; return {NormalizationResult::True, {tt->indexer->indexResultType}};
} }

File diff suppressed because it is too large Load diff

View file

@ -14,7 +14,7 @@
#include <vector> #include <vector>
LUAU_DYNAMIC_FASTINT(LuauTypeFunctionSerdeIterationLimit) LUAU_DYNAMIC_FASTINT(LuauTypeFunctionSerdeIterationLimit)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixRegister) LUAU_FASTFLAGVARIABLE(LuauUserTypeFunPrintToError)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixNoReadWrite) LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixNoReadWrite)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunThreadBuffer) LUAU_FASTFLAGVARIABLE(LuauUserTypeFunThreadBuffer)
@ -1408,8 +1408,6 @@ static int isEqualToType(lua_State* L)
void registerTypesLibrary(lua_State* L) void registerTypesLibrary(lua_State* L)
{ {
LUAU_ASSERT(FFlag::LuauUserTypeFunFixRegister);
luaL_Reg fields[] = { luaL_Reg fields[] = {
{"unknown", createUnknown}, {"unknown", createUnknown},
{"never", createNever}, {"never", createNever},
@ -1464,170 +1462,99 @@ static int typeUserdataIndex(lua_State* L)
void registerTypeUserData(lua_State* L) void registerTypeUserData(lua_State* L)
{ {
if (FFlag::LuauUserTypeFunFixRegister) luaL_Reg typeUserdataMethods[] = {
{ {"is", checkTag},
luaL_Reg typeUserdataMethods[] = {
{"is", checkTag},
// Negation type methods // Negation type methods
{"inner", getNegatedValue}, {"inner", getNegatedValue},
// Singleton type methods // Singleton type methods
{"value", getSingletonValue}, {"value", getSingletonValue},
// Table type methods // Table type methods
{"setproperty", setTableProp}, {"setproperty", setTableProp},
{"setreadproperty", setReadTableProp}, {"setreadproperty", setReadTableProp},
{"setwriteproperty", setWriteTableProp}, {"setwriteproperty", setWriteTableProp},
{"readproperty", readTableProp}, {"readproperty", readTableProp},
{"writeproperty", writeTableProp}, {"writeproperty", writeTableProp},
{"properties", getProps}, {"properties", getProps},
{"setindexer", setTableIndexer}, {"setindexer", setTableIndexer},
{"setreadindexer", setTableReadIndexer}, {"setreadindexer", setTableReadIndexer},
{"setwriteindexer", setTableWriteIndexer}, {"setwriteindexer", setTableWriteIndexer},
{"indexer", getIndexer}, {"indexer", getIndexer},
{"readindexer", getReadIndexer}, {"readindexer", getReadIndexer},
{"writeindexer", getWriteIndexer}, {"writeindexer", getWriteIndexer},
{"setmetatable", setTableMetatable}, {"setmetatable", setTableMetatable},
{"metatable", getMetatable}, {"metatable", getMetatable},
// Function type methods // Function type methods
{"setparameters", setFunctionParameters}, {"setparameters", setFunctionParameters},
{"parameters", getFunctionParameters}, {"parameters", getFunctionParameters},
{"setreturns", setFunctionReturns}, {"setreturns", setFunctionReturns},
{"returns", getFunctionReturns}, {"returns", getFunctionReturns},
// Union and Intersection type methods // Union and Intersection type methods
{"components", getComponents}, {"components", getComponents},
// Class type methods // Class type methods
{"parent", getClassParent}, {"parent", getClassParent},
{nullptr, nullptr} {nullptr, nullptr}
}; };
// Create and register metatable for type userdata // Create and register metatable for type userdata
luaL_newmetatable(L, "type"); luaL_newmetatable(L, "type");
// Protect metatable from being changed // Protect metatable from being changed
lua_pushstring(L, "The metatable is locked"); lua_pushstring(L, "The metatable is locked");
lua_setfield(L, -2, "__metatable"); lua_setfield(L, -2, "__metatable");
lua_pushcfunction(L, isEqualToType, "__eq"); lua_pushcfunction(L, isEqualToType, "__eq");
lua_setfield(L, -2, "__eq"); lua_setfield(L, -2, "__eq");
// Indexing will be a dynamic function because some type fields are dynamic // Indexing will be a dynamic function because some type fields are dynamic
lua_newtable(L); lua_newtable(L);
luaL_register(L, nullptr, typeUserdataMethods); luaL_register(L, nullptr, typeUserdataMethods);
lua_setreadonly(L, -1, true); lua_setreadonly(L, -1, true);
lua_pushcclosure(L, typeUserdataIndex, "__index", 1); lua_pushcclosure(L, typeUserdataIndex, "__index", 1);
lua_setfield(L, -2, "__index"); lua_setfield(L, -2, "__index");
lua_setreadonly(L, -1, true); lua_setreadonly(L, -1, true);
lua_pop(L, 1); lua_pop(L, 1);
}
else
{
// List of fields for type userdata
luaL_Reg typeUserdataFields[] = {
{"unknown", createUnknown},
{"never", createNever},
{"any", createAny},
{"boolean", createBoolean},
{"number", createNumber},
{"string", createString},
{nullptr, nullptr}
};
// List of methods for type userdata
luaL_Reg typeUserdataMethods[] = {
{"singleton", createSingleton},
{"negationof", createNegation},
{"unionof", createUnion},
{"intersectionof", createIntersection},
{"newtable", createTable},
{"newfunction", createFunction},
{"copy", deepCopy},
// Common methods
{"is", checkTag},
// Negation type methods
{"inner", getNegatedValue},
// Singleton type methods
{"value", getSingletonValue},
// Table type methods
{"setproperty", setTableProp},
{"setreadproperty", setReadTableProp},
{"setwriteproperty", setWriteTableProp},
{"readproperty", readTableProp},
{"writeproperty", writeTableProp},
{"properties", getProps},
{"setindexer", setTableIndexer},
{"setreadindexer", setTableReadIndexer},
{"setwriteindexer", setTableWriteIndexer},
{"indexer", getIndexer},
{"readindexer", getReadIndexer},
{"writeindexer", getWriteIndexer},
{"setmetatable", setTableMetatable},
{"metatable", getMetatable},
// Function type methods
{"setparameters", setFunctionParameters},
{"parameters", getFunctionParameters},
{"setreturns", setFunctionReturns},
{"returns", getFunctionReturns},
// Union and Intersection type methods
{"components", getComponents},
// Class type methods
{"parent", getClassParent},
{"indexer", getIndexer},
{nullptr, nullptr}
};
// Create and register metatable for type userdata
luaL_newmetatable(L, "type");
// Protect metatable from being fetched.
lua_pushstring(L, "The metatable is locked");
lua_setfield(L, -2, "__metatable");
// Set type userdata metatable's __eq to type_equals()
lua_pushcfunction(L, isEqualToType, "__eq");
lua_setfield(L, -2, "__eq");
// Set type userdata metatable's __index to itself
lua_pushvalue(L, -1); // Push a copy of type userdata metatable
lua_setfield(L, -2, "__index");
luaL_register(L, nullptr, typeUserdataMethods);
// Set fields for type userdata
for (luaL_Reg* l = typeUserdataFields; l->name; l++)
{
l->func(L);
lua_setfield(L, -2, l->name);
}
// Set types library as a global name "types"
lua_setglobal(L, "types");
}
// Sets up a destructor for the type userdata. // Sets up a destructor for the type userdata.
lua_setuserdatadtor(L, kTypeUserdataTag, deallocTypeUserData); lua_setuserdatadtor(L, kTypeUserdataTag, deallocTypeUserData);
} }
// Used to redirect all the removed global functions to say "this function is unsupported" // Used to redirect all the removed global functions to say "this function is unsupported"
int unsupportedFunction(lua_State* L) static int unsupportedFunction(lua_State* L)
{ {
luaL_errorL(L, "this function is not supported in type functions"); luaL_errorL(L, "this function is not supported in type functions");
return 0; return 0;
} }
static int print(lua_State* L)
{
std::string result;
int n = lua_gettop(L);
for (int i = 1; i <= n; i++)
{
size_t l = 0;
const char* s = luaL_tolstring(L, i, &l); // convert to string using __tostring et al
if (i > 1)
result.append('\t', 1);
result.append(s, l);
lua_pop(L, 1);
}
auto ctx = getTypeFunctionRuntime(L);
ctx->messages.push_back(std::move(result));
return 0;
}
// Add libraries / globals for type function environment // Add libraries / globals for type function environment
void setTypeFunctionEnvironment(lua_State* L) void setTypeFunctionEnvironment(lua_State* L)
{ {
@ -1659,12 +1586,28 @@ void setTypeFunctionEnvironment(lua_State* L)
luaopen_base(L); luaopen_base(L);
lua_pop(L, 1); lua_pop(L, 1);
// Remove certain global functions from the base library if (FFlag::LuauUserTypeFunPrintToError)
static const std::string unavailableGlobals[] = {"gcinfo", "getfenv", "newproxy", "setfenv", "pcall", "xpcall"};
for (auto& name : unavailableGlobals)
{ {
lua_pushcfunction(L, unsupportedFunction, "Removing global function from type function environment"); // Remove certain global functions from the base library
lua_setglobal(L, name.c_str()); static const char* unavailableGlobals[] = {"gcinfo", "getfenv", "newproxy", "setfenv", "pcall", "xpcall"};
for (auto& name : unavailableGlobals)
{
lua_pushcfunction(L, unsupportedFunction, name);
lua_setglobal(L, name);
}
lua_pushcfunction(L, print, "print");
lua_setglobal(L, "print");
}
else
{
// Remove certain global functions from the base library
static const std::string unavailableGlobals[] = {"gcinfo", "getfenv", "newproxy", "setfenv", "pcall", "xpcall"};
for (auto& name : unavailableGlobals)
{
lua_pushcfunction(L, unsupportedFunction, "Removing global function from type function environment");
lua_setglobal(L, name.c_str());
}
} }
} }

View file

@ -20,7 +20,6 @@
// currently, controls serialization, deserialization, and `type.copy` // currently, controls serialization, deserialization, and `type.copy`
LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFunctionSerdeIterationLimit, 100'000); LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFunctionSerdeIterationLimit, 100'000);
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixMetatable)
LUAU_FASTFLAG(LuauUserTypeFunThreadBuffer) LUAU_FASTFLAG(LuauUserTypeFunThreadBuffer)
namespace Luau namespace Luau
@ -383,21 +382,9 @@ private:
void serializeChildren(const MetatableType* m1, TypeFunctionTableType* m2) void serializeChildren(const MetatableType* m1, TypeFunctionTableType* m2)
{ {
if (FFlag::LuauUserTypeFunFixMetatable) // Serialize main part of the metatable immediately
{ if (auto tableTy = get<TableType>(m1->table))
// Serialize main part of the metatable immediately serializeChildren(tableTy, m2);
if (auto tableTy = get<TableType>(m1->table))
serializeChildren(tableTy, m2);
}
else
{
auto tmpTable = get<TypeFunctionTableType>(shallowSerialize(m1->table));
if (!tmpTable)
state->ctx->ice->ice("Serializing user defined type function arguments: metatable's table is not a TableType");
m2->props = tmpTable->props;
m2->indexer = tmpTable->indexer;
}
m2->metatable = shallowSerialize(m1->metatable); m2->metatable = shallowSerialize(m1->metatable);
} }

View file

@ -32,7 +32,7 @@ LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500)
LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification)
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauMetatableFollow) LUAU_FASTFLAGVARIABLE(LuauOldSolverCreatesChildScopePointers)
namespace Luau namespace Luau
{ {
@ -2866,7 +2866,7 @@ TypeId TypeChecker::checkRelationalOperation(
std::optional<TypeId> metamethod = findMetatableEntry(lhsType, metamethodName, expr.location, /* addErrors= */ true); std::optional<TypeId> metamethod = findMetatableEntry(lhsType, metamethodName, expr.location, /* addErrors= */ true);
if (metamethod) if (metamethod)
{ {
if (const FunctionType* ftv = get<FunctionType>(FFlag::LuauMetatableFollow ? follow(*metamethod) : *metamethod)) if (const FunctionType* ftv = get<FunctionType>(follow(*metamethod)))
{ {
if (isEquality) if (isEquality)
{ {
@ -5205,6 +5205,13 @@ LUAU_NOINLINE void TypeChecker::reportErrorCodeTooComplex(const Location& locati
ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel) ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel)
{ {
ScopePtr scope = std::make_shared<Scope>(parent, subLevel); ScopePtr scope = std::make_shared<Scope>(parent, subLevel);
if (FFlag::LuauOldSolverCreatesChildScopePointers)
{
scope->location = location;
scope->returnType = parent->returnType;
parent->children.emplace_back(scope.get());
}
currentModule->scopes.push_back(std::make_pair(location, scope)); currentModule->scopes.push_back(std::make_pair(location, scope));
return scope; return scope;
} }
@ -5215,6 +5222,12 @@ ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& locatio
ScopePtr scope = std::make_shared<Scope>(parent); ScopePtr scope = std::make_shared<Scope>(parent);
scope->level = parent->level; scope->level = parent->level;
scope->varargPack = parent->varargPack; scope->varargPack = parent->varargPack;
if (FFlag::LuauOldSolverCreatesChildScopePointers)
{
scope->location = location;
scope->returnType = parent->returnType;
parent->children.emplace_back(scope.get());
}
currentModule->scopes.push_back(std::make_pair(location, scope)); currentModule->scopes.push_back(std::make_pair(location, scope));
return scope; return scope;

View file

@ -230,17 +230,6 @@ private:
bool skipComments; bool skipComments;
bool readNames; bool readNames;
// This offset represents a column offset to be applied to any positions created by the lexer until the next new line.
// For example:
// local x = 4
// local y = 5
// If we start lexing from the position of `l` in `local x = 4`, the line number will be 1, and the column will be 4
// However, because the lexer calculates line offsets by 'index in source buffer where there is a newline', the column
// count will start at 0. For this reason, for just the first line, we'll need to store the offset.
unsigned int lexResumeOffset;
enum class BraceType enum class BraceType
{ {
InterpolatedString, InterpolatedString,

View file

@ -8,7 +8,7 @@
#include <limits.h> #include <limits.h>
LUAU_FASTFLAGVARIABLE(LexerResumesFromPosition) LUAU_FASTFLAGVARIABLE(LexerResumesFromPosition2)
namespace Luau namespace Luau
{ {
@ -308,16 +308,15 @@ Lexer::Lexer(const char* buffer, size_t bufferSize, AstNameTable& names, Positio
: buffer(buffer) : buffer(buffer)
, bufferSize(bufferSize) , bufferSize(bufferSize)
, offset(0) , offset(0)
, line(FFlag::LexerResumesFromPosition ? startPosition.line : 0) , line(FFlag::LexerResumesFromPosition2 ? startPosition.line : 0)
, lineOffset(0) , lineOffset(FFlag::LexerResumesFromPosition2 ? 0u - startPosition.column : 0)
, lexeme( , lexeme(
(FFlag::LexerResumesFromPosition ? Location(Position(startPosition.line, startPosition.column), 0) : Location(Position(0, 0), 0)), (FFlag::LexerResumesFromPosition2 ? Location(Position(startPosition.line, startPosition.column), 0) : Location(Position(0, 0), 0)),
Lexeme::Eof Lexeme::Eof
) )
, names(names) , names(names)
, skipComments(false) , skipComments(false)
, readNames(true) , readNames(true)
, lexResumeOffset(FFlag::LexerResumesFromPosition ? startPosition.column : 0)
{ {
} }
@ -372,7 +371,6 @@ Lexeme Lexer::lookahead()
Location currentPrevLocation = prevLocation; Location currentPrevLocation = prevLocation;
size_t currentBraceStackSize = braceStack.size(); size_t currentBraceStackSize = braceStack.size();
BraceType currentBraceType = braceStack.empty() ? BraceType::Normal : braceStack.back(); BraceType currentBraceType = braceStack.empty() ? BraceType::Normal : braceStack.back();
unsigned int currentLexResumeOffset = lexResumeOffset;
Lexeme result = next(); Lexeme result = next();
@ -381,7 +379,6 @@ Lexeme Lexer::lookahead()
lineOffset = currentLineOffset; lineOffset = currentLineOffset;
lexeme = currentLexeme; lexeme = currentLexeme;
prevLocation = currentPrevLocation; prevLocation = currentPrevLocation;
lexResumeOffset = currentLexResumeOffset;
if (braceStack.size() < currentBraceStackSize) if (braceStack.size() < currentBraceStackSize)
braceStack.push_back(currentBraceType); braceStack.push_back(currentBraceType);
@ -412,9 +409,10 @@ char Lexer::peekch(unsigned int lookahead) const
return (offset + lookahead < bufferSize) ? buffer[offset + lookahead] : 0; return (offset + lookahead < bufferSize) ? buffer[offset + lookahead] : 0;
} }
LUAU_FORCEINLINE
Position Lexer::position() const Position Lexer::position() const
{ {
return Position(line, offset - lineOffset + (FFlag::LexerResumesFromPosition ? lexResumeOffset : 0)); return Position(line, offset - lineOffset);
} }
LUAU_FORCEINLINE LUAU_FORCEINLINE
@ -433,9 +431,6 @@ void Lexer::consumeAny()
{ {
line++; line++;
lineOffset = offset + 1; lineOffset = offset + 1;
// every new line, we reset
if (FFlag::LexerResumesFromPosition)
lexResumeOffset = 0;
} }
offset++; offset++;

View file

@ -18,10 +18,8 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
// flag so that we don't break production games by reverting syntax changes. // flag so that we don't break production games by reverting syntax changes.
// See docs/SyntaxChanges.md for an explanation. // See docs/SyntaxChanges.md for an explanation.
LUAU_FASTFLAGVARIABLE(LuauSolverV2) LUAU_FASTFLAGVARIABLE(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionsSyntax2)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunParseExport) LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunParseExport)
LUAU_FASTFLAGVARIABLE(LuauAllowFragmentParsing) LUAU_FASTFLAGVARIABLE(LuauAllowFragmentParsing)
LUAU_FASTFLAGVARIABLE(LuauPortableStringZeroCheck)
LUAU_FASTFLAGVARIABLE(LuauAllowComplexTypesInGenericParams) LUAU_FASTFLAGVARIABLE(LuauAllowComplexTypesInGenericParams)
LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForTableTypes) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForTableTypes)
LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForClassNames) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForClassNames)
@ -910,11 +908,8 @@ AstStat* Parser::parseReturn()
AstStat* Parser::parseTypeAlias(const Location& start, bool exported) AstStat* Parser::parseTypeAlias(const Location& start, bool exported)
{ {
// parsing a type function // parsing a type function
if (FFlag::LuauUserDefinedTypeFunctionsSyntax2) if (lexer.current().type == Lexeme::ReservedFunction)
{ return parseTypeFunction(start, exported);
if (lexer.current().type == Lexeme::ReservedFunction)
return parseTypeFunction(start, exported);
}
// parsing a type alias // parsing a type alias
@ -1134,8 +1129,7 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArray<AstAttr*
AstType* type = parseType(); AstType* type = parseType();
// since AstName contains a char*, it can't contain null // since AstName contains a char*, it can't contain null
bool containsNull = chars && (FFlag::LuauPortableStringZeroCheck ? memchr(chars->data, 0, chars->size) != nullptr bool containsNull = chars && (memchr(chars->data, 0, chars->size) != nullptr);
: strnlen(chars->data, chars->size) < chars->size);
if (chars && !containsNull) if (chars && !containsNull)
{ {
@ -1647,8 +1641,7 @@ AstType* Parser::parseTableType(bool inDeclarationContext)
AstType* type = parseType(); AstType* type = parseType();
// since AstName contains a char*, it can't contain null // since AstName contains a char*, it can't contain null
bool containsNull = chars && (FFlag::LuauPortableStringZeroCheck ? memchr(chars->data, 0, chars->size) != nullptr bool containsNull = chars && (memchr(chars->data, 0, chars->size) != nullptr);
: strnlen(chars->data, chars->size) < chars->size);
if (chars && !containsNull) if (chars && !containsNull)
props.push_back(AstTableProp{AstName(chars->data), begin.location, type, access, accessLocation}); props.push_back(AstTableProp{AstName(chars->data), begin.location, type, access, accessLocation});
@ -2352,11 +2345,8 @@ AstExpr* Parser::parseNameExpr(const char* context)
{ {
AstLocal* local = *value; AstLocal* local = *value;
if (FFlag::LuauUserDefinedTypeFunctionsSyntax2) if (local->functionDepth < typeFunctionDepth)
{ return reportExprError(lexer.current().location, {}, "Type function cannot reference outer local '%s'", local->name.value);
if (local->functionDepth < typeFunctionDepth)
return reportExprError(lexer.current().location, {}, "Type function cannot reference outer local '%s'", local->name.value);
}
return allocator.alloc<AstExprLocal>(name->location, local, local->functionDepth != functionStack.size() - 1); return allocator.alloc<AstExprLocal>(name->location, local, local->functionDepth != functionStack.size() - 1);
} }

View file

@ -13,6 +13,16 @@ struct ParseResult;
class BytecodeBuilder; class BytecodeBuilder;
class BytecodeEncoder; class BytecodeEncoder;
using CompileConstant = void*;
// return a type identifier for a global library member
// values are defined by 'enum LuauBytecodeType' in Bytecode.h
using LibraryMemberTypeCallback = int (*)(const char* library, const char* member);
// setup a value of a constant for a global library member
// use setCompileConstant*** set of functions for values
using LibraryMemberConstantCallback = void (*)(const char* library, const char* member, CompileConstant* constant);
// Note: this structure is duplicated in luacode.h, don't forget to change these in sync! // Note: this structure is duplicated in luacode.h, don't forget to change these in sync!
struct CompileOptions struct CompileOptions
{ {
@ -49,6 +59,15 @@ struct CompileOptions
// null-terminated array of userdata types that will be included in the type information // null-terminated array of userdata types that will be included in the type information
const char* const* userdataTypes = nullptr; const char* const* userdataTypes = nullptr;
// null-terminated array of globals which act as libraries and have members with known type and/or constant value
// when an import of one of these libraries is accessed, callbacks below will be called to receive that information
const char* const* librariesWithKnownMembers = nullptr;
LibraryMemberTypeCallback libraryMemberTypeCb = nullptr;
LibraryMemberConstantCallback libraryMemberConstantCb = nullptr;
// null-terminated array of library functions that should not be compiled into a built-in fastcall ("name" "lib.name")
const char* const* disabledBuiltins = nullptr;
}; };
class CompileError : public std::exception class CompileError : public std::exception
@ -81,4 +100,10 @@ std::string compile(
BytecodeEncoder* encoder = nullptr BytecodeEncoder* encoder = nullptr
); );
void setCompileConstantNil(CompileConstant* constant);
void setCompileConstantBoolean(CompileConstant* constant, bool b);
void setCompileConstantNumber(CompileConstant* constant, double n);
void setCompileConstantVector(CompileConstant* constant, float x, float y, float z, float w);
void setCompileConstantString(CompileConstant* constant, const char* s, size_t l);
} // namespace Luau } // namespace Luau

View file

@ -3,12 +3,21 @@
#include <stddef.h> #include <stddef.h>
// Can be used to reconfigure visibility/exports for public APIs // can be used to reconfigure visibility/exports for public APIs
#ifndef LUACODE_API #ifndef LUACODE_API
#define LUACODE_API extern #define LUACODE_API extern
#endif #endif
typedef struct lua_CompileOptions lua_CompileOptions; typedef struct lua_CompileOptions lua_CompileOptions;
typedef void* lua_CompileConstant;
// return a type identifier for a global library member
// values are defined by 'enum LuauBytecodeType' in Bytecode.h
typedef int (*lua_LibraryMemberTypeCallback)(const char* library, const char* member);
// setup a value of a constant for a global library member
// use luau_set_compile_constant_*** set of functions for values
typedef void (*lua_LibraryMemberConstantCallback)(const char* library, const char* member, lua_CompileConstant* constant);
struct lua_CompileOptions struct lua_CompileOptions
{ {
@ -45,7 +54,25 @@ struct lua_CompileOptions
// null-terminated array of userdata types that will be included in the type information // null-terminated array of userdata types that will be included in the type information
const char* const* userdataTypes; const char* const* userdataTypes;
// null-terminated array of globals which act as libraries and have members with known type and/or constant value
// when an import of one of these libraries is accessed, callbacks below will be called to receive that information
const char* const* librariesWithKnownMembers;
lua_LibraryMemberTypeCallback libraryMemberTypeCb;
lua_LibraryMemberConstantCallback libraryMemberConstantCb;
// null-terminated array of library functions that should not be compiled into a built-in fastcall ("name" "lib.name")
const char* const* disabledBuiltins;
}; };
// compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy // compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy
LUACODE_API char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize); LUACODE_API char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize);
// when libraryMemberConstantCb is called, these methods can be used to set a value of the opaque lua_CompileConstant struct
// vector component 'w' is not visible to VM runtime configured with LUA_VECTOR_SIZE == 3, but can affect constant folding during compilation
// string storage must outlive the invocation of 'luau_compile' which used the callback
LUACODE_API void luau_set_compile_constant_nil(lua_CompileConstant* constant);
LUACODE_API void luau_set_compile_constant_boolean(lua_CompileConstant* constant, int b);
LUACODE_API void luau_set_compile_constant_number(lua_CompileConstant* constant, double n);
LUACODE_API void luau_set_compile_constant_vector(lua_CompileConstant* constant, float x, float y, float z, float w);
LUACODE_API void luau_set_compile_constant_string(lua_CompileConstant* constant, const char* s, size_t l);

View file

@ -3,8 +3,12 @@
#include "Luau/Bytecode.h" #include "Luau/Bytecode.h"
#include "Luau/Compiler.h" #include "Luau/Compiler.h"
#include "Luau/Lexer.h"
#include <array>
LUAU_FASTFLAGVARIABLE(LuauVectorBuiltins) LUAU_FASTFLAGVARIABLE(LuauVectorBuiltins)
LUAU_FASTFLAGVARIABLE(LuauCompileDisabledBuiltins)
namespace Luau namespace Luau
{ {
@ -270,23 +274,61 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op
struct BuiltinVisitor : AstVisitor struct BuiltinVisitor : AstVisitor
{ {
DenseHashMap<AstExprCall*, int>& result; DenseHashMap<AstExprCall*, int>& result;
std::array<bool, 256> builtinIsDisabled;
const DenseHashMap<AstName, Global>& globals; const DenseHashMap<AstName, Global>& globals;
const DenseHashMap<AstLocal*, Variable>& variables; const DenseHashMap<AstLocal*, Variable>& variables;
const CompileOptions& options; const CompileOptions& options;
const AstNameTable& names;
BuiltinVisitor( BuiltinVisitor(
DenseHashMap<AstExprCall*, int>& result, DenseHashMap<AstExprCall*, int>& result,
const DenseHashMap<AstName, Global>& globals, const DenseHashMap<AstName, Global>& globals,
const DenseHashMap<AstLocal*, Variable>& variables, const DenseHashMap<AstLocal*, Variable>& variables,
const CompileOptions& options const CompileOptions& options,
const AstNameTable& names
) )
: result(result) : result(result)
, globals(globals) , globals(globals)
, variables(variables) , variables(variables)
, options(options) , options(options)
, names(names)
{ {
if (FFlag::LuauCompileDisabledBuiltins)
{
builtinIsDisabled.fill(false);
if (const char* const* ptr = options.disabledBuiltins)
{
for (; *ptr; ++ptr)
{
if (const char* dot = strchr(*ptr, '.'))
{
AstName library = names.getWithType(*ptr, dot - *ptr).first;
AstName name = names.get(dot + 1);
if (library.value && name.value && getGlobalState(globals, name) == Global::Default)
{
Builtin builtin = Builtin{library, name};
if (int bfid = getBuiltinFunctionId(builtin, options); bfid >= 0)
builtinIsDisabled[bfid] = true;
}
}
else
{
if (AstName name = names.get(*ptr); name.value && getGlobalState(globals, name) == Global::Default)
{
Builtin builtin = Builtin{AstName(), name};
if (int bfid = getBuiltinFunctionId(builtin, options); bfid >= 0)
builtinIsDisabled[bfid] = true;
}
}
}
}
}
} }
bool visit(AstExprCall* node) override bool visit(AstExprCall* node) override
@ -297,6 +339,9 @@ struct BuiltinVisitor : AstVisitor
int bfid = getBuiltinFunctionId(builtin, options); int bfid = getBuiltinFunctionId(builtin, options);
if (FFlag::LuauCompileDisabledBuiltins && bfid >= 0 && builtinIsDisabled[bfid])
bfid = -1;
// getBuiltinFunctionId optimistically assumes all select() calls are builtin but actually the second argument must be a vararg // getBuiltinFunctionId optimistically assumes all select() calls are builtin but actually the second argument must be a vararg
if (bfid == LBF_SELECT_VARARG && !(node->args.size == 2 && node->args.data[1]->is<AstExprVarargs>())) if (bfid == LBF_SELECT_VARARG && !(node->args.size == 2 && node->args.data[1]->is<AstExprVarargs>()))
bfid = -1; bfid = -1;
@ -313,10 +358,11 @@ void analyzeBuiltins(
const DenseHashMap<AstName, Global>& globals, const DenseHashMap<AstName, Global>& globals,
const DenseHashMap<AstLocal*, Variable>& variables, const DenseHashMap<AstLocal*, Variable>& variables,
const CompileOptions& options, const CompileOptions& options,
AstNode* root AstNode* root,
const AstNameTable& names
) )
{ {
BuiltinVisitor visitor{result, globals, variables, options}; BuiltinVisitor visitor{result, globals, variables, options, names};
root->visit(&visitor); root->visit(&visitor);
} }

View file

@ -41,7 +41,8 @@ void analyzeBuiltins(
const DenseHashMap<AstName, Global>& globals, const DenseHashMap<AstName, Global>& globals,
const DenseHashMap<AstLocal*, Variable>& variables, const DenseHashMap<AstLocal*, Variable>& variables,
const CompileOptions& options, const CompileOptions& options,
AstNode* root AstNode* root,
const AstNameTable& names
); );
struct BuiltinInfo struct BuiltinInfo

View file

@ -27,6 +27,7 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300)
LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5)
LUAU_FASTFLAGVARIABLE(LuauCompileOptimizeRevArith) LUAU_FASTFLAGVARIABLE(LuauCompileOptimizeRevArith)
LUAU_FASTFLAGVARIABLE(LuauCompileLibraryConstants)
namespace Luau namespace Luau
{ {
@ -725,7 +726,7 @@ struct Compiler
inlineFrames.push_back({func, oldLocals, target, targetCount}); inlineFrames.push_back({func, oldLocals, target, targetCount});
// fold constant values updated above into expressions in the function body // fold constant values updated above into expressions in the function body
foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, func->body); foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldLibraryK, options.libraryMemberConstantCb, func->body);
bool usedFallthrough = false; bool usedFallthrough = false;
@ -770,7 +771,7 @@ struct Compiler
var->type = Constant::Type_Unknown; var->type = Constant::Type_Unknown;
} }
foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, func->body); foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldLibraryK, options.libraryMemberConstantCb, func->body);
} }
void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false) void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false)
@ -3052,7 +3053,7 @@ struct Compiler
locstants[var].type = Constant::Type_Number; locstants[var].type = Constant::Type_Number;
locstants[var].valueNumber = from + iv * step; locstants[var].valueNumber = from + iv * step;
foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, stat); foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldLibraryK, options.libraryMemberConstantCb, stat);
size_t iterJumps = loopJumps.size(); size_t iterJumps = loopJumps.size();
@ -3080,7 +3081,7 @@ struct Compiler
// clean up fold state in case we need to recompile - normally we compile the loop body once, but due to inlining we may need to do it again // clean up fold state in case we need to recompile - normally we compile the loop body once, but due to inlining we may need to do it again
locstants[var].type = Constant::Type_Unknown; locstants[var].type = Constant::Type_Unknown;
foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, stat); foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldLibraryK, options.libraryMemberConstantCb, stat);
} }
void compileStatFor(AstStatFor* stat) void compileStatFor(AstStatFor* stat)
@ -4141,7 +4142,7 @@ struct Compiler
BuiltinAstTypes builtinTypes; BuiltinAstTypes builtinTypes;
const DenseHashMap<AstExprCall*, int>* builtinsFold = nullptr; const DenseHashMap<AstExprCall*, int>* builtinsFold = nullptr;
bool builtinsFoldMathK = false; bool builtinsFoldLibraryK = false;
// compileFunction state, gets reset for every function // compileFunction state, gets reset for every function
unsigned int regTop = 0; unsigned int regTop = 0;
@ -4221,16 +4222,40 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c
compiler.builtinsFold = &compiler.builtins; compiler.builtinsFold = &compiler.builtins;
if (AstName math = names.get("math"); math.value && getGlobalState(compiler.globals, math) == Global::Default) if (AstName math = names.get("math"); math.value && getGlobalState(compiler.globals, math) == Global::Default)
compiler.builtinsFoldMathK = true; {
compiler.builtinsFoldLibraryK = true;
}
else if (FFlag::LuauCompileLibraryConstants)
{
if (const char* const* ptr = options.librariesWithKnownMembers)
{
for (; *ptr; ++ptr)
{
if (AstName name = names.get(*ptr); name.value && getGlobalState(compiler.globals, name) == Global::Default)
{
compiler.builtinsFoldLibraryK = true;
break;
}
}
}
}
} }
if (options.optimizationLevel >= 1) if (options.optimizationLevel >= 1)
{ {
// this pass tracks which calls are builtins and can be compiled more efficiently // this pass tracks which calls are builtins and can be compiled more efficiently
analyzeBuiltins(compiler.builtins, compiler.globals, compiler.variables, options, root); analyzeBuiltins(compiler.builtins, compiler.globals, compiler.variables, options, root, names);
// this pass analyzes constantness of expressions // this pass analyzes constantness of expressions
foldConstants(compiler.constants, compiler.variables, compiler.locstants, compiler.builtinsFold, compiler.builtinsFoldMathK, root); foldConstants(
compiler.constants,
compiler.variables,
compiler.locstants,
compiler.builtinsFold,
compiler.builtinsFoldLibraryK,
options.libraryMemberConstantCb,
root
);
// this pass analyzes table assignments to estimate table shapes for initially empty tables // this pass analyzes table assignments to estimate table shapes for initially empty tables
predictTableShapes(compiler.tableShapes, root); predictTableShapes(compiler.tableShapes, root);
@ -4261,6 +4286,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c
compiler.builtinTypes, compiler.builtinTypes,
compiler.builtins, compiler.builtins,
compiler.globals, compiler.globals,
options.libraryMemberTypeCb,
bytecode bytecode
); );
@ -4277,9 +4303,9 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c
AstExprFunction main( AstExprFunction main(
root->location, root->location,
/*attributes=*/AstArray<AstAttr*>({nullptr, 0}), /* attributes= */ AstArray<AstAttr*>({nullptr, 0}),
/*generics= */ AstArray<AstGenericType>(), /* generics= */ AstArray<AstGenericType>(),
/*genericPacks= */ AstArray<AstGenericTypePack>(), /* genericPacks= */ AstArray<AstGenericTypePack>(),
/* self= */ nullptr, /* self= */ nullptr,
AstArray<AstLocal*>(), AstArray<AstLocal*>(),
/* vararg= */ true, /* vararg= */ true,
@ -4340,4 +4366,50 @@ std::string compile(const std::string& source, const CompileOptions& options, co
} }
} }
void setCompileConstantNil(CompileConstant* constant)
{
Compile::Constant* target = reinterpret_cast<Compile::Constant*>(constant);
target->type = Compile::Constant::Type_Nil;
}
void setCompileConstantBoolean(CompileConstant* constant, bool b)
{
Compile::Constant* target = reinterpret_cast<Compile::Constant*>(constant);
target->type = Compile::Constant::Type_Boolean;
target->valueBoolean = b;
}
void setCompileConstantNumber(CompileConstant* constant, double n)
{
Compile::Constant* target = reinterpret_cast<Compile::Constant*>(constant);
target->type = Compile::Constant::Type_Number;
target->valueNumber = n;
}
void setCompileConstantVector(CompileConstant* constant, float x, float y, float z, float w)
{
Compile::Constant* target = reinterpret_cast<Compile::Constant*>(constant);
target->type = Compile::Constant::Type_Vector;
target->valueVector[0] = x;
target->valueVector[1] = y;
target->valueVector[2] = z;
target->valueVector[3] = w;
}
void setCompileConstantString(CompileConstant* constant, const char* s, size_t l)
{
Compile::Constant* target = reinterpret_cast<Compile::Constant*>(constant);
if (l > std::numeric_limits<unsigned int>::max())
CompileError::raise({}, "Exceeded custom string constant length limit");
target->type = Compile::Constant::Type_String;
target->stringLength = l;
target->valueString = s;
}
} // namespace Luau } // namespace Luau

View file

@ -6,6 +6,9 @@
#include <vector> #include <vector>
#include <math.h> #include <math.h>
LUAU_FASTFLAG(LuauCompileLibraryConstants)
LUAU_FASTFLAGVARIABLE(LuauVectorFolding)
namespace Luau namespace Luau
{ {
namespace Compile namespace Compile
@ -57,6 +60,14 @@ static void foldUnary(Constant& result, AstExprUnary::Op op, const Constant& arg
result.type = Constant::Type_Number; result.type = Constant::Type_Number;
result.valueNumber = -arg.valueNumber; result.valueNumber = -arg.valueNumber;
} }
else if (FFlag::LuauVectorFolding && arg.type == Constant::Type_Vector)
{
result.type = Constant::Type_Vector;
result.valueVector[0] = -arg.valueVector[0];
result.valueVector[1] = -arg.valueVector[1];
result.valueVector[2] = -arg.valueVector[2];
result.valueVector[3] = -arg.valueVector[3];
}
break; break;
case AstExprUnary::Len: case AstExprUnary::Len:
@ -82,6 +93,14 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l
result.type = Constant::Type_Number; result.type = Constant::Type_Number;
result.valueNumber = la.valueNumber + ra.valueNumber; result.valueNumber = la.valueNumber + ra.valueNumber;
} }
else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector)
{
result.type = Constant::Type_Vector;
result.valueVector[0] = la.valueVector[0] + ra.valueVector[0];
result.valueVector[1] = la.valueVector[1] + ra.valueVector[1];
result.valueVector[2] = la.valueVector[2] + ra.valueVector[2];
result.valueVector[3] = la.valueVector[3] + ra.valueVector[3];
}
break; break;
case AstExprBinary::Sub: case AstExprBinary::Sub:
@ -90,6 +109,14 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l
result.type = Constant::Type_Number; result.type = Constant::Type_Number;
result.valueNumber = la.valueNumber - ra.valueNumber; result.valueNumber = la.valueNumber - ra.valueNumber;
} }
else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector)
{
result.type = Constant::Type_Vector;
result.valueVector[0] = la.valueVector[0] - ra.valueVector[0];
result.valueVector[1] = la.valueVector[1] - ra.valueVector[1];
result.valueVector[2] = la.valueVector[2] - ra.valueVector[2];
result.valueVector[3] = la.valueVector[3] - ra.valueVector[3];
}
break; break;
case AstExprBinary::Mul: case AstExprBinary::Mul:
@ -98,6 +125,48 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l
result.type = Constant::Type_Number; result.type = Constant::Type_Number;
result.valueNumber = la.valueNumber * ra.valueNumber; result.valueNumber = la.valueNumber * ra.valueNumber;
} }
else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector)
{
bool hadW = la.valueVector[3] != 0.0f || ra.valueVector[3] != 0.0f;
float resultW = la.valueVector[3] * ra.valueVector[3];
if (resultW == 0.0f || hadW)
{
result.type = Constant::Type_Vector;
result.valueVector[0] = la.valueVector[0] * ra.valueVector[0];
result.valueVector[1] = la.valueVector[1] * ra.valueVector[1];
result.valueVector[2] = la.valueVector[2] * ra.valueVector[2];
result.valueVector[3] = resultW;
}
}
else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Number && ra.type == Constant::Type_Vector)
{
bool hadW = ra.valueVector[3] != 0.0f;
float resultW = float(la.valueNumber) * ra.valueVector[3];
if (resultW == 0.0f || hadW)
{
result.type = Constant::Type_Vector;
result.valueVector[0] = float(la.valueNumber) * ra.valueVector[0];
result.valueVector[1] = float(la.valueNumber) * ra.valueVector[1];
result.valueVector[2] = float(la.valueNumber) * ra.valueVector[2];
result.valueVector[3] = resultW;
}
}
else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Vector && ra.type == Constant::Type_Number)
{
bool hadW = la.valueVector[3] != 0.0f;
float resultW = la.valueVector[3] * float(ra.valueNumber);
if (resultW == 0.0f || hadW)
{
result.type = Constant::Type_Vector;
result.valueVector[0] = la.valueVector[0] * float(ra.valueNumber);
result.valueVector[1] = la.valueVector[1] * float(ra.valueNumber);
result.valueVector[2] = la.valueVector[2] * float(ra.valueNumber);
result.valueVector[3] = resultW;
}
}
break; break;
case AstExprBinary::Div: case AstExprBinary::Div:
@ -106,6 +175,48 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l
result.type = Constant::Type_Number; result.type = Constant::Type_Number;
result.valueNumber = la.valueNumber / ra.valueNumber; result.valueNumber = la.valueNumber / ra.valueNumber;
} }
else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector)
{
bool hadW = la.valueVector[3] != 0.0f || ra.valueVector[3] != 0.0f;
float resultW = la.valueVector[3] / ra.valueVector[3];
if (resultW == 0.0f || hadW)
{
result.type = Constant::Type_Vector;
result.valueVector[0] = la.valueVector[0] / ra.valueVector[0];
result.valueVector[1] = la.valueVector[1] / ra.valueVector[1];
result.valueVector[2] = la.valueVector[2] / ra.valueVector[2];
result.valueVector[3] = resultW;
}
}
else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Number && ra.type == Constant::Type_Vector)
{
bool hadW = ra.valueVector[3] != 0.0f;
float resultW = float(la.valueNumber) / ra.valueVector[3];
if (resultW == 0.0f || hadW)
{
result.type = Constant::Type_Vector;
result.valueVector[0] = float(la.valueNumber) / ra.valueVector[0];
result.valueVector[1] = float(la.valueNumber) / ra.valueVector[1];
result.valueVector[2] = float(la.valueNumber) / ra.valueVector[2];
result.valueVector[3] = resultW;
}
}
else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Vector && ra.type == Constant::Type_Number)
{
bool hadW = la.valueVector[3] != 0.0f;
float resultW = la.valueVector[3] / float(ra.valueNumber);
if (resultW == 0.0f || hadW)
{
result.type = Constant::Type_Vector;
result.valueVector[0] = la.valueVector[0] / float(ra.valueNumber);
result.valueVector[1] = la.valueVector[1] / float(ra.valueNumber);
result.valueVector[2] = la.valueVector[2] / float(ra.valueNumber);
result.valueVector[3] = resultW;
}
}
break; break;
case AstExprBinary::FloorDiv: case AstExprBinary::FloorDiv:
@ -114,6 +225,48 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l
result.type = Constant::Type_Number; result.type = Constant::Type_Number;
result.valueNumber = floor(la.valueNumber / ra.valueNumber); result.valueNumber = floor(la.valueNumber / ra.valueNumber);
} }
else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector)
{
bool hadW = la.valueVector[3] != 0.0f || ra.valueVector[3] != 0.0f;
float resultW = floor(la.valueVector[3] / ra.valueVector[3]);
if (resultW == 0.0f || hadW)
{
result.type = Constant::Type_Vector;
result.valueVector[0] = floor(la.valueVector[0] / ra.valueVector[0]);
result.valueVector[1] = floor(la.valueVector[1] / ra.valueVector[1]);
result.valueVector[2] = floor(la.valueVector[2] / ra.valueVector[2]);
result.valueVector[3] = resultW;
}
}
else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Number && ra.type == Constant::Type_Vector)
{
bool hadW = ra.valueVector[3] != 0.0f;
float resultW = floor(float(la.valueNumber) / ra.valueVector[3]);
if (resultW == 0.0f || hadW)
{
result.type = Constant::Type_Vector;
result.valueVector[0] = floor(float(la.valueNumber) / ra.valueVector[0]);
result.valueVector[1] = floor(float(la.valueNumber) / ra.valueVector[1]);
result.valueVector[2] = floor(float(la.valueNumber) / ra.valueVector[2]);
result.valueVector[3] = resultW;
}
}
else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Vector && ra.type == Constant::Type_Number)
{
bool hadW = la.valueVector[3] != 0.0f;
float resultW = floor(la.valueVector[3] / float(ra.valueNumber));
if (resultW == 0.0f || hadW)
{
result.type = Constant::Type_Vector;
result.valueVector[0] = floor(la.valueVector[0] / float(ra.valueNumber));
result.valueVector[1] = floor(la.valueVector[1] / float(ra.valueNumber));
result.valueVector[2] = floor(la.valueVector[2] / float(ra.valueNumber));
result.valueVector[3] = floor(la.valueVector[3] / float(ra.valueNumber));
}
}
break; break;
case AstExprBinary::Mod: case AstExprBinary::Mod:
@ -209,7 +362,8 @@ struct ConstantVisitor : AstVisitor
DenseHashMap<AstLocal*, Constant>& locals; DenseHashMap<AstLocal*, Constant>& locals;
const DenseHashMap<AstExprCall*, int>* builtins; const DenseHashMap<AstExprCall*, int>* builtins;
bool foldMathK = false; bool foldLibraryK = false;
LibraryMemberConstantCallback libraryMemberConstantCb;
bool wasEmpty = false; bool wasEmpty = false;
@ -220,13 +374,15 @@ struct ConstantVisitor : AstVisitor
DenseHashMap<AstLocal*, Variable>& variables, DenseHashMap<AstLocal*, Variable>& variables,
DenseHashMap<AstLocal*, Constant>& locals, DenseHashMap<AstLocal*, Constant>& locals,
const DenseHashMap<AstExprCall*, int>* builtins, const DenseHashMap<AstExprCall*, int>* builtins,
bool foldMathK bool foldLibraryK,
LibraryMemberConstantCallback libraryMemberConstantCb
) )
: constants(constants) : constants(constants)
, variables(variables) , variables(variables)
, locals(locals) , locals(locals)
, builtins(builtins) , builtins(builtins)
, foldMathK(foldMathK) , foldLibraryK(foldLibraryK)
, libraryMemberConstantCb(libraryMemberConstantCb)
{ {
// since we do a single pass over the tree, if the initial state was empty we don't need to clear out old entries // since we do a single pass over the tree, if the initial state was empty we don't need to clear out old entries
wasEmpty = constants.empty() && locals.empty(); wasEmpty = constants.empty() && locals.empty();
@ -316,11 +472,26 @@ struct ConstantVisitor : AstVisitor
{ {
analyze(expr->expr); analyze(expr->expr);
if (foldMathK) if (foldLibraryK)
{ {
if (AstExprGlobal* eg = expr->expr->as<AstExprGlobal>(); eg && eg->name == "math") if (FFlag::LuauCompileLibraryConstants)
{ {
result = foldBuiltinMath(expr->index); if (AstExprGlobal* eg = expr->expr->as<AstExprGlobal>())
{
if (eg->name == "math")
result = foldBuiltinMath(expr->index);
// if we have a custom handler and the constant hasn't been resolved
if (libraryMemberConstantCb && result.type == Constant::Type_Unknown)
libraryMemberConstantCb(eg->name.value, expr->index.value, reinterpret_cast<Luau::CompileConstant*>(&result));
}
}
else
{
if (AstExprGlobal* eg = expr->expr->as<AstExprGlobal>(); eg && eg->name == "math")
{
result = foldBuiltinMath(expr->index);
}
} }
} }
} }
@ -468,11 +639,12 @@ void foldConstants(
DenseHashMap<AstLocal*, Variable>& variables, DenseHashMap<AstLocal*, Variable>& variables,
DenseHashMap<AstLocal*, Constant>& locals, DenseHashMap<AstLocal*, Constant>& locals,
const DenseHashMap<AstExprCall*, int>* builtins, const DenseHashMap<AstExprCall*, int>* builtins,
bool foldMathK, bool foldLibraryK,
LibraryMemberConstantCallback libraryMemberConstantCb,
AstNode* root AstNode* root
) )
{ {
ConstantVisitor visitor{constants, variables, locals, builtins, foldMathK}; ConstantVisitor visitor{constants, variables, locals, builtins, foldLibraryK, libraryMemberConstantCb};
root->visit(&visitor); root->visit(&visitor);
} }

View file

@ -1,6 +1,8 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/Compiler.h"
#include "ValueTracking.h" #include "ValueTracking.h"
namespace Luau namespace Luau
@ -49,7 +51,8 @@ void foldConstants(
DenseHashMap<AstLocal*, Variable>& variables, DenseHashMap<AstLocal*, Variable>& variables,
DenseHashMap<AstLocal*, Constant>& locals, DenseHashMap<AstLocal*, Constant>& locals,
const DenseHashMap<AstExprCall*, int>* builtins, const DenseHashMap<AstExprCall*, int>* builtins,
bool foldMathK, bool foldLibraryK,
LibraryMemberConstantCallback libraryMemberConstantCb,
AstNode* root AstNode* root
); );

View file

@ -4,6 +4,7 @@
#include "Luau/BytecodeBuilder.h" #include "Luau/BytecodeBuilder.h"
LUAU_FASTFLAGVARIABLE(LuauCompileVectorTypeInfo) LUAU_FASTFLAGVARIABLE(LuauCompileVectorTypeInfo)
LUAU_FASTFLAG(LuauCompileLibraryConstants)
namespace Luau namespace Luau
{ {
@ -175,16 +176,32 @@ static bool isMatchingGlobal(const DenseHashMap<AstName, Compile::Global>& globa
return false; return false;
} }
static bool isMatchingGlobalMember(
const DenseHashMap<AstName, Compile::Global>& globals,
AstExprIndexName* expr,
const char* library,
const char* member
)
{
LUAU_ASSERT(FFlag::LuauCompileLibraryConstants);
if (AstExprGlobal* object = expr->expr->as<AstExprGlobal>())
return getGlobalState(globals, object->name) == Compile::Global::Default && object->name == library && expr->index == member;
return false;
}
struct TypeMapVisitor : AstVisitor struct TypeMapVisitor : AstVisitor
{ {
DenseHashMap<AstExprFunction*, std::string>& functionTypes; DenseHashMap<AstExprFunction*, std::string>& functionTypes;
DenseHashMap<AstLocal*, LuauBytecodeType>& localTypes; DenseHashMap<AstLocal*, LuauBytecodeType>& localTypes;
DenseHashMap<AstExpr*, LuauBytecodeType>& exprTypes; DenseHashMap<AstExpr*, LuauBytecodeType>& exprTypes;
const char* hostVectorType; const char* hostVectorType = nullptr;
const DenseHashMap<AstName, uint8_t>& userdataTypes; const DenseHashMap<AstName, uint8_t>& userdataTypes;
const BuiltinAstTypes& builtinTypes; const BuiltinAstTypes& builtinTypes;
const DenseHashMap<AstExprCall*, int>& builtinCalls; const DenseHashMap<AstExprCall*, int>& builtinCalls;
const DenseHashMap<AstName, Compile::Global>& globals; const DenseHashMap<AstName, Compile::Global>& globals;
LibraryMemberTypeCallback libraryMemberTypeCb = nullptr;
BytecodeBuilder& bytecode; BytecodeBuilder& bytecode;
DenseHashMap<AstName, AstStatTypeAlias*> typeAliases; DenseHashMap<AstName, AstStatTypeAlias*> typeAliases;
@ -201,6 +218,7 @@ struct TypeMapVisitor : AstVisitor
const BuiltinAstTypes& builtinTypes, const BuiltinAstTypes& builtinTypes,
const DenseHashMap<AstExprCall*, int>& builtinCalls, const DenseHashMap<AstExprCall*, int>& builtinCalls,
const DenseHashMap<AstName, Compile::Global>& globals, const DenseHashMap<AstName, Compile::Global>& globals,
LibraryMemberTypeCallback libraryMemberTypeCb,
BytecodeBuilder& bytecode BytecodeBuilder& bytecode
) )
: functionTypes(functionTypes) : functionTypes(functionTypes)
@ -211,6 +229,7 @@ struct TypeMapVisitor : AstVisitor
, builtinTypes(builtinTypes) , builtinTypes(builtinTypes)
, builtinCalls(builtinCalls) , builtinCalls(builtinCalls)
, globals(globals) , globals(globals)
, libraryMemberTypeCb(libraryMemberTypeCb)
, bytecode(bytecode) , bytecode(bytecode)
, typeAliases(AstName()) , typeAliases(AstName())
, resolvedLocals(nullptr) , resolvedLocals(nullptr)
@ -461,7 +480,53 @@ struct TypeMapVisitor : AstVisitor
if (*typeBcPtr == LBC_TYPE_VECTOR) if (*typeBcPtr == LBC_TYPE_VECTOR)
{ {
if (node->index == "X" || node->index == "Y" || node->index == "Z") if (node->index == "X" || node->index == "Y" || node->index == "Z")
{
recordResolvedType(node, &builtinTypes.numberType); recordResolvedType(node, &builtinTypes.numberType);
if (FFlag::LuauCompileLibraryConstants)
return false;
}
}
}
if (FFlag::LuauCompileLibraryConstants)
{
if (isMatchingGlobalMember(globals, node, "vector", "zero") || isMatchingGlobalMember(globals, node, "vector", "one"))
{
recordResolvedType(node, &builtinTypes.vectorType);
return false;
}
if (libraryMemberTypeCb)
{
if (AstExprGlobal* object = node->expr->as<AstExprGlobal>())
{
if (LuauBytecodeType ty = LuauBytecodeType(libraryMemberTypeCb(object->name.value, node->index.value)); ty != LBC_TYPE_ANY)
{
// TODO: 'resolvedExprs' is more limited than 'exprTypes' which limits full inference of more complex types that a user
// callback can return
switch (ty)
{
case LBC_TYPE_BOOLEAN:
resolvedExprs[node] = &builtinTypes.booleanType;
break;
case LBC_TYPE_NUMBER:
resolvedExprs[node] = &builtinTypes.numberType;
break;
case LBC_TYPE_STRING:
resolvedExprs[node] = &builtinTypes.stringType;
break;
case LBC_TYPE_VECTOR:
resolvedExprs[node] = &builtinTypes.vectorType;
break;
default:
break;
}
exprTypes[node] = ty;
return false;
}
}
} }
} }
@ -733,10 +798,13 @@ void buildTypeMap(
const BuiltinAstTypes& builtinTypes, const BuiltinAstTypes& builtinTypes,
const DenseHashMap<AstExprCall*, int>& builtinCalls, const DenseHashMap<AstExprCall*, int>& builtinCalls,
const DenseHashMap<AstName, Compile::Global>& globals, const DenseHashMap<AstName, Compile::Global>& globals,
LibraryMemberTypeCallback libraryMemberTypeCb,
BytecodeBuilder& bytecode BytecodeBuilder& bytecode
) )
{ {
TypeMapVisitor visitor(functionTypes, localTypes, exprTypes, hostVectorType, userdataTypes, builtinTypes, builtinCalls, globals, bytecode); TypeMapVisitor visitor(
functionTypes, localTypes, exprTypes, hostVectorType, userdataTypes, builtinTypes, builtinCalls, globals, libraryMemberTypeCb, bytecode
);
root->visit(&visitor); root->visit(&visitor);
} }

View file

@ -3,6 +3,7 @@
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Bytecode.h" #include "Luau/Bytecode.h"
#include "Luau/Compiler.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "ValueTracking.h" #include "ValueTracking.h"
@ -19,7 +20,7 @@ struct BuiltinAstTypes
{ {
} }
// AstName use here will not match the AstNameTable, but the was we use them here always force a full string compare // AstName use here will not match the AstNameTable, but the way we use them here always forces a full string compare
AstTypeReference booleanType{{}, std::nullopt, AstName{"boolean"}, std::nullopt, {}}; AstTypeReference booleanType{{}, std::nullopt, AstName{"boolean"}, std::nullopt, {}};
AstTypeReference numberType{{}, std::nullopt, AstName{"number"}, std::nullopt, {}}; AstTypeReference numberType{{}, std::nullopt, AstName{"number"}, std::nullopt, {}};
AstTypeReference stringType{{}, std::nullopt, AstName{"string"}, std::nullopt, {}}; AstTypeReference stringType{{}, std::nullopt, AstName{"string"}, std::nullopt, {}};
@ -38,6 +39,7 @@ void buildTypeMap(
const BuiltinAstTypes& builtinTypes, const BuiltinAstTypes& builtinTypes,
const DenseHashMap<AstExprCall*, int>& builtinCalls, const DenseHashMap<AstExprCall*, int>& builtinCalls,
const DenseHashMap<AstName, Compile::Global>& globals, const DenseHashMap<AstName, Compile::Global>& globals,
LibraryMemberTypeCallback libraryMemberTypeCb,
BytecodeBuilder& bytecode BytecodeBuilder& bytecode
); );

View file

@ -27,3 +27,28 @@ char* luau_compile(const char* source, size_t size, lua_CompileOptions* options,
*outsize = result.size(); *outsize = result.size();
return copy; return copy;
} }
void luau_set_compile_constant_nil(lua_CompileConstant* constant)
{
Luau::setCompileConstantNil(constant);
}
void luau_set_compile_constant_boolean(lua_CompileConstant* constant, int b)
{
Luau::setCompileConstantBoolean(constant, b != 0);
}
void luau_set_compile_constant_number(lua_CompileConstant* constant, double n)
{
Luau::setCompileConstantNumber(constant, n);
}
void luau_set_compile_constant_vector(lua_CompileConstant* constant, float x, float y, float z, float w)
{
Luau::setCompileConstantVector(constant, x, y, z, w);
}
void luau_set_compile_constant_string(lua_CompileConstant* constant, const char* s, size_t l)
{
Luau::setCompileConstantString(constant, s, l);
}

View file

@ -42,6 +42,7 @@ struct Config
{ {
std::string value; std::string value;
std::string_view configLocation; std::string_view configLocation;
std::string originalCase; // The alias in its original case.
}; };
DenseHashMap<std::string, AliasInfo> aliases{""}; DenseHashMap<std::string, AliasInfo> aliases{""};

View file

@ -26,9 +26,9 @@ Config::Config(const Config& other)
, typeErrors(other.typeErrors) , typeErrors(other.typeErrors)
, globals(other.globals) , globals(other.globals)
{ {
for (const auto& [alias, aliasInfo] : other.aliases) for (const auto& [_, aliasInfo] : other.aliases)
{ {
setAlias(alias, aliasInfo.value, std::string(aliasInfo.configLocation)); setAlias(aliasInfo.originalCase, aliasInfo.value, std::string(aliasInfo.configLocation));
} }
} }
@ -44,8 +44,20 @@ Config& Config::operator=(const Config& other)
void Config::setAlias(std::string alias, std::string value, const std::string& configLocation) void Config::setAlias(std::string alias, std::string value, const std::string& configLocation)
{ {
AliasInfo& info = aliases[alias]; std::string lowercasedAlias = alias;
std::transform(
lowercasedAlias.begin(),
lowercasedAlias.end(),
lowercasedAlias.begin(),
[](unsigned char c)
{
return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c;
}
);
AliasInfo& info = aliases[lowercasedAlias];
info.value = std::move(value); info.value = std::move(value);
info.originalCase = std::move(alias);
if (!configLocationCache.contains(configLocation)) if (!configLocationCache.contains(configLocation))
configLocationCache[configLocation] = std::make_unique<std::string>(configLocation); configLocationCache[configLocation] = std::make_unique<std::string>(configLocation);
@ -175,7 +187,7 @@ bool isValidAlias(const std::string& alias)
static Error parseAlias( static Error parseAlias(
Config& config, Config& config,
std::string aliasKey, const std::string& aliasKey,
const std::string& aliasValue, const std::string& aliasValue,
const std::optional<ConfigOptions::AliasOptions>& aliasOptions const std::optional<ConfigOptions::AliasOptions>& aliasOptions
) )
@ -183,21 +195,11 @@ static Error parseAlias(
if (!isValidAlias(aliasKey)) if (!isValidAlias(aliasKey))
return Error{"Invalid alias " + aliasKey}; return Error{"Invalid alias " + aliasKey};
std::transform(
aliasKey.begin(),
aliasKey.end(),
aliasKey.begin(),
[](unsigned char c)
{
return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c;
}
);
if (!aliasOptions) if (!aliasOptions)
return Error("Cannot parse aliases without alias options"); return Error("Cannot parse aliases without alias options");
if (aliasOptions->overwriteAliases || !config.aliases.contains(aliasKey)) if (aliasOptions->overwriteAliases || !config.aliases.contains(aliasKey))
config.setAlias(std::move(aliasKey), aliasValue, aliasOptions->configLocation); config.setAlias(aliasKey, aliasValue, aliasOptions->configLocation);
return std::nullopt; return std::nullopt;
} }

View file

@ -198,8 +198,7 @@ private:
{ {
// An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where // An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where
// canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...). // canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...).
for (Id& id : enode.mutableOperands()) Luau::EqSat::canonicalize(enode, [&](Id id) { return find(id); });
id = find(id);
} }
bool isCanonical(const L& enode) const bool isCanonical(const L& enode) const

View file

@ -244,6 +244,9 @@ private:
template<typename Phantom, typename T> template<typename Phantom, typename T>
struct NodeSet struct NodeSet
{ {
template <typename P_, typename T_, typename Find>
friend void canonicalize(NodeSet<P_, T_>& node, Find&& find);
template<typename... Args> template<typename... Args>
NodeSet(Args&&... args) NodeSet(Args&&... args)
: vector{std::forward<Args>(args)...} : vector{std::forward<Args>(args)...}
@ -299,6 +302,9 @@ struct Language final
template<typename T> template<typename T>
using WithinDomain = std::disjunction<std::is_same<std::decay_t<T>, Ts>...>; using WithinDomain = std::disjunction<std::is_same<std::decay_t<T>, Ts>...>;
template <typename Find, typename... Vs>
friend void canonicalize(Language<Vs...>& enode, Find&& find);
template<typename T> template<typename T>
Language(T&& t, std::enable_if_t<WithinDomain<T>::value>* = 0) noexcept Language(T&& t, std::enable_if_t<WithinDomain<T>::value>* = 0) noexcept
: v(std::forward<T>(t)) : v(std::forward<T>(t))
@ -382,4 +388,37 @@ private:
VariantTy v; VariantTy v;
}; };
template <typename Node, typename Find>
void canonicalize(Node& node, Find&& find)
{
// An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where
// canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...).
for (Id& id : node.mutableOperands())
id = find(id);
}
// Canonicalizing the Ids in a NodeSet may result in the set decreasing in size.
template <typename Phantom, typename T, typename Find>
void canonicalize(NodeSet<Phantom, T>& node, Find&& find)
{
for (Id& id : node.vector)
id = find(id);
std::sort(begin(node.vector), end(node.vector));
auto endIt = std::unique(begin(node.vector), end(node.vector));
node.vector.erase(endIt, end(node.vector));
}
template <typename Find, typename... Vs>
void canonicalize(Language<Vs...>& enode, Find&& find)
{
visit(
[&](auto&& v)
{
Luau::EqSat::canonicalize(v, find);
},
enode.v
);
}
} // namespace Luau::EqSat } // namespace Luau::EqSat

View file

@ -6,7 +6,6 @@
#include "lstate.h" #include "lstate.h"
#include "lvm.h" #include "lvm.h"
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauCoroCheckStack, false)
LUAU_DYNAMIC_FASTFLAG(LuauStackLimit) LUAU_DYNAMIC_FASTFLAG(LuauStackLimit)
#define CO_STATUS_ERROR -1 #define CO_STATUS_ERROR -1
@ -41,7 +40,7 @@ static int auxresume(lua_State* L, lua_State* co, int narg)
luaL_error(L, "too many arguments to resume"); luaL_error(L, "too many arguments to resume");
lua_xmove(L, co, narg); lua_xmove(L, co, narg);
} }
else if (DFFlag::LuauCoroCheckStack) else
{ {
// coroutine might be completely full already // coroutine might be completely full already
if ((co->top - co->base) > LUAI_MAXCSTACK) if ((co->top - co->base) > LUAI_MAXCSTACK)

View file

@ -14,8 +14,6 @@
#include <string.h> #include <string.h>
LUAU_DYNAMIC_FASTFLAG(LuauCoroCheckStack)
/* /*
* Luau uses an incremental non-generational non-moving mark&sweep garbage collector. * Luau uses an incremental non-generational non-moving mark&sweep garbage collector.
* *
@ -439,26 +437,13 @@ static void shrinkstack(lua_State* L)
if (L->size_ci > LUAI_MAXCALLS) // handling overflow? if (L->size_ci > LUAI_MAXCALLS) // handling overflow?
return; // do not touch the stacks return; // do not touch the stacks
if (DFFlag::LuauCoroCheckStack) if (3 * size_t(ci_used) < size_t(L->size_ci) && 2 * BASIC_CI_SIZE < L->size_ci)
{ luaD_reallocCI(L, L->size_ci / 2); // still big enough...
if (3 * size_t(ci_used) < size_t(L->size_ci) && 2 * BASIC_CI_SIZE < L->size_ci) condhardstacktests(luaD_reallocCI(L, ci_used + 1));
luaD_reallocCI(L, L->size_ci / 2); // still big enough...
condhardstacktests(luaD_reallocCI(L, ci_used + 1));
if (3 * size_t(s_used) < size_t(L->stacksize) && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) if (3 * size_t(s_used) < size_t(L->stacksize) && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize)
luaD_reallocstack(L, L->stacksize / 2); // still big enough... luaD_reallocstack(L, L->stacksize / 2); // still big enough...
condhardstacktests(luaD_reallocstack(L, s_used)); condhardstacktests(luaD_reallocstack(L, s_used));
}
else
{
if (3 * ci_used < L->size_ci && 2 * BASIC_CI_SIZE < L->size_ci)
luaD_reallocCI(L, L->size_ci / 2); // still big enough...
condhardstacktests(luaD_reallocCI(L, ci_used + 1));
if (3 * s_used < L->stacksize && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize)
luaD_reallocstack(L, L->stacksize / 2); // still big enough...
condhardstacktests(luaD_reallocstack(L, s_used));
}
} }
/* /*

View file

@ -66,7 +66,7 @@ end
-- and 'false' otherwise. -- and 'false' otherwise.
-- --
-- Example usage: -- Example usage:
-- local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end -- local function prequire(name) local success, result = pcall(require, name); return success and result end
-- local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") -- local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
-- function testFunc() -- function testFunc()
-- ... -- ...

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,5 +1,5 @@
--!nonstrict --!nonstrict
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
local stretchTreeDepth = 18 -- about 16Mb local stretchTreeDepth = 18 -- about 16Mb

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -21,7 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
]] ]]
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -22,7 +22,7 @@
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
]] ]]
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
bench.runCode(function() bench.runCode(function()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

View file

@ -1,4 +1,4 @@
local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
function test() function test()

Some files were not shown because too many files have changed in this diff Show more