diff --git a/Analysis/include/Luau/EqSatSimplificationImpl.h b/Analysis/include/Luau/EqSatSimplificationImpl.h index 24e8777a..2e704e98 100644 --- a/Analysis/include/Luau/EqSatSimplificationImpl.h +++ b/Analysis/include/Luau/EqSatSimplificationImpl.h @@ -53,7 +53,7 @@ LUAU_EQSAT_NODE_SET(Intersection); LUAU_EQSAT_NODE_ARRAY(Negation, 1); -LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(TTypeFun, const TypeFunction*); +LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(TTypeFun, std::shared_ptr); LUAU_EQSAT_UNIT(TNoRefine); LUAU_EQSAT_UNIT(Invalid); @@ -218,6 +218,7 @@ struct Simplifier void simplifyUnion(Id id); void uninhabitedIntersection(Id id); void intersectWithNegatedClass(Id id); + void intersectWithNegatedAtom(Id id); void intersectWithNoRefine(Id id); void cyclicIntersectionOfUnion(Id id); void cyclicUnionOfIntersection(Id id); @@ -228,6 +229,7 @@ struct Simplifier void unneededTableModification(Id id); void builtinTypeFunctions(Id id); void iffyTypeFunctions(Id id); + void strictMetamethods(Id id); }; template diff --git a/Analysis/include/Luau/NonStrictTypeChecker.h b/Analysis/include/Luau/NonStrictTypeChecker.h index 6229a932..880d487f 100644 --- a/Analysis/include/Luau/NonStrictTypeChecker.h +++ b/Analysis/include/Luau/NonStrictTypeChecker.h @@ -15,6 +15,7 @@ struct TypeCheckLimits; void checkNonStrict( NotNull builtinTypes, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull ice, NotNull unifierState, diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 97d13a60..f014c433 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/EqSatSimplification.h" #include "Luau/NotNull.h" #include "Luau/Set.h" #include "Luau/TypeFwd.h" @@ -21,8 +22,22 @@ struct Scope; using ModulePtr = std::shared_ptr; -bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); -bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); +bool isSubtype( + TypeId subTy, + TypeId superTy, + NotNull scope, + NotNull builtinTypes, + NotNull simplifier, + InternalErrorReporter& ice +); +bool isSubtype( + TypePackId subPack, + TypePackId superPack, + NotNull scope, + NotNull builtinTypes, + NotNull simplifier, + InternalErrorReporter& ice +); class TypeIds { diff --git a/Analysis/include/Luau/OverloadResolution.h b/Analysis/include/Luau/OverloadResolution.h index 83a33215..d85d769e 100644 --- a/Analysis/include/Luau/OverloadResolution.h +++ b/Analysis/include/Luau/OverloadResolution.h @@ -2,12 +2,13 @@ #pragma once #include "Luau/Ast.h" -#include "Luau/InsertionOrderedMap.h" -#include "Luau/NotNull.h" -#include "Luau/TypeFwd.h" -#include "Luau/Location.h" +#include "Luau/EqSatSimplification.h" #include "Luau/Error.h" +#include "Luau/InsertionOrderedMap.h" +#include "Luau/Location.h" +#include "Luau/NotNull.h" #include "Luau/Subtyping.h" +#include "Luau/TypeFwd.h" namespace Luau { @@ -34,6 +35,7 @@ struct OverloadResolver OverloadResolver( NotNull builtinTypes, NotNull arena, + NotNull simplifier, NotNull normalizer, NotNull typeFunctionRuntime, NotNull scope, @@ -44,6 +46,7 @@ struct OverloadResolver NotNull builtinTypes; NotNull arena; + NotNull simplifier; NotNull normalizer; NotNull typeFunctionRuntime; NotNull scope; @@ -110,6 +113,7 @@ struct SolveResult SolveResult solveFunctionCall( NotNull arena, NotNull builtinTypes, + NotNull simplifier, NotNull normalizer, NotNull typeFunctionRuntime, NotNull iceReporter, diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 0e6eff56..302c273c 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -85,6 +85,10 @@ struct Scope void inheritAssignments(const ScopePtr& childScope); void inheritRefinements(const ScopePtr& childScope); + // Track globals that should emit warnings during type checking. + DenseHashSet globalsToWarn{""}; + bool shouldWarnGlobal(std::string name) const; + // For mutually recursive type aliases, it's important that // they use the same types for the same names. // For instance, in `type Tree { data: T, children: Forest } type Forest = {Tree}` diff --git a/Analysis/include/Luau/Subtyping.h b/Analysis/include/Luau/Subtyping.h index 1e781056..26c4553e 100644 --- a/Analysis/include/Luau/Subtyping.h +++ b/Analysis/include/Luau/Subtyping.h @@ -1,13 +1,14 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/DenseHash.h" +#include "Luau/EqSatSimplification.h" #include "Luau/Set.h" +#include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFunction.h" #include "Luau/TypeFwd.h" #include "Luau/TypePairHash.h" #include "Luau/TypePath.h" -#include "Luau/TypeFunction.h" -#include "Luau/TypeCheckLimits.h" -#include "Luau/DenseHash.h" #include #include @@ -134,6 +135,7 @@ struct Subtyping { NotNull builtinTypes; NotNull arena; + NotNull simplifier; NotNull normalizer; NotNull typeFunctionRuntime; NotNull iceReporter; @@ -155,6 +157,7 @@ struct Subtyping Subtyping( NotNull builtinTypes, NotNull typeArena, + NotNull simplifier, NotNull normalizer, NotNull typeFunctionRuntime, NotNull iceReporter diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 85957bed..9e525ac6 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -608,7 +608,8 @@ struct UserDefinedFunctionData // References to AST elements are owned by the Module allocator which also stores this type AstStatTypeFunction* definition = nullptr; - DenseHashMap environment{""}; + DenseHashMap> environment{""}; + DenseHashMap environment_DEPRECATED{""}; }; /** diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h index 3ede5ca7..306c6413 100644 --- a/Analysis/include/Luau/TypeChecker2.h +++ b/Analysis/include/Luau/TypeChecker2.h @@ -2,15 +2,16 @@ #pragma once -#include "Luau/Error.h" -#include "Luau/NotNull.h" #include "Luau/Common.h" -#include "Luau/TypeUtils.h" +#include "Luau/EqSatSimplification.h" +#include "Luau/Error.h" +#include "Luau/Normalize.h" +#include "Luau/NotNull.h" +#include "Luau/Subtyping.h" #include "Luau/Type.h" #include "Luau/TypeFwd.h" #include "Luau/TypeOrPack.h" -#include "Luau/Normalize.h" -#include "Luau/Subtyping.h" +#include "Luau/TypeUtils.h" namespace Luau { @@ -60,6 +61,7 @@ struct Reasonings void check( NotNull builtinTypes, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull sharedState, NotNull limits, @@ -71,6 +73,7 @@ void check( struct TypeChecker2 { NotNull builtinTypes; + NotNull simplifier; NotNull typeFunctionRuntime; DcrLogger* logger; const NotNull limits; @@ -90,6 +93,7 @@ struct TypeChecker2 TypeChecker2( NotNull builtinTypes, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull unifierState, NotNull limits, @@ -213,6 +217,9 @@ private: std::vector& errors ); + // Avoid duplicate warnings being emitted for the same global variable. + DenseHashSet warnedGlobals{""}; + void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) const; bool isErrorSuppressing(Location loc, TypeId ty); bool isErrorSuppressing(Location loc1, TypeId ty1, Location loc2, TypeId ty2); diff --git a/Analysis/include/Luau/TypeFunction.h b/Analysis/include/Luau/TypeFunction.h index df696b62..ba864621 100644 --- a/Analysis/include/Luau/TypeFunction.h +++ b/Analysis/include/Luau/TypeFunction.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Constraint.h" +#include "Luau/EqSatSimplification.h" #include "Luau/Error.h" #include "Luau/NotNull.h" #include "Luau/TypeCheckLimits.h" @@ -41,9 +42,15 @@ struct TypeFunctionRuntime StateRef state; + // Set of functions which have their environment table initialized + DenseHashSet initialized{nullptr}; + // Evaluation of type functions should only be performed in the absence of parse errors in the source module bool allowEvaluation = true; + // Output created by 'print' function + std::vector messages; + private: void prepareState(); }; @@ -53,6 +60,7 @@ struct TypeFunctionContext NotNull arena; NotNull builtins; NotNull scope; + NotNull simplifier; NotNull normalizer; NotNull typeFunctionRuntime; NotNull ice; @@ -71,6 +79,7 @@ struct TypeFunctionContext NotNull arena, NotNull builtins, NotNull scope, + NotNull simplifier, NotNull normalizer, NotNull typeFunctionRuntime, NotNull ice, @@ -79,6 +88,7 @@ struct TypeFunctionContext : arena(arena) , builtins(builtins) , scope(scope) + , simplifier(simplifier) , normalizer(normalizer) , typeFunctionRuntime(typeFunctionRuntime) , ice(ice) @@ -91,19 +101,31 @@ struct TypeFunctionContext NotNull pushConstraint(ConstraintV&& c) const; }; +enum class Reduction +{ + // The type function is either known to be reducible or the determination is blocked. + MaybeOk, + // The type function is known to be irreducible, but maybe not be erroneous, e.g. when it's over generics or free types. + Irreducible, + // The type function is known to be irreducible, and is definitely erroneous. + Erroneous, +}; + /// Represents a reduction result, which may have successfully reduced the type, /// may have concretely failed to reduce the type, or may simply be stuck /// without more information. template struct TypeFunctionReductionResult { + /// The result of the reduction, if any. If this is nullopt, the type function /// could not be reduced. std::optional result; - /// Whether the result is uninhabited: whether we know, unambiguously and - /// permanently, whether this type function reduction results in an - /// uninhabitable type. This will trigger an error to be reported. - bool uninhabited; + /// Indicates the status of this reduction: is `Reduction::Irreducible` if + /// the this result indicates the type function is irreducible, and + /// `Reduction::Erroneous` if this result indicates the type function is + /// erroneous. `Reduction::MaybeOk` otherwise. + Reduction reductionStatus; /// Any types that need to be progressed or mutated before the reduction may /// proceed. std::vector blockedTypes; @@ -112,6 +134,8 @@ struct TypeFunctionReductionResult std::vector blockedPacks; /// A runtime error message from user-defined type functions std::optional error; + /// Messages printed out from user-defined type functions + std::vector messages; }; template @@ -145,6 +169,7 @@ struct TypePackFunction struct FunctionGraphReductionResult { ErrorVec errors; + ErrorVec messages; DenseHashSet blockedTypes{nullptr}; DenseHashSet blockedPacks{nullptr}; DenseHashSet reducedTypes{nullptr}; diff --git a/Analysis/src/AutocompleteCore.cpp b/Analysis/src/AutocompleteCore.cpp index 3e231acf..f9e7e10f 100644 --- a/Analysis/src/AutocompleteCore.cpp +++ b/Analysis/src/AutocompleteCore.cpp @@ -150,6 +150,7 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, T { InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); + SimplifierPtr simplifier = newSimplifier(NotNull{typeArena}, builtinTypes); Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; if (FFlag::LuauSolverV2) @@ -162,7 +163,9 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, T unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; - Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}}; + Subtyping subtyping{ + builtinTypes, NotNull{typeArena}, NotNull{simplifier.get()}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter} + }; return subtyping.isSubtype(subTy, superTy, scope).isSubtype; } diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 56b89354..e6ecc8d0 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -33,10 +33,14 @@ LUAU_FASTFLAG(DebugLuauLogSolverToJson) LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(DebugLuauEqSatSimplification) LUAU_FASTFLAG(LuauTypestateBuiltins2) +LUAU_FASTFLAG(LuauUserTypeFunUpdateAllEnvs) LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues) LUAU_FASTFLAGVARIABLE(LuauUserTypeFunExportedAndLocal) LUAU_FASTFLAGVARIABLE(LuauNewSolverPopulateTableLocations) +LUAU_FASTFLAGVARIABLE(LuauUserTypeFunNoExtraConstraint) + +LUAU_FASTFLAGVARIABLE(InferGlobalTypes) namespace Luau { @@ -791,9 +795,10 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc } // Fill it with all visible type functions - if (mainTypeFun) + if (FFlag::LuauUserTypeFunUpdateAllEnvs && mainTypeFun) { UserDefinedFunctionData& userFuncData = mainTypeFun->userFuncData; + size_t level = 0; for (Scope* curr = scope.get(); curr; curr = curr->parent.get()) { @@ -803,7 +808,7 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc continue; if (auto ty = get(tf.type); ty && ty->userFuncData.definition) - userFuncData.environment[name] = ty->userFuncData.definition; + userFuncData.environment[name] = std::make_pair(ty->userFuncData.definition, level); } for (auto& [name, tf] : curr->exportedTypeBindings) @@ -812,7 +817,34 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc continue; if (auto ty = get(tf.type); ty && ty->userFuncData.definition) - userFuncData.environment[name] = ty->userFuncData.definition; + userFuncData.environment[name] = std::make_pair(ty->userFuncData.definition, level); + } + + level++; + } + } + else if (mainTypeFun) + { + UserDefinedFunctionData& userFuncData = mainTypeFun->userFuncData; + + for (Scope* curr = scope.get(); curr; curr = curr->parent.get()) + { + for (auto& [name, tf] : curr->privateTypeBindings) + { + if (userFuncData.environment_DEPRECATED.find(name)) + continue; + + if (auto ty = get(tf.type); ty && ty->userFuncData.definition) + userFuncData.environment_DEPRECATED[name] = ty->userFuncData.definition; + } + + for (auto& [name, tf] : curr->exportedTypeBindings) + { + if (userFuncData.environment_DEPRECATED.find(name)) + continue; + + if (auto ty = get(tf.type); ty && ty->userFuncData.definition) + userFuncData.environment_DEPRECATED[name] = ty->userFuncData.definition; } } } @@ -1543,18 +1575,22 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeAlias* ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeFunction* function) { - // If a type function with the same name was already defined, we skip over - auto bindingIt = scope->privateTypeBindings.find(function->name.value); - if (bindingIt == scope->privateTypeBindings.end()) - return ControlFlow::None; - - TypeFun typeFunction = bindingIt->second; - - // Adding typeAliasExpansionConstraint on user-defined type function for the constraint solver - if (auto typeFunctionTy = get(follow(typeFunction.type))) + if (!FFlag::LuauUserTypeFunNoExtraConstraint) { - TypeId expansionTy = arena->addType(PendingExpansionType{{}, function->name, typeFunctionTy->typeArguments, typeFunctionTy->packArguments}); - addConstraint(scope, function->location, TypeAliasExpansionConstraint{/* target */ expansionTy}); + // If a type function with the same name was already defined, we skip over + auto bindingIt = scope->privateTypeBindings.find(function->name.value); + if (bindingIt == scope->privateTypeBindings.end()) + return ControlFlow::None; + + TypeFun typeFunction = bindingIt->second; + + // Adding typeAliasExpansionConstraint on user-defined type function for the constraint solver + if (auto typeFunctionTy = get(follow(typeFunction.type))) + { + TypeId expansionTy = + arena->addType(PendingExpansionType{{}, function->name, typeFunctionTy->typeArguments, typeFunctionTy->packArguments}); + addConstraint(scope, function->location, TypeAliasExpansionConstraint{/* target */ expansionTy}); + } } return ControlFlow::None; @@ -2747,6 +2783,14 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* glob DefId def = dfg->getDef(global); rootScope->lvalueTypes[def] = rhsType; + if (FFlag::InferGlobalTypes) + { + // Sketchy: We're specifically looking for BlockedTypes that were + // initially created by ConstraintGenerator::prepopulateGlobalScope. + if (auto bt = get(follow(*annotatedTy)); bt && !bt->getOwner()) + emplaceType(asMutable(*annotatedTy), rhsType); + } + addConstraint(scope, global->location, SubtypeConstraint{rhsType, *annotatedTy}); } } @@ -3139,9 +3183,9 @@ TypeId ConstraintGenerator::resolveReferenceType( if (alias.has_value()) { - // If the alias is not generic, we don't need to set up a blocked - // type and an instantiation constraint. - if (alias.has_value() && alias->typeParams.empty() && alias->typePackParams.empty()) + // If the alias is not generic, we don't need to set up a blocked type and an instantiation constraint + if (alias.has_value() && alias->typeParams.empty() && alias->typePackParams.empty() && + (!FFlag::LuauUserTypeFunNoExtraConstraint || !ref->hasParameterList)) { result = alias->type; } @@ -3651,6 +3695,26 @@ struct GlobalPrepopulator : AstVisitor return true; } + bool visit(AstStatAssign* assign) override + { + if (FFlag::InferGlobalTypes) + { + for (const Luau::AstExpr* expr : assign->vars) + { + if (const AstExprGlobal* g = expr->as()) + { + if (!globalScope->lookup(g->name)) + globalScope->globalsToWarn.insert(g->name.value); + + TypeId bt = arena->addType(BlockedType{}); + globalScope->bindings[g->name] = Binding{bt, g->location}; + } + } + } + + return true; + } + bool visit(AstStatFunction* function) override { if (AstExprGlobal* g = function->name->as()) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index c29aa915..f0cd03f2 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -35,6 +35,7 @@ LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack) LUAU_FASTFLAGVARIABLE(DebugLuauEqSatSimplification) LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations) LUAU_FASTFLAGVARIABLE(LuauAllowNilAssignmentToIndexer) +LUAU_FASTFLAG(LuauUserTypeFunNoExtraConstraint) namespace Luau { @@ -939,12 +940,14 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul if (auto typeFn = get(follow(tf->type))) pushConstraint(NotNull(constraint->scope.get()), constraint->location, ReduceConstraint{tf->type}); - // If there are no parameters to the type function we can just use the type - // directly. - if (tf->typeParams.empty() && tf->typePackParams.empty()) + if (!FFlag::LuauUserTypeFunNoExtraConstraint) { - bindResult(tf->type); - return true; + // If there are no parameters to the type function we can just use the type directly + if (tf->typeParams.empty() && tf->typePackParams.empty()) + { + bindResult(tf->type); + return true; + } } // Due to how pending expansion types and TypeFun's are created @@ -959,6 +962,16 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul return true; } + if (FFlag::LuauUserTypeFunNoExtraConstraint) + { + // If there are no parameters to the type function we can just use the type directly + if (tf->typeParams.empty() && tf->typePackParams.empty()) + { + bindResult(tf->type); + return true; + } + } + auto [typeArguments, packArguments] = saturateArguments(arena, builtinTypes, *tf, petv->typeArguments, petv->packArguments); bool sameTypes = std::equal( @@ -1263,6 +1276,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope, @@ -2102,6 +2116,11 @@ bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNull() || node.get(); } +static bool areTerminalAndDefinitelyDisjoint(const EType& lhs, const EType& rhs) +{ + // If either node is non-terminal, then we early exit: we're not going to + // do a state space search for whether something like: + // (A | B | C | D) & (E | F | G | H) + // ... is a disjoint intersection. + if (!isTerminal(lhs) || !isTerminal(rhs)) + return false; + + // Special case some types that aren't strict, disjoint subsets. + if (lhs.get() || lhs.get()) + return !(rhs.get() || rhs.get()); + + // Handling strings / booleans: these are the types for which we + // expect something like: + // + // "foo" & ~"bar" + // + // ... to simplify to "foo". + if (lhs.get()) + return !(rhs.get() || rhs.get()); + + if (lhs.get()) + return !(rhs.get() || rhs.get()); + + if (auto lhsSString = lhs.get()) + { + auto rhsSString = rhs.get(); + if (!rhsSString) + return !rhs.get(); + return lhsSString->value() != rhsSString->value(); + } + + if (auto lhsSBoolean = lhs.get()) + { + auto rhsSBoolean = rhs.get(); + if (!rhsSBoolean) + return !rhs.get(); + return lhsSBoolean->value() != rhsSBoolean->value(); + } + + // At this point: + // - We know both nodes are terminal + // - We know that the LHS is not any boolean, string, or class + // At this point, we have two classes of checks left: + // - Whether the two enodes are exactly the same set (now that the static + // sets have been covered). + // - Whether one of the enodes is a large semantic set such as TAny, + // TUnknown, or TError. + return !( + lhs.index() == rhs.index() || + lhs.get() || rhs.get() || lhs.get() || rhs.get() || lhs.get() || rhs.get() || + lhs.get() || rhs.get() || lhs.get() || rhs.get() + ); +} + static bool isTerminal(const EGraph& egraph, Id eclass) { const auto& nodes = egraph[eclass].nodes; @@ -336,10 +392,20 @@ Id toId( LUAU_ASSERT(tfun->packArguments.empty()); std::vector parts; + parts.reserve(tfun->typeArguments.size()); for (TypeId part : tfun->typeArguments) parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part)); - return cache(egraph.add(TTypeFun{tfun->function.get(), std::move(parts)})); + // This looks sily, but we're making a copy of the specific + // `TypeFunctionInstanceType` outside of the provided arena so that + // we can access the members without fear of the specific TFIT being + // overwritten with a bound type. + return cache(egraph.add(TTypeFun{ + std::make_shared( + tfun->function, tfun->typeArguments, tfun->packArguments, tfun->userFuncName, tfun->userFuncData + ), + std::move(parts) + })); } else if (get(ty)) return egraph.add(TNoRefine{}); @@ -574,18 +640,24 @@ TypeId flattenTableNode( // If a TTable is its own basis, it must be the case that some other // node on this eclass is a TImportedTable. Let's use that. + bool found = false; + for (size_t i = 0; i < eclass.nodes.size(); ++i) { if (eclass.nodes[i].get()) { + found = true; index = i; break; } } - // If we couldn't find one, we don't know what to do. Use ErrorType. - LUAU_ASSERT(0); - return builtinTypes->errorType; + if (!found) + { + // If we couldn't find one, we don't know what to do. Use ErrorType. + LUAU_ASSERT(0); + return builtinTypes->errorType; + } } const auto& node = eclass.nodes[index]; @@ -703,7 +775,20 @@ TypeId fromId( if (parts.empty()) return builtinTypes->neverType; else if (parts.size() == 1) - return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, parts[0]); + { + TypeId placeholder = arena->addType(BlockedType{}); + seen[rootId] = placeholder; + auto result = fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, parts[0]); + if (follow(result) == placeholder) + { + emplaceType(asMutable(placeholder), "EGRAPH-SINGLETON-CYCLE"); + } + else + { + emplaceType(asMutable(placeholder), result); + } + return result; + } else { TypeId res = arena->addType(BlockedType{}); @@ -768,7 +853,11 @@ TypeId fromId( for (Id part : tfun->operands()) args.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part)); - asMutable(res)->ty.emplace(*tfun->value(), std::move(args)); + auto oldInstance = tfun->value(); + + asMutable(res)->ty.emplace( + oldInstance->function, std::move(args), std::vector(), oldInstance->userFuncName, oldInstance->userFuncData + ); newTypeFunctions.push_back(res); @@ -906,7 +995,7 @@ static std::string getNodeName(const StringCache& strings, const EType& node) else if (node.get()) return "never"; else if (auto tfun = node.get()) - return "tfun " + tfun->value()->name; + return "tfun " + tfun->value()->function->name; else if (node.get()) return "~"; else if (node.get()) @@ -1552,11 +1641,6 @@ std::optional intersectOne(EGraph& egraph, Id hereId, const EType* hereNo thereNode->get() || thereNode->get() || hereNode->get() || thereNode->get()) return std::nullopt; - if (hereNode->get()) - return *thereNode; - if (thereNode->get()) - return *hereNode; - if (hereNode->get()) return *thereNode; if (thereNode->get()) @@ -1839,6 +1923,74 @@ void Simplifier::intersectWithNegatedClass(Id id) } } +void Simplifier::intersectWithNegatedAtom(Id id) +{ + // Let I and ~J be two arbitrary distinct operands of an intersection where + // I and J are terminal but are not type variables. (free, generic, or + // otherwise opaque) + // + // If I and J are equal, then the whole intersection is equivalent to never. + // + // If I and J are inequal, then J & ~I == J + + for (const auto [intersection, intersectionIndex] : Query(&egraph, id)) + { + const Slice& intersectionOperands = intersection->operands(); + for (size_t i = 0; i < intersectionOperands.size(); ++i) + { + for (const auto [negation, negationIndex] : Query(&egraph, intersectionOperands[i])) + { + for (size_t negationOperandIndex = 0; negationOperandIndex < egraph[negation->operands()[0]].nodes.size(); ++negationOperandIndex) + { + const EType* negationOperand = &egraph[negation->operands()[0]].nodes[negationOperandIndex]; + if (!isTerminal(*negationOperand) || negationOperand->get()) + continue; + + for (size_t j = 0; j < intersectionOperands.size(); ++j) + { + if (j == i) + continue; + + for (size_t jNodeIndex = 0; jNodeIndex < egraph[intersectionOperands[j]].nodes.size(); ++jNodeIndex) + { + const EType* jNode = &egraph[intersectionOperands[j]].nodes[jNodeIndex]; + if (!isTerminal(*jNode) || jNode->get()) + continue; + + if (*negationOperand == *jNode) + { + // eg "Hello" & ~"Hello" + // or boolean & ~boolean + subst( + id, + egraph.add(TNever{}), + "intersectWithNegatedAtom", + {{id, intersectionIndex}, {intersectionOperands[i], negationIndex}, {intersectionOperands[j], jNodeIndex}} + ); + return; + } + else if (areTerminalAndDefinitelyDisjoint(*jNode, *negationOperand)) + { + // eg "Hello" & ~"World" + // or boolean & ~string + std::vector newOperands(intersectionOperands.begin(), intersectionOperands.end()); + newOperands.erase(newOperands.begin() + std::vector::difference_type(i)); + + subst( + id, + egraph.add(Intersection{newOperands}), + "intersectWithNegatedAtom", + {{id, intersectionIndex}, {intersectionOperands[i], negationIndex}, {intersectionOperands[j], jNodeIndex}} + ); + } + } + } + } + } + } + } +} + void Simplifier::intersectWithNoRefine(Id id) { for (const auto pair : Query(&egraph, id)) @@ -2160,7 +2312,7 @@ void Simplifier::intersectTableProperty(Id id) subst( id, - egraph.add(Intersection{std::move(newIntersectionParts)}), + mkIntersection(egraph, std::move(newIntersectionParts)), "intersectTableProperty", {{id, intersectionIndex}, {iId, table1Index}, {jId, table2Index}} ); @@ -2250,7 +2402,7 @@ void Simplifier::builtinTypeFunctions(Id id) if (args.size() != 2) continue; - const std::string& name = tfun->value()->name; + const std::string& name = tfun->value()->function->name; if (name == "add" || name == "sub" || name == "mul" || name == "div" || name == "idiv" || name == "pow" || name == "mod") { if (isTag(args[0]) && isTag(args[1])) @@ -2272,15 +2424,43 @@ void Simplifier::iffyTypeFunctions(Id id) { const Slice& args = tfun->operands(); - const std::string& name = tfun->value()->name; + const std::string& name = tfun->value()->function->name; if (name == "union") subst(id, add(Union{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}}); - else if (name == "intersect" || name == "refine") + else if (name == "intersect") subst(id, add(Intersection{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}}); } } +// Replace instances of `lt` and `le` when either X or Y is `number` +// or `string` with `boolean`. Lua semantics are that if we see the expression: +// +// x < y +// +// ... we error if `x` and `y` don't have the same type. We know that for +// `string` and `number`, comparisons will always return a boolean. So if either +// of the arguments to `lt<>` are equivalent to `number` or `string`, then the +// type is effectively `boolean`: either the other type is equivalent, in which +// case we eval to `boolean`, or we diverge (raise an error). +void Simplifier::strictMetamethods(Id id) +{ + for (const auto [tfun, index] : Query(&egraph, id)) + { + const Slice& args = tfun->operands(); + + const std::string& name = tfun->value()->function->name; + + if (!(name == "lt" || name == "le") || args.size() != 2) + continue; + + if (isTag(args[0]) || isTag(args[0]) || isTag(args[1]) || isTag(args[1])) + { + subst(id, add(TBoolean{}), __FUNCTION__, {{id, index}}); + } + } +} + static void deleteSimplifier(Simplifier* s) { delete s; @@ -2308,6 +2488,7 @@ std::optional eqSatSimplify(NotNull simpl &Simplifier::simplifyUnion, &Simplifier::uninhabitedIntersection, &Simplifier::intersectWithNegatedClass, + &Simplifier::intersectWithNegatedAtom, &Simplifier::intersectWithNoRefine, &Simplifier::cyclicIntersectionOfUnion, &Simplifier::cyclicUnionOfIntersection, @@ -2318,6 +2499,7 @@ std::optional eqSatSimplify(NotNull simpl &Simplifier::unneededTableModification, &Simplifier::builtinTypeFunctions, &Simplifier::iffyTypeFunctions, + &Simplifier::strictMetamethods, }; std::unordered_set seen; diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 053e99c2..f7164256 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -50,6 +50,8 @@ LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false) LUAU_FASTFLAG(StudioReportLuauAny2) LUAU_FASTFLAGVARIABLE(LuauStoreSolverTypeOnModule) +LUAU_FASTFLAGVARIABLE(LuauReferenceAllocatorInNewSolver) + namespace Luau { @@ -1317,6 +1319,11 @@ ModulePtr check( result->mode = mode; result->internalTypes.owningModule = result.get(); result->interfaceTypes.owningModule = result.get(); + if (FFlag::LuauReferenceAllocatorInNewSolver) + { + result->allocator = sourceModule.allocator; + result->names = sourceModule.names; + } iceHandler->moduleName = sourceModule.name; @@ -1427,6 +1434,7 @@ ModulePtr check( case Mode::Nonstrict: Luau::checkNonStrict( builtinTypes, + NotNull{simplifier.get()}, NotNull{&typeFunctionRuntime}, iceHandler, NotNull{&unifierState}, @@ -1440,7 +1448,14 @@ ModulePtr check( // fallthrough intentional case Mode::Strict: Luau::check( - builtinTypes, NotNull{&typeFunctionRuntime}, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get() + builtinTypes, + NotNull{simplifier.get()}, + NotNull{&typeFunctionRuntime}, + NotNull{&unifierState}, + NotNull{&limits}, + logger.get(), + sourceModule, + result.get() ); break; case Mode::NoCheck: diff --git a/Analysis/src/Generalization.cpp b/Analysis/src/Generalization.cpp index 3eb14fda..506087ba 100644 --- a/Analysis/src/Generalization.cpp +++ b/Analysis/src/Generalization.cpp @@ -10,12 +10,14 @@ #include "Luau/VisitType.h" LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete) +LUAU_FASTFLAGVARIABLE(LuauGeneralizationRemoveRecursiveUpperBound) namespace Luau { struct MutatingGeneralizer : TypeOnceVisitor { + NotNull arena; NotNull builtinTypes; NotNull scope; @@ -29,6 +31,7 @@ struct MutatingGeneralizer : TypeOnceVisitor bool avoidSealingTables = false; MutatingGeneralizer( + NotNull arena, NotNull builtinTypes, NotNull scope, NotNull> cachedTypes, @@ -37,6 +40,7 @@ struct MutatingGeneralizer : TypeOnceVisitor bool avoidSealingTables ) : TypeOnceVisitor(/* skipBoundTypes */ true) + , arena(arena) , builtinTypes(builtinTypes) , scope(scope) , cachedTypes(cachedTypes) @@ -229,6 +233,53 @@ struct MutatingGeneralizer : TypeOnceVisitor else { TypeId ub = follow(ft->upperBound); + if (FFlag::LuauGeneralizationRemoveRecursiveUpperBound) + { + + // If the upper bound is a union type or an intersection type, + // and one of it's members is the free type we're + // generalizing, don't include it in the upper bound. For a + // free type such as: + // + // t1 where t1 = D <: 'a <: (A | B | C | t1) + // + // Naively replacing it with it's upper bound creates: + // + // t1 where t1 = A | B | C | t1 + // + // It makes sense to just optimize this and exclude the + // recursive component by semantic subtyping rules. + + if (auto itv = get(ub)) + { + std::vector newIds; + newIds.reserve(itv->parts.size()); + for (auto part : itv) + { + if (part != ty) + newIds.push_back(part); + } + if (newIds.size() == 1) + ub = newIds[0]; + else if (newIds.size() > 0) + ub = arena->addType(IntersectionType{std::move(newIds)}); + } + else if (auto utv = get(ub)) + { + std::vector newIds; + newIds.reserve(utv->options.size()); + for (auto part : utv) + { + if (part != ty) + newIds.push_back(part); + } + if (newIds.size() == 1) + ub = newIds[0]; + else if (newIds.size() > 0) + ub = arena->addType(UnionType{std::move(newIds)}); + } + } + if (FreeType* upperFree = getMutable(ub); upperFree && upperFree->lowerBound == ty) upperFree->lowerBound = builtinTypes->neverType; else @@ -969,7 +1020,7 @@ std::optional generalize( FreeTypeSearcher fts{scope, cachedTypes}; fts.traverse(ty); - MutatingGeneralizer gen{builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables}; + MutatingGeneralizer gen{arena, builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables}; gen.traverse(ty); diff --git a/Analysis/src/NonStrictTypeChecker.cpp b/Analysis/src/NonStrictTypeChecker.cpp index b4a5eaf6..f830d126 100644 --- a/Analysis/src/NonStrictTypeChecker.cpp +++ b/Analysis/src/NonStrictTypeChecker.cpp @@ -19,7 +19,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauUserTypeFunNonstrict) LUAU_FASTFLAGVARIABLE(LuauCountSelfCallsNonstrict) namespace Luau @@ -158,6 +157,7 @@ private: struct NonStrictTypeChecker { NotNull builtinTypes; + NotNull simplifier; NotNull typeFunctionRuntime; const NotNull ice; NotNull arena; @@ -174,6 +174,7 @@ struct NonStrictTypeChecker NonStrictTypeChecker( NotNull arena, NotNull builtinTypes, + NotNull simplifier, NotNull typeFunctionRuntime, const NotNull ice, NotNull unifierState, @@ -182,12 +183,13 @@ struct NonStrictTypeChecker Module* module ) : builtinTypes(builtinTypes) + , simplifier(simplifier) , typeFunctionRuntime(typeFunctionRuntime) , ice(ice) , arena(arena) , module(module) , normalizer{arena, builtinTypes, unifierState, /* cache inhabitance */ true} - , subtyping{builtinTypes, arena, NotNull(&normalizer), typeFunctionRuntime, ice} + , subtyping{builtinTypes, arena, simplifier, NotNull(&normalizer), typeFunctionRuntime, ice} , dfg(dfg) , limits(limits) { @@ -232,13 +234,14 @@ struct NonStrictTypeChecker if (noTypeFunctionErrors.find(instance)) return instance; - ErrorVec errors = reduceTypeFunctions( - instance, - location, - TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, typeFunctionRuntime, ice, limits}, - true - ) - .errors; + ErrorVec errors = + reduceTypeFunctions( + instance, + location, + TypeFunctionContext{arena, builtinTypes, stack.back(), simplifier, NotNull{&normalizer}, typeFunctionRuntime, ice, limits}, + true + ) + .errors; if (errors.empty()) noTypeFunctionErrors.insert(instance); @@ -424,9 +427,6 @@ struct NonStrictTypeChecker NonStrictContext visit(AstStatTypeFunction* typeFunc) { - if (!FFlag::LuauUserTypeFunNonstrict) - reportError(GenericError{"This syntax is not supported"}, typeFunc->location); - return {}; } @@ -888,6 +888,7 @@ private: void checkNonStrict( NotNull builtinTypes, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull ice, NotNull unifierState, @@ -899,7 +900,9 @@ void checkNonStrict( { LUAU_TIMETRACE_SCOPE("checkNonStrict", "Typechecking"); - NonStrictTypeChecker typeChecker{NotNull{&module->internalTypes}, builtinTypes, typeFunctionRuntime, ice, unifierState, dfg, limits, module}; + NonStrictTypeChecker typeChecker{ + NotNull{&module->internalTypes}, builtinTypes, simplifier, typeFunctionRuntime, ice, unifierState, dfg, limits, module + }; typeChecker.visit(sourceModule.root); unfreeze(module->interfaceTypes); copyErrors(module->errors, module->interfaceTypes, builtinTypes); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index a2b440b9..2c3cb162 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -3465,7 +3465,14 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) return arena->addType(UnionType{std::move(result)}); } -bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) +bool isSubtype( + TypeId subTy, + TypeId superTy, + NotNull scope, + NotNull builtinTypes, + NotNull simplifier, + InternalErrorReporter& ice +) { UnifierSharedState sharedState{&ice}; TypeArena arena; @@ -3478,7 +3485,7 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull scope, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) +bool isSubtype( + TypePackId subPack, + TypePackId superPack, + NotNull scope, + NotNull builtinTypes, + NotNull simplifier, + InternalErrorReporter& ice +) { UnifierSharedState sharedState{&ice}; TypeArena arena; @@ -3504,7 +3518,7 @@ bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, N // Subtyping under DCR is not implemented using unification! if (FFlag::LuauSolverV2) { - Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}}; + Subtyping subtyping{builtinTypes, NotNull{&arena}, simplifier, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}}; return subtyping.isSubtype(subPack, superPack, scope).isSubtype; } diff --git a/Analysis/src/OverloadResolution.cpp b/Analysis/src/OverloadResolution.cpp index f5557f2d..e8471264 100644 --- a/Analysis/src/OverloadResolution.cpp +++ b/Analysis/src/OverloadResolution.cpp @@ -16,6 +16,7 @@ namespace Luau OverloadResolver::OverloadResolver( NotNull builtinTypes, NotNull arena, + NotNull simplifier, NotNull normalizer, NotNull typeFunctionRuntime, NotNull scope, @@ -25,12 +26,13 @@ OverloadResolver::OverloadResolver( ) : builtinTypes(builtinTypes) , arena(arena) + , simplifier(simplifier) , normalizer(normalizer) , typeFunctionRuntime(typeFunctionRuntime) , scope(scope) , ice(reporter) , limits(limits) - , subtyping({builtinTypes, arena, normalizer, typeFunctionRuntime, ice}) + , subtyping({builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, ice}) , callLoc(callLocation) { } @@ -202,7 +204,7 @@ std::pair OverloadResolver::checkOverload_ ) { FunctionGraphReductionResult result = reduceTypeFunctions( - fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, ice, limits}, /*force=*/true + fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, simplifier, normalizer, typeFunctionRuntime, ice, limits}, /*force=*/true ); if (!result.errors.empty()) return {OverloadIsNonviable, result.errors}; @@ -404,9 +406,10 @@ void OverloadResolver::add(Analysis analysis, TypeId ty, ErrorVec&& errors) // we wrap calling the overload resolver in a separate function to reduce overall stack pressure in `solveFunctionCall`. // this limits the lifetime of `OverloadResolver`, a large type, to only as long as it is actually needed. -std::optional selectOverload( +static std::optional selectOverload( NotNull builtinTypes, NotNull arena, + NotNull simplifier, NotNull normalizer, NotNull typeFunctionRuntime, NotNull scope, @@ -417,7 +420,7 @@ std::optional selectOverload( TypePackId argsPack ) { - auto resolver = std::make_unique(builtinTypes, arena, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location); + auto resolver = std::make_unique(builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location); auto [status, overload] = resolver->selectOverload(fn, argsPack); if (status == OverloadResolver::Analysis::Ok) @@ -432,6 +435,7 @@ std::optional selectOverload( SolveResult solveFunctionCall( NotNull arena, NotNull builtinTypes, + NotNull simplifier, NotNull normalizer, NotNull typeFunctionRuntime, NotNull iceReporter, @@ -443,7 +447,7 @@ SolveResult solveFunctionCall( ) { std::optional overloadToUse = - selectOverload(builtinTypes, arena, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location, fn, argsPack); + selectOverload(builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location, fn, argsPack); if (!overloadToUse) return {SolveResult::NoMatchingOverload}; diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 27894505..db99d827 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -211,6 +211,16 @@ void Scope::inheritRefinements(const ScopePtr& childScope) } } +bool Scope::shouldWarnGlobal(std::string name) const +{ + for (const Scope* current = this; current; current = current->parent.get()) + { + if (current->globalsToWarn.contains(name)) + return true; + } + return false; +} + bool subsumesStrict(Scope* left, Scope* right) { while (right) diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index a93d910b..6f3a6f26 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -396,12 +396,14 @@ TypePackId* SubtypingEnvironment::getMappedPackBounds(TypePackId tp) Subtyping::Subtyping( NotNull builtinTypes, NotNull typeArena, + NotNull simplifier, NotNull normalizer, NotNull typeFunctionRuntime, NotNull iceReporter ) : builtinTypes(builtinTypes) , arena(typeArena) + , simplifier(simplifier) , normalizer(normalizer) , typeFunctionRuntime(typeFunctionRuntime) , iceReporter(iceReporter) @@ -1861,7 +1863,7 @@ TypeId Subtyping::makeAggregateType(const Container& container, TypeId orElse) std::pair Subtyping::handleTypeFunctionReductionResult(const TypeFunctionInstanceType* functionInstance, NotNull scope) { - TypeFunctionContext context{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, iceReporter, NotNull{&limits}}; + TypeFunctionContext context{arena, builtinTypes, scope, simplifier, normalizer, typeFunctionRuntime, iceReporter, NotNull{&limits}}; TypeId function = arena->addType(*functionInstance); FunctionGraphReductionResult result = reduceTypeFunctions(function, {}, context, true); ErrorVec errors; diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 6f28b11c..6fc60b2f 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -386,8 +386,12 @@ public: } AstType* operator()(const NegationType& ntv) { - // FIXME: do the same thing we do with ErrorType - throw InternalCompilerError("Cannot convert NegationType into AstNode"); + AstArray params; + params.size = 1; + params.data = static_cast(allocator->allocate(sizeof(AstType*))); + params.data[0] = AstTypeOrPack{Luau::visit(*this, ntv.ty->ty), nullptr}; + + return allocator->alloc(Location(), std::nullopt, AstName("negate"), std::nullopt, Location(), true, params); } AstType* operator()(const TypeFunctionInstanceType& tfit) { diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 5397a0e8..655abfa7 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -32,6 +32,7 @@ LUAU_FASTFLAG(DebugLuauMagicTypes) +LUAU_FASTFLAG(InferGlobalTypes) LUAU_FASTFLAGVARIABLE(LuauTableKeysAreRValues) namespace Luau @@ -268,6 +269,7 @@ struct InternalTypeFunctionFinder : TypeOnceVisitor void check( NotNull builtinTypes, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull unifierState, NotNull limits, @@ -278,7 +280,7 @@ void check( { LUAU_TIMETRACE_SCOPE("check", "Typechecking"); - TypeChecker2 typeChecker{builtinTypes, typeFunctionRuntime, unifierState, limits, logger, &sourceModule, module}; + TypeChecker2 typeChecker{builtinTypes, simplifier, typeFunctionRuntime, unifierState, limits, logger, &sourceModule, module}; typeChecker.visit(sourceModule.root); @@ -295,6 +297,7 @@ void check( TypeChecker2::TypeChecker2( NotNull builtinTypes, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull unifierState, NotNull limits, @@ -303,6 +306,7 @@ TypeChecker2::TypeChecker2( Module* module ) : builtinTypes(builtinTypes) + , simplifier(simplifier) , typeFunctionRuntime(typeFunctionRuntime) , logger(logger) , limits(limits) @@ -310,7 +314,7 @@ TypeChecker2::TypeChecker2( , sourceModule(sourceModule) , module(module) , normalizer{&module->internalTypes, builtinTypes, unifierState, /* cacheInhabitance */ true} - , _subtyping{builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, typeFunctionRuntime, NotNull{unifierState->iceHandler}} + , _subtyping{builtinTypes, NotNull{&module->internalTypes}, simplifier, NotNull{&normalizer}, typeFunctionRuntime, NotNull{unifierState->iceHandler}} , subtyping(&_subtyping) { } @@ -492,7 +496,9 @@ TypeId TypeChecker2::checkForTypeFunctionInhabitance(TypeId instance, Location l reduceTypeFunctions( instance, location, - TypeFunctionContext{NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, typeFunctionRuntime, ice, limits}, + TypeFunctionContext{ + NotNull{&module->internalTypes}, builtinTypes, stack.back(), simplifier, NotNull{&normalizer}, typeFunctionRuntime, ice, limits + }, true ) .errors; @@ -1349,7 +1355,17 @@ void TypeChecker2::visit(AstExprGlobal* expr) { NotNull scope = stack.back(); if (!scope->lookup(expr->name)) + { reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location); + } + else if (FFlag::InferGlobalTypes) + { + if (scope->shouldWarnGlobal(expr->name.value) && !warnedGlobals.contains(expr->name.value)) + { + reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location); + warnedGlobals.insert(expr->name.value); + } + } } void TypeChecker2::visit(AstExprVarargs* expr) @@ -1448,6 +1464,7 @@ void TypeChecker2::visitCall(AstExprCall* call) OverloadResolver resolver{ builtinTypes, NotNull{&module->internalTypes}, + simplifier, NotNull{&normalizer}, typeFunctionRuntime, NotNull{stack.back()}, @@ -3024,7 +3041,7 @@ PropertyType TypeChecker2::hasIndexTypeFromType( { TypeId indexType = follow(tt->indexer->indexType); TypeId givenType = module->internalTypes.addType(SingletonType{StringSingleton{prop}}); - if (isSubtype(givenType, indexType, NotNull{module->getModuleScope().get()}, builtinTypes, *ice)) + if (isSubtype(givenType, indexType, NotNull{module->getModuleScope().get()}, builtinTypes, simplifier, *ice)) return {NormalizationResult::True, {tt->indexer->indexResultType}}; } diff --git a/Analysis/src/TypeFunction.cpp b/Analysis/src/TypeFunction.cpp index e7620cc4..64680eca 100644 --- a/Analysis/src/TypeFunction.cpp +++ b/Analysis/src/TypeFunction.cpp @@ -46,10 +46,11 @@ LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyApplicationCartesianProductLimit, 5'0 LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyUseGuesserDepth, -1); LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies) -LUAU_FASTFLAG(LuauUserTypeFunFixRegister) +LUAU_FASTFLAG(DebugLuauEqSatSimplification) +LUAU_FASTFLAG(LuauUserTypeFunPrintToError) LUAU_FASTFLAG(LuauRemoveNotAnyHack) -LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionResetState) LUAU_FASTFLAG(LuauUserTypeFunExportedAndLocal) +LUAU_FASTFLAGVARIABLE(LuauUserTypeFunUpdateAllEnvs) namespace Luau { @@ -220,6 +221,12 @@ struct TypeFunctionReducer template void handleTypeFunctionReduction(T subject, TypeFunctionReductionResult reduction) { + if (FFlag::LuauUserTypeFunPrintToError) + { + for (auto& message : reduction.messages) + result.messages.emplace_back(location, UserDefinedTypeFunctionError{std::move(message)}); + } + if (reduction.result) replace(subject, *reduction.result); else @@ -229,7 +236,7 @@ struct TypeFunctionReducer if (reduction.error.has_value()) result.errors.emplace_back(location, UserDefinedTypeFunctionError{*reduction.error}); - if (reduction.uninhabited || force) + if (reduction.reductionStatus != Reduction::MaybeOk || force) { if (FFlag::DebugLuauLogTypeFamilies) printf("%s is uninhabited\n", toString(subject, {true}).c_str()); @@ -239,7 +246,7 @@ struct TypeFunctionReducer else if constexpr (std::is_same_v) result.errors.emplace_back(location, UninhabitedTypePackFunction{subject}); } - else if (!reduction.uninhabited && !force) + else if (reduction.reductionStatus == Reduction::MaybeOk && !force) { if (FFlag::DebugLuauLogTypeFamilies) printf( @@ -528,7 +535,7 @@ static std::optional> tryDistributeTypeFunct ) { // op (a | b) (c | d) ~ (op a (c | d)) | (op b (c | d)) ~ (op a c) | (op a d) | (op b c) | (op b d) - bool uninhabited = false; + Reduction reductionStatus = Reduction::MaybeOk; std::vector blockedTypes; std::vector results; size_t cartesianProductSize = 1; @@ -557,7 +564,7 @@ static std::optional> tryDistributeTypeFunct // TODO: We'd like to report that the type function application is too complex here. if (size_t(DFInt::LuauTypeFamilyApplicationCartesianProductLimit) <= cartesianProductSize) - return {{std::nullopt, true, {}, {}}}; + return {{std::nullopt, Reduction::Erroneous, {}, {}}}; } if (!firstUnion) @@ -572,21 +579,22 @@ static std::optional> tryDistributeTypeFunct TypeFunctionReductionResult result = f(instance, arguments, packParams, ctx, args...); blockedTypes.insert(blockedTypes.end(), result.blockedTypes.begin(), result.blockedTypes.end()); - uninhabited |= result.uninhabited; + if (result.reductionStatus != Reduction::MaybeOk) + reductionStatus = result.reductionStatus; - if (result.uninhabited || !result.result) + if (reductionStatus != Reduction::MaybeOk || !result.result) break; else results.push_back(*result.result); } - if (uninhabited || !blockedTypes.empty()) - return {{std::nullopt, uninhabited, blockedTypes, {}}}; + if (reductionStatus != Reduction::MaybeOk || !blockedTypes.empty()) + return {{std::nullopt, reductionStatus, blockedTypes, {}}}; if (!results.empty()) { if (results.size() == 1) - return {{results[0], false, {}, {}}}; + return {{results[0], Reduction::MaybeOk, {}, {}}}; TypeId resultTy = ctx->arena->addType(TypeFunctionInstanceType{ NotNull{&builtinTypeFunctions().unionFunc}, @@ -594,7 +602,7 @@ static std::optional> tryDistributeTypeFunct {}, }); - return {{resultTy, false, {}, {}}}; + return {{resultTy, Reduction::MaybeOk, {}, {}}}; } return std::nullopt; @@ -614,13 +622,13 @@ TypeFunctionReductionResult userDefinedTypeFunction( if (typeFunction->userFuncData.owner.expired()) { ctx->ice->ice("user-defined type function module has expired"); - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } if (!typeFunction->userFuncName || !typeFunction->userFuncData.definition) { ctx->ice->ice("all user-defined type functions must have an associated function definition"); - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } } else @@ -628,13 +636,13 @@ TypeFunctionReductionResult userDefinedTypeFunction( if (!ctx->userFuncName) { ctx->ice->ice("all user-defined type functions must have an associated function definition"); - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } } // If type functions cannot be evaluated because of errors in the code, we do not generate any additional ones if (!ctx->typeFunctionRuntime->allowEvaluation) - return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + return {ctx->builtins->errorRecoveryType(), Reduction::MaybeOk, {}, {}}; for (auto typeParam : typeParams) { @@ -642,20 +650,34 @@ TypeFunctionReductionResult userDefinedTypeFunction( // block if we need to if (isPending(ty, ctx->solver)) - return {std::nullopt, false, {ty}, {}}; + return {std::nullopt, Reduction::MaybeOk, {ty}, {}}; } - if (FFlag::LuauUserTypeFunExportedAndLocal) + if (FFlag::LuauUserTypeFunExportedAndLocal && FFlag::LuauUserTypeFunUpdateAllEnvs) { // Ensure that whole type function environment is registered for (auto& [name, definition] : typeFunction->userFuncData.environment) + { + if (std::optional error = ctx->typeFunctionRuntime->registerFunction(definition.first)) + { + // Failure to register at this point means that original definition had to error out and should not have been present in the + // environment + ctx->ice->ice("user-defined type function reference cannot be registered"); + return {std::nullopt, Reduction::Erroneous, {}, {}}; + } + } + } + else if (FFlag::LuauUserTypeFunExportedAndLocal) + { + // Ensure that whole type function environment is registered + for (auto& [name, definition] : typeFunction->userFuncData.environment_DEPRECATED) { if (std::optional error = ctx->typeFunctionRuntime->registerFunction(definition)) { // Failure to register at this point means that original definition had to error out and should not have been present in the // environment ctx->ice->ice("user-defined type function reference cannot be registered"); - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } } } @@ -665,13 +687,65 @@ TypeFunctionReductionResult userDefinedTypeFunction( lua_State* global = ctx->typeFunctionRuntime->state.get(); if (global == nullptr) - return {std::nullopt, true, {}, {}, format("'%s' type function: cannot be evaluated in this context", name.value)}; + return {std::nullopt, Reduction::Erroneous, {}, {}, format("'%s' type function: cannot be evaluated in this context", name.value)}; // Separate sandboxed thread for individual execution and private globals lua_State* L = lua_newthread(global); LuauTempThreadPopper popper(global); - if (FFlag::LuauUserTypeFunExportedAndLocal) + if (FFlag::LuauUserTypeFunExportedAndLocal && FFlag::LuauUserTypeFunUpdateAllEnvs) + { + // Build up the environment table of each function we have visible + for (auto& [_, curr] : typeFunction->userFuncData.environment) + { + // Environment table has to be filled only once in the current execution context + if (ctx->typeFunctionRuntime->initialized.find(curr.first)) + continue; + ctx->typeFunctionRuntime->initialized.insert(curr.first); + + lua_pushlightuserdata(L, curr.first); + lua_gettable(L, LUA_REGISTRYINDEX); + + if (!lua_isfunction(L, -1)) + { + ctx->ice->ice("user-defined type function reference cannot be found in the registry"); + return {std::nullopt, Reduction::Erroneous, {}, {}}; + } + + // Build up the environment of the current function, where some might not be visible + lua_getfenv(L, -1); + lua_setreadonly(L, -1, false); + + for (auto& [name, definition] : typeFunction->userFuncData.environment) + { + // Filter visibility based on original scope depth + if (definition.second >= curr.second) + { + lua_pushlightuserdata(L, definition.first); + lua_gettable(L, LUA_REGISTRYINDEX); + + if (!lua_isfunction(L, -1)) + break; // Don't have to report an error here, we will visit each function in outer loop + + lua_setfield(L, -2, name.c_str()); + } + } + + lua_setreadonly(L, -1, true); + lua_pop(L, 2); + } + + // Fetch the function we want to evaluate + lua_pushlightuserdata(L, typeFunction->userFuncData.definition); + lua_gettable(L, LUA_REGISTRYINDEX); + + if (!lua_isfunction(L, -1)) + { + ctx->ice->ice("user-defined type function reference cannot be found in the registry"); + return {std::nullopt, Reduction::Erroneous, {}, {}}; + } + } + else if (FFlag::LuauUserTypeFunExportedAndLocal) { // Fetch the function we want to evaluate lua_pushlightuserdata(L, typeFunction->userFuncData.definition); @@ -680,14 +754,14 @@ TypeFunctionReductionResult userDefinedTypeFunction( if (!lua_isfunction(L, -1)) { ctx->ice->ice("user-defined type function reference cannot be found in the registry"); - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } // Build up the environment lua_getfenv(L, -1); lua_setreadonly(L, -1, false); - for (auto& [name, definition] : typeFunction->userFuncData.environment) + for (auto& [name, definition] : typeFunction->userFuncData.environment_DEPRECATED) { lua_pushlightuserdata(L, definition); lua_gettable(L, LUA_REGISTRYINDEX); @@ -695,7 +769,7 @@ TypeFunctionReductionResult userDefinedTypeFunction( if (!lua_isfunction(L, -1)) { ctx->ice->ice("user-defined type function reference cannot be found in the registry"); - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } lua_setfield(L, -2, name.c_str()); @@ -710,8 +784,7 @@ TypeFunctionReductionResult userDefinedTypeFunction( lua_xmove(global, L, 1); } - if (FFlag::LuauUserDefinedTypeFunctionResetState) - resetTypeFunctionState(L); + resetTypeFunctionState(L); // Push serialized arguments onto the stack @@ -727,7 +800,7 @@ TypeFunctionReductionResult userDefinedTypeFunction( TypeFunctionTypeId serializedTy = serialize(ty, runtimeBuilder.get()); // Check if there were any errors while serializing if (runtimeBuilder->errors.size() != 0) - return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()}; + return {std::nullopt, Reduction::Erroneous, {}, {}, runtimeBuilder->errors.front()}; allocTypeUserData(L, serializedTy->type); } @@ -743,12 +816,27 @@ TypeFunctionReductionResult userDefinedTypeFunction( throw UserCancelError(ctx->ice->moduleName); }; + if (FFlag::LuauUserTypeFunPrintToError) + ctx->typeFunctionRuntime->messages.clear(); + if (auto error = checkResultForError(L, name.value, lua_pcall(L, int(typeParams.size()), 1, 0))) - return {std::nullopt, true, {}, {}, error}; + { + if (FFlag::LuauUserTypeFunPrintToError) + return {std::nullopt, Reduction::Erroneous, {}, {}, error, ctx->typeFunctionRuntime->messages}; + else + return {std::nullopt, Reduction::Erroneous, {}, {}, error}; + } // If the return value is not a type userdata, return with error message if (!isTypeUserData(L, 1)) - return {std::nullopt, true, {}, {}, format("'%s' type function: returned a non-type value", name.value)}; + { + if (FFlag::LuauUserTypeFunPrintToError) + return { + std::nullopt, Reduction::Erroneous, {}, {}, format("'%s' type function: returned a non-type value", name.value), ctx->typeFunctionRuntime->messages + }; + else + return {std::nullopt, Reduction::Erroneous, {}, {}, format("'%s' type function: returned a non-type value", name.value)}; + } TypeFunctionTypeId retTypeFunctionTypeId = getTypeUserData(L, 1); @@ -759,9 +847,17 @@ TypeFunctionReductionResult userDefinedTypeFunction( // At least 1 error occurred while deserializing if (runtimeBuilder->errors.size() > 0) - return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()}; + { + if (FFlag::LuauUserTypeFunPrintToError) + return {std::nullopt, Reduction::Erroneous, {}, {}, runtimeBuilder->errors.front(), ctx->typeFunctionRuntime->messages}; + else + return {std::nullopt, Reduction::Erroneous, {}, {}, runtimeBuilder->errors.front()}; + } - return {retTypeId, false, {}, {}}; + if (FFlag::LuauUserTypeFunPrintToError) + return {retTypeId, Reduction::MaybeOk, {}, {}, std::nullopt, ctx->typeFunctionRuntime->messages}; + else + return {retTypeId, Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult notTypeFunction( @@ -780,16 +876,16 @@ TypeFunctionReductionResult notTypeFunction( TypeId ty = follow(typeParams.at(0)); if (ty == instance) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; if (isPending(ty, ctx->solver)) - return {std::nullopt, false, {ty}, {}}; + return {std::nullopt, Reduction::MaybeOk, {ty}, {}}; if (auto result = tryDistributeTypeFunctionApp(notTypeFunction, instance, typeParams, packParams, ctx)) return *result; // `not` operates on anything and returns a `boolean` always. - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult lenTypeFunction( @@ -808,19 +904,19 @@ TypeFunctionReductionResult lenTypeFunction( TypeId operandTy = follow(typeParams.at(0)); if (operandTy == instance) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // check to see if the operand type is resolved enough, and wait to reduce if not // the use of `typeFromNormal` later necessitates blocking on local types. if (isPending(operandTy, ctx->solver)) - return {std::nullopt, false, {operandTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {operandTy}, {}}; // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. if (ctx->solver) { std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy, /* avoidSealingTables */ true); if (!maybeGeneralized) - return {std::nullopt, false, {operandTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {operandTy}, {}}; operandTy = *maybeGeneralized; } @@ -829,21 +925,21 @@ TypeFunctionReductionResult lenTypeFunction( // if the type failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normTy || inhabited == NormalizationResult::HitLimits) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if the operand type is error suppressing, we can immediately reduce to `number`. if (normTy->shouldSuppressErrors()) - return {ctx->builtins->numberType, false, {}, {}}; + return {ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; // # always returns a number, even if its operand is never. // if we're checking the length of a string, that works! if (inhabited == NormalizationResult::False || normTy->isSubtypeOfString()) - return {ctx->builtins->numberType, false, {}, {}}; + return {ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; // we use the normalized operand here in case there was an intersection or union. TypeId normalizedOperand = follow(ctx->normalizer->typeFromNormal(*normTy)); if (normTy->hasTopTable() || get(normalizedOperand)) - return {ctx->builtins->numberType, false, {}, {}}; + return {ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; if (auto result = tryDistributeTypeFunctionApp(lenTypeFunction, instance, typeParams, packParams, ctx)) return *result; @@ -854,35 +950,35 @@ TypeFunctionReductionResult lenTypeFunction( std::optional mmType = findMetatableEntry(ctx->builtins, dummy, operandTy, "__len", Location{}); if (!mmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) - return {std::nullopt, false, {*mmType}, {}}; + return {std::nullopt, Reduction::MaybeOk, {*mmType}, {}}; const FunctionType* mmFtv = get(*mmType); if (!mmFtv) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); if (!instantiatedMmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); if (!instantiatedMmFtv) - return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + return {ctx->builtins->errorRecoveryType(), Reduction::MaybeOk, {}, {}}; TypePackId inferredArgPack = ctx->arena->addTypePack({operandTy}); Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) - return {std::nullopt, true, {}, {}}; // occurs check failed + return {std::nullopt, Reduction::Erroneous, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->simplifier, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; // `len` must return a `number`. - return {ctx->builtins->numberType, false, {}, {}}; + return {ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult unmTypeFunction( @@ -901,18 +997,18 @@ TypeFunctionReductionResult unmTypeFunction( TypeId operandTy = follow(typeParams.at(0)); if (operandTy == instance) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // check to see if the operand type is resolved enough, and wait to reduce if not if (isPending(operandTy, ctx->solver)) - return {std::nullopt, false, {operandTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {operandTy}, {}}; // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. if (ctx->solver) { std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy); if (!maybeGeneralized) - return {std::nullopt, false, {operandTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {operandTy}, {}}; operandTy = *maybeGeneralized; } @@ -920,19 +1016,19 @@ TypeFunctionReductionResult unmTypeFunction( // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normTy) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if the operand is error suppressing, we can just go ahead and reduce. if (normTy->shouldSuppressErrors()) - return {operandTy, false, {}, {}}; + return {operandTy, Reduction::MaybeOk, {}, {}}; // if we have a `never`, we can never observe that the operation didn't work. if (is(operandTy)) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // If the type is exactly `number`, we can reduce now. if (normTy->isExactlyNumber()) - return {ctx->builtins->numberType, false, {}, {}}; + return {ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; if (auto result = tryDistributeTypeFunctionApp(unmTypeFunction, instance, typeParams, packParams, ctx)) return *result; @@ -943,37 +1039,37 @@ TypeFunctionReductionResult unmTypeFunction( std::optional mmType = findMetatableEntry(ctx->builtins, dummy, operandTy, "__unm", Location{}); if (!mmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) - return {std::nullopt, false, {*mmType}, {}}; + return {std::nullopt, Reduction::MaybeOk, {*mmType}, {}}; const FunctionType* mmFtv = get(*mmType); if (!mmFtv) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); if (!instantiatedMmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); if (!instantiatedMmFtv) - return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + return {ctx->builtins->errorRecoveryType(), Reduction::MaybeOk, {}, {}}; TypePackId inferredArgPack = ctx->arena->addTypePack({operandTy}); Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) - return {std::nullopt, true, {}, {}}; // occurs check failed + return {std::nullopt, Reduction::Erroneous, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->simplifier, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; if (std::optional ret = first(instantiatedMmFtv->retTypes)) - return {*ret, false, {}, {}}; + return {ret, Reduction::MaybeOk, {}, {}}; else - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } void dummyStateClose(lua_State*) {} @@ -1096,8 +1192,7 @@ void TypeFunctionRuntime::prepareState() registerTypeUserData(L); - if (FFlag::LuauUserTypeFunFixRegister) - registerTypesLibrary(L); + registerTypesLibrary(L); luaL_sandbox(L); luaL_sandboxthread(L); @@ -1107,6 +1202,7 @@ TypeFunctionContext::TypeFunctionContext(NotNull cs, NotNullarena) , builtins(cs->builtinTypes) , scope(scope) + , simplifier(cs->simplifier) , normalizer(cs->normalizer) , typeFunctionRuntime(cs->typeFunctionRuntime) , ice(NotNull{&cs->iceReporter}) @@ -1148,19 +1244,19 @@ TypeFunctionReductionResult numericBinopTypeFunction( // isPending of `lhsTy` or `rhsTy` would return true, even if it cycles. We want a different answer for that. if (lhsTy == instance || rhsTy == instance) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // if we have a `never`, we can never observe that the math operator is unreachable. if (is(lhsTy) || is(rhsTy)) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; const Location location = ctx->constraint ? ctx->constraint->location : Location{}; // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. if (ctx->solver) @@ -1169,9 +1265,9 @@ TypeFunctionReductionResult numericBinopTypeFunction( std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); if (!lhsMaybeGeneralized) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (!rhsMaybeGeneralized) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; lhsTy = *lhsMaybeGeneralized; rhsTy = *rhsMaybeGeneralized; @@ -1183,15 +1279,15 @@ TypeFunctionReductionResult numericBinopTypeFunction( // if either failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normLhsTy || !normRhsTy) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if one of the types is error suppressing, we can reduce to `any` since we should suppress errors in the result of the usage. if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) - return {ctx->builtins->anyType, false, {}, {}}; + return {ctx->builtins->anyType, Reduction::MaybeOk, {}, {}}; // if we're adding two `number` types, the result is `number`. if (normLhsTy->isExactlyNumber() && normRhsTy->isExactlyNumber()) - return {ctx->builtins->numberType, false, {}, {}}; + return {ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; if (auto result = tryDistributeTypeFunctionApp(numericBinopTypeFunction, instance, typeParams, packParams, ctx, metamethod)) return *result; @@ -1209,36 +1305,56 @@ TypeFunctionReductionResult numericBinopTypeFunction( } if (!mmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) - return {std::nullopt, false, {*mmType}, {}}; + return {std::nullopt, Reduction::MaybeOk, {*mmType}, {}}; TypePackId argPack = ctx->arena->addTypePack({lhsTy, rhsTy}); SolveResult solveResult; if (!reversed) solveResult = solveFunctionCall( - ctx->arena, ctx->builtins, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack + ctx->arena, + ctx->builtins, + ctx->simplifier, + ctx->normalizer, + ctx->typeFunctionRuntime, + ctx->ice, + ctx->limits, + ctx->scope, + location, + *mmType, + argPack ); else { TypePack* p = getMutable(argPack); std::swap(p->head.front(), p->head.back()); solveResult = solveFunctionCall( - ctx->arena, ctx->builtins, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack + ctx->arena, + ctx->builtins, + ctx->simplifier, + ctx->normalizer, + ctx->typeFunctionRuntime, + ctx->ice, + ctx->limits, + ctx->scope, + location, + *mmType, + argPack ); } if (!solveResult.typePackId.has_value()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; TypePack extracted = extendTypePack(*ctx->arena, ctx->builtins, *solveResult.typePackId, 1); if (extracted.head.empty()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; - return {extracted.head.front(), false, {}, {}}; + return {extracted.head.front(), Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult addTypeFunction( @@ -1371,13 +1487,13 @@ TypeFunctionReductionResult concatTypeFunction( // isPending of `lhsTy` or `rhsTy` would return true, even if it cycles. We want a different answer for that. if (lhsTy == instance || rhsTy == instance) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. if (ctx->solver) @@ -1386,9 +1502,9 @@ TypeFunctionReductionResult concatTypeFunction( std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); if (!lhsMaybeGeneralized) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (!rhsMaybeGeneralized) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; lhsTy = *lhsMaybeGeneralized; rhsTy = *rhsMaybeGeneralized; @@ -1399,19 +1515,19 @@ TypeFunctionReductionResult concatTypeFunction( // if either failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normLhsTy || !normRhsTy) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if one of the types is error suppressing, we can reduce to `any` since we should suppress errors in the result of the usage. if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) - return {ctx->builtins->anyType, false, {}, {}}; + return {ctx->builtins->anyType, Reduction::MaybeOk, {}, {}}; - // if we have a `never`, we can never observe that the numeric operator didn't work. + // if we have a `never`, we can never observe that the operator didn't work. if (is(lhsTy) || is(rhsTy)) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // if we're concatenating two elements that are either strings or numbers, the result is `string`. if ((normLhsTy->isSubtypeOfString() || normLhsTy->isExactlyNumber()) && (normRhsTy->isSubtypeOfString() || normRhsTy->isExactlyNumber())) - return {ctx->builtins->stringType, false, {}, {}}; + return {ctx->builtins->stringType, Reduction::MaybeOk, {}, {}}; if (auto result = tryDistributeTypeFunctionApp(concatTypeFunction, instance, typeParams, packParams, ctx)) return *result; @@ -1429,23 +1545,23 @@ TypeFunctionReductionResult concatTypeFunction( } if (!mmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) - return {std::nullopt, false, {*mmType}, {}}; + return {std::nullopt, Reduction::MaybeOk, {*mmType}, {}}; const FunctionType* mmFtv = get(*mmType); if (!mmFtv) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); if (!instantiatedMmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); if (!instantiatedMmFtv) - return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + return {ctx->builtins->errorRecoveryType(), Reduction::MaybeOk, {}, {}}; std::vector inferredArgs; if (!reversed) @@ -1456,13 +1572,13 @@ TypeFunctionReductionResult concatTypeFunction( TypePackId inferredArgPack = ctx->arena->addTypePack(std::move(inferredArgs)); Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) - return {std::nullopt, true, {}, {}}; // occurs check failed + return {std::nullopt, Reduction::Erroneous, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->simplifier, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; - return {ctx->builtins->stringType, false, {}, {}}; + return {ctx->builtins->stringType, Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult andTypeFunction( @@ -1483,16 +1599,16 @@ TypeFunctionReductionResult andTypeFunction( // t1 = and ~> lhs if (follow(rhsTy) == instance && lhsTy != rhsTy) - return {lhsTy, false, {}, {}}; + return {lhsTy, Reduction::MaybeOk, {}, {}}; // t1 = and ~> rhs if (follow(lhsTy) == instance && lhsTy != rhsTy) - return {rhsTy, false, {}, {}}; + return {rhsTy, Reduction::MaybeOk, {}, {}}; // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. if (ctx->solver) @@ -1501,9 +1617,9 @@ TypeFunctionReductionResult andTypeFunction( std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); if (!lhsMaybeGeneralized) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (!rhsMaybeGeneralized) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; lhsTy = *lhsMaybeGeneralized; rhsTy = *rhsMaybeGeneralized; @@ -1517,7 +1633,7 @@ TypeFunctionReductionResult andTypeFunction( blockedTypes.push_back(ty); for (auto ty : overallResult.blockedTypes) blockedTypes.push_back(ty); - return {overallResult.result, false, std::move(blockedTypes), {}}; + return {overallResult.result, Reduction::MaybeOk, std::move(blockedTypes), {}}; } TypeFunctionReductionResult orTypeFunction( @@ -1538,16 +1654,16 @@ TypeFunctionReductionResult orTypeFunction( // t1 = or ~> lhs if (follow(rhsTy) == instance && lhsTy != rhsTy) - return {lhsTy, false, {}, {}}; + return {lhsTy, Reduction::MaybeOk, {}, {}}; // t1 = or ~> rhs if (follow(lhsTy) == instance && lhsTy != rhsTy) - return {rhsTy, false, {}, {}}; + return {rhsTy, Reduction::MaybeOk, {}, {}}; // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. if (ctx->solver) @@ -1556,9 +1672,9 @@ TypeFunctionReductionResult orTypeFunction( std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); if (!lhsMaybeGeneralized) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (!rhsMaybeGeneralized) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; lhsTy = *lhsMaybeGeneralized; rhsTy = *rhsMaybeGeneralized; @@ -1572,7 +1688,7 @@ TypeFunctionReductionResult orTypeFunction( blockedTypes.push_back(ty); for (auto ty : overallResult.blockedTypes) blockedTypes.push_back(ty); - return {overallResult.result, false, std::move(blockedTypes), {}}; + return {overallResult.result, Reduction::MaybeOk, std::move(blockedTypes), {}}; } static TypeFunctionReductionResult comparisonTypeFunction( @@ -1594,12 +1710,12 @@ static TypeFunctionReductionResult comparisonTypeFunction( TypeId rhsTy = follow(typeParams.at(1)); if (lhsTy == instance || rhsTy == instance) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // Algebra Reduction Rules for comparison type functions // Note that comparing to never tells you nothing about the other operand @@ -1642,9 +1758,9 @@ static TypeFunctionReductionResult comparisonTypeFunction( std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); if (!lhsMaybeGeneralized) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (!rhsMaybeGeneralized) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; lhsTy = *lhsMaybeGeneralized; rhsTy = *rhsMaybeGeneralized; @@ -1659,23 +1775,23 @@ static TypeFunctionReductionResult comparisonTypeFunction( // if either failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normLhsTy || !normRhsTy || lhsInhabited == NormalizationResult::HitLimits || rhsInhabited == NormalizationResult::HitLimits) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if one of the types is error suppressing, we can just go ahead and reduce. if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; // if we have an uninhabited type (e.g. `never`), we can never observe that the comparison didn't work. if (lhsInhabited == NormalizationResult::False || rhsInhabited == NormalizationResult::False) - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; // If both types are some strict subset of `string`, we can reduce now. if (normLhsTy->isSubtypeOfString() && normRhsTy->isSubtypeOfString()) - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; // If both types are exactly `number`, we can reduce now. if (normLhsTy->isExactlyNumber() && normRhsTy->isExactlyNumber()) - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; if (auto result = tryDistributeTypeFunctionApp(comparisonTypeFunction, instance, typeParams, packParams, ctx, metamethod)) return *result; @@ -1689,34 +1805,34 @@ static TypeFunctionReductionResult comparisonTypeFunction( mmType = findMetatableEntry(ctx->builtins, dummy, rhsTy, metamethod, Location{}); if (!mmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) - return {std::nullopt, false, {*mmType}, {}}; + return {std::nullopt, Reduction::MaybeOk, {*mmType}, {}}; const FunctionType* mmFtv = get(*mmType); if (!mmFtv) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); if (!instantiatedMmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); if (!instantiatedMmFtv) - return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + return {ctx->builtins->errorRecoveryType(), Reduction::MaybeOk, {}, {}}; TypePackId inferredArgPack = ctx->arena->addTypePack({lhsTy, rhsTy}); Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) - return {std::nullopt, true, {}, {}}; // occurs check failed + return {std::nullopt, Reduction::Erroneous, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->simplifier, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult ltTypeFunction( @@ -1769,9 +1885,9 @@ TypeFunctionReductionResult eqTypeFunction( // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. if (ctx->solver) @@ -1780,9 +1896,9 @@ TypeFunctionReductionResult eqTypeFunction( std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); if (!lhsMaybeGeneralized) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (!rhsMaybeGeneralized) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; lhsTy = *lhsMaybeGeneralized; rhsTy = *rhsMaybeGeneralized; @@ -1795,15 +1911,15 @@ TypeFunctionReductionResult eqTypeFunction( // if either failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normLhsTy || !normRhsTy || lhsInhabited == NormalizationResult::HitLimits || rhsInhabited == NormalizationResult::HitLimits) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if one of the types is error suppressing, we can just go ahead and reduce. if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; // if we have a `never`, we can never observe that the comparison didn't work. if (lhsInhabited == NormalizationResult::False || rhsInhabited == NormalizationResult::False) - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. @@ -1818,49 +1934,49 @@ TypeFunctionReductionResult eqTypeFunction( if (!mmType) { if (intersectInhabited == NormalizationResult::True) - return {ctx->builtins->booleanType, false, {}, {}}; // if it's inhabited, everything is okay! + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; // if it's inhabited, everything is okay! // we might be in a case where we still want to accept the comparison... if (intersectInhabited == NormalizationResult::False) { // if they're both subtypes of `string` but have no common intersection, the comparison is allowed but always `false`. if (normLhsTy->isSubtypeOfString() && normRhsTy->isSubtypeOfString()) - return {ctx->builtins->falseType, false, {}, {}}; + return {ctx->builtins->falseType, Reduction::MaybeOk, {}, {}}; // if they're both subtypes of `boolean` but have no common intersection, the comparison is allowed but always `false`. if (normLhsTy->isSubtypeOfBooleans() && normRhsTy->isSubtypeOfBooleans()) - return {ctx->builtins->falseType, false, {}, {}}; + return {ctx->builtins->falseType, Reduction::MaybeOk, {}, {}}; } - return {std::nullopt, true, {}, {}}; // if it's not, then this type function is irreducible! + return {std::nullopt, Reduction::Erroneous, {}, {}}; // if it's not, then this type function is irreducible! } mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) - return {std::nullopt, false, {*mmType}, {}}; + return {std::nullopt, Reduction::MaybeOk, {*mmType}, {}}; const FunctionType* mmFtv = get(*mmType); if (!mmFtv) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); if (!instantiatedMmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); if (!instantiatedMmFtv) - return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + return {ctx->builtins->errorRecoveryType(), Reduction::MaybeOk, {}, {}}; TypePackId inferredArgPack = ctx->arena->addTypePack({lhsTy, rhsTy}); Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) - return {std::nullopt, true, {}, {}}; // occurs check failed + return {std::nullopt, Reduction::Erroneous, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->simplifier, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; } // Collect types that prevent us from reducing a particular refinement. @@ -1905,13 +2021,13 @@ TypeFunctionReductionResult refineTypeFunction( // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(targetTy, ctx->solver)) - return {std::nullopt, false, {targetTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {targetTy}, {}}; else { for (auto t : discriminantTypes) { if (isPending(t, ctx->solver)) - return {std::nullopt, false, {t}, {}}; + return {std::nullopt, Reduction::MaybeOk, {t}, {}}; } } // Refine a target type and a discriminant one at a time. @@ -1940,57 +2056,76 @@ TypeFunctionReductionResult refineTypeFunction( if (!frb.found.empty()) return {nullptr, {frb.found.begin(), frb.found.end()}}; - /* HACK: Refinements sometimes produce a type T & ~any under the assumption - * that ~any is the same as any. This is so so weird, but refinements needs - * some way to say "I may refine this, but I'm not sure." - * - * It does this by refining on a blocked type and deferring the decision - * until it is unblocked. - * - * Refinements also get negated, so we wind up with types like T & ~*blocked* - * - * We need to treat T & ~any as T in this case. - */ - if (auto nt = get(discriminant)) + if (FFlag::DebugLuauEqSatSimplification) { - if (FFlag::LuauRemoveNotAnyHack) + auto simplifyResult = eqSatSimplify(ctx->simplifier, ctx->arena->addType(IntersectionType{{target, discriminant}})); + if (simplifyResult) { - if (get(follow(nt->ty))) - return {target, {}}; + if (ctx->solver) + { + for (TypeId newTf : simplifyResult->newTypeFunctions) + ctx->solver->pushConstraint(ctx->scope, ctx->constraint->location, ReduceConstraint{newTf}); + } + + return {simplifyResult->result, {}}; } else - { - if (get(follow(nt->ty))) - return {target, {}}; - } + return {nullptr, {}}; } - - // If the target type is a table, then simplification already implements the logic to deal with refinements properly since the - // type of the discriminant is guaranteed to only ever be an (arbitrarily-nested) table of a single property type. - if (get(target)) + else { - SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, target, discriminant); - if (!result.blockedTypes.empty()) - return {nullptr, {result.blockedTypes.begin(), result.blockedTypes.end()}}; + /* HACK: Refinements sometimes produce a type T & ~any under the assumption + * that ~any is the same as any. This is so so weird, but refinements needs + * some way to say "I may refine this, but I'm not sure." + * + * It does this by refining on a blocked type and deferring the decision + * until it is unblocked. + * + * Refinements also get negated, so we wind up with types like T & ~*blocked* + * + * We need to treat T & ~any as T in this case. + */ + if (auto nt = get(discriminant)) + { + if (FFlag::LuauRemoveNotAnyHack) + { + if (get(follow(nt->ty))) + return {target, {}}; + } + else + { + if (get(follow(nt->ty))) + return {target, {}}; + } + } - return {result.result, {}}; + // If the target type is a table, then simplification already implements the logic to deal with refinements properly since the + // type of the discriminant is guaranteed to only ever be an (arbitrarily-nested) table of a single property type. + if (get(target)) + { + SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, target, discriminant); + if (!result.blockedTypes.empty()) + return {nullptr, {result.blockedTypes.begin(), result.blockedTypes.end()}}; + + return {result.result, {}}; + } + + // In the general case, we'll still use normalization though. + TypeId intersection = ctx->arena->addType(IntersectionType{{target, discriminant}}); + std::shared_ptr normIntersection = ctx->normalizer->normalize(intersection); + std::shared_ptr normType = ctx->normalizer->normalize(target); + + // if the intersection failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normIntersection || !normType) + return {nullptr, {}}; + + TypeId resultTy = ctx->normalizer->typeFromNormal(*normIntersection); + // include the error type if the target type is error-suppressing and the intersection we computed is not + if (normType->shouldSuppressErrors() && !normIntersection->shouldSuppressErrors()) + resultTy = ctx->arena->addType(UnionType{{resultTy, ctx->builtins->errorType}}); + + return {resultTy, {}}; } - - // In the general case, we'll still use normalization though. - TypeId intersection = ctx->arena->addType(IntersectionType{{target, discriminant}}); - std::shared_ptr normIntersection = ctx->normalizer->normalize(intersection); - std::shared_ptr normType = ctx->normalizer->normalize(target); - - // if the intersection failed to normalize, we can't reduce, but know nothing about inhabitance. - if (!normIntersection || !normType) - return {nullptr, {}}; - - TypeId resultTy = ctx->normalizer->typeFromNormal(*normIntersection); - // include the error type if the target type is error-suppressing and the intersection we computed is not - if (normType->shouldSuppressErrors() && !normIntersection->shouldSuppressErrors()) - resultTy = ctx->arena->addType(UnionType{{resultTy, ctx->builtins->errorType}}); - - return {resultTy, {}}; }; // refine target with each discriminant type in sequence (reverse of insertion order) @@ -2003,15 +2138,15 @@ TypeFunctionReductionResult refineTypeFunction( auto [refined, blocked] = stepRefine(target, discriminant); if (blocked.empty() && refined == nullptr) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; if (!blocked.empty()) - return {std::nullopt, false, blocked, {}}; + return {std::nullopt, Reduction::MaybeOk, blocked, {}}; target = refined; discriminantTypes.pop_back(); } - return {target, false, {}, {}}; + return {target, Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult singletonTypeFunction( @@ -2031,14 +2166,14 @@ TypeFunctionReductionResult singletonTypeFunction( // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(type, ctx->solver)) - return {std::nullopt, false, {type}, {}}; + return {std::nullopt, Reduction::MaybeOk, {type}, {}}; // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. if (ctx->solver) { std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, type); if (!maybeGeneralized) - return {std::nullopt, false, {type}, {}}; + return {std::nullopt, Reduction::MaybeOk, {type}, {}}; type = *maybeGeneralized; } @@ -2049,10 +2184,10 @@ TypeFunctionReductionResult singletonTypeFunction( // if we have a singleton type or `nil`, which is its own singleton type... if (get(followed) || isNil(followed)) - return {type, false, {}, {}}; + return {type, Reduction::MaybeOk, {}, {}}; // otherwise, we'll return the top type, `unknown`. - return {ctx->builtins->unknownType, false, {}, {}}; + return {ctx->builtins->unknownType, Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult unionTypeFunction( @@ -2070,7 +2205,7 @@ TypeFunctionReductionResult unionTypeFunction( // if we only have one parameter, there's nothing to do. if (typeParams.size() == 1) - return {follow(typeParams[0]), false, {}, {}}; + return {follow(typeParams[0]), Reduction::MaybeOk, {}, {}}; // we need to follow all of the type parameters. std::vector types; @@ -2098,12 +2233,12 @@ TypeFunctionReductionResult unionTypeFunction( // if we still have a `lastType` at the end, we're taking the short-circuit and reducing early. if (lastType) - return {lastType, false, {}, {}}; + return {lastType, Reduction::MaybeOk, {}, {}}; // check to see if the operand types are resolved enough, and wait to reduce if not for (auto ty : types) if (isPending(ty, ctx->solver)) - return {std::nullopt, false, {ty}, {}}; + return {std::nullopt, Reduction::MaybeOk, {ty}, {}}; // fold over the types with `simplifyUnion` TypeId resultTy = ctx->builtins->neverType; @@ -2111,12 +2246,12 @@ TypeFunctionReductionResult unionTypeFunction( { SimplifyResult result = simplifyUnion(ctx->builtins, ctx->arena, resultTy, ty); if (!result.blockedTypes.empty()) - return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + return {std::nullopt, Reduction::MaybeOk, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; resultTy = result.result; } - return {resultTy, false, {}, {}}; + return {resultTy, Reduction::MaybeOk, {}, {}}; } @@ -2135,7 +2270,7 @@ TypeFunctionReductionResult intersectTypeFunction( // if we only have one parameter, there's nothing to do. if (typeParams.size() == 1) - return {follow(typeParams[0]), false, {}, {}}; + return {follow(typeParams[0]), Reduction::MaybeOk, {}, {}}; // we need to follow all of the type parameters. std::vector types; @@ -2147,9 +2282,9 @@ TypeFunctionReductionResult intersectTypeFunction( { // if we only have two parameters and one is `*no-refine*`, we're all done. if (types.size() == 2 && get(types[1])) - return {types[0], false, {}, {}}; + return {types[0], Reduction::MaybeOk, {}, {}}; else if (types.size() == 2 && get(types[0])) - return {types[1], false, {}, {}}; + return {types[1], Reduction::MaybeOk, {}, {}}; } // check to see if the operand types are resolved enough, and wait to reduce if not @@ -2157,9 +2292,9 @@ TypeFunctionReductionResult intersectTypeFunction( for (auto ty : types) { if (isPending(ty, ctx->solver)) - return {std::nullopt, false, {ty}, {}}; + return {std::nullopt, Reduction::MaybeOk, {ty}, {}}; else if (get(ty)) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; } // fold over the types with `simplifyIntersection` @@ -2172,7 +2307,7 @@ TypeFunctionReductionResult intersectTypeFunction( SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, resultTy, ty); if (!result.blockedTypes.empty()) - return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + return {std::nullopt, Reduction::MaybeOk, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; resultTy = result.result; } @@ -2183,10 +2318,10 @@ TypeFunctionReductionResult intersectTypeFunction( if (get(resultTy)) { TypeId intersection = ctx->arena->addType(IntersectionType{typeParams}); - return {intersection, false, {}, {}}; + return {intersection, Reduction::MaybeOk, {}, {}}; } - return {resultTy, false, {}, {}}; + return {resultTy, Reduction::MaybeOk, {}, {}}; } // computes the keys of `ty` into `result` @@ -2286,17 +2421,17 @@ TypeFunctionReductionResult keyofFunctionImpl( // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normTy) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if we don't have either just tables or just classes, we've got nothing to get keys of (at least until a future version perhaps adds classes // as well) if (normTy->hasTables() == normTy->hasClasses()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; // this is sort of atrocious, but we're trying to reject any type that has not normalized to a table or a union of tables. if (normTy->hasTops() || normTy->hasBooleans() || normTy->hasErrors() || normTy->hasNils() || normTy->hasNumbers() || normTy->hasStrings() || normTy->hasThreads() || normTy->hasBuffers() || normTy->hasFunctions() || normTy->hasTyvars()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; // we're going to collect the keys in here Set keys{{}}; @@ -2315,7 +2450,7 @@ TypeFunctionReductionResult keyofFunctionImpl( // collect all the properties from the first class type if (!computeKeysOf(*classesIter, keys, seen, isRaw, ctx)) - return {ctx->builtins->stringType, false, {}, {}}; // if it failed, we have a top type! + return {ctx->builtins->stringType, Reduction::MaybeOk, {}, {}}; // if it failed, we have a top type! // we need to look at each class to remove any keys that are not common amongst them all while (++classesIter != classesIterEnd) @@ -2350,7 +2485,7 @@ TypeFunctionReductionResult keyofFunctionImpl( // collect all the properties from the first table type if (!computeKeysOf(*tablesIter, keys, seen, isRaw, ctx)) - return {ctx->builtins->stringType, false, {}, {}}; // if it failed, we have the top table type! + return {ctx->builtins->stringType, Reduction::MaybeOk, {}, {}}; // if it failed, we have the top table type! // we need to look at each tables to remove any keys that are not common amongst them all while (++tablesIter != normTy->tables.end()) @@ -2374,7 +2509,7 @@ TypeFunctionReductionResult keyofFunctionImpl( // if the set of keys is empty, `keyof` is `never` if (keys.empty()) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // everything is validated, we need only construct our big union of singletons now! std::vector singletons; @@ -2387,9 +2522,9 @@ TypeFunctionReductionResult keyofFunctionImpl( // We can take straight take it from the first entry // because it was added into the type arena already. if (singletons.size() == 1) - return {singletons.front(), false, {}, {}}; + return {singletons.front(), Reduction::MaybeOk, {}, {}}; - return {ctx->arena->addType(UnionType{singletons}), false, {}, {}}; + return {ctx->arena->addType(UnionType{singletons}), Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult keyofTypeFunction( @@ -2460,7 +2595,7 @@ bool searchPropsAndIndexer( // index into tbl's indexer if (tblIndexer) { - if (isSubtype(ty, tblIndexer->indexType, ctx->scope, ctx->builtins, *ctx->ice)) + if (isSubtype(ty, tblIndexer->indexType, ctx->scope, ctx->builtins, ctx->simplifier, *ctx->ice)) { TypeId idxResultTy = follow(tblIndexer->indexResultType); @@ -2535,32 +2670,32 @@ TypeFunctionReductionResult indexFunctionImpl( // if the indexee failed to normalize, we can't reduce, but know nothing about inhabitance. if (!indexeeNormTy) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if we don't have either just tables or just classes, we've got nothing to index into if (indexeeNormTy->hasTables() == indexeeNormTy->hasClasses()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; // we're trying to reject any type that has not normalized to a table/class or a union of tables/classes. if (indexeeNormTy->hasTops() || indexeeNormTy->hasBooleans() || indexeeNormTy->hasErrors() || indexeeNormTy->hasNils() || indexeeNormTy->hasNumbers() || indexeeNormTy->hasStrings() || indexeeNormTy->hasThreads() || indexeeNormTy->hasBuffers() || indexeeNormTy->hasFunctions() || indexeeNormTy->hasTyvars()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; TypeId indexerTy = follow(typeParams.at(1)); if (isPending(indexerTy, ctx->solver)) - return {std::nullopt, false, {indexerTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {indexerTy}, {}}; std::shared_ptr indexerNormTy = ctx->normalizer->normalize(indexerTy); // if the indexer failed to normalize, we can't reduce, but know nothing about inhabitance. if (!indexerNormTy) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // we're trying to reject any type that is not a string singleton or primitive (string, number, boolean, thread, nil, function, table, or buffer) if (indexerNormTy->hasTops() || indexerNormTy->hasErrors()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; // indexer can be a union —> break them down into a vector const std::vector* typesToFind = nullptr; @@ -2577,7 +2712,7 @@ TypeFunctionReductionResult indexFunctionImpl( LUAU_ASSERT(!indexeeNormTy->hasTables()); if (isRaw) // rawget should never reduce for classes (to match the behavior of the rawget global function) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; // at least one class is guaranteed to be in the iterator by .hasClasses() for (auto classesIter = indexeeNormTy->classes.ordering.begin(); classesIter != indexeeNormTy->classes.ordering.end(); ++classesIter) @@ -2586,7 +2721,7 @@ TypeFunctionReductionResult indexFunctionImpl( if (!classTy) { LUAU_ASSERT(false); // this should not be possible according to normalization's spec - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } for (TypeId ty : *typesToFind) @@ -2615,10 +2750,10 @@ TypeFunctionReductionResult indexFunctionImpl( ErrorVec dummy; std::optional mmType = findMetatableEntry(ctx->builtins, dummy, *classesIter, "__index", Location{}); if (!mmType) // if a metatable does not exist, there is no where else to look - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; if (!tblIndexInto(ty, *mmType, properties, ctx, isRaw)) // if indexer is not in the metatable, we fail to reduce - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } } } @@ -2632,7 +2767,7 @@ TypeFunctionReductionResult indexFunctionImpl( { for (TypeId ty : *typesToFind) if (!tblIndexInto(ty, *tablesIter, properties, ctx, isRaw)) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } } @@ -2649,9 +2784,9 @@ TypeFunctionReductionResult indexFunctionImpl( // If the type being reduced to is a single type, no need to union if (properties.size() == 1) - return {*properties.begin(), false, {}, {}}; + return {*properties.begin(), Reduction::MaybeOk, {}, {}}; - return {ctx->arena->addType(UnionType{std::vector(properties.begin(), properties.end())}), false, {}, {}}; + return {ctx->arena->addType(UnionType{std::vector(properties.begin(), properties.end())}), Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult indexTypeFunction( diff --git a/Analysis/src/TypeFunctionRuntime.cpp b/Analysis/src/TypeFunctionRuntime.cpp index 8a129462..24c75a51 100644 --- a/Analysis/src/TypeFunctionRuntime.cpp +++ b/Analysis/src/TypeFunctionRuntime.cpp @@ -14,7 +14,7 @@ #include LUAU_DYNAMIC_FASTINT(LuauTypeFunctionSerdeIterationLimit) -LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixRegister) +LUAU_FASTFLAGVARIABLE(LuauUserTypeFunPrintToError) LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixNoReadWrite) LUAU_FASTFLAGVARIABLE(LuauUserTypeFunThreadBuffer) @@ -1408,8 +1408,6 @@ static int isEqualToType(lua_State* L) void registerTypesLibrary(lua_State* L) { - LUAU_ASSERT(FFlag::LuauUserTypeFunFixRegister); - luaL_Reg fields[] = { {"unknown", createUnknown}, {"never", createNever}, @@ -1464,170 +1462,99 @@ static int typeUserdataIndex(lua_State* L) void registerTypeUserData(lua_State* L) { - if (FFlag::LuauUserTypeFunFixRegister) - { - luaL_Reg typeUserdataMethods[] = { - {"is", checkTag}, + luaL_Reg typeUserdataMethods[] = { + {"is", checkTag}, - // Negation type methods - {"inner", getNegatedValue}, + // Negation type methods + {"inner", getNegatedValue}, - // Singleton type methods - {"value", getSingletonValue}, + // Singleton type methods + {"value", getSingletonValue}, - // Table type methods - {"setproperty", setTableProp}, - {"setreadproperty", setReadTableProp}, - {"setwriteproperty", setWriteTableProp}, - {"readproperty", readTableProp}, - {"writeproperty", writeTableProp}, - {"properties", getProps}, - {"setindexer", setTableIndexer}, - {"setreadindexer", setTableReadIndexer}, - {"setwriteindexer", setTableWriteIndexer}, - {"indexer", getIndexer}, - {"readindexer", getReadIndexer}, - {"writeindexer", getWriteIndexer}, - {"setmetatable", setTableMetatable}, - {"metatable", getMetatable}, + // Table type methods + {"setproperty", setTableProp}, + {"setreadproperty", setReadTableProp}, + {"setwriteproperty", setWriteTableProp}, + {"readproperty", readTableProp}, + {"writeproperty", writeTableProp}, + {"properties", getProps}, + {"setindexer", setTableIndexer}, + {"setreadindexer", setTableReadIndexer}, + {"setwriteindexer", setTableWriteIndexer}, + {"indexer", getIndexer}, + {"readindexer", getReadIndexer}, + {"writeindexer", getWriteIndexer}, + {"setmetatable", setTableMetatable}, + {"metatable", getMetatable}, - // Function type methods - {"setparameters", setFunctionParameters}, - {"parameters", getFunctionParameters}, - {"setreturns", setFunctionReturns}, - {"returns", getFunctionReturns}, + // Function type methods + {"setparameters", setFunctionParameters}, + {"parameters", getFunctionParameters}, + {"setreturns", setFunctionReturns}, + {"returns", getFunctionReturns}, - // Union and Intersection type methods - {"components", getComponents}, + // Union and Intersection type methods + {"components", getComponents}, - // Class type methods - {"parent", getClassParent}, + // Class type methods + {"parent", getClassParent}, - {nullptr, nullptr} - }; + {nullptr, nullptr} + }; - // Create and register metatable for type userdata - luaL_newmetatable(L, "type"); + // Create and register metatable for type userdata + luaL_newmetatable(L, "type"); - // Protect metatable from being changed - lua_pushstring(L, "The metatable is locked"); - lua_setfield(L, -2, "__metatable"); + // Protect metatable from being changed + lua_pushstring(L, "The metatable is locked"); + lua_setfield(L, -2, "__metatable"); - lua_pushcfunction(L, isEqualToType, "__eq"); - lua_setfield(L, -2, "__eq"); + lua_pushcfunction(L, isEqualToType, "__eq"); + lua_setfield(L, -2, "__eq"); - // Indexing will be a dynamic function because some type fields are dynamic - lua_newtable(L); - luaL_register(L, nullptr, typeUserdataMethods); - lua_setreadonly(L, -1, true); - lua_pushcclosure(L, typeUserdataIndex, "__index", 1); - lua_setfield(L, -2, "__index"); + // Indexing will be a dynamic function because some type fields are dynamic + lua_newtable(L); + luaL_register(L, nullptr, typeUserdataMethods); + lua_setreadonly(L, -1, true); + lua_pushcclosure(L, typeUserdataIndex, "__index", 1); + lua_setfield(L, -2, "__index"); - lua_setreadonly(L, -1, true); - lua_pop(L, 1); - } - else - { - // List of fields for type userdata - luaL_Reg typeUserdataFields[] = { - {"unknown", createUnknown}, - {"never", createNever}, - {"any", createAny}, - {"boolean", createBoolean}, - {"number", createNumber}, - {"string", createString}, - {nullptr, nullptr} - }; - - // List of methods for type userdata - luaL_Reg typeUserdataMethods[] = { - {"singleton", createSingleton}, - {"negationof", createNegation}, - {"unionof", createUnion}, - {"intersectionof", createIntersection}, - {"newtable", createTable}, - {"newfunction", createFunction}, - {"copy", deepCopy}, - - // Common methods - {"is", checkTag}, - - // Negation type methods - {"inner", getNegatedValue}, - - // Singleton type methods - {"value", getSingletonValue}, - - // Table type methods - {"setproperty", setTableProp}, - {"setreadproperty", setReadTableProp}, - {"setwriteproperty", setWriteTableProp}, - {"readproperty", readTableProp}, - {"writeproperty", writeTableProp}, - {"properties", getProps}, - {"setindexer", setTableIndexer}, - {"setreadindexer", setTableReadIndexer}, - {"setwriteindexer", setTableWriteIndexer}, - {"indexer", getIndexer}, - {"readindexer", getReadIndexer}, - {"writeindexer", getWriteIndexer}, - {"setmetatable", setTableMetatable}, - {"metatable", getMetatable}, - - // Function type methods - {"setparameters", setFunctionParameters}, - {"parameters", getFunctionParameters}, - {"setreturns", setFunctionReturns}, - {"returns", getFunctionReturns}, - - // Union and Intersection type methods - {"components", getComponents}, - - // Class type methods - {"parent", getClassParent}, - {"indexer", getIndexer}, - {nullptr, nullptr} - }; - - // Create and register metatable for type userdata - luaL_newmetatable(L, "type"); - - // Protect metatable from being fetched. - lua_pushstring(L, "The metatable is locked"); - lua_setfield(L, -2, "__metatable"); - - // Set type userdata metatable's __eq to type_equals() - lua_pushcfunction(L, isEqualToType, "__eq"); - lua_setfield(L, -2, "__eq"); - - // Set type userdata metatable's __index to itself - lua_pushvalue(L, -1); // Push a copy of type userdata metatable - lua_setfield(L, -2, "__index"); - - luaL_register(L, nullptr, typeUserdataMethods); - - // Set fields for type userdata - for (luaL_Reg* l = typeUserdataFields; l->name; l++) - { - l->func(L); - lua_setfield(L, -2, l->name); - } - - // Set types library as a global name "types" - lua_setglobal(L, "types"); - } + lua_setreadonly(L, -1, true); + lua_pop(L, 1); // Sets up a destructor for the type userdata. lua_setuserdatadtor(L, kTypeUserdataTag, deallocTypeUserData); } // Used to redirect all the removed global functions to say "this function is unsupported" -int unsupportedFunction(lua_State* L) +static int unsupportedFunction(lua_State* L) { luaL_errorL(L, "this function is not supported in type functions"); return 0; } +static int print(lua_State* L) +{ + std::string result; + + int n = lua_gettop(L); + for (int i = 1; i <= n; i++) + { + size_t l = 0; + const char* s = luaL_tolstring(L, i, &l); // convert to string using __tostring et al + if (i > 1) + result.append('\t', 1); + result.append(s, l); + lua_pop(L, 1); + } + + auto ctx = getTypeFunctionRuntime(L); + + ctx->messages.push_back(std::move(result)); + + return 0; +} + // Add libraries / globals for type function environment void setTypeFunctionEnvironment(lua_State* L) { @@ -1659,12 +1586,28 @@ void setTypeFunctionEnvironment(lua_State* L) luaopen_base(L); lua_pop(L, 1); - // Remove certain global functions from the base library - static const std::string unavailableGlobals[] = {"gcinfo", "getfenv", "newproxy", "setfenv", "pcall", "xpcall"}; - for (auto& name : unavailableGlobals) + if (FFlag::LuauUserTypeFunPrintToError) { - lua_pushcfunction(L, unsupportedFunction, "Removing global function from type function environment"); - lua_setglobal(L, name.c_str()); + // Remove certain global functions from the base library + static const char* unavailableGlobals[] = {"gcinfo", "getfenv", "newproxy", "setfenv", "pcall", "xpcall"}; + for (auto& name : unavailableGlobals) + { + lua_pushcfunction(L, unsupportedFunction, name); + lua_setglobal(L, name); + } + + lua_pushcfunction(L, print, "print"); + lua_setglobal(L, "print"); + } + else + { + // Remove certain global functions from the base library + static const std::string unavailableGlobals[] = {"gcinfo", "getfenv", "newproxy", "setfenv", "pcall", "xpcall"}; + for (auto& name : unavailableGlobals) + { + lua_pushcfunction(L, unsupportedFunction, "Removing global function from type function environment"); + lua_setglobal(L, name.c_str()); + } } } diff --git a/Analysis/src/TypeFunctionRuntimeBuilder.cpp b/Analysis/src/TypeFunctionRuntimeBuilder.cpp index a102e5da..c1ed9ff3 100644 --- a/Analysis/src/TypeFunctionRuntimeBuilder.cpp +++ b/Analysis/src/TypeFunctionRuntimeBuilder.cpp @@ -20,7 +20,6 @@ // currently, controls serialization, deserialization, and `type.copy` LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFunctionSerdeIterationLimit, 100'000); -LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixMetatable) LUAU_FASTFLAG(LuauUserTypeFunThreadBuffer) namespace Luau @@ -383,21 +382,9 @@ private: void serializeChildren(const MetatableType* m1, TypeFunctionTableType* m2) { - if (FFlag::LuauUserTypeFunFixMetatable) - { - // Serialize main part of the metatable immediately - if (auto tableTy = get(m1->table)) - serializeChildren(tableTy, m2); - } - else - { - auto tmpTable = get(shallowSerialize(m1->table)); - if (!tmpTable) - state->ctx->ice->ice("Serializing user defined type function arguments: metatable's table is not a TableType"); - - m2->props = tmpTable->props; - m2->indexer = tmpTable->indexer; - } + // Serialize main part of the metatable immediately + if (auto tableTy = get(m1->table)) + serializeChildren(tableTy, m2); m2->metatable = shallowSerialize(m1->metatable); } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 911d4b5e..1fd8f7ea 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -32,7 +32,7 @@ LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification) LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAGVARIABLE(LuauMetatableFollow) +LUAU_FASTFLAGVARIABLE(LuauOldSolverCreatesChildScopePointers) namespace Luau { @@ -2866,7 +2866,7 @@ TypeId TypeChecker::checkRelationalOperation( std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location, /* addErrors= */ true); if (metamethod) { - if (const FunctionType* ftv = get(FFlag::LuauMetatableFollow ? follow(*metamethod) : *metamethod)) + if (const FunctionType* ftv = get(follow(*metamethod))) { if (isEquality) { @@ -5205,6 +5205,13 @@ LUAU_NOINLINE void TypeChecker::reportErrorCodeTooComplex(const Location& locati ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel) { ScopePtr scope = std::make_shared(parent, subLevel); + if (FFlag::LuauOldSolverCreatesChildScopePointers) + { + scope->location = location; + scope->returnType = parent->returnType; + parent->children.emplace_back(scope.get()); + } + currentModule->scopes.push_back(std::make_pair(location, scope)); return scope; } @@ -5215,6 +5222,12 @@ ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& locatio ScopePtr scope = std::make_shared(parent); scope->level = parent->level; scope->varargPack = parent->varargPack; + if (FFlag::LuauOldSolverCreatesChildScopePointers) + { + scope->location = location; + scope->returnType = parent->returnType; + parent->children.emplace_back(scope.get()); + } currentModule->scopes.push_back(std::make_pair(location, scope)); return scope; diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index f91f6115..3d93cf75 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -230,17 +230,6 @@ private: bool skipComments; bool readNames; - // This offset represents a column offset to be applied to any positions created by the lexer until the next new line. - // For example: - // local x = 4 - // local y = 5 - // If we start lexing from the position of `l` in `local x = 4`, the line number will be 1, and the column will be 4 - // However, because the lexer calculates line offsets by 'index in source buffer where there is a newline', the column - // count will start at 0. For this reason, for just the first line, we'll need to store the offset. - unsigned int lexResumeOffset; - - - enum class BraceType { InterpolatedString, diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 03532e06..86b44044 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -8,7 +8,7 @@ #include -LUAU_FASTFLAGVARIABLE(LexerResumesFromPosition) +LUAU_FASTFLAGVARIABLE(LexerResumesFromPosition2) namespace Luau { @@ -308,16 +308,15 @@ Lexer::Lexer(const char* buffer, size_t bufferSize, AstNameTable& names, Positio : buffer(buffer) , bufferSize(bufferSize) , offset(0) - , line(FFlag::LexerResumesFromPosition ? startPosition.line : 0) - , lineOffset(0) + , line(FFlag::LexerResumesFromPosition2 ? startPosition.line : 0) + , lineOffset(FFlag::LexerResumesFromPosition2 ? 0u - startPosition.column : 0) , lexeme( - (FFlag::LexerResumesFromPosition ? Location(Position(startPosition.line, startPosition.column), 0) : Location(Position(0, 0), 0)), + (FFlag::LexerResumesFromPosition2 ? Location(Position(startPosition.line, startPosition.column), 0) : Location(Position(0, 0), 0)), Lexeme::Eof ) , names(names) , skipComments(false) , readNames(true) - , lexResumeOffset(FFlag::LexerResumesFromPosition ? startPosition.column : 0) { } @@ -372,7 +371,6 @@ Lexeme Lexer::lookahead() Location currentPrevLocation = prevLocation; size_t currentBraceStackSize = braceStack.size(); BraceType currentBraceType = braceStack.empty() ? BraceType::Normal : braceStack.back(); - unsigned int currentLexResumeOffset = lexResumeOffset; Lexeme result = next(); @@ -381,7 +379,6 @@ Lexeme Lexer::lookahead() lineOffset = currentLineOffset; lexeme = currentLexeme; prevLocation = currentPrevLocation; - lexResumeOffset = currentLexResumeOffset; if (braceStack.size() < currentBraceStackSize) braceStack.push_back(currentBraceType); @@ -412,9 +409,10 @@ char Lexer::peekch(unsigned int lookahead) const return (offset + lookahead < bufferSize) ? buffer[offset + lookahead] : 0; } +LUAU_FORCEINLINE Position Lexer::position() const { - return Position(line, offset - lineOffset + (FFlag::LexerResumesFromPosition ? lexResumeOffset : 0)); + return Position(line, offset - lineOffset); } LUAU_FORCEINLINE @@ -433,9 +431,6 @@ void Lexer::consumeAny() { line++; lineOffset = offset + 1; - // every new line, we reset - if (FFlag::LexerResumesFromPosition) - lexResumeOffset = 0; } offset++; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 1a533fa5..e821902e 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -18,10 +18,8 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) // flag so that we don't break production games by reverting syntax changes. // See docs/SyntaxChanges.md for an explanation. LUAU_FASTFLAGVARIABLE(LuauSolverV2) -LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionsSyntax2) LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunParseExport) LUAU_FASTFLAGVARIABLE(LuauAllowFragmentParsing) -LUAU_FASTFLAGVARIABLE(LuauPortableStringZeroCheck) LUAU_FASTFLAGVARIABLE(LuauAllowComplexTypesInGenericParams) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForTableTypes) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForClassNames) @@ -910,11 +908,8 @@ AstStat* Parser::parseReturn() AstStat* Parser::parseTypeAlias(const Location& start, bool exported) { // parsing a type function - if (FFlag::LuauUserDefinedTypeFunctionsSyntax2) - { - if (lexer.current().type == Lexeme::ReservedFunction) - return parseTypeFunction(start, exported); - } + if (lexer.current().type == Lexeme::ReservedFunction) + return parseTypeFunction(start, exported); // parsing a type alias @@ -1134,8 +1129,7 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArraydata, 0, chars->size) != nullptr - : strnlen(chars->data, chars->size) < chars->size); + bool containsNull = chars && (memchr(chars->data, 0, chars->size) != nullptr); if (chars && !containsNull) { @@ -1647,8 +1641,7 @@ AstType* Parser::parseTableType(bool inDeclarationContext) AstType* type = parseType(); // since AstName contains a char*, it can't contain null - bool containsNull = chars && (FFlag::LuauPortableStringZeroCheck ? memchr(chars->data, 0, chars->size) != nullptr - : strnlen(chars->data, chars->size) < chars->size); + bool containsNull = chars && (memchr(chars->data, 0, chars->size) != nullptr); if (chars && !containsNull) props.push_back(AstTableProp{AstName(chars->data), begin.location, type, access, accessLocation}); @@ -2352,11 +2345,8 @@ AstExpr* Parser::parseNameExpr(const char* context) { AstLocal* local = *value; - if (FFlag::LuauUserDefinedTypeFunctionsSyntax2) - { - if (local->functionDepth < typeFunctionDepth) - return reportExprError(lexer.current().location, {}, "Type function cannot reference outer local '%s'", local->name.value); - } + if (local->functionDepth < typeFunctionDepth) + return reportExprError(lexer.current().location, {}, "Type function cannot reference outer local '%s'", local->name.value); return allocator.alloc(name->location, local, local->functionDepth != functionStack.size() - 1); } diff --git a/Compiler/include/Luau/Compiler.h b/Compiler/include/Luau/Compiler.h index b37b58ff..2c82116d 100644 --- a/Compiler/include/Luau/Compiler.h +++ b/Compiler/include/Luau/Compiler.h @@ -13,6 +13,16 @@ struct ParseResult; class BytecodeBuilder; class BytecodeEncoder; +using CompileConstant = void*; + +// return a type identifier for a global library member +// values are defined by 'enum LuauBytecodeType' in Bytecode.h +using LibraryMemberTypeCallback = int (*)(const char* library, const char* member); + +// setup a value of a constant for a global library member +// use setCompileConstant*** set of functions for values +using LibraryMemberConstantCallback = void (*)(const char* library, const char* member, CompileConstant* constant); + // Note: this structure is duplicated in luacode.h, don't forget to change these in sync! struct CompileOptions { @@ -49,6 +59,15 @@ struct CompileOptions // null-terminated array of userdata types that will be included in the type information const char* const* userdataTypes = nullptr; + + // null-terminated array of globals which act as libraries and have members with known type and/or constant value + // when an import of one of these libraries is accessed, callbacks below will be called to receive that information + const char* const* librariesWithKnownMembers = nullptr; + LibraryMemberTypeCallback libraryMemberTypeCb = nullptr; + LibraryMemberConstantCallback libraryMemberConstantCb = nullptr; + + // null-terminated array of library functions that should not be compiled into a built-in fastcall ("name" "lib.name") + const char* const* disabledBuiltins = nullptr; }; class CompileError : public std::exception @@ -81,4 +100,10 @@ std::string compile( BytecodeEncoder* encoder = nullptr ); +void setCompileConstantNil(CompileConstant* constant); +void setCompileConstantBoolean(CompileConstant* constant, bool b); +void setCompileConstantNumber(CompileConstant* constant, double n); +void setCompileConstantVector(CompileConstant* constant, float x, float y, float z, float w); +void setCompileConstantString(CompileConstant* constant, const char* s, size_t l); + } // namespace Luau diff --git a/Compiler/include/luacode.h b/Compiler/include/luacode.h index 1eaf28d4..4445af43 100644 --- a/Compiler/include/luacode.h +++ b/Compiler/include/luacode.h @@ -3,12 +3,21 @@ #include -// Can be used to reconfigure visibility/exports for public APIs +// can be used to reconfigure visibility/exports for public APIs #ifndef LUACODE_API #define LUACODE_API extern #endif typedef struct lua_CompileOptions lua_CompileOptions; +typedef void* lua_CompileConstant; + +// return a type identifier for a global library member +// values are defined by 'enum LuauBytecodeType' in Bytecode.h +typedef int (*lua_LibraryMemberTypeCallback)(const char* library, const char* member); + +// setup a value of a constant for a global library member +// use luau_set_compile_constant_*** set of functions for values +typedef void (*lua_LibraryMemberConstantCallback)(const char* library, const char* member, lua_CompileConstant* constant); struct lua_CompileOptions { @@ -45,7 +54,25 @@ struct lua_CompileOptions // null-terminated array of userdata types that will be included in the type information const char* const* userdataTypes; + + // null-terminated array of globals which act as libraries and have members with known type and/or constant value + // when an import of one of these libraries is accessed, callbacks below will be called to receive that information + const char* const* librariesWithKnownMembers; + lua_LibraryMemberTypeCallback libraryMemberTypeCb; + lua_LibraryMemberConstantCallback libraryMemberConstantCb; + + // null-terminated array of library functions that should not be compiled into a built-in fastcall ("name" "lib.name") + const char* const* disabledBuiltins; }; // compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy LUACODE_API char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize); + +// when libraryMemberConstantCb is called, these methods can be used to set a value of the opaque lua_CompileConstant struct +// vector component 'w' is not visible to VM runtime configured with LUA_VECTOR_SIZE == 3, but can affect constant folding during compilation +// string storage must outlive the invocation of 'luau_compile' which used the callback +LUACODE_API void luau_set_compile_constant_nil(lua_CompileConstant* constant); +LUACODE_API void luau_set_compile_constant_boolean(lua_CompileConstant* constant, int b); +LUACODE_API void luau_set_compile_constant_number(lua_CompileConstant* constant, double n); +LUACODE_API void luau_set_compile_constant_vector(lua_CompileConstant* constant, float x, float y, float z, float w); +LUACODE_API void luau_set_compile_constant_string(lua_CompileConstant* constant, const char* s, size_t l); diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index d5d23629..e8b0cd98 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -3,8 +3,12 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" +#include "Luau/Lexer.h" + +#include LUAU_FASTFLAGVARIABLE(LuauVectorBuiltins) +LUAU_FASTFLAGVARIABLE(LuauCompileDisabledBuiltins) namespace Luau { @@ -270,23 +274,61 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op struct BuiltinVisitor : AstVisitor { DenseHashMap& result; + std::array builtinIsDisabled; const DenseHashMap& globals; const DenseHashMap& variables; const CompileOptions& options; + const AstNameTable& names; BuiltinVisitor( DenseHashMap& result, const DenseHashMap& globals, const DenseHashMap& variables, - const CompileOptions& options + const CompileOptions& options, + const AstNameTable& names ) : result(result) , globals(globals) , variables(variables) , options(options) + , names(names) { + if (FFlag::LuauCompileDisabledBuiltins) + { + builtinIsDisabled.fill(false); + + if (const char* const* ptr = options.disabledBuiltins) + { + for (; *ptr; ++ptr) + { + if (const char* dot = strchr(*ptr, '.')) + { + AstName library = names.getWithType(*ptr, dot - *ptr).first; + AstName name = names.get(dot + 1); + + if (library.value && name.value && getGlobalState(globals, name) == Global::Default) + { + Builtin builtin = Builtin{library, name}; + + if (int bfid = getBuiltinFunctionId(builtin, options); bfid >= 0) + builtinIsDisabled[bfid] = true; + } + } + else + { + if (AstName name = names.get(*ptr); name.value && getGlobalState(globals, name) == Global::Default) + { + Builtin builtin = Builtin{AstName(), name}; + + if (int bfid = getBuiltinFunctionId(builtin, options); bfid >= 0) + builtinIsDisabled[bfid] = true; + } + } + } + } + } } bool visit(AstExprCall* node) override @@ -297,6 +339,9 @@ struct BuiltinVisitor : AstVisitor int bfid = getBuiltinFunctionId(builtin, options); + if (FFlag::LuauCompileDisabledBuiltins && bfid >= 0 && builtinIsDisabled[bfid]) + bfid = -1; + // getBuiltinFunctionId optimistically assumes all select() calls are builtin but actually the second argument must be a vararg if (bfid == LBF_SELECT_VARARG && !(node->args.size == 2 && node->args.data[1]->is())) bfid = -1; @@ -313,10 +358,11 @@ void analyzeBuiltins( const DenseHashMap& globals, const DenseHashMap& variables, const CompileOptions& options, - AstNode* root + AstNode* root, + const AstNameTable& names ) { - BuiltinVisitor visitor{result, globals, variables, options}; + BuiltinVisitor visitor{result, globals, variables, options, names}; root->visit(&visitor); } diff --git a/Compiler/src/Builtins.h b/Compiler/src/Builtins.h index e6427c2a..cef48fa5 100644 --- a/Compiler/src/Builtins.h +++ b/Compiler/src/Builtins.h @@ -41,7 +41,8 @@ void analyzeBuiltins( const DenseHashMap& globals, const DenseHashMap& variables, const CompileOptions& options, - AstNode* root + AstNode* root, + const AstNameTable& names ); struct BuiltinInfo diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 84700177..da945b35 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -27,6 +27,7 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) LUAU_FASTFLAGVARIABLE(LuauCompileOptimizeRevArith) +LUAU_FASTFLAGVARIABLE(LuauCompileLibraryConstants) namespace Luau { @@ -725,7 +726,7 @@ struct Compiler inlineFrames.push_back({func, oldLocals, target, targetCount}); // fold constant values updated above into expressions in the function body - foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, func->body); + foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldLibraryK, options.libraryMemberConstantCb, func->body); bool usedFallthrough = false; @@ -770,7 +771,7 @@ struct Compiler var->type = Constant::Type_Unknown; } - foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, func->body); + foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldLibraryK, options.libraryMemberConstantCb, func->body); } void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false) @@ -3052,7 +3053,7 @@ struct Compiler locstants[var].type = Constant::Type_Number; locstants[var].valueNumber = from + iv * step; - foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, stat); + foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldLibraryK, options.libraryMemberConstantCb, stat); size_t iterJumps = loopJumps.size(); @@ -3080,7 +3081,7 @@ struct Compiler // clean up fold state in case we need to recompile - normally we compile the loop body once, but due to inlining we may need to do it again locstants[var].type = Constant::Type_Unknown; - foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, stat); + foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldLibraryK, options.libraryMemberConstantCb, stat); } void compileStatFor(AstStatFor* stat) @@ -4141,7 +4142,7 @@ struct Compiler BuiltinAstTypes builtinTypes; const DenseHashMap* builtinsFold = nullptr; - bool builtinsFoldMathK = false; + bool builtinsFoldLibraryK = false; // compileFunction state, gets reset for every function unsigned int regTop = 0; @@ -4221,16 +4222,40 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c compiler.builtinsFold = &compiler.builtins; if (AstName math = names.get("math"); math.value && getGlobalState(compiler.globals, math) == Global::Default) - compiler.builtinsFoldMathK = true; + { + compiler.builtinsFoldLibraryK = true; + } + else if (FFlag::LuauCompileLibraryConstants) + { + if (const char* const* ptr = options.librariesWithKnownMembers) + { + for (; *ptr; ++ptr) + { + if (AstName name = names.get(*ptr); name.value && getGlobalState(compiler.globals, name) == Global::Default) + { + compiler.builtinsFoldLibraryK = true; + break; + } + } + } + } } if (options.optimizationLevel >= 1) { // this pass tracks which calls are builtins and can be compiled more efficiently - analyzeBuiltins(compiler.builtins, compiler.globals, compiler.variables, options, root); + analyzeBuiltins(compiler.builtins, compiler.globals, compiler.variables, options, root, names); // this pass analyzes constantness of expressions - foldConstants(compiler.constants, compiler.variables, compiler.locstants, compiler.builtinsFold, compiler.builtinsFoldMathK, root); + foldConstants( + compiler.constants, + compiler.variables, + compiler.locstants, + compiler.builtinsFold, + compiler.builtinsFoldLibraryK, + options.libraryMemberConstantCb, + root + ); // this pass analyzes table assignments to estimate table shapes for initially empty tables predictTableShapes(compiler.tableShapes, root); @@ -4261,6 +4286,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c compiler.builtinTypes, compiler.builtins, compiler.globals, + options.libraryMemberTypeCb, bytecode ); @@ -4277,9 +4303,9 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c AstExprFunction main( root->location, - /*attributes=*/AstArray({nullptr, 0}), - /*generics= */ AstArray(), - /*genericPacks= */ AstArray(), + /* attributes= */ AstArray({nullptr, 0}), + /* generics= */ AstArray(), + /* genericPacks= */ AstArray(), /* self= */ nullptr, AstArray(), /* vararg= */ true, @@ -4340,4 +4366,50 @@ std::string compile(const std::string& source, const CompileOptions& options, co } } +void setCompileConstantNil(CompileConstant* constant) +{ + Compile::Constant* target = reinterpret_cast(constant); + + target->type = Compile::Constant::Type_Nil; +} + +void setCompileConstantBoolean(CompileConstant* constant, bool b) +{ + Compile::Constant* target = reinterpret_cast(constant); + + target->type = Compile::Constant::Type_Boolean; + target->valueBoolean = b; +} + +void setCompileConstantNumber(CompileConstant* constant, double n) +{ + Compile::Constant* target = reinterpret_cast(constant); + + target->type = Compile::Constant::Type_Number; + target->valueNumber = n; +} + +void setCompileConstantVector(CompileConstant* constant, float x, float y, float z, float w) +{ + Compile::Constant* target = reinterpret_cast(constant); + + target->type = Compile::Constant::Type_Vector; + target->valueVector[0] = x; + target->valueVector[1] = y; + target->valueVector[2] = z; + target->valueVector[3] = w; +} + +void setCompileConstantString(CompileConstant* constant, const char* s, size_t l) +{ + Compile::Constant* target = reinterpret_cast(constant); + + if (l > std::numeric_limits::max()) + CompileError::raise({}, "Exceeded custom string constant length limit"); + + target->type = Compile::Constant::Type_String; + target->stringLength = l; + target->valueString = s; +} + } // namespace Luau diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index 2895bf08..24e272d7 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -6,6 +6,9 @@ #include #include +LUAU_FASTFLAG(LuauCompileLibraryConstants) +LUAU_FASTFLAGVARIABLE(LuauVectorFolding) + namespace Luau { namespace Compile @@ -57,6 +60,14 @@ static void foldUnary(Constant& result, AstExprUnary::Op op, const Constant& arg result.type = Constant::Type_Number; result.valueNumber = -arg.valueNumber; } + else if (FFlag::LuauVectorFolding && arg.type == Constant::Type_Vector) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = -arg.valueVector[0]; + result.valueVector[1] = -arg.valueVector[1]; + result.valueVector[2] = -arg.valueVector[2]; + result.valueVector[3] = -arg.valueVector[3]; + } break; case AstExprUnary::Len: @@ -82,6 +93,14 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l result.type = Constant::Type_Number; result.valueNumber = la.valueNumber + ra.valueNumber; } + else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = la.valueVector[0] + ra.valueVector[0]; + result.valueVector[1] = la.valueVector[1] + ra.valueVector[1]; + result.valueVector[2] = la.valueVector[2] + ra.valueVector[2]; + result.valueVector[3] = la.valueVector[3] + ra.valueVector[3]; + } break; case AstExprBinary::Sub: @@ -90,6 +109,14 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l result.type = Constant::Type_Number; result.valueNumber = la.valueNumber - ra.valueNumber; } + else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = la.valueVector[0] - ra.valueVector[0]; + result.valueVector[1] = la.valueVector[1] - ra.valueVector[1]; + result.valueVector[2] = la.valueVector[2] - ra.valueVector[2]; + result.valueVector[3] = la.valueVector[3] - ra.valueVector[3]; + } break; case AstExprBinary::Mul: @@ -98,6 +125,48 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l result.type = Constant::Type_Number; result.valueNumber = la.valueNumber * ra.valueNumber; } + else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector) + { + bool hadW = la.valueVector[3] != 0.0f || ra.valueVector[3] != 0.0f; + float resultW = la.valueVector[3] * ra.valueVector[3]; + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = la.valueVector[0] * ra.valueVector[0]; + result.valueVector[1] = la.valueVector[1] * ra.valueVector[1]; + result.valueVector[2] = la.valueVector[2] * ra.valueVector[2]; + result.valueVector[3] = resultW; + } + } + else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Number && ra.type == Constant::Type_Vector) + { + bool hadW = ra.valueVector[3] != 0.0f; + float resultW = float(la.valueNumber) * ra.valueVector[3]; + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = float(la.valueNumber) * ra.valueVector[0]; + result.valueVector[1] = float(la.valueNumber) * ra.valueVector[1]; + result.valueVector[2] = float(la.valueNumber) * ra.valueVector[2]; + result.valueVector[3] = resultW; + } + } + else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Vector && ra.type == Constant::Type_Number) + { + bool hadW = la.valueVector[3] != 0.0f; + float resultW = la.valueVector[3] * float(ra.valueNumber); + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = la.valueVector[0] * float(ra.valueNumber); + result.valueVector[1] = la.valueVector[1] * float(ra.valueNumber); + result.valueVector[2] = la.valueVector[2] * float(ra.valueNumber); + result.valueVector[3] = resultW; + } + } break; case AstExprBinary::Div: @@ -106,6 +175,48 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l result.type = Constant::Type_Number; result.valueNumber = la.valueNumber / ra.valueNumber; } + else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector) + { + bool hadW = la.valueVector[3] != 0.0f || ra.valueVector[3] != 0.0f; + float resultW = la.valueVector[3] / ra.valueVector[3]; + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = la.valueVector[0] / ra.valueVector[0]; + result.valueVector[1] = la.valueVector[1] / ra.valueVector[1]; + result.valueVector[2] = la.valueVector[2] / ra.valueVector[2]; + result.valueVector[3] = resultW; + } + } + else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Number && ra.type == Constant::Type_Vector) + { + bool hadW = ra.valueVector[3] != 0.0f; + float resultW = float(la.valueNumber) / ra.valueVector[3]; + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = float(la.valueNumber) / ra.valueVector[0]; + result.valueVector[1] = float(la.valueNumber) / ra.valueVector[1]; + result.valueVector[2] = float(la.valueNumber) / ra.valueVector[2]; + result.valueVector[3] = resultW; + } + } + else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Vector && ra.type == Constant::Type_Number) + { + bool hadW = la.valueVector[3] != 0.0f; + float resultW = la.valueVector[3] / float(ra.valueNumber); + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = la.valueVector[0] / float(ra.valueNumber); + result.valueVector[1] = la.valueVector[1] / float(ra.valueNumber); + result.valueVector[2] = la.valueVector[2] / float(ra.valueNumber); + result.valueVector[3] = resultW; + } + } break; case AstExprBinary::FloorDiv: @@ -114,6 +225,48 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l result.type = Constant::Type_Number; result.valueNumber = floor(la.valueNumber / ra.valueNumber); } + else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector) + { + bool hadW = la.valueVector[3] != 0.0f || ra.valueVector[3] != 0.0f; + float resultW = floor(la.valueVector[3] / ra.valueVector[3]); + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = floor(la.valueVector[0] / ra.valueVector[0]); + result.valueVector[1] = floor(la.valueVector[1] / ra.valueVector[1]); + result.valueVector[2] = floor(la.valueVector[2] / ra.valueVector[2]); + result.valueVector[3] = resultW; + } + } + else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Number && ra.type == Constant::Type_Vector) + { + bool hadW = ra.valueVector[3] != 0.0f; + float resultW = floor(float(la.valueNumber) / ra.valueVector[3]); + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = floor(float(la.valueNumber) / ra.valueVector[0]); + result.valueVector[1] = floor(float(la.valueNumber) / ra.valueVector[1]); + result.valueVector[2] = floor(float(la.valueNumber) / ra.valueVector[2]); + result.valueVector[3] = resultW; + } + } + else if (FFlag::LuauVectorFolding && la.type == Constant::Type_Vector && ra.type == Constant::Type_Number) + { + bool hadW = la.valueVector[3] != 0.0f; + float resultW = floor(la.valueVector[3] / float(ra.valueNumber)); + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = floor(la.valueVector[0] / float(ra.valueNumber)); + result.valueVector[1] = floor(la.valueVector[1] / float(ra.valueNumber)); + result.valueVector[2] = floor(la.valueVector[2] / float(ra.valueNumber)); + result.valueVector[3] = floor(la.valueVector[3] / float(ra.valueNumber)); + } + } break; case AstExprBinary::Mod: @@ -209,7 +362,8 @@ struct ConstantVisitor : AstVisitor DenseHashMap& locals; const DenseHashMap* builtins; - bool foldMathK = false; + bool foldLibraryK = false; + LibraryMemberConstantCallback libraryMemberConstantCb; bool wasEmpty = false; @@ -220,13 +374,15 @@ struct ConstantVisitor : AstVisitor DenseHashMap& variables, DenseHashMap& locals, const DenseHashMap* builtins, - bool foldMathK + bool foldLibraryK, + LibraryMemberConstantCallback libraryMemberConstantCb ) : constants(constants) , variables(variables) , locals(locals) , builtins(builtins) - , foldMathK(foldMathK) + , foldLibraryK(foldLibraryK) + , libraryMemberConstantCb(libraryMemberConstantCb) { // since we do a single pass over the tree, if the initial state was empty we don't need to clear out old entries wasEmpty = constants.empty() && locals.empty(); @@ -316,11 +472,26 @@ struct ConstantVisitor : AstVisitor { analyze(expr->expr); - if (foldMathK) + if (foldLibraryK) { - if (AstExprGlobal* eg = expr->expr->as(); eg && eg->name == "math") + if (FFlag::LuauCompileLibraryConstants) { - result = foldBuiltinMath(expr->index); + if (AstExprGlobal* eg = expr->expr->as()) + { + if (eg->name == "math") + result = foldBuiltinMath(expr->index); + + // if we have a custom handler and the constant hasn't been resolved + if (libraryMemberConstantCb && result.type == Constant::Type_Unknown) + libraryMemberConstantCb(eg->name.value, expr->index.value, reinterpret_cast(&result)); + } + } + else + { + if (AstExprGlobal* eg = expr->expr->as(); eg && eg->name == "math") + { + result = foldBuiltinMath(expr->index); + } } } } @@ -468,11 +639,12 @@ void foldConstants( DenseHashMap& variables, DenseHashMap& locals, const DenseHashMap* builtins, - bool foldMathK, + bool foldLibraryK, + LibraryMemberConstantCallback libraryMemberConstantCb, AstNode* root ) { - ConstantVisitor visitor{constants, variables, locals, builtins, foldMathK}; + ConstantVisitor visitor{constants, variables, locals, builtins, foldLibraryK, libraryMemberConstantCb}; root->visit(&visitor); } diff --git a/Compiler/src/ConstantFolding.h b/Compiler/src/ConstantFolding.h index e4eb6428..2653c064 100644 --- a/Compiler/src/ConstantFolding.h +++ b/Compiler/src/ConstantFolding.h @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Compiler.h" + #include "ValueTracking.h" namespace Luau @@ -49,7 +51,8 @@ void foldConstants( DenseHashMap& variables, DenseHashMap& locals, const DenseHashMap* builtins, - bool foldMathK, + bool foldLibraryK, + LibraryMemberConstantCallback libraryMemberConstantCb, AstNode* root ); diff --git a/Compiler/src/Types.cpp b/Compiler/src/Types.cpp index 7f5885a5..02aec11a 100644 --- a/Compiler/src/Types.cpp +++ b/Compiler/src/Types.cpp @@ -4,6 +4,7 @@ #include "Luau/BytecodeBuilder.h" LUAU_FASTFLAGVARIABLE(LuauCompileVectorTypeInfo) +LUAU_FASTFLAG(LuauCompileLibraryConstants) namespace Luau { @@ -175,16 +176,32 @@ static bool isMatchingGlobal(const DenseHashMap& globa return false; } +static bool isMatchingGlobalMember( + const DenseHashMap& globals, + AstExprIndexName* expr, + const char* library, + const char* member +) +{ + LUAU_ASSERT(FFlag::LuauCompileLibraryConstants); + + if (AstExprGlobal* object = expr->expr->as()) + return getGlobalState(globals, object->name) == Compile::Global::Default && object->name == library && expr->index == member; + + return false; +} + struct TypeMapVisitor : AstVisitor { DenseHashMap& functionTypes; DenseHashMap& localTypes; DenseHashMap& exprTypes; - const char* hostVectorType; + const char* hostVectorType = nullptr; const DenseHashMap& userdataTypes; const BuiltinAstTypes& builtinTypes; const DenseHashMap& builtinCalls; const DenseHashMap& globals; + LibraryMemberTypeCallback libraryMemberTypeCb = nullptr; BytecodeBuilder& bytecode; DenseHashMap typeAliases; @@ -201,6 +218,7 @@ struct TypeMapVisitor : AstVisitor const BuiltinAstTypes& builtinTypes, const DenseHashMap& builtinCalls, const DenseHashMap& globals, + LibraryMemberTypeCallback libraryMemberTypeCb, BytecodeBuilder& bytecode ) : functionTypes(functionTypes) @@ -211,6 +229,7 @@ struct TypeMapVisitor : AstVisitor , builtinTypes(builtinTypes) , builtinCalls(builtinCalls) , globals(globals) + , libraryMemberTypeCb(libraryMemberTypeCb) , bytecode(bytecode) , typeAliases(AstName()) , resolvedLocals(nullptr) @@ -461,7 +480,53 @@ struct TypeMapVisitor : AstVisitor if (*typeBcPtr == LBC_TYPE_VECTOR) { if (node->index == "X" || node->index == "Y" || node->index == "Z") + { recordResolvedType(node, &builtinTypes.numberType); + + if (FFlag::LuauCompileLibraryConstants) + return false; + } + } + } + + if (FFlag::LuauCompileLibraryConstants) + { + if (isMatchingGlobalMember(globals, node, "vector", "zero") || isMatchingGlobalMember(globals, node, "vector", "one")) + { + recordResolvedType(node, &builtinTypes.vectorType); + return false; + } + + if (libraryMemberTypeCb) + { + if (AstExprGlobal* object = node->expr->as()) + { + if (LuauBytecodeType ty = LuauBytecodeType(libraryMemberTypeCb(object->name.value, node->index.value)); ty != LBC_TYPE_ANY) + { + // TODO: 'resolvedExprs' is more limited than 'exprTypes' which limits full inference of more complex types that a user + // callback can return + switch (ty) + { + case LBC_TYPE_BOOLEAN: + resolvedExprs[node] = &builtinTypes.booleanType; + break; + case LBC_TYPE_NUMBER: + resolvedExprs[node] = &builtinTypes.numberType; + break; + case LBC_TYPE_STRING: + resolvedExprs[node] = &builtinTypes.stringType; + break; + case LBC_TYPE_VECTOR: + resolvedExprs[node] = &builtinTypes.vectorType; + break; + default: + break; + } + + exprTypes[node] = ty; + return false; + } + } } } @@ -733,10 +798,13 @@ void buildTypeMap( const BuiltinAstTypes& builtinTypes, const DenseHashMap& builtinCalls, const DenseHashMap& globals, + LibraryMemberTypeCallback libraryMemberTypeCb, BytecodeBuilder& bytecode ) { - TypeMapVisitor visitor(functionTypes, localTypes, exprTypes, hostVectorType, userdataTypes, builtinTypes, builtinCalls, globals, bytecode); + TypeMapVisitor visitor( + functionTypes, localTypes, exprTypes, hostVectorType, userdataTypes, builtinTypes, builtinCalls, globals, libraryMemberTypeCb, bytecode + ); root->visit(&visitor); } diff --git a/Compiler/src/Types.h b/Compiler/src/Types.h index 46610db2..e60b3b93 100644 --- a/Compiler/src/Types.h +++ b/Compiler/src/Types.h @@ -3,6 +3,7 @@ #include "Luau/Ast.h" #include "Luau/Bytecode.h" +#include "Luau/Compiler.h" #include "Luau/DenseHash.h" #include "ValueTracking.h" @@ -19,7 +20,7 @@ struct BuiltinAstTypes { } - // AstName use here will not match the AstNameTable, but the was we use them here always force a full string compare + // AstName use here will not match the AstNameTable, but the way we use them here always forces a full string compare AstTypeReference booleanType{{}, std::nullopt, AstName{"boolean"}, std::nullopt, {}}; AstTypeReference numberType{{}, std::nullopt, AstName{"number"}, std::nullopt, {}}; AstTypeReference stringType{{}, std::nullopt, AstName{"string"}, std::nullopt, {}}; @@ -38,6 +39,7 @@ void buildTypeMap( const BuiltinAstTypes& builtinTypes, const DenseHashMap& builtinCalls, const DenseHashMap& globals, + LibraryMemberTypeCallback libraryMemberTypeCb, BytecodeBuilder& bytecode ); diff --git a/Compiler/src/lcode.cpp b/Compiler/src/lcode.cpp index ee150b17..ff2edc3d 100644 --- a/Compiler/src/lcode.cpp +++ b/Compiler/src/lcode.cpp @@ -27,3 +27,28 @@ char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, *outsize = result.size(); return copy; } + +void luau_set_compile_constant_nil(lua_CompileConstant* constant) +{ + Luau::setCompileConstantNil(constant); +} + +void luau_set_compile_constant_boolean(lua_CompileConstant* constant, int b) +{ + Luau::setCompileConstantBoolean(constant, b != 0); +} + +void luau_set_compile_constant_number(lua_CompileConstant* constant, double n) +{ + Luau::setCompileConstantNumber(constant, n); +} + +void luau_set_compile_constant_vector(lua_CompileConstant* constant, float x, float y, float z, float w) +{ + Luau::setCompileConstantVector(constant, x, y, z, w); +} + +void luau_set_compile_constant_string(lua_CompileConstant* constant, const char* s, size_t l) +{ + Luau::setCompileConstantString(constant, s, l); +} diff --git a/Config/include/Luau/Config.h b/Config/include/Luau/Config.h index 3f29a24f..89d018d2 100644 --- a/Config/include/Luau/Config.h +++ b/Config/include/Luau/Config.h @@ -42,6 +42,7 @@ struct Config { std::string value; std::string_view configLocation; + std::string originalCase; // The alias in its original case. }; DenseHashMap aliases{""}; diff --git a/Config/src/Config.cpp b/Config/src/Config.cpp index 15e58e29..5dae6f03 100644 --- a/Config/src/Config.cpp +++ b/Config/src/Config.cpp @@ -26,9 +26,9 @@ Config::Config(const Config& other) , typeErrors(other.typeErrors) , globals(other.globals) { - for (const auto& [alias, aliasInfo] : other.aliases) + for (const auto& [_, aliasInfo] : other.aliases) { - setAlias(alias, aliasInfo.value, std::string(aliasInfo.configLocation)); + setAlias(aliasInfo.originalCase, aliasInfo.value, std::string(aliasInfo.configLocation)); } } @@ -44,8 +44,20 @@ Config& Config::operator=(const Config& other) void Config::setAlias(std::string alias, std::string value, const std::string& configLocation) { - AliasInfo& info = aliases[alias]; + std::string lowercasedAlias = alias; + std::transform( + lowercasedAlias.begin(), + lowercasedAlias.end(), + lowercasedAlias.begin(), + [](unsigned char c) + { + return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; + } + ); + + AliasInfo& info = aliases[lowercasedAlias]; info.value = std::move(value); + info.originalCase = std::move(alias); if (!configLocationCache.contains(configLocation)) configLocationCache[configLocation] = std::make_unique(configLocation); @@ -175,7 +187,7 @@ bool isValidAlias(const std::string& alias) static Error parseAlias( Config& config, - std::string aliasKey, + const std::string& aliasKey, const std::string& aliasValue, const std::optional& aliasOptions ) @@ -183,21 +195,11 @@ static Error parseAlias( if (!isValidAlias(aliasKey)) return Error{"Invalid alias " + aliasKey}; - std::transform( - aliasKey.begin(), - aliasKey.end(), - aliasKey.begin(), - [](unsigned char c) - { - return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; - } - ); - if (!aliasOptions) return Error("Cannot parse aliases without alias options"); if (aliasOptions->overwriteAliases || !config.aliases.contains(aliasKey)) - config.setAlias(std::move(aliasKey), aliasValue, aliasOptions->configLocation); + config.setAlias(aliasKey, aliasValue, aliasOptions->configLocation); return std::nullopt; } diff --git a/EqSat/include/Luau/EGraph.h b/EqSat/include/Luau/EGraph.h index 924da974..e8cc2e35 100644 --- a/EqSat/include/Luau/EGraph.h +++ b/EqSat/include/Luau/EGraph.h @@ -198,8 +198,7 @@ private: { // An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where // canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...). - for (Id& id : enode.mutableOperands()) - id = find(id); + Luau::EqSat::canonicalize(enode, [&](Id id) { return find(id); }); } bool isCanonical(const L& enode) const diff --git a/EqSat/include/Luau/Language.h b/EqSat/include/Luau/Language.h index 56fc7202..c4b60f97 100644 --- a/EqSat/include/Luau/Language.h +++ b/EqSat/include/Luau/Language.h @@ -244,6 +244,9 @@ private: template struct NodeSet { + template + friend void canonicalize(NodeSet& node, Find&& find); + template NodeSet(Args&&... args) : vector{std::forward(args)...} @@ -299,6 +302,9 @@ struct Language final template using WithinDomain = std::disjunction, Ts>...>; + template + friend void canonicalize(Language& enode, Find&& find); + template Language(T&& t, std::enable_if_t::value>* = 0) noexcept : v(std::forward(t)) @@ -382,4 +388,37 @@ private: VariantTy v; }; +template +void canonicalize(Node& node, Find&& find) +{ + // An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where + // canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...). + for (Id& id : node.mutableOperands()) + id = find(id); +} + +// Canonicalizing the Ids in a NodeSet may result in the set decreasing in size. +template +void canonicalize(NodeSet& node, Find&& find) +{ + for (Id& id : node.vector) + id = find(id); + + std::sort(begin(node.vector), end(node.vector)); + auto endIt = std::unique(begin(node.vector), end(node.vector)); + node.vector.erase(endIt, end(node.vector)); +} + +template +void canonicalize(Language& enode, Find&& find) +{ + visit( + [&](auto&& v) + { + Luau::EqSat::canonicalize(v, find); + }, + enode.v + ); +} + } // namespace Luau::EqSat diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index 941137e9..5a372aec 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -6,7 +6,6 @@ #include "lstate.h" #include "lvm.h" -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauCoroCheckStack, false) LUAU_DYNAMIC_FASTFLAG(LuauStackLimit) #define CO_STATUS_ERROR -1 @@ -41,7 +40,7 @@ static int auxresume(lua_State* L, lua_State* co, int narg) luaL_error(L, "too many arguments to resume"); lua_xmove(L, co, narg); } - else if (DFFlag::LuauCoroCheckStack) + else { // coroutine might be completely full already if ((co->top - co->base) > LUAI_MAXCSTACK) diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 6ba758df..513a3a5a 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -14,8 +14,6 @@ #include -LUAU_DYNAMIC_FASTFLAG(LuauCoroCheckStack) - /* * Luau uses an incremental non-generational non-moving mark&sweep garbage collector. * @@ -439,26 +437,13 @@ static void shrinkstack(lua_State* L) if (L->size_ci > LUAI_MAXCALLS) // handling overflow? return; // do not touch the stacks - if (DFFlag::LuauCoroCheckStack) - { - if (3 * size_t(ci_used) < size_t(L->size_ci) && 2 * BASIC_CI_SIZE < L->size_ci) - luaD_reallocCI(L, L->size_ci / 2); // still big enough... - condhardstacktests(luaD_reallocCI(L, ci_used + 1)); + if (3 * size_t(ci_used) < size_t(L->size_ci) && 2 * BASIC_CI_SIZE < L->size_ci) + luaD_reallocCI(L, L->size_ci / 2); // still big enough... + condhardstacktests(luaD_reallocCI(L, ci_used + 1)); - if (3 * size_t(s_used) < size_t(L->stacksize) && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) - luaD_reallocstack(L, L->stacksize / 2); // still big enough... - condhardstacktests(luaD_reallocstack(L, s_used)); - } - else - { - if (3 * ci_used < L->size_ci && 2 * BASIC_CI_SIZE < L->size_ci) - luaD_reallocCI(L, L->size_ci / 2); // still big enough... - condhardstacktests(luaD_reallocCI(L, ci_used + 1)); - - if (3 * s_used < L->stacksize && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) - luaD_reallocstack(L, L->stacksize / 2); // still big enough... - condhardstacktests(luaD_reallocstack(L, s_used)); - } + if (3 * size_t(s_used) < size_t(L->stacksize) && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) + luaD_reallocstack(L, L->stacksize / 2); // still big enough... + condhardstacktests(luaD_reallocstack(L, s_used)); } /* diff --git a/bench/bench_support.lua b/bench/bench_support.lua index da637ac9..b731c2fc 100644 --- a/bench/bench_support.lua +++ b/bench/bench_support.lua @@ -66,7 +66,7 @@ end -- and 'false' otherwise. -- -- Example usage: --- local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +-- local function prequire(name) local success, result = pcall(require, name); return success and result end -- local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") -- function testFunc() -- ... diff --git a/bench/gc/test_BinaryTree.lua b/bench/gc/test_BinaryTree.lua index 36dff9de..b7a36d73 100644 --- a/bench/gc/test_BinaryTree.lua +++ b/bench/gc/test_BinaryTree.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_GC_Boehm_Trees.lua b/bench/gc/test_GC_Boehm_Trees.lua index 8170103d..3a3a3698 100644 --- a/bench/gc/test_GC_Boehm_Trees.lua +++ b/bench/gc/test_GC_Boehm_Trees.lua @@ -1,5 +1,5 @@ --!nonstrict -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local stretchTreeDepth = 18 -- about 16Mb diff --git a/bench/gc/test_GC_Tree_Pruning_Eager.lua b/bench/gc/test_GC_Tree_Pruning_Eager.lua index 38aa7626..7a086254 100644 --- a/bench/gc/test_GC_Tree_Pruning_Eager.lua +++ b/bench/gc/test_GC_Tree_Pruning_Eager.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_GC_Tree_Pruning_Gen.lua b/bench/gc/test_GC_Tree_Pruning_Gen.lua index 85081f70..eb747e77 100644 --- a/bench/gc/test_GC_Tree_Pruning_Gen.lua +++ b/bench/gc/test_GC_Tree_Pruning_Gen.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_GC_Tree_Pruning_Lazy.lua b/bench/gc/test_GC_Tree_Pruning_Lazy.lua index 834ec1ab..16b68083 100644 --- a/bench/gc/test_GC_Tree_Pruning_Lazy.lua +++ b/bench/gc/test_GC_Tree_Pruning_Lazy.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_GC_hashtable_Keyval.lua b/bench/gc/test_GC_hashtable_Keyval.lua index aa7481d3..6e59072c 100644 --- a/bench/gc/test_GC_hashtable_Keyval.lua +++ b/bench/gc/test_GC_hashtable_Keyval.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_LB_mandel.lua b/bench/gc/test_LB_mandel.lua index a8beb4fd..be9977d6 100644 --- a/bench/gc/test_LB_mandel.lua +++ b/bench/gc/test_LB_mandel.lua @@ -21,7 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_LargeTableCtor_array.lua b/bench/gc/test_LargeTableCtor_array.lua index 016dfd2d..35b6f449 100644 --- a/bench/gc/test_LargeTableCtor_array.lua +++ b/bench/gc/test_LargeTableCtor_array.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_LargeTableCtor_hash.lua b/bench/gc/test_LargeTableCtor_hash.lua index c46a7ab4..e2b11b4b 100644 --- a/bench/gc/test_LargeTableCtor_hash.lua +++ b/bench/gc/test_LargeTableCtor_hash.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_Pcall_pcall_yield.lua b/bench/gc/test_Pcall_pcall_yield.lua index ae0a4b46..2ae0baa6 100644 --- a/bench/gc/test_Pcall_pcall_yield.lua +++ b/bench/gc/test_Pcall_pcall_yield.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_SunSpider_3d-raytrace.lua b/bench/gc/test_SunSpider_3d-raytrace.lua index 3c050df7..d8f224c4 100644 --- a/bench/gc/test_SunSpider_3d-raytrace.lua +++ b/bench/gc/test_SunSpider_3d-raytrace.lua @@ -22,7 +22,7 @@ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_TableCreate_nil.lua b/bench/gc/test_TableCreate_nil.lua index 707a2750..546e9d6b 100644 --- a/bench/gc/test_TableCreate_nil.lua +++ b/bench/gc/test_TableCreate_nil.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_TableCreate_number.lua b/bench/gc/test_TableCreate_number.lua index 3e4305bd..fe8437b7 100644 --- a/bench/gc/test_TableCreate_number.lua +++ b/bench/gc/test_TableCreate_number.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_TableCreate_zerofill.lua b/bench/gc/test_TableCreate_zerofill.lua index fed439b4..e2cfda30 100644 --- a/bench/gc/test_TableCreate_zerofill.lua +++ b/bench/gc/test_TableCreate_zerofill.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_TableMarshal_select.lua b/bench/gc/test_TableMarshal_select.lua index 9869da60..df5ebf78 100644 --- a/bench/gc/test_TableMarshal_select.lua +++ b/bench/gc/test_TableMarshal_select.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_TableMarshal_table_pack.lua b/bench/gc/test_TableMarshal_table_pack.lua index 3da855f5..3d0190e7 100644 --- a/bench/gc/test_TableMarshal_table_pack.lua +++ b/bench/gc/test_TableMarshal_table_pack.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_TableMarshal_varargs.lua b/bench/gc/test_TableMarshal_varargs.lua index 64b41b43..b88d8213 100644 --- a/bench/gc/test_TableMarshal_varargs.lua +++ b/bench/gc/test_TableMarshal_varargs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_AbsSum_abs.lua b/bench/micro_tests/test_AbsSum_abs.lua index 7e85646e..ea473556 100644 --- a/bench/micro_tests/test_AbsSum_abs.lua +++ b/bench/micro_tests/test_AbsSum_abs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_AbsSum_and_or.lua b/bench/micro_tests/test_AbsSum_and_or.lua index c6ef3dea..6cd5b4d0 100644 --- a/bench/micro_tests/test_AbsSum_and_or.lua +++ b/bench/micro_tests/test_AbsSum_and_or.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_AbsSum_math_abs.lua b/bench/micro_tests/test_AbsSum_math_abs.lua index e95ea674..e02b710a 100644 --- a/bench/micro_tests/test_AbsSum_math_abs.lua +++ b/bench/micro_tests/test_AbsSum_math_abs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Assert.lua b/bench/micro_tests/test_Assert.lua index 014de8dc..750f411b 100644 --- a/bench/micro_tests/test_Assert.lua +++ b/bench/micro_tests/test_Assert.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Factorial.lua b/bench/micro_tests/test_Factorial.lua index 90cff22a..5dc797ce 100644 --- a/bench/micro_tests/test_Factorial.lua +++ b/bench/micro_tests/test_Factorial.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Failure_pcall_a_bar.lua b/bench/micro_tests/test_Failure_pcall_a_bar.lua index 5b6108ba..95887e58 100644 --- a/bench/micro_tests/test_Failure_pcall_a_bar.lua +++ b/bench/micro_tests/test_Failure_pcall_a_bar.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Failure_pcall_game_Foo.lua b/bench/micro_tests/test_Failure_pcall_game_Foo.lua index 6bd209ae..9966262d 100644 --- a/bench/micro_tests/test_Failure_pcall_game_Foo.lua +++ b/bench/micro_tests/test_Failure_pcall_game_Foo.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Failure_xpcall_a_bar.lua b/bench/micro_tests/test_Failure_xpcall_a_bar.lua index e00a3ca6..44534da4 100644 --- a/bench/micro_tests/test_Failure_xpcall_a_bar.lua +++ b/bench/micro_tests/test_Failure_xpcall_a_bar.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Failure_xpcall_game_Foo.lua b/bench/micro_tests/test_Failure_xpcall_game_Foo.lua index 86dadc90..35659598 100644 --- a/bench/micro_tests/test_Failure_xpcall_game_Foo.lua +++ b/bench/micro_tests/test_Failure_xpcall_game_Foo.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_LargeTableCtor_array.lua b/bench/micro_tests/test_LargeTableCtor_array.lua index 016dfd2d..35b6f449 100644 --- a/bench/micro_tests/test_LargeTableCtor_array.lua +++ b/bench/micro_tests/test_LargeTableCtor_array.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_LargeTableCtor_hash.lua b/bench/micro_tests/test_LargeTableCtor_hash.lua index c46a7ab4..e2b11b4b 100644 --- a/bench/micro_tests/test_LargeTableCtor_hash.lua +++ b/bench/micro_tests/test_LargeTableCtor_hash.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_LargeTableSum_loop_index.lua b/bench/micro_tests/test_LargeTableSum_loop_index.lua index 2aae109e..dd64ca00 100644 --- a/bench/micro_tests/test_LargeTableSum_loop_index.lua +++ b/bench/micro_tests/test_LargeTableSum_loop_index.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_LargeTableSum_loop_ipairs.lua b/bench/micro_tests/test_LargeTableSum_loop_ipairs.lua index 29205e26..54ee888d 100644 --- a/bench/micro_tests/test_LargeTableSum_loop_ipairs.lua +++ b/bench/micro_tests/test_LargeTableSum_loop_ipairs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_LargeTableSum_loop_iter.lua b/bench/micro_tests/test_LargeTableSum_loop_iter.lua index ea2b157c..fb69470f 100644 --- a/bench/micro_tests/test_LargeTableSum_loop_iter.lua +++ b/bench/micro_tests/test_LargeTableSum_loop_iter.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_LargeTableSum_loop_pairs.lua b/bench/micro_tests/test_LargeTableSum_loop_pairs.lua index 8d789fcf..ffe19a20 100644 --- a/bench/micro_tests/test_LargeTableSum_loop_pairs.lua +++ b/bench/micro_tests/test_LargeTableSum_loop_pairs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_MethodCalls.lua b/bench/micro_tests/test_MethodCalls.lua index f8b44527..016a4798 100644 --- a/bench/micro_tests/test_MethodCalls.lua +++ b/bench/micro_tests/test_MethodCalls.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_OOP_constructor.lua b/bench/micro_tests/test_OOP_constructor.lua index 9fec3b67..b1c03dfc 100644 --- a/bench/micro_tests/test_OOP_constructor.lua +++ b/bench/micro_tests/test_OOP_constructor.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_OOP_method_call.lua b/bench/micro_tests/test_OOP_method_call.lua index 1e5249c5..09699acb 100644 --- a/bench/micro_tests/test_OOP_method_call.lua +++ b/bench/micro_tests/test_OOP_method_call.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_OOP_virtual_constructor.lua b/bench/micro_tests/test_OOP_virtual_constructor.lua index df99e13b..68dfba61 100644 --- a/bench/micro_tests/test_OOP_virtual_constructor.lua +++ b/bench/micro_tests/test_OOP_virtual_constructor.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Pcall_call_return.lua b/bench/micro_tests/test_Pcall_call_return.lua index 2a612175..45d8ca58 100644 --- a/bench/micro_tests/test_Pcall_call_return.lua +++ b/bench/micro_tests/test_Pcall_call_return.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Pcall_pcall_return.lua b/bench/micro_tests/test_Pcall_pcall_return.lua index 16bdfdd3..09a032df 100644 --- a/bench/micro_tests/test_Pcall_pcall_return.lua +++ b/bench/micro_tests/test_Pcall_pcall_return.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Pcall_pcall_yield.lua b/bench/micro_tests/test_Pcall_pcall_yield.lua index ae0a4b46..2ae0baa6 100644 --- a/bench/micro_tests/test_Pcall_pcall_yield.lua +++ b/bench/micro_tests/test_Pcall_pcall_yield.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Pcall_xpcall_return.lua b/bench/micro_tests/test_Pcall_xpcall_return.lua index 8ac2f0eb..5fb69f1b 100644 --- a/bench/micro_tests/test_Pcall_xpcall_return.lua +++ b/bench/micro_tests/test_Pcall_xpcall_return.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_SqrtSum_exponent.lua b/bench/micro_tests/test_SqrtSum_exponent.lua index bfd6fd72..1bb6a7d2 100644 --- a/bench/micro_tests/test_SqrtSum_exponent.lua +++ b/bench/micro_tests/test_SqrtSum_exponent.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_SqrtSum_math_sqrt.lua b/bench/micro_tests/test_SqrtSum_math_sqrt.lua index 1e1f42c7..7a280460 100644 --- a/bench/micro_tests/test_SqrtSum_math_sqrt.lua +++ b/bench/micro_tests/test_SqrtSum_math_sqrt.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_SqrtSum_sqrt.lua b/bench/micro_tests/test_SqrtSum_sqrt.lua index 96880e7b..ddcddb9d 100644 --- a/bench/micro_tests/test_SqrtSum_sqrt.lua +++ b/bench/micro_tests/test_SqrtSum_sqrt.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_SqrtSum_sqrt_getfenv.lua b/bench/micro_tests/test_SqrtSum_sqrt_getfenv.lua index 55f29e2e..1dd29776 100644 --- a/bench/micro_tests/test_SqrtSum_sqrt_getfenv.lua +++ b/bench/micro_tests/test_SqrtSum_sqrt_getfenv.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_SqrtSum_sqrt_roundabout.lua b/bench/micro_tests/test_SqrtSum_sqrt_roundabout.lua index bbe48a64..0527ea4d 100644 --- a/bench/micro_tests/test_SqrtSum_sqrt_roundabout.lua +++ b/bench/micro_tests/test_SqrtSum_sqrt_roundabout.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_StringInterp.lua b/bench/micro_tests/test_StringInterp.lua index 55430519..d44f5b07 100644 --- a/bench/micro_tests/test_StringInterp.lua +++ b/bench/micro_tests/test_StringInterp.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") bench.runCode(function() diff --git a/bench/micro_tests/test_TableCreate_nil.lua b/bench/micro_tests/test_TableCreate_nil.lua index 707a2750..546e9d6b 100644 --- a/bench/micro_tests/test_TableCreate_nil.lua +++ b/bench/micro_tests/test_TableCreate_nil.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableCreate_number.lua b/bench/micro_tests/test_TableCreate_number.lua index 3e4305bd..fe8437b7 100644 --- a/bench/micro_tests/test_TableCreate_number.lua +++ b/bench/micro_tests/test_TableCreate_number.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableCreate_zerofill.lua b/bench/micro_tests/test_TableCreate_zerofill.lua index fed439b4..e2cfda30 100644 --- a/bench/micro_tests/test_TableCreate_zerofill.lua +++ b/bench/micro_tests/test_TableCreate_zerofill.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableFind_loop_ipairs.lua b/bench/micro_tests/test_TableFind_loop_ipairs.lua index 46560274..ef7f4c81 100644 --- a/bench/micro_tests/test_TableFind_loop_ipairs.lua +++ b/bench/micro_tests/test_TableFind_loop_ipairs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableFind_table_find.lua b/bench/micro_tests/test_TableFind_table_find.lua index 3f22122f..05882c50 100644 --- a/bench/micro_tests/test_TableFind_table_find.lua +++ b/bench/micro_tests/test_TableFind_table_find.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableInsertion_index_cached.lua b/bench/micro_tests/test_TableInsertion_index_cached.lua index 0c34818f..adb40822 100644 --- a/bench/micro_tests/test_TableInsertion_index_cached.lua +++ b/bench/micro_tests/test_TableInsertion_index_cached.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableInsertion_index_len.lua b/bench/micro_tests/test_TableInsertion_index_len.lua index 120a5e28..797dec80 100644 --- a/bench/micro_tests/test_TableInsertion_index_len.lua +++ b/bench/micro_tests/test_TableInsertion_index_len.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableInsertion_table_insert.lua b/bench/micro_tests/test_TableInsertion_table_insert.lua index 1ad3fe22..632e9080 100644 --- a/bench/micro_tests/test_TableInsertion_table_insert.lua +++ b/bench/micro_tests/test_TableInsertion_table_insert.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableInsertion_table_insert_index.lua b/bench/micro_tests/test_TableInsertion_table_insert_index.lua index 41747139..7b35fe39 100644 --- a/bench/micro_tests/test_TableInsertion_table_insert_index.lua +++ b/bench/micro_tests/test_TableInsertion_table_insert_index.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableIteration.lua b/bench/micro_tests/test_TableIteration.lua index 5f78a48b..2c44f43c 100644 --- a/bench/micro_tests/test_TableIteration.lua +++ b/bench/micro_tests/test_TableIteration.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMarshal_select.lua b/bench/micro_tests/test_TableMarshal_select.lua index 9869da60..df5ebf78 100644 --- a/bench/micro_tests/test_TableMarshal_select.lua +++ b/bench/micro_tests/test_TableMarshal_select.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMarshal_table_pack.lua b/bench/micro_tests/test_TableMarshal_table_pack.lua index 3da855f5..3d0190e7 100644 --- a/bench/micro_tests/test_TableMarshal_table_pack.lua +++ b/bench/micro_tests/test_TableMarshal_table_pack.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMarshal_table_unpack_array.lua b/bench/micro_tests/test_TableMarshal_table_unpack_array.lua index 13d1d1c3..32f2eb9a 100644 --- a/bench/micro_tests/test_TableMarshal_table_unpack_array.lua +++ b/bench/micro_tests/test_TableMarshal_table_unpack_array.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMarshal_table_unpack_range.lua b/bench/micro_tests/test_TableMarshal_table_unpack_range.lua index e3aa68be..fa53a31c 100644 --- a/bench/micro_tests/test_TableMarshal_table_unpack_range.lua +++ b/bench/micro_tests/test_TableMarshal_table_unpack_range.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMarshal_varargs.lua b/bench/micro_tests/test_TableMarshal_varargs.lua index 64b41b43..b88d8213 100644 --- a/bench/micro_tests/test_TableMarshal_varargs.lua +++ b/bench/micro_tests/test_TableMarshal_varargs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMove_empty_table.lua b/bench/micro_tests/test_TableMove_empty_table.lua index 39335564..18737f74 100644 --- a/bench/micro_tests/test_TableMove_empty_table.lua +++ b/bench/micro_tests/test_TableMove_empty_table.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMove_same_table.lua b/bench/micro_tests/test_TableMove_same_table.lua index f62022b1..8fc9fa03 100644 --- a/bench/micro_tests/test_TableMove_same_table.lua +++ b/bench/micro_tests/test_TableMove_same_table.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMove_table_create.lua b/bench/micro_tests/test_TableMove_table_create.lua index f03c4de7..3c0cb9e9 100644 --- a/bench/micro_tests/test_TableMove_table_create.lua +++ b/bench/micro_tests/test_TableMove_table_create.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableRemoval_table_remove.lua b/bench/micro_tests/test_TableRemoval_table_remove.lua index 13410116..3ba3e503 100644 --- a/bench/micro_tests/test_TableRemoval_table_remove.lua +++ b/bench/micro_tests/test_TableRemoval_table_remove.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableSort.lua b/bench/micro_tests/test_TableSort.lua index 502cb2a5..e3276845 100644 --- a/bench/micro_tests/test_TableSort.lua +++ b/bench/micro_tests/test_TableSort.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local arr_months = {"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"} diff --git a/bench/micro_tests/test_ToNumberString.lua b/bench/micro_tests/test_ToNumberString.lua index 842b7c22..cda886c0 100644 --- a/bench/micro_tests/test_ToNumberString.lua +++ b/bench/micro_tests/test_ToNumberString.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") bench.runCode(function() diff --git a/bench/micro_tests/test_UpvalueCapture.lua b/bench/micro_tests/test_UpvalueCapture.lua index 4a2608c4..6c2f2616 100644 --- a/bench/micro_tests/test_UpvalueCapture.lua +++ b/bench/micro_tests/test_UpvalueCapture.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_VariadicSelect.lua b/bench/micro_tests/test_VariadicSelect.lua index 5a62f2d8..9710e237 100644 --- a/bench/micro_tests/test_VariadicSelect.lua +++ b/bench/micro_tests/test_VariadicSelect.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_string_lib.lua b/bench/micro_tests/test_string_lib.lua index 041f5b15..5f180151 100644 --- a/bench/micro_tests/test_string_lib.lua +++ b/bench/micro_tests/test_string_lib.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") bench.runCode(function() diff --git a/bench/micro_tests/test_table_concat.lua b/bench/micro_tests/test_table_concat.lua index 590b7d4a..879b63fe 100644 --- a/bench/micro_tests/test_table_concat.lua +++ b/bench/micro_tests/test_table_concat.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") bench.runCode(function() diff --git a/bench/tests/base64.lua b/bench/tests/base64.lua index e580c595..13bfd070 100644 --- a/bench/tests/base64.lua +++ b/bench/tests/base64.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/chess.lua b/bench/tests/chess.lua index f551139e..7e6c9c0c 100644 --- a/bench/tests/chess.lua +++ b/bench/tests/chess.lua @@ -1,5 +1,5 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local RANKS = "12345678" diff --git a/bench/tests/life.lua b/bench/tests/life.lua index d050b013..a61730aa 100644 --- a/bench/tests/life.lua +++ b/bench/tests/life.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/matrixmult.lua b/bench/tests/matrixmult.lua index af38cb64..fa04b864 100644 --- a/bench/tests/matrixmult.lua +++ b/bench/tests/matrixmult.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local function mmul(matrix1, matrix2) diff --git a/bench/tests/mesh-normal-scalar.lua b/bench/tests/mesh-normal-scalar.lua index 05bef373..509e1e62 100644 --- a/bench/tests/mesh-normal-scalar.lua +++ b/bench/tests/mesh-normal-scalar.lua @@ -1,5 +1,5 @@ --!strict -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/mesh-normal-vector.lua b/bench/tests/mesh-normal-vector.lua index bfc0f1c7..ff4f2b46 100644 --- a/bench/tests/mesh-normal-vector.lua +++ b/bench/tests/mesh-normal-vector.lua @@ -1,5 +1,5 @@ --!strict -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/pcmmix.lua b/bench/tests/pcmmix.lua index c98cee2c..1e8e27a5 100644 --- a/bench/tests/pcmmix.lua +++ b/bench/tests/pcmmix.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local samples = 100_000 diff --git a/bench/tests/qsort.lua b/bench/tests/qsort.lua index 566c1b98..37413fa2 100644 --- a/bench/tests/qsort.lua +++ b/bench/tests/qsort.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/sha256.lua b/bench/tests/sha256.lua index 2ac0ab33..e478e763 100644 --- a/bench/tests/sha256.lua +++ b/bench/tests/sha256.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/shootout/ack.lua b/bench/tests/shootout/ack.lua index f7fd43a8..ca8913ac 100644 --- a/bench/tests/shootout/ack.lua +++ b/bench/tests/shootout/ack.lua @@ -23,7 +23,7 @@ SOFTWARE. ]] -- http://www.bagley.org/~doug/shootout/ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/binary-trees.lua b/bench/tests/shootout/binary-trees.lua index 89c5933c..50d40597 100644 --- a/bench/tests/shootout/binary-trees.lua +++ b/bench/tests/shootout/binary-trees.lua @@ -25,7 +25,7 @@ SOFTWARE. -- http://benchmarksgame.alioth.debian.org/ -- contributed by Mike Pall -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/fannkuch-redux.lua b/bench/tests/shootout/fannkuch-redux.lua index 43bc9e41..60f7c3c0 100644 --- a/bench/tests/shootout/fannkuch-redux.lua +++ b/bench/tests/shootout/fannkuch-redux.lua @@ -25,7 +25,7 @@ SOFTWARE. -- http://benchmarksgame.alioth.debian.org/ -- contributed by Mike Pall -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/fixpoint-fact.lua b/bench/tests/shootout/fixpoint-fact.lua index 112acb4a..226c78a8 100644 --- a/bench/tests/shootout/fixpoint-fact.lua +++ b/bench/tests/shootout/fixpoint-fact.lua @@ -21,7 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/heapsort.lua b/bench/tests/shootout/heapsort.lua index 0daf97ab..69c1b885 100644 --- a/bench/tests/shootout/heapsort.lua +++ b/bench/tests/shootout/heapsort.lua @@ -21,7 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/mandel.lua b/bench/tests/shootout/mandel.lua index a3bbb7e5..547741e6 100644 --- a/bench/tests/shootout/mandel.lua +++ b/bench/tests/shootout/mandel.lua @@ -21,7 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/n-body.lua b/bench/tests/shootout/n-body.lua index e0f9c63c..082b7fa0 100644 --- a/bench/tests/shootout/n-body.lua +++ b/bench/tests/shootout/n-body.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/qt.lua b/bench/tests/shootout/qt.lua index d9b4a517..c15accd0 100644 --- a/bench/tests/shootout/qt.lua +++ b/bench/tests/shootout/qt.lua @@ -23,7 +23,7 @@ SOFTWARE. ]] -- Julia sets via interval cell-mapping (quadtree version) -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/queen.lua b/bench/tests/shootout/queen.lua index c3508d60..8f27e06f 100644 --- a/bench/tests/shootout/queen.lua +++ b/bench/tests/shootout/queen.lua @@ -21,7 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/scimark.lua b/bench/tests/shootout/scimark.lua index 1b66df53..dd7cae53 100644 --- a/bench/tests/shootout/scimark.lua +++ b/bench/tests/shootout/scimark.lua @@ -33,7 +33,7 @@ -- Modification to be compatible with Lua 5.3 ------------------------------------------------------------------------------ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/spectral-norm.lua b/bench/tests/shootout/spectral-norm.lua index b5116612..f1acd34c 100644 --- a/bench/tests/shootout/spectral-norm.lua +++ b/bench/tests/shootout/spectral-norm.lua @@ -25,7 +25,7 @@ SOFTWARE. -- http://benchmarksgame.alioth.debian.org/ -- contributed by Mike Pall -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sieve.lua b/bench/tests/sieve.lua index 1bb45d99..8d8cf82a 100644 --- a/bench/tests/sieve.lua +++ b/bench/tests/sieve.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/sunspider/3d-cube.lua b/bench/tests/sunspider/3d-cube.lua index aac7a156..ea132463 100644 --- a/bench/tests/sunspider/3d-cube.lua +++ b/bench/tests/sunspider/3d-cube.lua @@ -2,7 +2,7 @@ -- http://www.speich.net/computer/moztesting/3d.htm -- Created by Simon Speich -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/3d-morph.lua b/bench/tests/sunspider/3d-morph.lua index 8263f015..0dbf1c63 100644 --- a/bench/tests/sunspider/3d-morph.lua +++ b/bench/tests/sunspider/3d-morph.lua @@ -23,7 +23,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/3d-raytrace.lua b/bench/tests/sunspider/3d-raytrace.lua index 33d464b8..83ca7bd9 100644 --- a/bench/tests/sunspider/3d-raytrace.lua +++ b/bench/tests/sunspider/3d-raytrace.lua @@ -22,7 +22,7 @@ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/controlflow-recursive.lua b/bench/tests/sunspider/controlflow-recursive.lua index 1c78a3c2..67c77293 100644 --- a/bench/tests/sunspider/controlflow-recursive.lua +++ b/bench/tests/sunspider/controlflow-recursive.lua @@ -3,7 +3,7 @@ http://shootout.alioth.debian.org/ contributed by Isaac Gouy ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/crypto-aes.lua b/bench/tests/sunspider/crypto-aes.lua index 9692cf52..6b23719b 100644 --- a/bench/tests/sunspider/crypto-aes.lua +++ b/bench/tests/sunspider/crypto-aes.lua @@ -9,7 +9,7 @@ * returns byte-array encrypted value (16 bytes) */]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") -- Sbox is pre-computed multiplicative inverse in GF(2^8) used in SubBytes and KeyExpansion [§5.1.1] diff --git a/bench/tests/sunspider/fannkuch.lua b/bench/tests/sunspider/fannkuch.lua index 08cdcc24..24098740 100644 --- a/bench/tests/sunspider/fannkuch.lua +++ b/bench/tests/sunspider/fannkuch.lua @@ -3,7 +3,7 @@ http://shootout.alioth.debian.org/ contributed by Isaac Gouy ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/math-cordic.lua b/bench/tests/sunspider/math-cordic.lua index 2b622377..861cc51a 100644 --- a/bench/tests/sunspider/math-cordic.lua +++ b/bench/tests/sunspider/math-cordic.lua @@ -23,7 +23,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ]] - local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end + local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/math-partial-sums.lua b/bench/tests/sunspider/math-partial-sums.lua index f0b4b0b7..21f63295 100644 --- a/bench/tests/sunspider/math-partial-sums.lua +++ b/bench/tests/sunspider/math-partial-sums.lua @@ -3,7 +3,7 @@ http://shootout.alioth.debian.org/ contributed by Isaac Gouy ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/n-body-oop.lua b/bench/tests/sunspider/n-body-oop.lua index e04286c8..469e22c1 100644 --- a/bench/tests/sunspider/n-body-oop.lua +++ b/bench/tests/sunspider/n-body-oop.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") local PI = 3.141592653589793 diff --git a/bench/tests/tictactoe.lua b/bench/tests/tictactoe.lua index 673dcd48..bc3282a0 100644 --- a/bench/tests/tictactoe.lua +++ b/bench/tests/tictactoe.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/trig.lua b/bench/tests/trig.lua index 64bf611c..269fd610 100644 --- a/bench/tests/trig.lua +++ b/bench/tests/trig.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/voxelgen.lua b/bench/tests/voxelgen.lua index b50a4592..813838c1 100644 --- a/bench/tests/voxelgen.lua +++ b/bench/tests/voxelgen.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") -- Based on voxel terrain generator by Stickmasterluke diff --git a/tests/AnyTypeSummary.test.cpp b/tests/AnyTypeSummary.test.cpp index 5c3b4aa3..74996071 100644 --- a/tests/AnyTypeSummary.test.cpp +++ b/tests/AnyTypeSummary.test.cpp @@ -111,7 +111,7 @@ TEST_CASE_FIXTURE(ATSFixture, "typepacks_no_ret") -- TODO: if partially typed, we'd want to know too local function fallible(t: number) if t > 0 then - return true, t + return true, t end return false, "must be positive" end @@ -911,7 +911,7 @@ TEST_CASE_FIXTURE(ATSFixture, "type_alias_any") fileResolver.source["game/Gui/Modules/A"] = R"( type Clear = any - local z: Clear = "zip" + local z: Clear = "zip" )"; CheckResult result1 = frontend.check("game/Gui/Modules/A"); @@ -938,7 +938,7 @@ TEST_CASE_FIXTURE(ATSFixture, "multi_module_any") fileResolver.source["game/B"] = R"( local MyFunc = require(script.Parent.A) type Clear = any - local z: Clear = "zip" + local z: Clear = "zip" )"; fileResolver.source["game/Gui/Modules/A"] = R"( @@ -972,7 +972,7 @@ TEST_CASE_FIXTURE(ATSFixture, "cast_on_cyclic_req") fileResolver.source["game/B"] = R"( local MyFunc = require(script.Parent.A) :: any type Clear = any - local z: Clear = "zip" + local z: Clear = "zip" )"; CheckResult result = frontend.check("game/B"); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index b062cbfe..56201b32 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -3,6 +3,8 @@ #include "Luau/BytecodeBuilder.h" #include "Luau/StringUtils.h" +#include "luacode.h" + #include "ScopedFlags.h" #include "doctest.h" @@ -21,12 +23,52 @@ LUAU_FASTINT(LuauCompileInlineThresholdMaxBoost) LUAU_FASTINT(LuauCompileLoopUnrollThreshold) LUAU_FASTINT(LuauCompileLoopUnrollThresholdMaxBoost) LUAU_FASTINT(LuauRecursionLimit) -LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) LUAU_FASTFLAG(LuauCompileVectorTypeInfo) LUAU_FASTFLAG(LuauCompileOptimizeRevArith) +LUAU_FASTFLAG(LuauCompileLibraryConstants) +LUAU_FASTFLAG(LuauVectorBuiltins) +LUAU_FASTFLAG(LuauVectorFolding) +LUAU_FASTFLAG(LuauCompileDisabledBuiltins) using namespace Luau; +static void luauLibraryConstantLookup(const char* library, const char* member, Luau::CompileConstant* constant) +{ + // While 'vector' is built-in, because of LUA_VECTOR_SIZE VM configuration, compiler cannot provide the right default by itself + if (strcmp(library, "vector") == 0) + { + if (strcmp(member, "zero") == 0) + return Luau::setCompileConstantVector(constant, 0.0f, 0.0f, 0.0f, 0.0f); + + if (strcmp(member, "one") == 0) + return Luau::setCompileConstantVector(constant, 1.0f, 1.0f, 1.0f, 0.0f); + } + + if (strcmp(library, "Vector3") == 0) + { + if (strcmp(member, "one") == 0) + return Luau::setCompileConstantVector(constant, 1.0f, 1.0f, 1.0f, 0.0f); + + if (strcmp(member, "xAxis") == 0) + return Luau::setCompileConstantVector(constant, 1.0f, 0.0f, 0.0f, 0.0f); + } + + if (strcmp(library, "test") == 0) + { + if (strcmp(member, "some_nil") == 0) + return Luau::setCompileConstantNil(constant); + + if (strcmp(member, "some_boolean") == 0) + return Luau::setCompileConstantBoolean(constant, true); + + if (strcmp(member, "some_number") == 0) + return Luau::setCompileConstantNumber(constant, 4.75); + + if (strcmp(member, "some_string") == 0) + return Luau::setCompileConstantString(constant, "test", 4); + } +} + static std::string compileFunction(const char* source, uint32_t id, int optimizationLevel = 1, int typeInfoLevel = 0, bool enableVectors = false) { Luau::BytecodeBuilder bcb; @@ -39,6 +81,12 @@ static std::string compileFunction(const char* source, uint32_t id, int optimiza options.vectorLib = "Vector3"; options.vectorCtor = "new"; } + + static const char* kLibrariesWithConstants[] = {"vector", "Vector3", "test", nullptr}; + options.librariesWithKnownMembers = kLibrariesWithConstants; + + options.libraryMemberConstantCb = luauLibraryConstantLookup; + Luau::compileOrThrow(bcb, source, options); return bcb.dumpFunction(id); @@ -1442,6 +1490,131 @@ RETURN R0 1 )"); } +TEST_CASE("ConstantFoldVectorArith") +{ + ScopedFastFlag luauVectorBuiltins{FFlag::LuauVectorBuiltins, true}; + ScopedFastFlag luauVectorFolding{FFlag::LuauVectorFolding, true}; + + CHECK_EQ("\n" + compileFunction("local n = 2; local a, b = vector.create(1, 2, 3), vector.create(2, 4, 8); return a + b", 0, 2), R"( +LOADK R0 K0 [3, 6, 11] +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction("local n = 2; local a, b = vector.create(1, 2, 3), vector.create(2, 4, 8); return a - b", 0, 2), R"( +LOADK R0 K0 [-1, -2, -5] +RETURN R0 1 +)"); + + // Multiplication by infinity cannot be folded as it creates a non-zero value in W + CHECK_EQ( + "\n" + compileFunction( + "local n = 2; local a, b = vector.create(1, 2, 3), vector.create(2, 4, 8); return a * n, a * b, n * b, a * math.huge", 0, 2 + ), + R"( +LOADK R0 K0 [2, 4, 6] +LOADK R1 K1 [2, 8, 24] +LOADK R2 K2 [4, 8, 16] +LOADK R4 K4 [1, 2, 3] +MULK R3 R4 K3 [inf] +RETURN R0 4 +)" + ); + + // Divisions creating an infinity in W cannot be constant-folded + CHECK_EQ( + "\n" + compileFunction( + "local n = 2; local a, b = vector.create(1, 2, 3), vector.create(2, 4, 8); return a / n, a / b, n / b, a / math.huge", 0, 2 + ), + R"( +LOADK R0 K0 [0.5, 1, 1.5] +LOADK R2 K1 [1, 2, 3] +LOADK R3 K2 [2, 4, 8] +DIV R1 R2 R3 +LOADK R3 K2 [2, 4, 8] +DIVRK R2 K3 [2] R3 +LOADK R3 K4 [0, 0, 0] +RETURN R0 4 +)" + ); + + // Divisions creating an infinity in W cannot be constant-folded + CHECK_EQ( + "\n" + compileFunction("local n = 2; local a, b = vector.create(1, 2, 3), vector.create(2, 4, 8); return a // n, a // b, n // b", 0, 2), + R"( +LOADK R0 K0 [0, 1, 1] +LOADK R2 K1 [1, 2, 3] +LOADK R3 K2 [2, 4, 8] +IDIV R1 R2 R3 +LOADN R3 2 +LOADK R4 K2 [2, 4, 8] +IDIV R2 R3 R4 +RETURN R0 3 +)" + ); + + CHECK_EQ("\n" + compileFunction("local a = vector.create(1, 2, 3); return -a", 0, 2), R"( +LOADK R0 K0 [-1, -2, -3] +RETURN R0 1 +)"); +} + +TEST_CASE("ConstantFoldVectorArith4Wide") +{ + ScopedFastFlag luauVectorBuiltins{FFlag::LuauVectorBuiltins, true}; + ScopedFastFlag luauVectorFolding{FFlag::LuauVectorFolding, true}; + + CHECK_EQ("\n" + compileFunction("local n = 2; local a, b = vector.create(1, 2, 3, 4), vector.create(2, 4, 8, 1); return a + b", 0, 2), R"( +LOADK R0 K0 [3, 6, 11, 5] +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction("local n = 2; local a, b = vector.create(1, 2, 3, 4), vector.create(2, 4, 8, 1); return a - b", 0, 2), R"( +LOADK R0 K0 [-1, -2, -5, 3] +RETURN R0 1 +)"); + + CHECK_EQ( + "\n" + compileFunction( + "local n = 2; local a, b = vector.create(1, 2, 3, 4), vector.create(2, 4, 8, 1); return a * n, a * b, n * b, a * math.huge", 0, 2 + ), + R"( +LOADK R0 K0 [2, 4, 6, 8] +LOADK R1 K1 [2, 8, 24, 4] +LOADK R2 K2 [4, 8, 16, 2] +LOADK R3 K3 [inf, inf, inf, inf] +RETURN R0 4 +)" + ); + + CHECK_EQ( + "\n" + compileFunction( + "local n = 2; local a, b = vector.create(1, 2, 3, 4), vector.create(2, 4, 8, 1); return a / n, a / b, n / b, a / math.huge", 0, 2 + ), + R"( +LOADK R0 K0 [0.5, 1, 1.5, 2] +LOADK R1 K1 [0.5, 0.5, 0.375, 4] +LOADK R2 K2 [1, 0.5, 0.25, 2] +LOADK R3 K3 [0, 0, 0] +RETURN R0 4 +)" + ); + + CHECK_EQ( + "\n" + compileFunction("local n = 2; local a, b = vector.create(1, 2, 3, 4), vector.create(2, 4, 8, 1); return a // n, a // b, n // b", 0, 2), + R"( +LOADK R0 K0 [0, 1, 1, 2] +LOADK R1 K1 [0, 0, 0, 4] +LOADK R2 K2 [1, 0, 0, 2] +RETURN R0 3 +)" + ); + + CHECK_EQ("\n" + compileFunction("local a = vector.create(1, 2, 3, 4); return -a", 0, 2), R"( +LOADK R0 K0 [-1, -2, -3, -4] +RETURN R0 1 +)"); +} + TEST_CASE("ConstantFoldStringLen") { CHECK_EQ("\n" + compileFunction0("return #'string', #'', #'a', #('b')"), R"( @@ -2804,8 +2977,6 @@ TEST_CASE("TypeAliasing") TEST_CASE("TypeFunction") { - ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - Luau::BytecodeBuilder bcb; Luau::CompileOptions options; Luau::ParseOptions parseOptions; @@ -4964,6 +5135,42 @@ RETURN R0 1 )"); } +TEST_CASE("VectorConstantFields") +{ + ScopedFastFlag luauVectorBuiltins{FFlag::LuauVectorBuiltins, true}; + ScopedFastFlag luauCompileLibraryConstants{FFlag::LuauCompileLibraryConstants, true}; + + CHECK_EQ("\n" + compileFunction("return vector.one, vector.zero", 0, 2), R"( +LOADK R0 K0 [1, 1, 1] +LOADK R1 K1 [0, 0, 0] +RETURN R0 2 +)"); + + CHECK_EQ("\n" + compileFunction("return Vector3.one, Vector3.xAxis", 0, 2, 0, /*enableVectors*/ true), R"( +LOADK R0 K0 [1, 1, 1] +LOADK R1 K1 [1, 0, 0] +RETURN R0 2 +)"); + + CHECK_EQ("\n" + compileFunction("return vector.one == vector.create(1, 1, 1)", 0, 2), R"( +LOADB R0 1 +RETURN R0 1 +)"); +} + +TEST_CASE("CustomConstantFields") +{ + ScopedFastFlag luauCompileLibraryConstants{FFlag::LuauCompileLibraryConstants, true}; + + CHECK_EQ("\n" + compileFunction("return test.some_nil, test.some_boolean, test.some_number, test.some_string", 0, 2), R"( +LOADNIL R0 +LOADB R1 1 +LOADK R2 K0 [4.75] +LOADK R3 K1 ['test'] +RETURN R0 4 +)"); +} + TEST_CASE("TypeAssertion") { // validate that type assertions work with the compiler and that the code inside type assertion isn't evaluated @@ -7686,6 +7893,41 @@ RETURN R0 1 ); } +TEST_CASE("BuiltinFoldingProhibitedInOptions") +{ + ScopedFastFlag luauCompileDisabledBuiltins{FFlag::LuauCompileDisabledBuiltins, true}; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::CompileOptions options; + options.optimizationLevel = 2; + + // math.floor from the test is excluded in this list on purpose + static const char* kDisabledBuiltins[] = {"tostring", "math.abs", "math.sqrt", nullptr}; + options.disabledBuiltins = kDisabledBuiltins; + + Luau::compileOrThrow(bcb, "return math.abs(-42), math.floor(-1.5), math.sqrt(9), (tostring(2))", options); + + std::string result = bcb.dumpFunction(0); + + CHECK_EQ( + "\n" + result, + R"( +GETIMPORT R0 2 [math.abs] +LOADN R1 -42 +CALL R0 1 1 +LOADN R1 -2 +GETIMPORT R2 4 [math.sqrt] +LOADN R3 9 +CALL R2 1 1 +GETIMPORT R3 6 [tostring] +LOADN R4 2 +CALL R3 1 1 +RETURN R0 4 +)" + ); +} + TEST_CASE("LocalReassign") { // locals can be re-assigned and the register gets reused diff --git a/tests/EqSatSimplification.test.cpp b/tests/EqSatSimplification.test.cpp index aaaec456..d4f57182 100644 --- a/tests/EqSatSimplification.test.cpp +++ b/tests/EqSatSimplification.test.cpp @@ -640,6 +640,20 @@ TEST_CASE_FIXTURE(ESFixture, "string & (\"hi\" | \"bye\")") }}))); } +TEST_CASE_FIXTURE(ESFixture, "(\"err\" | \"ok\") & ~\"ok\"") +{ + TypeId err = arena->addType(SingletonType{StringSingleton{"err"}}); + TypeId ok1 = arena->addType(SingletonType{StringSingleton{"ok"}}); + TypeId ok2 = arena->addType(SingletonType{StringSingleton{"ok"}}); + + TypeId ty = arena->addType(IntersectionType{{ + arena->addType(UnionType{{err, ok1}}), + arena->addType(NegationType{ok2}) + }}); + + CHECK("\"err\"" == simplifyStr(ty)); +} + TEST_CASE_FIXTURE(ESFixture, "(Child | Unrelated) & ~Child") { const TypeId ty = arena->addType(IntersectionType{{ @@ -717,6 +731,38 @@ TEST_CASE_FIXTURE(ESFixture, "Child & intersect == boolean") +{ + std::vector> cases{ + {builtinTypes->numberType, arena->addType(BlockedType{})}, + {builtinTypes->stringType, arena->addType(BlockedType{})}, + {arena->addType(BlockedType{}), builtinTypes->numberType}, + {arena->addType(BlockedType{}), builtinTypes->stringType}, + }; + + for (const auto& [lhs, rhs] : cases) { + const TypeId tfun = arena->addType(TypeFunctionInstanceType{builtinTypeFunctions().ltFunc, {lhs, rhs}}); + CHECK("boolean" == simplifyStr(tfun)); + } +} + +TEST_CASE_FIXTURE(ESFixture, "unknown & ~string") +{ + CHECK_EQ( + "~string", simplifyStr(arena->addType(IntersectionType{{builtinTypes->unknownType, arena->addType(NegationType{builtinTypes->stringType})}})) + ); +} + +TEST_CASE_FIXTURE(ESFixture, "string & ~\"foo\"") +{ + CHECK_EQ( + "string & ~\"foo\"", + simplifyStr(arena->addType( + IntersectionType{{builtinTypes->stringType, arena->addType(NegationType{arena->addType(SingletonType{StringSingleton{"foo"}})})}} + )) + ); +} + // {someKey: ~any} // // Maybe something we could do here is to try to reduce the key, get the diff --git a/tests/FragmentAutocomplete.test.cpp b/tests/FragmentAutocomplete.test.cpp index 42f2bf09..91c388bd 100644 --- a/tests/FragmentAutocomplete.test.cpp +++ b/tests/FragmentAutocomplete.test.cpp @@ -24,7 +24,7 @@ LUAU_FASTFLAG(LuauAllowFragmentParsing); LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete) LUAU_FASTFLAG(LuauSymbolEquality); LUAU_FASTFLAG(LuauStoreSolverTypeOnModule); -LUAU_FASTFLAG(LexerResumesFromPosition) +LUAU_FASTFLAG(LexerResumesFromPosition2) static std::optional nullCallback(std::string tag, std::optional ptr, std::optional contents) { @@ -52,7 +52,7 @@ struct FragmentAutocompleteFixtureImpl : BaseType {FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete, true}, {FFlag::LuauStoreSolverTypeOnModule, true}, {FFlag::LuauSymbolEquality, true}, - {FFlag::LexerResumesFromPosition, true} + {FFlag::LexerResumesFromPosition2, true} }; FragmentAutocompleteFixtureImpl() diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index bfa69fe4..b0ba86ed 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -15,6 +15,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(DebugLuauFreezeArena); LUAU_FASTFLAG(DebugLuauMagicTypes); +LUAU_FASTFLAG(LuauReferenceAllocatorInNewSolver); namespace { @@ -1522,4 +1523,22 @@ TEST_CASE_FIXTURE(FrontendFixture, "get_required_scripts_dirty") CHECK(requiredScripts[0] == "game/workspace/MyModuleScript"); } +TEST_CASE_FIXTURE(FrontendFixture, "check_module_references_allocator") +{ + ScopedFastFlag sff{FFlag::LuauReferenceAllocatorInNewSolver, true}; + fileResolver.source["game/workspace/MyScript"] = R"( + print("Hello World") + )"; + + frontend.check("game/workspace/MyScript"); + + ModulePtr module = frontend.moduleResolver.getModule("game/workspace/MyScript"); + SourceModule* source = frontend.getSourceModule("game/workspace/MyScript"); + CHECK(module); + CHECK(source); + + CHECK_EQ(module->allocator.get(), source->allocator.get()); + CHECK_EQ(module->names.get(), source->names.get()); +} + TEST_SUITE_END(); diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index 27376777..cce167bf 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "lua.h" #include "lualib.h" +#include "luacode.h" #include "Luau/BytecodeBuilder.h" #include "Luau/CodeGen.h" @@ -15,10 +16,67 @@ #include #include -static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1) -{ - Luau::CodeGen::AssemblyOptions options; +LUAU_FASTFLAG(LuauCompileLibraryConstants) +static void luauLibraryConstantLookup(const char* library, const char* member, Luau::CompileConstant* constant) +{ + // While 'vector' library constants are a Luau built-in, their constant value depends on the embedder LUA_VECTOR_SIZE value + if (strcmp(library, "vector") == 0) + { + if (strcmp(member, "zero") == 0) + return Luau::setCompileConstantVector(constant, 0.0f, 0.0f, 0.0f, 0.0f); + + if (strcmp(member, "one") == 0) + return Luau::setCompileConstantVector(constant, 1.0f, 1.0f, 1.0f, 0.0f); + } + + if (strcmp(library, "Vector3") == 0) + { + if (strcmp(member, "xAxis") == 0) + return Luau::setCompileConstantVector(constant, 1.0f, 0.0f, 0.0f, 0.0f); + + if (strcmp(member, "yAxis") == 0) + return Luau::setCompileConstantVector(constant, 0.0f, 1.0f, 0.0f, 0.0f); + } +} + +static void luauLibraryConstantLookupC(const char* library, const char* member, lua_CompileConstant* constant) +{ + if (strcmp(library, "test") == 0) + { + if (strcmp(member, "some_nil") == 0) + return luau_set_compile_constant_nil(constant); + + if (strcmp(member, "some_boolean") == 0) + return luau_set_compile_constant_boolean(constant, 1); + + if (strcmp(member, "some_number") == 0) + return luau_set_compile_constant_number(constant, 4.75); + + if (strcmp(member, "some_vector") == 0) + return luau_set_compile_constant_vector(constant, 1.0f, 2.0f, 4.0f, 8.0f); + + if (strcmp(member, "some_string") == 0) + return luau_set_compile_constant_string(constant, "test", 4); + } +} + +static int luauLibraryTypeLookup(const char* library, const char* member) +{ + if (strcmp(library, "Vector3") == 0) + { + if (strcmp(member, "xAxis") == 0) + return LuauBytecodeType::LBC_TYPE_VECTOR; + + if (strcmp(member, "yAxis") == 0) + return LuauBytecodeType::LBC_TYPE_VECTOR; + } + + return LuauBytecodeType::LBC_TYPE_ANY; +} + +static void setupAssemblyOptions(Luau::CodeGen::AssemblyOptions& options, bool includeIrTypes) +{ options.compilationOptions.hooks.vectorAccessBytecodeType = vectorAccessBytecodeType; options.compilationOptions.hooks.vectorNamecallBytecodeType = vectorNamecallBytecodeType; options.compilationOptions.hooks.vectorAccess = vectorAccess; @@ -44,35 +102,10 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes = options.includeUseInfo = Luau::CodeGen::IncludeUseInfo::No; options.includeCfgInfo = Luau::CodeGen::IncludeCfgInfo::No; options.includeRegFlowInfo = Luau::CodeGen::IncludeRegFlowInfo::No; +} - Luau::Allocator allocator; - Luau::AstNameTable names(allocator); - Luau::ParseResult result = Luau::Parser::parse(source, strlen(source), names, allocator); - - if (!result.errors.empty()) - throw Luau::ParseErrors(result.errors); - - Luau::CompileOptions copts = {}; - - copts.optimizationLevel = 2; - copts.debugLevel = debugLevel; - copts.typeInfoLevel = 1; - copts.vectorCtor = "vector"; - copts.vectorType = "vector"; - - static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr}; - copts.userdataTypes = kUserdataCompileTypes; - - Luau::BytecodeBuilder bcb; - Luau::compileOrThrow(bcb, result, names, copts); - - std::string bytecode = bcb.getBytecode(); - std::unique_ptr globalState(luaL_newstate(), lua_close); - lua_State* L = globalState.get(); - - // Runtime mapping is specifically created to NOT match the compilation mapping - options.compilationOptions.userdataTypes = kUserdataRunTypes; - +static void initializeCodegen(lua_State* L) +{ if (Luau::CodeGen::isSupported()) { // Type remapper requires the codegen runtime @@ -101,9 +134,95 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes = } ); } +} + +static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1) +{ + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + Luau::ParseResult result = Luau::Parser::parse(source, strlen(source), names, allocator); + + if (!result.errors.empty()) + throw Luau::ParseErrors(result.errors); + + Luau::CompileOptions copts = {}; + + copts.optimizationLevel = 2; + copts.debugLevel = debugLevel; + copts.typeInfoLevel = 1; + copts.vectorCtor = "vector"; + copts.vectorType = "vector"; + + static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr}; + copts.userdataTypes = kUserdataCompileTypes; + + static const char* kLibrariesWithConstants[] = {"vector", "Vector3", nullptr}; + copts.librariesWithKnownMembers = kLibrariesWithConstants; + + copts.libraryMemberTypeCb = luauLibraryTypeLookup; + copts.libraryMemberConstantCb = luauLibraryConstantLookup; + + Luau::BytecodeBuilder bcb; + Luau::compileOrThrow(bcb, result, names, copts); + + std::string bytecode = bcb.getBytecode(); + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + initializeCodegen(L); if (luau_load(L, "name", bytecode.data(), bytecode.size(), 0) == 0) + { + Luau::CodeGen::AssemblyOptions options; + setupAssemblyOptions(options, includeIrTypes); + + // Runtime mapping is specifically created to NOT match the compilation mapping + options.compilationOptions.userdataTypes = kUserdataRunTypes; + return Luau::CodeGen::getAssembly(L, -1, options, nullptr); + } + + FAIL("Failed to load bytecode"); + return ""; +} + +static std::string getCodegenAssemblyUsingCApi(const char* source, bool includeIrTypes = false, int debugLevel = 1) +{ + lua_CompileOptions copts = {}; + + copts.optimizationLevel = 2; + copts.debugLevel = debugLevel; + copts.typeInfoLevel = 1; + + static const char* kLibrariesWithConstants[] = {"test", nullptr}; + copts.librariesWithKnownMembers = kLibrariesWithConstants; + + copts.libraryMemberTypeCb = luauLibraryTypeLookup; + copts.libraryMemberConstantCb = luauLibraryConstantLookupC; + + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source, strlen(source), &copts, &bytecodeSize); + REQUIRE(bytecode); + + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + initializeCodegen(L); + + if (luau_load(L, "name", bytecode, bytecodeSize, 0) == 0) + { + free(bytecode); + + Luau::CodeGen::AssemblyOptions options; + setupAssemblyOptions(options, includeIrTypes); + + // Runtime mapping is specifically created to NOT match the compilation mapping + options.compilationOptions.userdataTypes = kUserdataRunTypes; + + return Luau::CodeGen::getAssembly(L, -1, options, nullptr); + } + + free(bytecode); FAIL("Failed to load bytecode"); return ""; @@ -1994,4 +2113,109 @@ bb_bytecode_1: ); } +TEST_CASE("LibraryFieldTypesAndConstants") +{ + ScopedFastFlag luauCompileLibraryConstants{FFlag::LuauCompileLibraryConstants, true}; + + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function foo(a: vector) + return Vector3.xAxis * a + Vector3.yAxis +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo($arg0) line 2 +; R0: vector [argument] +; R2: vector from 3 to 4 +; R3: vector from 1 to 2 +; R3: vector from 3 to 4 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %4 = LOAD_TVALUE K0, 0i, tvector + %11 = LOAD_TVALUE R0 + %12 = MUL_VEC %4, %11 + %15 = LOAD_TVALUE K1, 0i, tvector + %23 = ADD_VEC %12, %15 + %24 = TAG_VECTOR %23 + STORE_TVALUE R1, %24 + INTERRUPT 4u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("LibraryFieldTypesAndConstants") +{ + ScopedFastFlag luauCompileLibraryConstants{FFlag::LuauCompileLibraryConstants, true}; + + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function foo(a: vector) + local x = vector.zero + x += a + return x +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo($arg0) line 2 +; R0: vector [argument] +; R1: vector from 0 to 3 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %4 = LOAD_TVALUE K0, 0i, tvector + %11 = LOAD_TVALUE R0 + %12 = ADD_VEC %4, %11 + %13 = TAG_VECTOR %12 + STORE_TVALUE R1, %13 + INTERRUPT 2u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("LibraryFieldTypesAndConstantsCApi") +{ + ScopedFastFlag luauCompileLibraryConstants{FFlag::LuauCompileLibraryConstants, true}; + + CHECK_EQ( + "\n" + getCodegenAssemblyUsingCApi( + R"( +local function foo() + return test.some_nil, test.some_boolean, test.some_number, test.some_vector, test.some_string +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo() line 2 +bb_bytecode_0: + STORE_TAG R0, tnil + STORE_INT R1, 1i + STORE_TAG R1, tboolean + STORE_DOUBLE R2, 4.75 + STORE_TAG R2, tnumber + %5 = LOAD_TVALUE K1, 0i, tvector + STORE_TVALUE R3, %5 + %7 = LOAD_TVALUE K2, 0i, tstring + STORE_TVALUE R4, %7 + INTERRUPT 5u + RETURN R0, 5i +)" + ); +} + TEST_SUITE_END(); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 025fa7fd..08b0bb0d 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -14,7 +14,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(DebugLuauFreezeArena); LUAU_FASTINT(LuauTypeCloneIterationLimit); - +LUAU_FASTFLAG(LuauOldSolverCreatesChildScopePointers) TEST_SUITE_BEGIN("ModuleTests"); TEST_CASE_FIXTURE(Fixture, "is_within_comment") @@ -540,4 +540,28 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "clone_a_bound_typepack_to_a_persistent_typep REQUIRE(res == follow(boundTo)); } +TEST_CASE_FIXTURE(Fixture, "old_solver_correctly_populates_child_scopes") +{ + ScopedFastFlag sff{FFlag::LuauOldSolverCreatesChildScopePointers, true}; + check(R"( +--!strict +if true then +end + +if false then +end + +if true then +else +end + +local x = {} +for i,v in x do +end +)"); + + auto& module = frontend.moduleResolver.getModule("MainModule"); + CHECK(module->getModuleScope()->children.size() == 7); +} + TEST_SUITE_END(); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 24186c0a..786d57b8 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -27,7 +27,9 @@ struct IsSubtypeFixture : Fixture if (!module->hasModuleScope()) FAIL("isSubtype: module scope data is not available"); - return ::Luau::isSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, ice); + SimplifierPtr simplifier = newSimplifier(NotNull{&module->internalTypes}, builtinTypes); + + return ::Luau::isSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, NotNull{simplifier.get()}, ice); } }; } // namespace diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index b35466cb..387c0d10 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -16,7 +16,6 @@ LUAU_FASTINT(LuauRecursionLimit) LUAU_FASTINT(LuauTypeLengthLimit) LUAU_FASTINT(LuauParseErrorLimit) LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) LUAU_FASTFLAG(LuauUserDefinedTypeFunParseExport) LUAU_FASTFLAG(LuauAllowComplexTypesInGenericParams) LUAU_FASTFLAG(LuauErrorRecoveryForTableTypes) @@ -2342,7 +2341,6 @@ TEST_CASE_FIXTURE(Fixture, "invalid_type_forms") TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions") { - ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; ScopedFastFlag sff2{FFlag::LuauUserDefinedTypeFunParseExport, true}; AstStat* stat = parse(R"( @@ -2363,8 +2361,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions") TEST_CASE_FIXTURE(Fixture, "parse_nested_type_function") { - ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - AstStat* stat = parse(R"( local v1 = 1 type function foo() @@ -2386,8 +2382,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_nested_type_function") TEST_CASE_FIXTURE(Fixture, "invalid_user_defined_type_functions") { - ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - matchParseError("local foo = 1; type function bar() print(foo) end", "Type function cannot reference outer local 'foo'"); matchParseError("type function foo() local v1 = 1; type function bar() print(v1) end end", "Type function cannot reference outer local 'v1'"); } diff --git a/tests/RequireByString.test.cpp b/tests/RequireByString.test.cpp index bc1161b0..574a4561 100644 --- a/tests/RequireByString.test.cpp +++ b/tests/RequireByString.test.cpp @@ -1,5 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Common.h" +#include "Luau/Config.h" + #include "ScopedFlags.h" #include "lua.h" #include "lualib.h" @@ -12,6 +14,8 @@ #include #include #include +#include +#include #if __APPLE__ #include @@ -491,4 +495,44 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "AliasHasIllegalFormat") assertOutputContainsAll({"false", " is not a valid alias"}); } +TEST_CASE("ParseAliases") +{ + std::string configJson = R"({ + "aliases": { + "MyAlias": "/my/alias/path", + } +})"; + + Luau::Config config; + + Luau::ConfigOptions::AliasOptions aliasOptions; + aliasOptions.configLocation = "/default/location"; + aliasOptions.overwriteAliases = true; + + Luau::ConfigOptions options{false, aliasOptions}; + + std::optional error = Luau::parseConfig(configJson, config, options); + REQUIRE(!error); + + auto checkContents = [](Luau::Config& config) -> void + { + CHECK(config.aliases.size() == 1); + REQUIRE(config.aliases.contains("myalias")); + + Luau::Config::AliasInfo& aliasInfo = config.aliases["myalias"]; + CHECK(aliasInfo.value == "/my/alias/path"); + CHECK(aliasInfo.originalCase == "MyAlias"); + }; + + checkContents(config); + + // Ensure that copied Configs retain the same information + Luau::Config copyConstructedConfig = config; + checkContents(copyConstructedConfig); + + Luau::Config copyAssignedConfig; + copyAssignedConfig = config; + checkContents(copyAssignedConfig); +} + TEST_SUITE_END(); diff --git a/tests/Subtyping.test.cpp b/tests/Subtyping.test.cpp index 27b2f6e7..76efc835 100644 --- a/tests/Subtyping.test.cpp +++ b/tests/Subtyping.test.cpp @@ -65,6 +65,7 @@ struct SubtypeFixture : Fixture TypeArena arena; InternalErrorReporter iceReporter; UnifierSharedState sharedState{&ice}; + SimplifierPtr simplifier = newSimplifier(NotNull{&arena}, builtinTypes); Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; TypeCheckLimits limits; TypeFunctionRuntime typeFunctionRuntime{NotNull{&iceReporter}, NotNull{&limits}}; @@ -79,7 +80,9 @@ struct SubtypeFixture : Fixture Subtyping mkSubtyping() { - return Subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}}; + return Subtyping{ + builtinTypes, NotNull{&arena}, NotNull{simplifier.get()}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter} + }; } TypePackId pack(std::initializer_list tys) diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 188d9682..dc63be77 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -12,8 +12,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) - TEST_SUITE_BEGIN("TranspilerTests"); TEST_CASE("test_1") @@ -698,8 +696,6 @@ TEST_CASE_FIXTURE(Fixture, "transpile_string_literal_escape") TEST_CASE_FIXTURE(Fixture, "transpile_type_functions") { - ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - std::string code = R"( type function foo(arg1, arg2) if arg1 == arg2 then return arg1 end return arg2 end )"; CHECK_EQ(code, transpile(code, {}, true).code); diff --git a/tests/TypeFunction.test.cpp b/tests/TypeFunction.test.cpp index 4aa6b680..0d717629 100644 --- a/tests/TypeFunction.test.cpp +++ b/tests/TypeFunction.test.cpp @@ -33,20 +33,20 @@ struct TypeFunctionFixture : Fixture if (isString(param)) { - return TypeFunctionReductionResult{ctx->builtins->numberType, false, {}, {}}; + return TypeFunctionReductionResult{ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; } else if (isNumber(param)) { - return TypeFunctionReductionResult{ctx->builtins->stringType, false, {}, {}}; + return TypeFunctionReductionResult{ctx->builtins->stringType, Reduction::MaybeOk, {}, {}}; } else if (is(param) || is(param) || is(param) || (ctx->solver && ctx->solver->hasUnresolvedConstraints(param))) { - return TypeFunctionReductionResult{std::nullopt, false, {param}, {}}; + return TypeFunctionReductionResult{std::nullopt, Reduction::MaybeOk, {param}, {}}; } else { - return TypeFunctionReductionResult{std::nullopt, true, {}, {}}; + return TypeFunctionReductionResult{std::nullopt, Reduction::Erroneous, {}, {}}; } } }; diff --git a/tests/TypeFunction.user.test.cpp b/tests/TypeFunction.user.test.cpp index 145772fd..a5af44fb 100644 --- a/tests/TypeFunction.user.test.cpp +++ b/tests/TypeFunction.user.test.cpp @@ -8,22 +8,18 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) -LUAU_FASTFLAG(LuauUserTypeFunFixRegister) LUAU_FASTFLAG(LuauUserTypeFunFixNoReadWrite) -LUAU_FASTFLAG(LuauUserTypeFunFixMetatable) -LUAU_FASTFLAG(LuauUserDefinedTypeFunctionResetState) -LUAU_FASTFLAG(LuauUserTypeFunNonstrict) +LUAU_FASTFLAG(LuauUserTypeFunPrintToError) LUAU_FASTFLAG(LuauUserTypeFunExportedAndLocal) LUAU_FASTFLAG(LuauUserDefinedTypeFunParseExport) LUAU_FASTFLAG(LuauUserTypeFunThreadBuffer) +LUAU_FASTFLAG(LuauUserTypeFunUpdateAllEnvs) TEST_SUITE_BEGIN("UserDefinedTypeFunctionTests"); TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_nil_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_nil(arg) @@ -39,7 +35,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_nil_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_nil_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getnil() @@ -59,7 +54,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_nil_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_unknown_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_unknown(arg) @@ -75,7 +69,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_unknown_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_unknown_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getunknown() @@ -95,7 +88,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_unknown_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_never_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_never(arg) @@ -111,7 +103,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_never_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_never_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getnever() @@ -131,7 +122,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_never_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_any_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_any(arg) @@ -147,7 +137,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_any_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_any_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getany() @@ -167,7 +156,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_any_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolean_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_bool(arg) @@ -183,7 +171,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolean_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolean_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getboolean() @@ -203,7 +190,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolean_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_number_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_num(arg) @@ -219,7 +205,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_number_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_number_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getnumber() @@ -239,8 +224,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_number_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "thread_and_buffer_types") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; ScopedFastFlag luauUserTypeFunThreadBuffer{FFlag::LuauUserTypeFunThreadBuffer, true}; LUAU_REQUIRE_NO_ERRORS(check(R"( @@ -269,7 +252,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "thread_and_buffer_types") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_string_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_str(arg) @@ -285,7 +267,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_string_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_string_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getstring() @@ -305,7 +286,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_string_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolsingleton_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_boolsingleton(arg) @@ -321,7 +301,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolsingleton_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolsingleton_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getboolsingleton() @@ -341,7 +320,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolsingleton_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strsingleton_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_strsingleton(arg) @@ -357,7 +335,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strsingleton_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strsingleton_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getstrsingleton() @@ -377,7 +354,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strsingleton_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_union_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_union(arg) @@ -397,7 +373,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_union_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_union_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getunion() @@ -426,7 +401,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_union_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_intersection_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_intersection(arg) @@ -446,7 +420,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_intersection_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_intersection_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getintersection() @@ -481,7 +454,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_intersection_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_negation_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getnegation() @@ -506,7 +478,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_negation_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_table_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_table(arg) @@ -526,7 +497,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_table_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_table_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function gettable() @@ -565,7 +535,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_table_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_metatable_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getmetatable() @@ -598,7 +567,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_metatable_methods_work") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_func(arg) @@ -614,7 +582,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_methods_work") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getfunction() @@ -644,7 +611,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_methods_work") TEST_CASE_FIXTURE(ClassFixture, "udtf_class_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_class(arg) @@ -659,7 +625,6 @@ TEST_CASE_FIXTURE(ClassFixture, "udtf_class_serialization_works") TEST_CASE_FIXTURE(ClassFixture, "udtf_class_methods_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( @@ -682,10 +647,8 @@ TEST_CASE_FIXTURE(ClassFixture, "udtf_class_methods_works") TEST_CASE_FIXTURE(ClassFixture, "write_of_readonly_is_nil") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; ScopedFastFlag udtfRwFix{FFlag::LuauUserTypeFunFixNoReadWrite, true}; - CheckResult result = check(R"( type function getclass(arg) local props = arg:properties() @@ -711,7 +674,6 @@ TEST_CASE_FIXTURE(ClassFixture, "write_of_readonly_is_nil") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_check_mutability") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function checkmut() @@ -743,7 +705,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_check_mutability") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_copy_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function getcopy() @@ -776,7 +737,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_copy_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_simple_cyclic_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_cycle(arg) @@ -797,7 +757,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_simple_cyclic_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_createtable_bad_metatable") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function badmetatable() @@ -806,7 +765,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_createtable_bad_metatable") local function bad(arg: badmetatable<>) end )"); - LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + LUAU_REQUIRE_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error UserDefinedTypeFunctionError* e = get(result.errors[0]); REQUIRE(e); CHECK( @@ -818,7 +777,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_createtable_bad_metatable") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_complex_cyclic_serialization_works") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function serialize_cycle2(arg) @@ -847,7 +805,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_complex_cyclic_serialization_works") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_user_error_is_reported") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function errors_if_string(arg) @@ -860,7 +817,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_user_error_is_reported") local function ok(idx: errors_if_string): nil return idx end )"); - LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + LUAU_REQUIRE_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error UserDefinedTypeFunctionError* e = get(result.errors[0]); REQUIRE(e); CHECK(e->message == "'errors_if_string' type function errored at runtime: [string \"errors_if_string\"]:5: We are in a math class! not english"); @@ -869,7 +826,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_user_error_is_reported") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_type_overrides_call_metamethod") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function hello(arg) @@ -878,7 +834,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_type_overrides_call_metamethod") local function ok(idx: hello): nil return idx end )"); - LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + LUAU_REQUIRE_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error UserDefinedTypeFunctionError* e = get(result.errors[0]); REQUIRE(e); CHECK(e->message == "'hello' type function errored at runtime: [string \"hello\"]:3: userdata"); @@ -887,7 +843,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_type_overrides_call_metamethod") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_type_overrides_eq_metamethod") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function hello() @@ -912,7 +867,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_type_overrides_eq_metamethod") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_type_cant_call_get_props") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function hello(arg) @@ -921,7 +875,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_type_cant_call_get_props") local function ok(idx: hello<() -> ()>): nil return idx end )"); - LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + LUAU_REQUIRE_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error UserDefinedTypeFunctionError* e = get(result.errors[0]); REQUIRE(e); CHECK( @@ -933,7 +887,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_type_cant_call_get_props") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_each_other") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function foo() @@ -945,16 +898,69 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_each_other") local function ok(idx: bar<>): nil return idx end )"); - LUAU_CHECK_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); TypePackMismatch* tpm = get(result.errors[0]); REQUIRE(tpm); CHECK(toString(tpm->givenTp) == "\"hi\""); } +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_each_other_2") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunUpdateAllEnvs{FFlag::LuauUserTypeFunUpdateAllEnvs, true}; + + CheckResult result = check(R"( + type function first(arg) + return arg + end + type function second(arg) + return types.singleton(first(arg)) + end + type function third() + return second("hi") + end + local function ok(idx: third<>): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "\"hi\""); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_each_other_3") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true}; + ScopedFastFlag luauUserTypeFunUpdateAllEnvs{FFlag::LuauUserTypeFunUpdateAllEnvs, true}; + + CheckResult result = check(R"( + -- this function should not see 'fourth' function when invoked from 'third' that sees it + type function first(arg) + return fourth(arg) + end + type function second(arg) + return types.singleton(first(arg)) + end + + do + type function fourth(arg) + return arg + end + type function third() + return second("hi") + end + local function ok(idx: third<>): nil return idx end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"('third' type function errored at runtime: [string "first"]:4: attempt to call a nil value)"); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_no_shared_state") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function foo() @@ -974,7 +980,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_no_shared_state") )"); // We are only checking first errors, others are mostly duplicates - LUAU_CHECK_ERROR_COUNT(8, result); + LUAU_REQUIRE_ERROR_COUNT(8, result); CHECK(toString(result.errors[0]) == R"('bar' type function errored at runtime: [string "foo"]:4: attempt to modify a readonly table)"); CHECK(toString(result.errors[1]) == R"(Type function instance bar<"x"> is uninhabited)"); } @@ -982,8 +988,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_no_shared_state") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_math_reset") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserDefinedTypeFunctionResetState{FFlag::LuauUserDefinedTypeFunctionResetState, true}; CheckResult result = check(R"( type function foo(x) @@ -998,7 +1002,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_math_reset") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_optionify") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function optionify(tbl) @@ -1018,7 +1021,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_optionify") local function ok(idx: optionify): nil return idx end )"); - LUAU_CHECK_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); TypePackMismatch* tpm = get(result.errors[0]); REQUIRE(tpm); CHECK(toString(tpm->givenTp) == "{ age: number?, alive: boolean?, name: string? }"); @@ -1027,7 +1030,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_optionify") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_illegal_global") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function illegal(arg) @@ -1039,7 +1041,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_illegal_global") local function ok(idx: illegal): nil return idx end )"); - LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + LUAU_REQUIRE_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error UserDefinedTypeFunctionError* e = get(result.errors[0]); REQUIRE(e); CHECK(e->message == "'illegal' type function errored at runtime: [string \"illegal\"]:3: this function is not supported in type functions"); @@ -1048,7 +1050,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_illegal_global") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_recursion_and_gc") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function foo(tbl) @@ -1065,7 +1066,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_recursion_and_gc") local function ok(idx: foo): nil return idx end )"); - LUAU_CHECK_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); TypePackMismatch* tpm = get(result.errors[0]); REQUIRE(tpm); } @@ -1073,7 +1074,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_recursion_and_gc") TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_recovery_no_upvalues") { ScopedFastFlag solverV2{FFlag::LuauSolverV2, true}; - ScopedFastFlag userDefinedTypeFunctionsSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( local var @@ -1089,14 +1089,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_recovery_no_upvalues") end )"); - LUAU_CHECK_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK(toString(result.errors[0]) == R"(Type function cannot reference outer local 'var')"); } TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_follow") { ScopedFastFlag solverV2{FFlag::LuauSolverV2, true}; - ScopedFastFlag userDefinedTypeFunctionsSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type t0 = any @@ -1105,14 +1104,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_follow") end )"); - LUAU_CHECK_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK(toString(result.errors[0]) == R"(Redefinition of type 't0', previously defined at line 2)"); } TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strip_indexer") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; CheckResult result = check(R"( type function stripindexer(tbl) @@ -1137,8 +1135,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strip_indexer") TEST_CASE_FIXTURE(BuiltinsFixture, "no_type_methods_on_types") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; CheckResult result = check(R"( type function test(x) @@ -1154,8 +1150,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "no_type_methods_on_types") TEST_CASE_FIXTURE(BuiltinsFixture, "no_types_functions_on_type") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; CheckResult result = check(R"( type function test(x) @@ -1171,8 +1165,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "no_types_functions_on_type") TEST_CASE_FIXTURE(BuiltinsFixture, "no_metatable_writes") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; CheckResult result = check(R"( type function test(x) @@ -1190,8 +1182,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "no_metatable_writes") TEST_CASE_FIXTURE(BuiltinsFixture, "no_eq_field") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; CheckResult result = check(R"( type function test(x) @@ -1207,8 +1197,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "no_eq_field") TEST_CASE_FIXTURE(BuiltinsFixture, "tag_field") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; CheckResult result = check(R"( type function test(x) @@ -1229,9 +1217,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tag_field") TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_serialization") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; - ScopedFastFlag luauUserTypeFunFixMetatable{FFlag::LuauUserTypeFunFixMetatable, true}; CheckResult result = check(R"( type function makemttbl() @@ -1260,9 +1245,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_serialization") TEST_CASE_FIXTURE(BuiltinsFixture, "nonstrict_mode") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; - ScopedFastFlag luauUserTypeFunNonstrict{FFlag::LuauUserTypeFunNonstrict, true}; CheckResult result = check(R"( --!nonstrict @@ -1275,8 +1257,6 @@ local a: foo<> = "a" TEST_CASE_FIXTURE(BuiltinsFixture, "implicit_export") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true}; fileResolver.source["game/A"] = R"( @@ -1305,8 +1285,6 @@ local b: Test.Concat<'third', 'fourth'> TEST_CASE_FIXTURE(BuiltinsFixture, "local_scope") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true}; CheckResult result = check(R"( @@ -1330,8 +1308,6 @@ local a = test() TEST_CASE_FIXTURE(BuiltinsFixture, "explicit_export") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true}; ScopedFastFlag luauUserDefinedTypeFunParseExport{FFlag::LuauUserDefinedTypeFunParseExport, true}; @@ -1357,4 +1333,64 @@ local b: Test.concat<'third', 'fourth'> CHECK(toString(requireType("b")) == R"("thirdfourth")"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "print_to_error") +{ + ScopedFastFlag solverV2{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunPrintToError{FFlag::LuauUserTypeFunPrintToError, true}; + + CheckResult result = check(R"( + type function t0(a) + print("Where does this go") + print(a.tag) + return types.any + end + local a: t0 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == R"(Where does this go)"); + CHECK(toString(result.errors[1]) == R"(string)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "print_to_error_plus_error") +{ + ScopedFastFlag solverV2{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunPrintToError{FFlag::LuauUserTypeFunPrintToError, true}; + + CheckResult result = check(R"( + type function t0(a) + print("Where does this go") + print(a.tag) + error("test") + end + local a: t0 + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"(Where does this go)"); + CHECK(toString(result.errors[1]) == R"(string)"); + CHECK(toString(result.errors[2]) == R"('t0' type function errored at runtime: [string "t0"]:5: test)"); + CHECK(toString(result.errors[3]) == R"(Type function instance t0 is uninhabited)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "print_to_error_plus_no_result") +{ + ScopedFastFlag solverV2{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunPrintToError{FFlag::LuauUserTypeFunPrintToError, true}; + + CheckResult result = check(R"( + type function t0(a) + print("Where does this go") + print(a.tag) + end + local a: t0 + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"(Where does this go)"); + CHECK(toString(result.errors[1]) == R"(string)"); + CHECK(toString(result.errors[2]) == R"('t0' type function: returned a non-type value)"); + CHECK(toString(result.errors[3]) == R"(Type function instance t0 is uninhabited)"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 2b5e64be..a9e5951a 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -9,7 +9,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) TEST_SUITE_BEGIN("TypeAliases"); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 3686f2d4..83910e1b 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -21,6 +21,7 @@ LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauRetrySubtypingWithoutHiddenPack) LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions) +LUAU_FASTFLAG(DebugLuauEqSatSimplification) TEST_SUITE_BEGIN("TypeInferFunctions"); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 5f4b730e..6bed0476 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -17,7 +17,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauMetatableFollow) TEST_SUITE_BEGIN("TypeInferOperators"); @@ -1614,8 +1613,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "compound_operator_on_upvalue") TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_operator_follow") { - ScopedFastFlag luauMetatableFollow{FFlag::LuauMetatableFollow, true}; - CheckResult result = check(R"( local t1 = {} local t2 = {} diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index b2cc6713..005f2291 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -11,6 +11,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAG(DebugLuauEqSatSimplification); LUAU_FASTINT(LuauNormalizeCacheLimit); LUAU_FASTINT(LuauTarjanChildLimit); LUAU_FASTINT(LuauTypeInferIterationLimit); @@ -65,8 +66,20 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - if (FFlag::LuauSolverV2) + const std::string expectedWithEqSat = R"( + function f(a:{fn:()->(unknown,...unknown)}): () + if type(a) == 'boolean'then + local a1:never=a + elseif a.fn()then + local a2:{fn:()->(unknown,...unknown)}=a + end + end + )"; + + if (FFlag::LuauSolverV2 && !FFlag::DebugLuauEqSatSimplification) CHECK_EQ(expectedWithNewSolver, decorateWithTypes(code)); + else if (FFlag::LuauSolverV2 && FFlag::DebugLuauEqSatSimplification) + CHECK_EQ(expectedWithEqSat, decorateWithTypes(code)); else CHECK_EQ(expected, decorateWithTypes(code)); } @@ -653,13 +666,15 @@ struct IsSubtypeFixture : Fixture { bool isSubtype(TypeId a, TypeId b) { + SimplifierPtr simplifier = newSimplifier(NotNull{&getMainModule()->internalTypes}, builtinTypes); + ModulePtr module = getMainModule(); REQUIRE(module); if (!module->hasModuleScope()) FAIL("isSubtype: module scope data is not available"); - return ::Luau::isSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, ice); + return ::Luau::isSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, NotNull{simplifier.get()}, ice); } }; } // namespace diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index dcbc712e..cc7123cf 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -8,6 +8,9 @@ #include "doctest.h" LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(DebugLuauEqSatSimplification) +LUAU_FASTFLAG(InferGlobalTypes) +LUAU_FASTFLAG(LuauGeneralizationRemoveRecursiveUpperBound) using namespace Luau; @@ -488,8 +491,15 @@ TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties") if (FFlag::LuauSolverV2) { - // CLI-115281 - Types produced by refinements don't always get simplified - CHECK("{ x: number? } & { x: ~(false?) }" == toString(requireTypeAtPosition({4, 23}))); + if (FFlag::DebugLuauEqSatSimplification) + { + CHECK("{ x: number }" == toString(requireTypeAtPosition({4, 23}))); + } + else + { + // CLI-115281 - Types produced by refinements don't always get simplified + CHECK("{ x: number? } & { x: ~(false?) }" == toString(requireTypeAtPosition({4, 23}))); + } CHECK("number" == toString(requireTypeAtPosition({5, 26}))); } @@ -1857,6 +1867,8 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "refine_a_param_that_got_resolved_duri TEST_CASE_FIXTURE(Fixture, "refine_a_property_of_some_global") { + ScopedFastFlag sff{FFlag::InferGlobalTypes, true}; + CheckResult result = check(R"( foo = { bar = 5 :: number? } @@ -1867,9 +1879,8 @@ TEST_CASE_FIXTURE(Fixture, "refine_a_property_of_some_global") if (FFlag::LuauSolverV2) { - LUAU_REQUIRE_ERROR_COUNT(3, result); - - CHECK_EQ("*error-type* | buffer | class | function | number | string | table | thread | true", toString(requireTypeAtPosition({4, 30}))); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("number", toString(requireTypeAtPosition({4, 30}))); } } @@ -2263,37 +2274,37 @@ TEST_CASE_FIXTURE(Fixture, "more_complex_long_disjunction_of_refinements_shouldn { CHECK_NOTHROW(check(R"( script:connect(function(obj) - if script.Parent.SeatNumber.Value == "1D" or - script.Parent.SeatNumber.Value == "2D" or - script.Parent.SeatNumber.Value == "3D" or - script.Parent.SeatNumber.Value == "4D" or - script.Parent.SeatNumber.Value == "5D" or - script.Parent.SeatNumber.Value == "6D" or - script.Parent.SeatNumber.Value == "7D" or - script.Parent.SeatNumber.Value == "8D" or - script.Parent.SeatNumber.Value == "9D" or - script.Parent.SeatNumber.Value == "10D" or - script.Parent.SeatNumber.Value == "11D" or - script.Parent.SeatNumber.Value == "12D" or - script.Parent.SeatNumber.Value == "13D" or - script.Parent.SeatNumber.Value == "14D" or - script.Parent.SeatNumber.Value == "15D" or - script.Parent.SeatNumber.Value == "16D" or - script.Parent.SeatNumber.Value == "1C" or - script.Parent.SeatNumber.Value == "2C" or - script.Parent.SeatNumber.Value == "3C" or - script.Parent.SeatNumber.Value == "4C" or - script.Parent.SeatNumber.Value == "5C" or - script.Parent.SeatNumber.Value == "6C" or - script.Parent.SeatNumber.Value == "7C" or - script.Parent.SeatNumber.Value == "8C" or - script.Parent.SeatNumber.Value == "9C" or - script.Parent.SeatNumber.Value == "10C" or - script.Parent.SeatNumber.Value == "11C" or - script.Parent.SeatNumber.Value == "12C" or - script.Parent.SeatNumber.Value == "13C" or - script.Parent.SeatNumber.Value == "14C" or - script.Parent.SeatNumber.Value == "15C" or + if script.Parent.SeatNumber.Value == "1D" or + script.Parent.SeatNumber.Value == "2D" or + script.Parent.SeatNumber.Value == "3D" or + script.Parent.SeatNumber.Value == "4D" or + script.Parent.SeatNumber.Value == "5D" or + script.Parent.SeatNumber.Value == "6D" or + script.Parent.SeatNumber.Value == "7D" or + script.Parent.SeatNumber.Value == "8D" or + script.Parent.SeatNumber.Value == "9D" or + script.Parent.SeatNumber.Value == "10D" or + script.Parent.SeatNumber.Value == "11D" or + script.Parent.SeatNumber.Value == "12D" or + script.Parent.SeatNumber.Value == "13D" or + script.Parent.SeatNumber.Value == "14D" or + script.Parent.SeatNumber.Value == "15D" or + script.Parent.SeatNumber.Value == "16D" or + script.Parent.SeatNumber.Value == "1C" or + script.Parent.SeatNumber.Value == "2C" or + script.Parent.SeatNumber.Value == "3C" or + script.Parent.SeatNumber.Value == "4C" or + script.Parent.SeatNumber.Value == "5C" or + script.Parent.SeatNumber.Value == "6C" or + script.Parent.SeatNumber.Value == "7C" or + script.Parent.SeatNumber.Value == "8C" or + script.Parent.SeatNumber.Value == "9C" or + script.Parent.SeatNumber.Value == "10C" or + script.Parent.SeatNumber.Value == "11C" or + script.Parent.SeatNumber.Value == "12C" or + script.Parent.SeatNumber.Value == "13C" or + script.Parent.SeatNumber.Value == "14C" or + script.Parent.SeatNumber.Value == "15C" or script.Parent.SeatNumber.Value == "16C" then end) )")); @@ -2418,4 +2429,23 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeof_instance_isa_refinement") CHECK_EQ("string", toString(requireTypeAtPosition({8, 28}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "remove_recursive_upper_bound_when_generalizing") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::DebugLuauEqSatSimplification, true}, + {FFlag::LuauGeneralizationRemoveRecursiveUpperBound, true}, + }; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + local t = {"hello"} + local v = t[2] + if type(v) == "nil" then + local foo = v + end + )")); + + CHECK_EQ("(nil & string)?", toString(requireTypeAtPosition({4, 24}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 80dddc67..2eb96152 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -25,6 +25,7 @@ LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTFLAG(LuauNewSolverVisitErrorExprLvalues) LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions) +LUAU_FASTFLAG(InferGlobalTypes) using namespace Luau; @@ -877,7 +878,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions1") CheckResult result = check(R"(local a = if true then "true" else "false")"); LUAU_REQUIRE_NO_ERRORS(result); TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveType::String); + CHECK("string" == toString(aType)); } TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions2") @@ -888,7 +889,7 @@ local a = if false then "a" elseif false then "b" else "c" )"); LUAU_REQUIRE_NO_ERRORS(result); TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveType::String); + CHECK("string" == toString(aType)); } TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_type_union") @@ -1727,7 +1728,7 @@ TEST_CASE_FIXTURE(Fixture, "visit_error_nodes_in_lvalue") // in lvalue positions. LUAU_REQUIRE_ERRORS(check(R"( --!strict - (::, + (::, )")); } @@ -1763,4 +1764,21 @@ TEST_CASE_FIXTURE(Fixture, "avoid_double_reference_to_free_type") )")); } +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_types_of_globals") +{ + ScopedFastFlag sff_LuauSolverV2{FFlag::LuauSolverV2, true}; + ScopedFastFlag sff_InferGlobalTypes{FFlag::InferGlobalTypes, true}; + + CheckResult result = check(R"( + --!strict + foo = 5 + print(foo) + )"); + + CHECK_EQ("number", toString(requireTypeAtPosition({3, 14}))); + + REQUIRE_EQ(1, result.errors.size()); + CHECK_EQ("Unknown global 'foo'", toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tools/heapsnapshot.py b/tools/heapsnapshot.py new file mode 100644 index 00000000..d3c0c92d --- /dev/null +++ b/tools/heapsnapshot.py @@ -0,0 +1,221 @@ +#!/usr/bin/python3 +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# Given a Luau heap dump, this tool generates a heap snapshot which can be imported by Chrome's DevTools Memory panel +# To generate a snapshot, use luaC_dump, ideally preceded by luaC_fullgc +# To import in Chrome, ensure the snapshot has the .heapsnapshot extension and go to: Inspect -> Memory -> Load Profile +# A reference for the heap snapshot schema can be found here: https://learn.microsoft.com/en-us/microsoft-edge/devtools-guide-chromium/memory-problems/heap-snapshot-schema + +# Usage: python3 heapsnapshot.py luauDump.json heapSnapshot.heapsnapshot + +import json +import sys + +# Header describing the snapshot format, copied from a real Chrome heap snapshot +snapshotMeta = { + "node_fields": ["type", "name", "id", "self_size", "edge_count", "trace_node_id", "detachedness"], + "node_types": [ + ["hidden", "array", "string", "object", "code", "closure", "regexp", "number", "native", "synthetic", "concatenated string", "sliced string", "symbol", "bigint", "object shape"], + "string", "number", "number", "number", "number", "number" + ], + "edge_fields": ["type", "name_or_index", "to_node"], + "edge_types": [ + ["context", "element", "property", "internal", "hidden", "shortcut", "weak"], + "string_or_number", "node" + ], + "trace_function_info_fields": ["function_id", "name", "script_name", "script_id", "line", "column"], + "trace_node_fields": ["id", "function_info_index", "count", "size", "children"], + "sample_fields": ["timestamp_us", "last_assigned_id"], + "location_fields": ["object_index", "script_id", "line", "column"], +} + +# These indices refer to the index in the snapshot's metadata header +nodeTypeToMetaIndex = {type: i for i, type in enumerate(snapshotMeta["node_types"][0])} +edgeTypeToMetaIndex = {type: i for i, type in enumerate(snapshotMeta["edge_types"][0])} + +nodeFieldCount = len(snapshotMeta["node_fields"]) +edgeFieldCount = len(snapshotMeta["edge_fields"]) + + +def readAddresses(data): + # Ordered list of addresses to ensure the registry is the first node, and also so we can process nodes in index order + addresses = [] + addressToNodeIndex = {} + + def addAddress(address): + assert address not in addressToNodeIndex, f"Address already exists in the snapshot: '{address}'" + addresses.append(address) + addressToNodeIndex[address] = len(addresses) - 1 + + # The registry is a special case that needs to be either the first or last node to ensure gc "distances" are calculated correctly + registryAddress = data["roots"]["registry"] + addAddress(registryAddress) + + for address, obj in data["objects"].items(): + if address == registryAddress: + continue + addAddress(address) + + return addresses, addressToNodeIndex + + +def convertToSnapshot(data): + addresses, addressToNodeIndex = readAddresses(data) + + # Some notable idiosyncrasies with the heap snapshot format: + # 1. The snapshot format contains a flat array of nodes and edges. Oddly, edges must reference the "absolute" index of a node's first element after flattening. + # 2. A node's outgoing edges are implicitly represented by a contiguous block of edges in the edges array which correspond to the node's position + # in the nodes array and its edge count. So if the first node has 3 edges, the first 3 edges in the edges array are its edges, and so on. + + nodes = [] + edges = [] + strings = [] + + stringToSnapshotIndex = {} + + def getUniqueId(address): + # TODO: we should hash this to an int32 instead of using the address directly + # Addresses are hexadecimal strings + return int(address, 16) + + def addNode(node): + assert len(node) == nodeFieldCount, f"Expected {nodeFieldCount} fields, got {len(node)}" + nodes.append(node) + + def addEdge(edge): + assert len(edge) == edgeFieldCount, f"Expected {edgeFieldCount} fields, got {len(edge)}" + edges.append(edge) + + def getStringSnapshotIndex(string): + assert isinstance(string, str), f"'{string}' is not of type string" + if string not in stringToSnapshotIndex: + strings.append(string) + stringToSnapshotIndex[string] = len(strings) - 1 + return stringToSnapshotIndex[string] + + def getNodeSnapshotIndex(address): + # This is the index of the first element of the node in the flattened nodes array + return addressToNodeIndex[address] * nodeFieldCount + + for address in addresses: + obj = data["objects"][address] + edgeCount = 0 + + if obj["type"] == "table": + # TODO: support weak references + name = f"Registry ({address})" if address == data["roots"]["registry"] else f"Luau table ({address})" + if "pairs" in obj: + for i in range(0, len(obj["pairs"]), 2): + key = obj["pairs"][i] + value = obj["pairs"][i + 1] + if key is None and value is None: + # Both the key and value are value types, nothing meaningful to add here + continue + elif key is None: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["property"], getStringSnapshotIndex("(Luau table key value type)"), getNodeSnapshotIndex(value)]) + elif value is None: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["internal"], getStringSnapshotIndex(f'Luau table key ref: {data["objects"][key]["type"]} ({key})'), getNodeSnapshotIndex(key)]) + elif data["objects"][key]["type"] == "string": + edgeCount += 2 + # This is a special case where the key is a string, so we can use it as the edge name + addEdge([edgeTypeToMetaIndex["property"], getStringSnapshotIndex(data["objects"][key]["data"]), getNodeSnapshotIndex(value)]) + addEdge([edgeTypeToMetaIndex["internal"], getStringSnapshotIndex(f'Luau table key ref: {data["objects"][key]["type"]} ({key})'), getNodeSnapshotIndex(key)]) + else: + edgeCount += 2 + addEdge([edgeTypeToMetaIndex["property"], getStringSnapshotIndex(f'{data["objects"][key]["type"]} ({key})'), getNodeSnapshotIndex(value)]) + addEdge([edgeTypeToMetaIndex["internal"], getStringSnapshotIndex(f'Luau table key ref: {data["objects"][key]["type"]} ({key})'), getNodeSnapshotIndex(key)]) + if "array" in obj: + for i, element in enumerate(obj["array"]): + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["element"], i, getNodeSnapshotIndex(element)]) + if "metatable" in obj: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["internal"], getStringSnapshotIndex(f'metatable ({obj["metatable"]})'), getNodeSnapshotIndex(obj["metatable"])]) + # TODO: consider distinguishing "object" and "array" node types + addNode([nodeTypeToMetaIndex["object"], getStringSnapshotIndex(name), getUniqueId(address), obj["size"], edgeCount, 0, 0]) + elif obj["type"] == "thread": + name = f'Luau thread: {obj["source"]}:{obj["line"]} ({address})' if "source" in obj else f"Luau thread ({address})" + if address == data["roots"]["mainthread"]: + name += " (main thread)" + if "env" in obj: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(f'env ({obj["env"]})'), getNodeSnapshotIndex(obj["env"])]) + if "stack" in obj: + for i, frame in enumerate(obj["stack"]): + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(f"callstack[{i}]"), getNodeSnapshotIndex(frame)]) + addNode([nodeTypeToMetaIndex["native"], getStringSnapshotIndex(name), getUniqueId(address), obj["size"], edgeCount, 0, 0]) + elif obj["type"] == "function": + name = f'Luau function: {obj["name"]} ({address})' if "name" in obj else f"Luau anonymous function ({address})" + if "env" in obj: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(f'env ({obj["env"]})'), getNodeSnapshotIndex(obj["env"])]) + if "proto" in obj: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(f'proto ({obj["proto"]})'), getNodeSnapshotIndex(obj["proto"])]) + if "upvalues" in obj: + for i, upvalue in enumerate(obj["upvalues"]): + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(f"up value ({upvalue})"), getNodeSnapshotIndex(upvalue)]) + addNode([nodeTypeToMetaIndex["closure"], getStringSnapshotIndex(name), getUniqueId(address), obj["size"], edgeCount, 0, 0]) + elif obj["type"] == "upvalue": + if "object" in obj: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(f'upvalue object ({obj["object"]})'), getNodeSnapshotIndex(obj["object"])]) + addNode([nodeTypeToMetaIndex["native"], getStringSnapshotIndex(f"Luau upvalue ({address})"), getUniqueId(address), obj["size"], edgeCount, 0, 0]) + elif obj["type"] == "userdata": + if "metatable" in obj: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["internal"], getStringSnapshotIndex(f'metatable ({obj["metatable"]})'), getNodeSnapshotIndex(obj["metatable"])]) + addNode([nodeTypeToMetaIndex["native"], getStringSnapshotIndex(f"Luau userdata ({address})"), getUniqueId(address), obj["size"], edgeCount, 0, 0]) + elif obj["type"] == "proto": + name = f'Luau proto: {obj["source"]}:{obj["line"]} ({address})' if "source" in obj else f"Luau proto ({address})" + if "constants" in obj: + for constant in obj["constants"]: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(constant), getNodeSnapshotIndex(constant)]) + if "protos" in obj: + for proto in obj["protos"]: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(proto), getNodeSnapshotIndex(proto)]) + addNode([nodeTypeToMetaIndex["code"], getStringSnapshotIndex(name), getUniqueId(address), obj["size"], edgeCount, 0, 0]) + elif obj["type"] == "string": + addNode([nodeTypeToMetaIndex["string"], getStringSnapshotIndex(obj["data"]), getUniqueId(address), obj["size"], 0, 0, 0]) + elif obj["type"] == "buffer": + addNode([nodeTypeToMetaIndex["native"], getStringSnapshotIndex(f'buffer ({address})'), getUniqueId(address), obj["size"], 0, 0, 0]) + else: + raise Exception(f"Unknown object type: '{obj['type']}'") + + return { + "snapshot": { + "meta": snapshotMeta, + "node_count": len(nodes), + "edge_count": len(edges), + "trace_function_count": 0, + }, + # flatten the nodes and edges arrays + "nodes": [field for node in nodes for field in node], + "edges": [field for edge in edges for field in edge], + "trace_function_infos": [], + "trace_tree": [], + "samples": [], + "locations": [], + "strings": strings, + } + + +if __name__ == "__main__": + luauDump = sys.argv[1] + heapSnapshot = sys.argv[2] + + with open(luauDump, "r") as file: + dump = json.load(file) + + snapshot = convertToSnapshot(dump) + + with open(heapSnapshot, "w") as file: + json.dump(snapshot, file) + + print(f"Heap snapshot written to: '{heapSnapshot}'")