diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index c6b4a828..4d38118a 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -12,6 +12,7 @@ #include "Luau/ToString.h" #include "Luau/Type.h" #include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFunction.h" #include "Luau/TypeFwd.h" #include "Luau/Variant.h" @@ -62,6 +63,7 @@ struct ConstraintSolver NotNull builtinTypes; InternalErrorReporter iceReporter; NotNull normalizer; + NotNull typeFunctionRuntime; // The entire set of constraints that the solver is trying to resolve. std::vector> constraints; NotNull rootScope; @@ -111,6 +113,7 @@ struct ConstraintSolver explicit ConstraintSolver( NotNull normalizer, + NotNull typeFunctionRuntime, NotNull rootScope, std::vector> constraints, ModuleName moduleName, @@ -278,18 +281,18 @@ public: /** * @returns true if the TypeId is in a blocked state. */ - bool isBlocked(TypeId ty); + bool isBlocked(TypeId ty) const; /** * @returns true if the TypePackId is in a blocked state. */ - bool isBlocked(TypePackId tp); + bool isBlocked(TypePackId tp) const; /** * Returns whether the constraint is blocked on anything. * @param constraint the constraint to check. */ - bool isBlocked(NotNull constraint); + bool isBlocked(NotNull constraint) const; /** Pushes a new solver constraint to the solver. * @param cv the body of the constraint. @@ -381,8 +384,8 @@ public: TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp); - void throwTimeLimitError(); - void throwUserCancelError(); + void throwTimeLimitError() const; + void throwUserCancelError() const; ToStringOptions opts; }; diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index baf3318c..fe9d7924 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -448,6 +448,13 @@ struct UnexpectedTypePackInSubtyping bool operator==(const UnexpectedTypePackInSubtyping& rhs) const; }; +struct UserDefinedTypeFunctionError +{ + std::string message; + + bool operator==(const UserDefinedTypeFunctionError& rhs) const; +}; + using TypeErrorData = Variant< TypeMismatch, UnknownSymbol, @@ -496,7 +503,8 @@ using TypeErrorData = Variant< CheckedFunctionIncorrectArgs, UnexpectedTypeInSubtyping, UnexpectedTypePackInSubtyping, - ExplicitFunctionAnnotationRecommended>; + ExplicitFunctionAnnotationRecommended, + UserDefinedTypeFunctionError>; struct TypeErrorSummary { diff --git a/Analysis/include/Luau/FragmentAutocomplete.h b/Analysis/include/Luau/FragmentAutocomplete.h new file mode 100644 index 00000000..53e301c1 --- /dev/null +++ b/Analysis/include/Luau/FragmentAutocomplete.h @@ -0,0 +1,23 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/DenseHash.h" +#include "Luau/Ast.h" + +#include + + +namespace Luau +{ + +struct FragmentAutocompleteAncestryResult +{ + DenseHashMap localMap{AstName()}; + std::vector localStack; + std::vector ancestry; + AstStat* nearestStatement; +}; + +FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos); + +} // namespace Luau diff --git a/Analysis/include/Luau/OverloadResolution.h b/Analysis/include/Luau/OverloadResolution.h index 9a2974a5..83a33215 100644 --- a/Analysis/include/Luau/OverloadResolution.h +++ b/Analysis/include/Luau/OverloadResolution.h @@ -35,6 +35,7 @@ struct OverloadResolver NotNull builtinTypes, NotNull arena, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull scope, NotNull reporter, NotNull limits, @@ -44,6 +45,7 @@ struct OverloadResolver NotNull builtinTypes; NotNull arena; NotNull normalizer; + NotNull typeFunctionRuntime; NotNull scope; NotNull ice; NotNull limits; @@ -109,6 +111,7 @@ SolveResult solveFunctionCall( NotNull arena, NotNull builtinTypes, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull iceReporter, NotNull limits, NotNull scope, diff --git a/Analysis/include/Luau/Subtyping.h b/Analysis/include/Luau/Subtyping.h index 09f46c4d..1e781056 100644 --- a/Analysis/include/Luau/Subtyping.h +++ b/Analysis/include/Luau/Subtyping.h @@ -135,6 +135,7 @@ struct Subtyping NotNull builtinTypes; NotNull arena; NotNull normalizer; + NotNull typeFunctionRuntime; NotNull iceReporter; TypeCheckLimits limits; @@ -155,6 +156,7 @@ struct Subtyping NotNull builtinTypes, NotNull typeArena, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull iceReporter ); diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h index 0faf036d..e7db9411 100644 --- a/Analysis/include/Luau/TypeChecker2.h +++ b/Analysis/include/Luau/TypeChecker2.h @@ -83,6 +83,7 @@ struct TypeChecker2 DenseHashSet seenTypeFunctionInstances{nullptr}; Normalizer normalizer; + TypeFunctionRuntime typeFunctionRuntime; Subtyping _subtyping; NotNull subtyping; diff --git a/Analysis/include/Luau/TypeFunction.h b/Analysis/include/Luau/TypeFunction.h index c686f482..252b4c9a 100644 --- a/Analysis/include/Luau/TypeFunction.h +++ b/Analysis/include/Luau/TypeFunction.h @@ -1,10 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/ConstraintSolver.h" +#include "Luau/Constraint.h" #include "Luau/Error.h" #include "Luau/NotNull.h" #include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFunctionRuntime.h" #include "Luau/TypeFwd.h" #include @@ -16,14 +17,23 @@ namespace Luau struct TypeArena; struct TxnLog; +struct ConstraintSolver; class Normalizer; +struct TypeFunctionRuntime +{ + // For user-defined type functions, we store all generated types and packs for the duration of the typecheck + TypedAllocator typeArena; + TypedAllocator typePackArena; +}; + struct TypeFunctionContext { NotNull arena; NotNull builtins; NotNull scope; NotNull normalizer; + NotNull typeFunctionRuntime; NotNull ice; NotNull limits; @@ -35,23 +45,14 @@ struct TypeFunctionContext std::optional userFuncName; // Name of the user-defined type function; only available for UDTFs std::optional userFuncBody; // Body of the user-defined type function; only available for UDTFs - TypeFunctionContext(NotNull cs, NotNull scope, NotNull constraint) - : arena(cs->arena) - , builtins(cs->builtinTypes) - , scope(scope) - , normalizer(cs->normalizer) - , ice(NotNull{&cs->iceReporter}) - , limits(NotNull{&cs->limits}) - , solver(cs.get()) - , constraint(constraint.get()) - { - } + TypeFunctionContext(NotNull cs, NotNull scope, NotNull constraint); TypeFunctionContext( NotNull arena, NotNull builtins, NotNull scope, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull ice, NotNull limits ) @@ -59,6 +60,7 @@ struct TypeFunctionContext , builtins(builtins) , scope(scope) , normalizer(normalizer) + , typeFunctionRuntime(typeFunctionRuntime) , ice(ice) , limits(limits) , solver(nullptr) @@ -66,7 +68,7 @@ struct TypeFunctionContext { } - NotNull pushConstraint(ConstraintV&& c); + NotNull pushConstraint(ConstraintV&& c) const; }; /// Represents a reduction result, which may have successfully reduced the type, @@ -88,6 +90,8 @@ struct TypeFunctionReductionResult /// Any type packs that need to be progressed or mutated before the /// reduction may proceed. std::vector blockedPacks; + /// A runtime error message from user-defined type functions + std::optional error; }; template diff --git a/Analysis/include/Luau/TypeFunctionRuntime.h b/Analysis/include/Luau/TypeFunctionRuntime.h new file mode 100644 index 00000000..eb5d19ee --- /dev/null +++ b/Analysis/include/Luau/TypeFunctionRuntime.h @@ -0,0 +1,267 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" +#include "Luau/Variant.h" + +#include +#include +#include +#include + +using lua_State = struct lua_State; + +namespace Luau +{ + +void* typeFunctionAlloc(void* ud, void* ptr, size_t osize, size_t nsize); + +// Replica of types from Type.h +struct TypeFunctionType; +using TypeFunctionTypeId = const TypeFunctionType*; + +struct TypeFunctionTypePackVar; +using TypeFunctionTypePackId = const TypeFunctionTypePackVar*; + +struct TypeFunctionPrimitiveType +{ + enum Type + { + NilType, + Boolean, + Number, + String, + }; + + Type type; + + TypeFunctionPrimitiveType(Type type) + : type(type) + { + } +}; + +struct TypeFunctionBooleanSingleton +{ + bool value = false; +}; + +struct TypeFunctionStringSingleton +{ + std::string value; +}; + +using TypeFunctionSingletonVariant = Variant; + +struct TypeFunctionSingletonType +{ + TypeFunctionSingletonVariant variant; + + explicit TypeFunctionSingletonType(TypeFunctionSingletonVariant variant) + : variant(std::move(variant)) + { + } +}; + +template +const T* get(const TypeFunctionSingletonType* tv) +{ + LUAU_ASSERT(tv); + + return tv ? get_if(&tv->variant) : nullptr; +} + +template +T* getMutable(const TypeFunctionSingletonType* tv) +{ + LUAU_ASSERT(tv); + + return tv ? get_if(&const_cast(tv)->variant) : nullptr; +} + +struct TypeFunctionUnionType +{ + std::vector components; +}; + +struct TypeFunctionIntersectionType +{ + std::vector components; +}; + +struct TypeFunctionAnyType +{ +}; + +struct TypeFunctionUnknownType +{ +}; + +struct TypeFunctionNeverType +{ +}; + +struct TypeFunctionNegationType +{ + TypeFunctionTypeId type; +}; + +struct TypeFunctionTypePack +{ + std::vector head; + std::optional tail; +}; + +struct TypeFunctionVariadicTypePack +{ + TypeFunctionTypeId type; +}; + +using TypeFunctionTypePackVariant = Variant; + +struct TypeFunctionTypePackVar +{ + TypeFunctionTypePackVariant type; + + TypeFunctionTypePackVar(TypeFunctionTypePackVariant type) + : type(std::move(type)) + { + } + + bool operator==(const TypeFunctionTypePackVar& rhs) const; +}; + +struct TypeFunctionFunctionType +{ + TypeFunctionTypePackId argTypes; + TypeFunctionTypePackId retTypes; +}; + +template +const T* get(TypeFunctionTypePackId tv) +{ + LUAU_ASSERT(tv); + + return tv ? get_if(&tv->type) : nullptr; +} + +template +T* getMutable(TypeFunctionTypePackId tv) +{ + LUAU_ASSERT(tv); + + return tv ? get_if(&const_cast(tv)->type) : nullptr; +} + +struct TypeFunctionTableIndexer +{ + TypeFunctionTableIndexer(TypeFunctionTypeId keyType, TypeFunctionTypeId valueType) + : keyType(keyType) + , valueType(valueType) + { + } + + TypeFunctionTypeId keyType; + TypeFunctionTypeId valueType; +}; + +struct TypeFunctionProperty +{ + static TypeFunctionProperty readonly(TypeFunctionTypeId ty); + static TypeFunctionProperty writeonly(TypeFunctionTypeId ty); + static TypeFunctionProperty rw(TypeFunctionTypeId ty); // Shared read-write type. + static TypeFunctionProperty rw(TypeFunctionTypeId read, TypeFunctionTypeId write); // Separate read-write type. + + bool isReadOnly() const; + bool isWriteOnly() const; + + std::optional readTy; + std::optional writeTy; +}; + +struct TypeFunctionTableType +{ + using Name = std::string; + using Props = std::unordered_map; + + Props props; + + std::optional indexer; + + // Should always be a TypeFunctionTableType + std::optional metatable; +}; + +struct TypeFunctionClassType +{ + using Name = std::string; + using Props = std::unordered_map; + + Props props; + + std::optional indexer; + + std::optional metatable; // metaclass? + + std::optional parent; + + std::string name; +}; + +using TypeFunctionTypeVariant = Luau::Variant< + TypeFunctionPrimitiveType, + TypeFunctionAnyType, + TypeFunctionUnknownType, + TypeFunctionNeverType, + TypeFunctionSingletonType, + TypeFunctionUnionType, + TypeFunctionIntersectionType, + TypeFunctionNegationType, + TypeFunctionFunctionType, + TypeFunctionTableType, + TypeFunctionClassType>; + +struct TypeFunctionType +{ + TypeFunctionTypeVariant type; + + TypeFunctionType(TypeFunctionTypeVariant type) + : type(std::move(type)) + { + } + + bool operator==(const TypeFunctionType& rhs) const; +}; + +template +const T* get(TypeFunctionTypeId tv) +{ + LUAU_ASSERT(tv); + + return tv ? Luau::get_if(&tv->type) : nullptr; +} + +template +T* getMutable(TypeFunctionTypeId tv) +{ + LUAU_ASSERT(tv); + + return tv ? Luau::get_if(&const_cast(tv)->type) : nullptr; +} + +std::optional checkResultForError(lua_State* L, const char* typeFunctionName, int luaResult); + +TypeFunctionType* allocateTypeFunctionType(lua_State* L, TypeFunctionTypeVariant type); +TypeFunctionTypePackVar* allocateTypeFunctionTypePack(lua_State* L, TypeFunctionTypePackVariant type); + +void allocTypeUserData(lua_State* L, TypeFunctionTypeVariant type); + +bool isTypeUserData(lua_State* L, int idx); +TypeFunctionTypeId getTypeUserData(lua_State* L, int idx); +std::optional optionalTypeUserData(lua_State* L, int idx); + +void registerTypeUserData(lua_State* L); + +void setTypeFunctionEnvironment(lua_State* L); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeFunctionRuntimeBuilder.h b/Analysis/include/Luau/TypeFunctionRuntimeBuilder.h new file mode 100644 index 00000000..c9e1152f --- /dev/null +++ b/Analysis/include/Luau/TypeFunctionRuntimeBuilder.h @@ -0,0 +1,52 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Type.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypeFunctionRuntime.h" + +namespace Luau +{ + +using Kind = Variant; + +template +const T* get(const Kind& kind) +{ + return get_if(&kind); +} + +using TypeFunctionKind = Variant; + +template +const T* get(const TypeFunctionKind& tfkind) +{ + return get_if(&tfkind); +} + +struct TypeFunctionRuntimeBuilderState +{ + NotNull ctx; + + // Mapping of class name to ClassType + // Invariant: users can not create a new class types -> any class types that get deserialized must have been an argument to the type function + // Using this invariant, whenever a ClassType is serialized, we can put it into this map + // whenever a ClassType is deserialized, we can use this map to return the corresponding value + DenseHashMap classesSerialized{{}}; + + // List of errors that occur during serialization/deserialization + // At every iteration of serialization/deserialzation, if this list.size() != 0, we halt the process + std::vector errors{}; + + TypeFunctionRuntimeBuilderState(NotNull ctx) + : ctx(ctx) + , classesSerialized({}) + , errors({}) + { + } +}; + +TypeFunctionTypeId serialize(TypeId ty, TypeFunctionRuntimeBuilderState* state); +TypeId deserialize(TypeFunctionTypeId ty, TypeFunctionRuntimeBuilderState* state); + +} // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 868e31f1..0cb14879 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -149,13 +149,15 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, T if (FFlag::LuauSolverV2) { + TypeFunctionRuntime typeFunctionRuntime; // TODO: maybe subtyping checks should not invoke user-defined type function runtime + if (FFlag::LuauAutocompleteNewSolverLimit) { unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; } - Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&iceReporter}}; + Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}}; return subtyping.isSubtype(subTy, superTy, scope).isSubtype; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index f7c4fb5e..7db74cfb 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -321,6 +321,7 @@ struct InstantiationQueuer : TypeOnceVisitor ConstraintSolver::ConstraintSolver( NotNull normalizer, + NotNull typeFunctionRuntime, NotNull rootScope, std::vector> constraints, ModuleName moduleName, @@ -332,11 +333,12 @@ ConstraintSolver::ConstraintSolver( : arena(normalizer->arena) , builtinTypes(normalizer->builtinTypes) , normalizer(normalizer) + , typeFunctionRuntime(typeFunctionRuntime) , constraints(std::move(constraints)) , rootScope(rootScope) , currentModuleName(std::move(moduleName)) , moduleResolver(moduleResolver) - , requireCycles(requireCycles) + , requireCycles(std::move(requireCycles)) , logger(logger) , limits(std::move(limits)) { @@ -344,7 +346,7 @@ ConstraintSolver::ConstraintSolver( for (NotNull c : this->constraints) { - unsolvedConstraints.push_back(c); + unsolvedConstraints.emplace_back(c); // initialize the reference counts for the free types in this constraint. for (auto ty : c->getMaybeMutatedFreeTypes()) @@ -1240,7 +1242,14 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope, NotNull{&iceReporter}, NotNull{&limits}, constraint->location + builtinTypes, + NotNull{arena}, + normalizer, + typeFunctionRuntime, + constraint->scope, + NotNull{&iceReporter}, + NotNull{&limits}, + constraint->location }; auto [status, overload] = resolver.selectOverload(fn, argsPack); TypeId overloadToUse = fn; @@ -1270,7 +1279,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulllocation, addition)); + upperBoundContributors[expanded].emplace_back(constraint->location, addition); } if (occursCheckPassed && c.callSite) @@ -1437,8 +1446,17 @@ bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNulllowerBound; - shiftReferences(c.freeType, bindTo); - bind(constraint, c.freeType, bindTo); + if (DFInt::LuauTypeSolverRelease >= 645) + { + auto ty = follow(c.freeType); + shiftReferences(ty, bindTo); + bind(constraint, ty, bindTo); + } + else + { + shiftReferences(c.freeType, bindTo); + bind(constraint, c.freeType, bindTo); + } return true; } @@ -2603,7 +2621,7 @@ bool ConstraintSolver::unify(NotNull constraint, TID subTy, TI for (const auto& [expanded, additions] : u2.expandedFreeTypes) { for (TypeId addition : additions) - upperBoundContributors[expanded].push_back(std::make_pair(constraint->location, addition)); + upperBoundContributors[expanded].emplace_back(constraint->location, addition); } } else @@ -2820,7 +2838,7 @@ void ConstraintSolver::reproduceConstraints(NotNull scope, const Location } } -bool ConstraintSolver::isBlocked(TypeId ty) +bool ConstraintSolver::isBlocked(TypeId ty) const { ty = follow(ty); @@ -2830,7 +2848,7 @@ bool ConstraintSolver::isBlocked(TypeId ty) return nullptr != get(ty) || nullptr != get(ty); } -bool ConstraintSolver::isBlocked(TypePackId tp) +bool ConstraintSolver::isBlocked(TypePackId tp) const { tp = follow(tp); @@ -2840,7 +2858,7 @@ bool ConstraintSolver::isBlocked(TypePackId tp) return nullptr != get(tp); } -bool ConstraintSolver::isBlocked(NotNull constraint) +bool ConstraintSolver::isBlocked(NotNull constraint) const { auto blockedIt = blockedConstraints.find(constraint); return blockedIt != blockedConstraints.end() && blockedIt->second > 0; @@ -2851,7 +2869,7 @@ NotNull ConstraintSolver::pushConstraint(NotNull scope, const std::unique_ptr c = std::make_unique(scope, location, std::move(cv)); NotNull borrow = NotNull(c.get()); solverConstraints.push_back(std::move(c)); - unsolvedConstraints.push_back(borrow); + unsolvedConstraints.emplace_back(borrow); return borrow; } @@ -2997,12 +3015,12 @@ TypePackId ConstraintSolver::anyifyModuleReturnTypePackGenerics(TypePackId tp) return arena->addTypePack(resultTypes, resultTail); } -LUAU_NOINLINE void ConstraintSolver::throwTimeLimitError() +LUAU_NOINLINE void ConstraintSolver::throwTimeLimitError() const { throw TimeLimitError(currentModuleName); } -LUAU_NOINLINE void ConstraintSolver::throwUserCancelError() +LUAU_NOINLINE void ConstraintSolver::throwUserCancelError() const { throw UserCancelError(currentModuleName); } diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 60058d99..c91ce00d 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -793,6 +793,11 @@ struct ErrorConverter return "Encountered an unexpected type pack in subtyping: " + toString(e.tp); } + std::string operator()(const UserDefinedTypeFunctionError& e) const + { + return e.message; + } + std::string operator()(const CannotAssignToNever& e) const { std::string result = "Cannot assign a value of type " + toString(e.rhsType) + " to a field of type never"; @@ -1175,6 +1180,11 @@ bool UnexpectedTypePackInSubtyping::operator==(const UnexpectedTypePackInSubtypi return tp == rhs.tp; } +bool UserDefinedTypeFunctionError::operator==(const UserDefinedTypeFunctionError& rhs) const +{ + return message == rhs.message; +} + bool CannotAssignToNever::operator==(const CannotAssignToNever& rhs) const { if (cause.size() != rhs.cause.size()) @@ -1384,6 +1394,9 @@ void copyError(T& e, TypeArena& destArena, CloneState& cloneState) e.ty = clone(e.ty); else if constexpr (std::is_same_v) e.tp = clone(e.tp); + else if constexpr (std::is_same_v) + { + } else if constexpr (std::is_same_v) { e.rhsType = clone(e.rhsType); diff --git a/Analysis/src/FragmentAutocomplete.cpp b/Analysis/src/FragmentAutocomplete.cpp new file mode 100644 index 00000000..4088c500 --- /dev/null +++ b/Analysis/src/FragmentAutocomplete.cpp @@ -0,0 +1,48 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/FragmentAutocomplete.h" + +#include "Luau/Ast.h" +#include "Luau/AstQuery.h" + +namespace Luau +{ + +FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos) +{ + std::vector ancestry = findAncestryAtPositionForAutocomplete(root, cursorPos); + DenseHashMap localMap{AstName()}; + std::vector localStack; + AstStat* nearestStatement = nullptr; + for (AstNode* node : ancestry) + { + if (auto block = node->as()) + { + for (auto stat : block->body) + { + if (stat->location.begin <= cursorPos) + nearestStatement = stat; + if (stat->location.begin <= cursorPos) + { + // This statement precedes the current one + if (auto loc = stat->as()) + { + for (auto v : loc->vars) + { + localStack.push_back(v); + localMap[v->name] = v; + } + } + else if (auto locFun = stat->as()) + { + localStack.push_back(locFun->name); + localMap[locFun->name->name] = locFun->name; + } + } + } + } + } + + return {std::move(localMap), std::move(localStack), std::move(ancestry), std::move(nearestStatement)}; +} + +} // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 8c439181..ca627728 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1383,6 +1383,7 @@ ModulePtr check( unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit); Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}}; + TypeFunctionRuntime typeFunctionRuntime; ConstraintGenerator cg{ result, @@ -1402,6 +1403,7 @@ ModulePtr check( ConstraintSolver cs{ NotNull{&normalizer}, + NotNull{&typeFunctionRuntime}, NotNull(cg.rootScope), borrowConstraints(cg.constraints), result->name, diff --git a/Analysis/src/Generalization.cpp b/Analysis/src/Generalization.cpp index d209cb81..a79814ec 100644 --- a/Analysis/src/Generalization.cpp +++ b/Analysis/src/Generalization.cpp @@ -9,6 +9,8 @@ #include "Luau/TypePack.h" #include "Luau/VisitType.h" +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) + namespace Luau { @@ -871,6 +873,17 @@ struct TypeCacher : TypeOnceVisitor markUncacheable(tp); return false; } + + bool visit(TypePackId tp, const BoundTypePack& btp) override { + if (DFInt::LuauTypeSolverRelease >= 645) { + traverse(btp.boundTo); + if (isUncacheable(btp.boundTo)) + markUncacheable(tp); + return false; + } + return true; + } + }; std::optional generalize( diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index a3d8b4e3..64e05993 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -227,6 +227,8 @@ static void errorToString(std::ostream& stream, const T& err) stream << "UnexpectedTypeInSubtyping { ty = '" + toString(err.ty) + "' }"; else if constexpr (std::is_same_v) stream << "UnexpectedTypePackInSubtyping { tp = '" + toString(err.tp) + "' }"; + else if constexpr (std::is_same_v) + stream << "UserDefinedTypeFunctionError { " << err.message << " }"; else if constexpr (std::is_same_v) { stream << "CannotAssignToNever { rvalueType = '" << toString(err.rhsType) << "', reason = '" << err.reason << "', cause = { "; diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 3a049216..564a3c35 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -15,6 +15,7 @@ #include LUAU_FASTFLAG(LuauSolverV2); +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) namespace Luau { @@ -131,10 +132,26 @@ struct ClonePublicInterface : Substitution } ftv->level = TypeLevel{0, 0}; + if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645) + ftv->scope = nullptr; } else if (TableType* ttv = getMutable(result)) { ttv->level = TypeLevel{0, 0}; + if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645) + ttv->scope = nullptr; + } + + if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645) + { + if (auto freety = getMutable(result)) + { + freety->scope = nullptr; + } + else if (auto genericty = getMutable(result)) + { + genericty->scope = nullptr; + } } return result; diff --git a/Analysis/src/NonStrictTypeChecker.cpp b/Analysis/src/NonStrictTypeChecker.cpp index 116cf5cb..2131887a 100644 --- a/Analysis/src/NonStrictTypeChecker.cpp +++ b/Analysis/src/NonStrictTypeChecker.cpp @@ -160,6 +160,7 @@ struct NonStrictTypeChecker NotNull arena; Module* module; Normalizer normalizer; + TypeFunctionRuntime typeFunctionRuntime; Subtyping subtyping; NotNull dfg; DenseHashSet noTypeFunctionErrors{nullptr}; @@ -182,7 +183,7 @@ struct NonStrictTypeChecker , arena(arena) , module(module) , normalizer{arena, builtinTypes, unifierState, /* cache inhabitance */ true} - , subtyping{builtinTypes, arena, NotNull(&normalizer), ice} + , subtyping{builtinTypes, arena, NotNull(&normalizer), NotNull(&typeFunctionRuntime), ice} , dfg(dfg) , limits(limits) { @@ -228,7 +229,12 @@ struct NonStrictTypeChecker return instance; ErrorVec errors = - reduceTypeFunctions(instance, location, TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true) + reduceTypeFunctions( + instance, + location, + TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, ice, limits}, + true + ) .errors; if (errors.empty()) diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index c768f02c..7ca57e61 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -3434,11 +3434,12 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull scope, N UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + TypeFunctionRuntime typeFunctionRuntime; // TODO: maybe subtyping checks should not invoke user-defined type function runtime // Subtyping under DCR is not implemented using unification! if (FFlag::LuauSolverV2) { - Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&ice}}; + Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}}; return subtyping.isSubtype(subPack, superPack, scope).isSubtype; } diff --git a/Analysis/src/OverloadResolution.cpp b/Analysis/src/OverloadResolution.cpp index 972c9e3a..fbcce2b7 100644 --- a/Analysis/src/OverloadResolution.cpp +++ b/Analysis/src/OverloadResolution.cpp @@ -17,6 +17,7 @@ OverloadResolver::OverloadResolver( NotNull builtinTypes, NotNull arena, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull scope, NotNull reporter, NotNull limits, @@ -25,10 +26,11 @@ OverloadResolver::OverloadResolver( : builtinTypes(builtinTypes) , arena(arena) , normalizer(normalizer) + , typeFunctionRuntime(typeFunctionRuntime) , scope(scope) , ice(reporter) , limits(limits) - , subtyping({builtinTypes, arena, normalizer, ice}) + , subtyping({builtinTypes, arena, normalizer, typeFunctionRuntime, ice}) , callLoc(callLocation) { } @@ -199,8 +201,9 @@ std::pair OverloadResolver::checkOverload_ const std::vector* argExprs ) { - FunctionGraphReductionResult result = - reduceTypeFunctions(fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, ice, limits}, /*force=*/true); + FunctionGraphReductionResult result = reduceTypeFunctions( + fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, ice, limits}, /*force=*/true + ); if (!result.errors.empty()) return {OverloadIsNonviable, result.errors}; @@ -405,6 +408,7 @@ std::optional selectOverload( NotNull builtinTypes, NotNull arena, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull scope, NotNull iceReporter, NotNull limits, @@ -413,7 +417,7 @@ std::optional selectOverload( TypePackId argsPack ) { - OverloadResolver resolver{builtinTypes, arena, normalizer, scope, iceReporter, limits, location}; + OverloadResolver resolver{builtinTypes, arena, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location}; auto [status, overload] = resolver.selectOverload(fn, argsPack); if (status == OverloadResolver::Analysis::Ok) @@ -429,6 +433,7 @@ SolveResult solveFunctionCall( NotNull arena, NotNull builtinTypes, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull iceReporter, NotNull limits, NotNull scope, @@ -437,7 +442,8 @@ SolveResult solveFunctionCall( TypePackId argsPack ) { - std::optional overloadToUse = selectOverload(builtinTypes, arena, normalizer, scope, iceReporter, limits, location, fn, argsPack); + std::optional overloadToUse = + selectOverload(builtinTypes, arena, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location, fn, argsPack); if (!overloadToUse) return {SolveResult::NoMatchingOverload}; diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index b13a2327..f8347c72 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -440,11 +440,13 @@ Subtyping::Subtyping( NotNull builtinTypes, NotNull typeArena, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull iceReporter ) : builtinTypes(builtinTypes) , arena(typeArena) , normalizer(normalizer) + , typeFunctionRuntime(typeFunctionRuntime) , iceReporter(iceReporter) { } @@ -1911,7 +1913,7 @@ TypeId Subtyping::makeAggregateType(const Container& container, TypeId orElse) std::pair Subtyping::handleTypeFunctionReductionResult(const TypeFunctionInstanceType* functionInstance, NotNull scope) { - TypeFunctionContext context{arena, builtinTypes, scope, normalizer, iceReporter, NotNull{&limits}}; + TypeFunctionContext context{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, iceReporter, NotNull{&limits}}; TypeId function = arena->addType(*functionInstance); FunctionGraphReductionResult result = reduceTypeFunctions(function, {}, context, true); ErrorVec errors; diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index f0850835..66d037ed 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1040,6 +1040,7 @@ struct TypeStringifier state.emit(tfitv.userFuncName->value); else state.emit(tfitv.function->name); + state.emit("<"); bool comma = false; diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index ed66453d..3dc708a2 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -31,6 +31,7 @@ #include LUAU_FASTFLAG(DebugLuauMagicTypes) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctions) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) namespace Luau @@ -306,7 +307,7 @@ TypeChecker2::TypeChecker2( , sourceModule(sourceModule) , module(module) , normalizer{&module->internalTypes, builtinTypes, unifierState, /* cacheInhabitance */ true} - , _subtyping{builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, NotNull{unifierState->iceHandler}} + , _subtyping{builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{unifierState->iceHandler}} , subtyping(&_subtyping) { } @@ -484,13 +485,16 @@ TypeId TypeChecker2::checkForTypeFunctionInhabitance(TypeId instance, Location l return instance; seenTypeFunctionInstances.insert(instance); - ErrorVec errors = reduceTypeFunctions( - instance, - location, - TypeFunctionContext{NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, - true - ) - .errors; + ErrorVec errors = + reduceTypeFunctions( + instance, + location, + TypeFunctionContext{ + NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, ice, limits + }, + true + ) + .errors; if (!isErrorSuppressing(location, instance)) reportErrors(std::move(errors)); return instance; @@ -1194,8 +1198,8 @@ void TypeChecker2::visit(AstStatTypeAlias* stat) void TypeChecker2::visit(AstStatTypeFunction* stat) { // TODO: add type checking for user-defined type functions - - reportError(TypeError{stat->location, GenericError{"This syntax is not supported"}}); + if (!FFlag::LuauUserDefinedTypeFunctions) + reportError(TypeError{stat->location, GenericError{"This syntax is not supported"}}); } void TypeChecker2::visit(AstTypeList types) @@ -1446,6 +1450,7 @@ void TypeChecker2::visitCall(AstExprCall* call) builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, + NotNull{&typeFunctionRuntime}, NotNull{stack.back()}, ice, limits, diff --git a/Analysis/src/TypeFunction.cpp b/Analysis/src/TypeFunction.cpp index 31154cc2..6d928faa 100644 --- a/Analysis/src/TypeFunction.cpp +++ b/Analysis/src/TypeFunction.cpp @@ -2,7 +2,9 @@ #include "Luau/TypeFunction.h" +#include "Luau/BytecodeBuilder.h" #include "Luau/Common.h" +#include "Luau/Compiler.h" #include "Luau/ConstraintSolver.h" #include "Luau/DenseHash.h" #include "Luau/Instantiation.h" @@ -12,17 +14,25 @@ #include "Luau/Set.h" #include "Luau/Simplify.h" #include "Luau/Subtyping.h" +#include "Luau/TimeTrace.h" #include "Luau/ToString.h" #include "Luau/TxnLog.h" #include "Luau/Type.h" #include "Luau/TypeFunctionReductionGuesser.h" +#include "Luau/TypeFunctionRuntime.h" +#include "Luau/TypeFunctionRuntimeBuilder.h" #include "Luau/TypeFwd.h" #include "Luau/TypeUtils.h" #include "Luau/Unifier2.h" #include "Luau/VecDeque.h" #include "Luau/VisitType.h" +#include "lua.h" +#include "lualib.h" + #include +#include +#include // used to control emitting CodeTooComplex warnings on type function reduction LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyGraphReductionMaximumSteps, 1'000'000); @@ -35,7 +45,8 @@ LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyApplicationCartesianProductLimit, 5'0 // when this value is set to a negative value, guessing will be totally disabled. LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyUseGuesserDepth, -1); -LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies, false); +LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies, false) +LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctions, false) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) @@ -166,7 +177,7 @@ struct TypeFunctionReducer return SkipTestResult::Okay; } - SkipTestResult testForSkippability(TypePackId ty) + SkipTestResult testForSkippability(TypePackId ty) const { ty = follow(ty); @@ -214,15 +225,18 @@ struct TypeFunctionReducer { irreducible.insert(subject); + if (reduction.error.has_value()) + result.errors.emplace_back(location, UserDefinedTypeFunctionError{*reduction.error}); + if (reduction.uninhabited || force) { if (FFlag::DebugLuauLogTypeFamilies) printf("%s is uninhabited\n", toString(subject, {true}).c_str()); if constexpr (std::is_same_v) - result.errors.push_back(TypeError{location, UninhabitedTypeFunction{subject}}); + result.errors.emplace_back(location, UninhabitedTypeFunction{subject}); else if constexpr (std::is_same_v) - result.errors.push_back(TypeError{location, UninhabitedTypePackFunction{subject}}); + result.errors.emplace_back(location, UninhabitedTypePackFunction{subject}); } else if (!reduction.uninhabited && !force) { @@ -243,7 +257,7 @@ struct TypeFunctionReducer } } - bool done() + bool done() const { return queuedTys.empty() && queuedTps.empty(); } @@ -422,7 +436,7 @@ static FunctionGraphReductionResult reduceFunctionsInternal( ++iterationCount; if (iterationCount > DFInt::LuauTypeFamilyGraphReductionMaximumSteps) { - reducer.result.errors.push_back(TypeError{location, CodeTooComplex{}}); + reducer.result.errors.emplace_back(location, CodeTooComplex{}); break; } } @@ -506,7 +520,7 @@ static std::optional> tryDistributeTypeFunct size_t cartesianProductSize = 1; const UnionType* firstUnion = nullptr; - size_t unionIndex; + size_t unionIndex = 0; std::vector arguments = typeParams; for (size_t i = 0; i < arguments.size(); ++i) @@ -572,6 +586,8 @@ static std::optional> tryDistributeTypeFunct return std::nullopt; } +using StateRef = std::unique_ptr; + TypeFunctionReductionResult userDefinedTypeFunction( TypeId instance, const std::vector& typeParams, @@ -585,9 +601,122 @@ TypeFunctionReductionResult userDefinedTypeFunction( return {std::nullopt, true, {}, {}}; } - // TODO: implementation of user-defined type functions goes here + for (auto typeParam : typeParams) + { + TypeId ty = follow(typeParam); - return {std::nullopt, true, {}, {}}; + // block if we need to + if (isPending(ty, ctx->solver)) + return {std::nullopt, false, {ty}, {}}; + } + + AstName name = *ctx->userFuncName; + AstExprFunction* function = *ctx->userFuncBody; + + // Construct ParseResult containing the type function + Allocator allocator; + AstNameTable names(allocator); + + AstExprGlobal globalName{Location{}, name}; + AstStatFunction typeFunction{Location{}, &globalName, function}; + AstStat* stmtArray[] = {&typeFunction}; + AstArray stmts{stmtArray, 1}; + AstStatBlock exec{Location{}, stmts}; + ParseResult parseResult{&exec, 1}; + + BytecodeBuilder builder; + try + { + compileOrThrow(builder, parseResult, names); + } + catch (CompileError& e) + { + std::string errMsg = format("'%s' type function failed to compile with error message: %s", name.value, e.what()); + return {std::nullopt, true, {}, {}, errMsg}; + } + + std::string bytecode = builder.getBytecode(); + + // Initialize Lua state + StateRef globalState(lua_newstate(typeFunctionAlloc, nullptr), lua_close); + lua_State* L = globalState.get(); + + lua_setthreaddata(L, ctx.get()); + + setTypeFunctionEnvironment(L); + + // Register type userdata + registerTypeUserData(L); + + luaL_sandbox(L); + luaL_sandboxthread(L); + + // Load bytecode into Luau state + if (auto error = checkResultForError(L, name.value, luau_load(L, name.value, bytecode.data(), bytecode.size(), 0))) + return {std::nullopt, true, {}, {}, error}; + + // Execute the loaded chunk to register the function in the global environment + if (auto error = checkResultForError(L, name.value, lua_pcall(L, 0, 0, 0))) + return {std::nullopt, true, {}, {}, error}; + + // Get type function from the global environment + lua_getglobal(L, name.value); + if (!lua_isfunction(L, -1)) + { + std::string errMsg = format("Could not find '%s' type function in the global scope", name.value); + + return {std::nullopt, true, {}, {}, errMsg}; + } + + // Push serialized arguments onto the stack + + // Since there aren't any new class types being created in type functions, there isn't a deserialization function + // class types. Instead, we can keep this map and return the mapping as the "deserialized value" + std::unique_ptr runtimeBuilder = std::make_unique(ctx); + for (auto typeParam : typeParams) + { + TypeId ty = follow(typeParam); + // This is checked at the top of the function, and should still be true. + LUAU_ASSERT(!isPending(ty, ctx->solver)); + + TypeFunctionTypeId serializedTy = serialize(ty, runtimeBuilder.get()); + // Check if there were any errors while serializing + if (runtimeBuilder->errors.size() != 0) + return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()}; + + allocTypeUserData(L, serializedTy->type); + } + + // Set up an interrupt handler for type functions to respect type checking limits and LSP cancellation requests. + lua_callbacks(L)->interrupt = [](lua_State* L, int gc) + { + auto ctx = static_cast(lua_getthreaddata(lua_mainthread(L))); + if (ctx->limits->finishTime && TimeTrace::getClock() > *ctx->limits->finishTime) + ctx->solver->throwTimeLimitError(); + + if (ctx->limits->cancellationToken && ctx->limits->cancellationToken->requested()) + ctx->solver->throwUserCancelError(); + }; + + if (auto error = checkResultForError(L, name.value, lua_resume(L, nullptr, int(typeParams.size())))) + return {std::nullopt, true, {}, {}, error}; + + // If the return value is not a type userdata, return with error message + if (!isTypeUserData(L, 1)) + return {std::nullopt, true, {}, {}, format("'%s' type function: returned a non-type value", name.value)}; + + TypeFunctionTypeId retTypeFunctionTypeId = getTypeUserData(L, 1); + + // No errors should be present here since we should've returned already if any were raised during serialization. + LUAU_ASSERT(runtimeBuilder->errors.size() == 0); + + TypeId retTypeId = deserialize(retTypeFunctionTypeId, runtimeBuilder.get()); + + // At least 1 error occured while deserializing + if (runtimeBuilder->errors.size() > 0) + return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()}; + + return {retTypeId, false, {}, {}}; } TypeFunctionReductionResult notTypeFunction( @@ -711,7 +840,7 @@ TypeFunctionReductionResult lenTypeFunction( if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) return {std::nullopt, true, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? return {std::nullopt, true, {}, {}}; @@ -808,7 +937,7 @@ TypeFunctionReductionResult unmTypeFunction( if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) return {std::nullopt, true, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? return {std::nullopt, true, {}, {}}; @@ -818,7 +947,20 @@ TypeFunctionReductionResult unmTypeFunction( return {std::nullopt, true, {}, {}}; } -NotNull TypeFunctionContext::pushConstraint(ConstraintV&& c) +TypeFunctionContext::TypeFunctionContext(NotNull cs, NotNull scope, NotNull constraint) + : arena(cs->arena) + , builtins(cs->builtinTypes) + , scope(scope) + , normalizer(cs->normalizer) + , typeFunctionRuntime(cs->typeFunctionRuntime) + , ice(NotNull{&cs->iceReporter}) + , limits(NotNull{&cs->limits}) + , solver(cs.get()) + , constraint(constraint.get()) +{ +} + +NotNull TypeFunctionContext::pushConstraint(ConstraintV&& c) const { LUAU_ASSERT(solver); NotNull newConstraint = solver->pushConstraint(scope, constraint ? constraint->location : Location{}, std::move(c)); @@ -921,12 +1063,16 @@ TypeFunctionReductionResult numericBinopTypeFunction( SolveResult solveResult; if (!reversed) - solveResult = solveFunctionCall(ctx->arena, ctx->builtins, ctx->normalizer, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack); + solveResult = solveFunctionCall( + ctx->arena, ctx->builtins, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack + ); else { TypePack* p = getMutable(argPack); std::swap(p->head.front(), p->head.back()); - solveResult = solveFunctionCall(ctx->arena, ctx->builtins, ctx->normalizer, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack); + solveResult = solveFunctionCall( + ctx->arena, ctx->builtins, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack + ); } if (!solveResult.typePackId.has_value()) @@ -1156,7 +1302,7 @@ TypeFunctionReductionResult concatTypeFunction( if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) return {std::nullopt, true, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? return {std::nullopt, true, {}, {}}; @@ -1410,7 +1556,7 @@ static TypeFunctionReductionResult comparisonTypeFunction( if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) return {std::nullopt, true, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? return {std::nullopt, true, {}, {}}; @@ -1554,7 +1700,7 @@ TypeFunctionReductionResult eqTypeFunction( if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) return {std::nullopt, true, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? return {std::nullopt, true, {}, {}}; @@ -2004,7 +2150,7 @@ TypeFunctionReductionResult keyofFunctionImpl( if (!computeKeysOf(*classesIter, localKeys, seen, isRaw, ctx)) continue; - for (auto key : keys) + for (auto& key : keys) { // remove any keys that are not present in each class if (!localKeys.contains(key)) @@ -2039,7 +2185,7 @@ TypeFunctionReductionResult keyofFunctionImpl( if (!computeKeysOf(*tablesIter, localKeys, seen, isRaw, ctx)) continue; - for (auto key : keys) + for (auto& key : keys) { // remove any keys that are not present in each table if (!localKeys.contains(key)) @@ -2239,7 +2385,7 @@ TypeFunctionReductionResult indexFunctionImpl( return {std::nullopt, true, {}, {}}; // indexer can be a union —> break them down into a vector - const std::vector* typesToFind; + const std::vector* typesToFind = nullptr; const std::vector singleType{indexerTy}; if (auto unionTy = get(indexerTy)) typesToFind = &unionTy->options; diff --git a/Analysis/src/TypeFunctionReductionGuesser.cpp b/Analysis/src/TypeFunctionReductionGuesser.cpp index d4a7c7c0..389a797d 100644 --- a/Analysis/src/TypeFunctionReductionGuesser.cpp +++ b/Analysis/src/TypeFunctionReductionGuesser.cpp @@ -3,6 +3,7 @@ #include "Luau/DenseHash.h" #include "Luau/Normalize.h" +#include "Luau/ToString.h" #include "Luau/TypeFunction.h" #include "Luau/Type.h" #include "Luau/TypePack.h" diff --git a/Analysis/src/TypeFunctionRuntime.cpp b/Analysis/src/TypeFunctionRuntime.cpp new file mode 100644 index 00000000..d3a33d07 --- /dev/null +++ b/Analysis/src/TypeFunctionRuntime.cpp @@ -0,0 +1,2192 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypeFunctionRuntime.h" + +#include "Luau/DenseHash.h" +#include "Luau/StringUtils.h" +#include "Luau/TypeFunction.h" + +#include "lua.h" +#include "lualib.h" + +#include +#include +#include + +// defined in TypeFunctionRuntimeBuilder.cpp +LUAU_DYNAMIC_FASTINT(LuauTypeFunctionSerdeIterationLimit); + +namespace Luau +{ + +constexpr int kTypeUserdataTag = 42; + +void* typeFunctionAlloc(void* ud, void* ptr, size_t osize, size_t nsize) +{ + if (nsize == 0) + { + ::operator delete(ptr); + return nullptr; + } + else if (osize == 0) + { + return ::operator new(nsize); + } + else + { + void* data = ::operator new(nsize); + memcpy(data, ptr, nsize < osize ? nsize : osize); + + ::operator delete(ptr); + + return data; + } +} + +std::optional checkResultForError(lua_State* L, const char* typeFunctionName, int luaResult) +{ + switch (luaResult) + { + case LUA_OK: + return std::nullopt; + case LUA_YIELD: + case LUA_BREAK: + return format("'%s' type function errored: unexpected yield or break", typeFunctionName); + default: + if (!lua_gettop(L)) + return format("'%s' type function errored unexpectedly", typeFunctionName); + + if (lua_isstring(L, -1)) + return format("'%s' type function errored at runtime: %s", typeFunctionName, lua_tostring(L, -1)); + + return format("'%s' type function errored at runtime: raised an error of type %s", typeFunctionName, lua_typename(L, -1)); + } +} + +static const TypeFunctionContext* getTypeFunctionContext(lua_State* L) +{ + return static_cast(lua_getthreaddata(lua_mainthread(L))); +} + +TypeFunctionType* allocateTypeFunctionType(lua_State* L, TypeFunctionTypeVariant type) +{ + auto ctx = getTypeFunctionContext(L); + return ctx->typeFunctionRuntime->typeArena.allocate(std::move(type)); +} + +TypeFunctionTypePackVar* allocateTypeFunctionTypePack(lua_State* L, TypeFunctionTypePackVariant type) +{ + auto ctx = getTypeFunctionContext(L); + return ctx->typeFunctionRuntime->typePackArena.allocate(std::move(type)); +} + +// Pushes a new type userdata onto the stack +void allocTypeUserData(lua_State* L, TypeFunctionTypeVariant type) +{ + // allocate a new type userdata + TypeFunctionTypeId* ptr = static_cast(lua_newuserdatatagged(L, sizeof(TypeFunctionTypeId), kTypeUserdataTag)); + *ptr = allocateTypeFunctionType(L, std::move(type)); + + // set the new userdata's metatable to type metatable + luaL_getmetatable(L, "type"); + lua_setmetatable(L, -2); +} + +void deallocTypeUserData(lua_State* L, void* data) +{ + // only non-owning pointers into an arena is stored +} + +bool isTypeUserData(lua_State* L, int idx) +{ + if (!lua_isuserdata(L, idx)) + return false; + + return lua_touserdatatagged(L, idx, kTypeUserdataTag) != nullptr; +} + +TypeFunctionTypeId getTypeUserData(lua_State* L, int idx) +{ + if (auto typ = static_cast(lua_touserdatatagged(L, idx, kTypeUserdataTag))) + return *typ; + + luaL_typeerrorL(L, idx, "type"); +} + +std::optional optionalTypeUserData(lua_State* L, int idx) +{ + if (lua_isnoneornil(L, idx)) + return std::nullopt; + else + return getTypeUserData(L, idx); +} + +// returns a string tag of TypeFunctionTypeId +static std::string getTag(lua_State* L, TypeFunctionTypeId ty) +{ + if (auto n = get(ty); n && n->type == TypeFunctionPrimitiveType::Type::NilType) + return "nil"; + else if (auto b = get(ty); b && b->type == TypeFunctionPrimitiveType::Type::Boolean) + return "boolean"; + else if (auto n = get(ty); n && n->type == TypeFunctionPrimitiveType::Type::Number) + return "number"; + else if (auto s = get(ty); s && s->type == TypeFunctionPrimitiveType::Type::String) + return "string"; + else if (get(ty)) + return "unknown"; + else if (get(ty)) + return "never"; + else if (get(ty)) + return "any"; + else if (auto s = get(ty)) + return "singleton"; + else if (get(ty)) + return "negation"; + else if (get(ty)) + return "union"; + else if (get(ty)) + return "intersection"; + else if (get(ty)) + return "table"; + else if (get(ty)) + return "function"; + else if (get(ty)) + return "class"; + + LUAU_UNREACHABLE(); + luaL_error(L, "VM encountered unexpected type variant when determining tag"); +} + +// Luau: `type.unknown` +// Returns the type instance representing unknown +static int createUnknown(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionUnknownType{}); + + return 1; +} + +// Luau: `type.never` +// Returns the type instance representing never +static int createNever(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionNeverType{}); + + return 1; +} + +// Luau: `type.any` +// Returns the type instance representing any +static int createAny(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionAnyType{}); + + return 1; +} + +// Luau: `type.boolean` +// Returns the type instance representing boolean +static int createBoolean(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionPrimitiveType{TypeFunctionPrimitiveType::Boolean}); + + return 1; +} + +// Luau: `type.number` +// Returns the type instance representing number +static int createNumber(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionPrimitiveType{TypeFunctionPrimitiveType::Number}); + + return 1; +} + +// Luau: `type.string` +// Returns the type instance representing string +static int createString(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionPrimitiveType{TypeFunctionPrimitiveType::String}); + + return 1; +} + +// Luau: `type.singleton(value: string | boolean | nil) -> type` +// Returns the type instance representing string or boolean singleton or nil +static int createSingleton(lua_State* L) +{ + if (lua_isboolean(L, 1)) // Create boolean singleton + { + bool value = luaL_checkboolean(L, 1); + allocTypeUserData(L, TypeFunctionSingletonType{TypeFunctionBooleanSingleton{value}}); + + return 1; + } + + // n.b. we cannot use lua_isstring here because lua committed the cardinal sin of calling a number a string + if (lua_type(L, 1) == LUA_TSTRING) // Create string singleton + { + const char* value = luaL_checkstring(L, 1); + allocTypeUserData(L, TypeFunctionSingletonType{TypeFunctionStringSingleton{value}}); + + return 1; + } + + if (lua_isnil(L, 1)) + { + allocTypeUserData(L, TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::NilType)); + + return 1; + } + + luaL_error(L, "types.singleton: can't create singleton from `%s` type", lua_typename(L, 1)); +} + +// Luau: `self:value() -> type` +// Returns the value of a singleton +static int getSingletonValue(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.value: expected 1 argument, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tfpt = get(self)) + { + if (tfpt->type != TypeFunctionPrimitiveType::NilType) + luaL_error(L, "type.value: expected self to be a singleton, but got %s instead", getTag(L, self).c_str()); + + lua_pushnil(L); + return 1; + } + + auto tfst = get(self); + if (!tfst) + luaL_error(L, "type.value: expected self to be a singleton, but got %s instead", getTag(L, self).c_str()); + + if (auto tfbst = get(tfst)) + { + lua_pushboolean(L, tfbst->value); + return 1; + } + + if (auto tfsst = get(tfst)) + { + lua_pushlstring(L, tfsst->value.c_str(), tfsst->value.length()); + return 1; + } + + luaL_error(L, "type.value: can't call `value` method on `%s` type", getTag(L, self).c_str()); +} + +// Luau: `types.unionof(...: type) -> type` +// Returns the type instance representing union +static int createUnion(lua_State* L) +{ + // get the number of arguments for union + int argSize = lua_gettop(L); + if (argSize < 2) + luaL_error(L, "types.unionof: expected at least 2 types to union, but got %d", argSize); + + std::vector components; + components.reserve(argSize); + + for (int i = 1; i <= argSize; i++) + components.push_back(getTypeUserData(L, i)); + + allocTypeUserData(L, TypeFunctionUnionType{components}); + + return 1; +} + +// Luau: `types.intersectionof(...: type) -> type` +// Returns the type instance representing intersection +static int createIntersection(lua_State* L) +{ + // get the number of arguments for intersection + int argSize = lua_gettop(L); + if (argSize < 2) + luaL_error(L, "types.intersectionof: expected at least 2 types to intersection, but got %d", argSize); + + std::vector components; + components.reserve(argSize); + + for (int i = 1; i <= argSize; i++) + components.push_back(getTypeUserData(L, i)); + + allocTypeUserData(L, TypeFunctionIntersectionType{components}); + + return 1; +} + +// Luau: `self:components() -> {type}` +// Returns the components of union or intersection +static int getComponents(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.components: expected 1 argument, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfut = get(self); + if (tfut) + { + int argSize = int(tfut->components.size()); + + lua_createtable(L, argSize, 0); + for (int i = 0; i < argSize; i++) + { + TypeFunctionTypeId component = tfut->components[i]; + allocTypeUserData(L, component->type); + lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed + } + + return 1; + } + + auto tfit = get(self); + if (tfit) + { + int argSize = int(tfit->components.size()); + + lua_createtable(L, argSize, 0); + for (int i = 0; i < argSize; i++) + { + TypeFunctionTypeId component = tfit->components[i]; + allocTypeUserData(L, component->type); + lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed + } + + return 1; + } + + luaL_error(L, "type.components: cannot call components of `%s` type", getTag(L, self).c_str()); +} + +// Luau: `types.negationof(arg: type) -> type` +// Returns the type instance representing negation +static int createNegation(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "types.negationof: expected 1 argument, but got %d", argumentCount); + + TypeFunctionTypeId arg = getTypeUserData(L, 1); + if (get(arg) || get(arg)) + luaL_error(L, "types.negationof: cannot perform negation on `%s` type", getTag(L, arg).c_str()); + + allocTypeUserData(L, TypeFunctionNegationType{arg}); + + return 1; +} + +// Luau: `self:inner() -> type` +// Returns the type instance being negated +static int getNegatedValue(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.inner: expected 1 argument, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tfnt = get(self); !tfnt) + allocTypeUserData(L, tfnt->type->type); + else + luaL_error(L, "type.inner: cannot call inner method on non-negation type: `%s` type", getTag(L, self).c_str()); + + return 1; +} + +// Luau: `types.newtable(props: {[type]: type | { read: type, write: type }}?, indexer: {index: type, readresult: type, writeresult: type}?, +// metatable: type?) -> type` Returns the type instance representing table +static int createTable(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount > 3) + luaL_error(L, "types.newtable: expected 0-3 arguments, but got %d", argumentCount); + + // Parse prop + TypeFunctionTableType::Props props{}; + if (lua_istable(L, 1)) + { + lua_pushnil(L); + while (lua_next(L, 1) != 0) + { + TypeFunctionTypeId key = getTypeUserData(L, -2); + + auto tfst = get(key); + if (!tfst) + luaL_error(L, "types.newtable: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "types.newtable: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + if (lua_istable(L, -1)) + { + lua_getfield(L, -1, "read"); + std::optional readTy; + if (!lua_isnil(L, -1)) + readTy = getTypeUserData(L, -1); + lua_pop(L, 1); + + lua_getfield(L, -1, "write"); + std::optional writeTy; + if (!lua_isnil(L, -1)) + writeTy = getTypeUserData(L, -1); + lua_pop(L, 1); + + props[tfsst->value] = TypeFunctionProperty{readTy, writeTy}; + } + else + { + TypeFunctionTypeId value = getTypeUserData(L, -1); + props[tfsst->value] = TypeFunctionProperty::rw(value); + } + + lua_pop(L, 1); + } + } + else if (!lua_isnoneornil(L, 1)) + luaL_typeerrorL(L, 1, "table"); + + // Parse indexer + std::optional indexer; + if (lua_istable(L, 2)) + { + // Parse keyType and valueType + lua_getfield(L, 2, "index"); + TypeFunctionTypeId keyType = getTypeUserData(L, -1); + lua_pop(L, 1); + + lua_getfield(L, 2, "readresult"); + TypeFunctionTypeId valueType = getTypeUserData(L, -1); + lua_pop(L, 1); + + indexer = TypeFunctionTableIndexer(keyType, valueType); + } + else if (!lua_isnoneornil(L, 2)) + luaL_typeerrorL(L, 2, "table"); + + // Parse metatable + std::optional metatable = optionalTypeUserData(L, 3); + if (metatable && !get(*metatable)) + luaL_error(L, "types.newtable: expected to be given a table type as a metatable, but got %s instead", getTag(L, *metatable).c_str()); + + allocTypeUserData(L, TypeFunctionTableType{props, indexer, metatable}); + return 1; +} + +// Luau: `self:setproperty(key: type, value: type?)` +// Sets the properties of a table +static int setTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount < 2 || argumentCount > 3) + luaL_error(L, "type.setproperty: expected 2-3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setproperty: expected self to be a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.setproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.setproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + if (argumentCount == 2 || lua_isnil(L, 3)) + { + tftt->props.erase(tfsst->value); + return 0; + } + + TypeFunctionTypeId value = getTypeUserData(L, 3); + tftt->props[tfsst->value] = TypeFunctionProperty::rw(value, value); + + return 0; +} + +// Luau: `self:setreadproperty(key: type, value: type?)` +// Sets the properties of a table +static int setReadTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount < 2 || argumentCount > 3) + luaL_error(L, "type.setreadproperty: expected 2-3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setreadproperty: expected self to be a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.setreadproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.setreadproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + auto iter = tftt->props.find(tfsst->value); + + if (argumentCount == 2 || lua_isnil(L, 3)) + { + // if it's read-only, remove it altogether + if (iter != tftt->props.end() && iter->second.isReadOnly()) + tftt->props.erase(tfsst->value); + // but if it's not, just null out the read type. + else if (iter != tftt->props.end()) + iter->second.readTy = std::nullopt; + + return 0; + } + + TypeFunctionTypeId value = getTypeUserData(L, 3); + if (iter == tftt->props.end()) + tftt->props[tfsst->value] = TypeFunctionProperty::readonly(value); + else + iter->second.readTy = value; + + return 0; +} + +// Luau: `self:setwriteproperty(key: type, value: type?)` +// Sets the properties of a table +static int setWriteTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount < 2 || argumentCount > 3) + luaL_error(L, "type.setwriteproperty: expected 2-3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setwriteproperty: expected self to be a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.setwriteproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.setwriteproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + auto iter = tftt->props.find(tfsst->value); + + if (argumentCount == 2 || lua_isnil(L, 3)) + { + // if it's write-only, remove it altogether + if (iter != tftt->props.end() && iter->second.isWriteOnly()) + tftt->props.erase(tfsst->value); + // but if it's not, just null out the write type. + else if (iter != tftt->props.end()) + iter->second.writeTy = std::nullopt; + + return 0; + } + + TypeFunctionTypeId value = getTypeUserData(L, 3); + if (iter == tftt->props.end()) + tftt->props[tfsst->value] = TypeFunctionProperty::writeonly(value); + else + iter->second.writeTy = value; + + return 0; +} + +// Luau: `self:readproperty(key: type) -> type` +// Returns the property of a table associated with the key +static int readTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "type.readproperty: expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = get(self); + if (!tftt) + luaL_error(L, "type.readproperty: expected self to be either a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.readproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.readproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + // Check if key is a valid prop + if (tftt->props.find(tfsst->value) == tftt->props.end()) + { + lua_pushnil(L); + return 1; + } + + auto prop = tftt->props.at(tfsst->value); + if (prop.readTy) + allocTypeUserData(L, (*prop.readTy)->type); + else + luaL_error(L, "type.readproperty: property %s is write-only, and therefore does not have a read type.", tfsst->value.c_str()); + + return 1; +} +// +// Luau: `self:writeproperty(key: type) -> type` +// Returns the property of a table associated with the key +static int writeTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "type.writeproperty: expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = get(self); + if (!tftt) + luaL_error(L, "type.writeproperty: expected self to be either a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.writeproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.writeproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + // Check if key is a valid prop + if (tftt->props.find(tfsst->value) == tftt->props.end()) + { + lua_pushnil(L); + return 1; + } + + auto prop = tftt->props.at(tfsst->value); + if (prop.writeTy) + allocTypeUserData(L, (*prop.writeTy)->type); + else + luaL_error(L, "type.writeproperty: property %s is read-only, and therefore does not have a write type.", tfsst->value.c_str()); + + return 1; +} + +// Luau: `self:setindexer(key: type, value: type)` +// Sets the indexer of the table +static int setTableIndexer(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 3) + luaL_error(L, "type.setindexer: expected 3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setindexer: expected self to be either a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + TypeFunctionTypeId value = getTypeUserData(L, 3); + + tftt->indexer = TypeFunctionTableIndexer{key, value}; + + return 0; +} + +// Luau: `self:setreadindexer(key: type, value: type)` +// Sets the read indexer of the table +static int setTableReadIndexer(lua_State* L) +{ + luaL_error(L, "type.setreadindexer: luau does not yet support separate read/write types for indexers."); +} + +// Luau: `self:setwriteindexer(key: type, value: type)` +// Sets the write indexer of the table +static int setTableWriteIndexer(lua_State* L) +{ + luaL_error(L, "type.setwriteindexer: luau does not yet support separate read/write types for indexers."); +} + +// Luau: `self:setmetatable(arg: type)` +// Sets the metatable of the table +static int setTableMetatable(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "type.setmetatable: expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setmetatable: expected self to be a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId arg = getTypeUserData(L, 2); + if (!get(arg)) + luaL_error(L, "type.setmetatable: expected the argument to be a table, but got %s instead", getTag(L, self).c_str()); + + tftt->metatable = arg; + + return 0; +} + +// Luau: `types.newfunction(parameters: {head: {type}?, tail: type?}, returns: {head: {type}?, tail: type?}) -> type` +// Returns the type instance representing a function +static int createFunction(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount > 2) + luaL_error(L, "types.newfunction: expected 0-2 arguments, but got %d", argumentCount); + + TypeFunctionTypePackId argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{}); + if (lua_istable(L, 1)) + { + std::vector head{}; + lua_getfield(L, 1, "head"); + if (lua_istable(L, -1)) + { + int argSize = lua_objlen(L, -1); + for (int i = 1; i <= argSize; i++) + { + lua_pushinteger(L, i); + lua_gettable(L, -2); + + if (lua_isnil(L, -1)) + { + lua_pop(L, 1); + break; + } + + TypeFunctionTypeId ty = getTypeUserData(L, -1); + head.push_back(ty); + + lua_pop(L, 1); // Remove `ty` from stack + } + } + lua_pop(L, 1); // Pop the "head" field + + std::optional tail; + lua_getfield(L, 1, "tail"); + if (auto type = optionalTypeUserData(L, -1)) + tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); + lua_pop(L, 1); // Pop the "tail" field + + if (head.size() == 0 && tail.has_value()) + argTypes = *tail; + else + argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); + } + else if (!lua_isnoneornil(L, 1)) + luaL_typeerrorL(L, 1, "table"); + + TypeFunctionTypePackId retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{}); + if (lua_istable(L, 2)) + { + std::vector head{}; + lua_getfield(L, 2, "head"); + if (lua_istable(L, -1)) + { + int argSize = lua_objlen(L, -1); + for (int i = 1; i <= argSize; i++) + { + lua_pushinteger(L, i); + lua_gettable(L, -2); + + if (lua_isnil(L, -1)) + { + lua_pop(L, 1); + break; + } + + TypeFunctionTypeId ty = getTypeUserData(L, -1); + head.push_back(ty); + + lua_pop(L, 1); // Remove `ty` from stack + } + } + lua_pop(L, 1); // Pop the "head" field + + std::optional tail; + lua_getfield(L, 2, "tail"); + if (auto type = optionalTypeUserData(L, -1)) + tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); + lua_pop(L, 1); // Pop the "tail" field + + if (head.size() == 0 && tail.has_value()) + retTypes = *tail; + else + retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); + } + else if (!lua_isnoneornil(L, 2)) + luaL_typeerrorL(L, 2, "table"); + + allocTypeUserData(L, TypeFunctionFunctionType{argTypes, retTypes}); + + return 1; +} + +// Luau: `self:setparameters(head: {type}?, tail: type?)` +// Sets the parameters of the function +static int setFunctionParameters(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount > 3 || argumentCount < 1) + luaL_error(L, "type.setparameters: expected 1-3, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = getMutable(self); + if (!tfft) + luaL_error(L, "type.setparameters: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + std::vector head{}; + if (lua_istable(L, 2)) + { + int argSize = lua_objlen(L, 2); + for (int i = 1; i <= argSize; i++) + { + lua_pushinteger(L, i); + lua_gettable(L, 2); + + if (lua_isnil(L, -1)) + { + lua_pop(L, 1); + break; + } + + TypeFunctionTypeId ty = getTypeUserData(L, -1); + head.push_back(ty); + + lua_pop(L, 1); // Remove `ty` from stack + } + } + else if (!lua_isnoneornil(L, 2)) + luaL_typeerrorL(L, 2, "table"); + + std::optional tail; + if (auto type = optionalTypeUserData(L, 3)) + tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); + + if (head.size() == 0 && tail.has_value()) // Make argTypes a variadic type pack + tfft->argTypes = *tail; + else // Make argTypes a type pack + tfft->argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); + + return 0; +} + +// Luau: `self:parameters() -> {head: {type}?, tail: type?}` +// Returns the parameters of the function +static int getFunctionParameters(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.parameters: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = get(self); + if (!tfft) + luaL_error(L, "type.parameters: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + if (auto tftp = get(tfft->argTypes)) + { + int size = 0; + if (tftp->head.size() > 0) + size++; + if (tftp->tail.has_value()) + size++; + + lua_createtable(L, 0, size); + + int argSize = (int)tftp->head.size(); + if (argSize > 0) + { + lua_createtable(L, argSize, 0); + for (int i = 0; i < argSize; i++) + { + allocTypeUserData(L, tftp->head[i]->type); + lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed + } + lua_setfield(L, -2, "head"); + } + + if (tftp->tail.has_value()) + { + auto tfvp = get(*tftp->tail); + if (!tfvp) + LUAU_ASSERT(!"We should only be supporting variadic packs as TypeFunctionTypePack.tail at the moment"); + + allocTypeUserData(L, tfvp->type->type); + lua_setfield(L, -2, "tail"); + } + + return 1; + } + + if (auto tfvp = get(tfft->argTypes)) + { + lua_createtable(L, 0, 1); + + allocTypeUserData(L, tfvp->type->type); + lua_setfield(L, -2, "tail"); + + return 1; + } + + lua_createtable(L, 0, 0); + return 1; +} + +// Luau: `self:setreturns(head: {type}?, tail: type?)` +// Sets the returns of the function +static int setFunctionReturns(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount < 2 || argumentCount > 3) + luaL_error(L, "type.setreturns: expected 1-3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = getMutable(self); + if (!tfft) + luaL_error(L, "type.setreturns: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + std::vector head{}; + if (lua_istable(L, 2)) + { + int argSize = lua_objlen(L, 2); + for (int i = 1; i <= argSize; i++) + { + lua_pushinteger(L, i); + lua_gettable(L, 2); + + if (lua_isnil(L, -1)) + { + lua_pop(L, 1); + break; + } + + TypeFunctionTypeId ty = getTypeUserData(L, -1); + head.push_back(ty); + + lua_pop(L, 1); // Remove `ty` from stack + } + } + else if (!lua_isnoneornil(L, 2)) + luaL_typeerrorL(L, 2, "table"); + + std::optional tail; + if (auto type = optionalTypeUserData(L, 3)) + tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); + + if (head.size() == 0 && tail.has_value()) // Make retTypes a variadic type pack + tfft->retTypes = *tail; + else // Make retTypes a type pack + tfft->retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); + + return 0; +} + +// Luau: `self:returns() -> {head: {type}?, tail: type?}` +// Returns the returns of the function +static int getFunctionReturns(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.returns: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = get(self); + if (!tfft) + luaL_error(L, "type.returns: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + if (auto tftp = get(tfft->retTypes)) + { + int size = 0; + if (tftp->head.size() > 0) + size++; + if (tftp->tail.has_value()) + size++; + + lua_createtable(L, 0, size); + + int argSize = (int)tftp->head.size(); + if (argSize > 0) + { + lua_createtable(L, argSize, 0); + for (int i = 0; i < argSize; i++) + { + allocTypeUserData(L, tftp->head[i]->type); + lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed + } + lua_setfield(L, -2, "head"); + } + + if (tftp->tail.has_value()) + { + auto tfvp = get(*tftp->tail); + if (!tfvp) + LUAU_ASSERT(!"We should only be supporting variadic packs as TypeFunctionTypePack.tail at the moment"); + + allocTypeUserData(L, tfvp->type->type); + lua_setfield(L, -2, "tail"); + } + + return 1; + } + + if (auto tfvp = get(tfft->retTypes)) + { + lua_createtable(L, 0, 1); + + allocTypeUserData(L, tfvp->type->type); + lua_setfield(L, -2, "tail"); + + return 1; + } + + lua_createtable(L, 0, 0); + return 1; +} + +// Luau: `self:parent() -> type` +// Returns the parent of a class type +static int getClassParent(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.parent: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfct = get(self); + if (!tfct) + luaL_error(L, "type.parent: expected self to be a class, but got %s instead", getTag(L, self).c_str()); + + // If the parent does not exist, we should return nil + if (!tfct->parent) + lua_pushnil(L); + else + allocTypeUserData(L, (*tfct->parent)->type); + + return 1; +} + +// Luau: `self:properties() -> {[type]: { read: type?, write: type? }}` +// Returns the properties of a table or class type +static int getProps(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.properties: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tftt = get(self)) + { + lua_createtable(L, int(tftt->props.size()), 0); + for (auto& [name, prop] : tftt->props) + { + allocTypeUserData(L, TypeFunctionSingletonType{TypeFunctionStringSingleton{name}}); + + int size = 0; + if (prop.readTy) + size++; + if (prop.writeTy) + size++; + + lua_createtable(L, 0, size); + if (prop.readTy) + { + allocTypeUserData(L, (*prop.readTy)->type); + lua_setfield(L, -2, "read"); + } + + if (prop.writeTy) + { + allocTypeUserData(L, (*prop.writeTy)->type); + lua_setfield(L, -2, "write"); + } + + lua_settable(L, -3); + } + + return 1; + } + + if (auto tfct = get(self)) + { + lua_createtable(L, int(tfct->props.size()), 0); + for (auto& [name, prop] : tfct->props) + { + allocTypeUserData(L, TypeFunctionSingletonType{TypeFunctionStringSingleton{name}}); + + int size = 0; + if (prop.readTy) + size++; + if (prop.writeTy) + size++; + + lua_createtable(L, 0, size); + if (prop.readTy) + { + allocTypeUserData(L, (*prop.readTy)->type); + lua_setfield(L, -2, "read"); + } + + if (prop.writeTy) + { + allocTypeUserData(L, (*prop.writeTy)->type); + lua_setfield(L, -2, "write"); + } + + lua_settable(L, -3); + } + + return 1; + } + + luaL_error(L, "type.properties: expected self to be either a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:indexer() -> {index: type, readresult: type, writeresult: type}?` +// Returns the indexer of a table or class type +static int getIndexer(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.indexer: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tftt = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tftt->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 3); + allocTypeUserData(L, tftt->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tftt->indexer->valueType->type); + lua_setfield(L, -2, "readresult"); + allocTypeUserData(L, tftt->indexer->valueType->type); + lua_setfield(L, -2, "writeresult"); + } + + return 1; + } + + if (auto tfct = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tfct->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 3); + allocTypeUserData(L, tfct->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tfct->indexer->valueType->type); + lua_setfield(L, -2, "readresult"); + allocTypeUserData(L, tfct->indexer->valueType->type); + lua_setfield(L, -2, "writeresult"); + } + + return 1; + } + + luaL_error(L, "type.indexer: self to be either a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:readindexer() -> {index: type, result: type}?` +// Returns the read indexer of a table or class type +static int getReadIndexer(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.readindexer: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tftt = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tftt->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 2); + allocTypeUserData(L, tftt->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tftt->indexer->valueType->type); + lua_setfield(L, -2, "result"); + } + + return 1; + } + + if (auto tfct = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tfct->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 2); + allocTypeUserData(L, tfct->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tfct->indexer->valueType->type); + lua_setfield(L, -2, "result"); + } + + return 1; + } + + luaL_error(L, "type.readindexer: expected self to be either a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:writeindexer() -> {index: type, result: type}?` +// Returns the write indexer of a table or class type +static int getWriteIndexer(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.writeindexer: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tftt = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tftt->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 2); + allocTypeUserData(L, tftt->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tftt->indexer->valueType->type); + lua_setfield(L, -2, "result"); + } + + return 1; + } + + if (auto tfct = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tfct->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 2); + allocTypeUserData(L, tfct->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tfct->indexer->valueType->type); + lua_setfield(L, -2, "result"); + } + + return 1; + } + + luaL_error(L, "type.writeindexer: expected self to be either a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:metatable() -> type?` +// Returns the metatable of a table or class type +static int getMetatable(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.metatable: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tfmt = get(self)) + { + // if the metatable does not exist, we should return nil + if (!tfmt->metatable.has_value()) + lua_pushnil(L); + else + allocTypeUserData(L, (*tfmt->metatable)->type); + + return 1; + } + + if (auto tfct = get(self)) + { + // if the metatable does not exist, we should return nil + if (!tfct->metatable.has_value()) + lua_pushnil(L); + else + allocTypeUserData(L, (*tfct->metatable)->type); + + return 1; + } + + luaL_error(L, "type.metatable: expected self to be a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:is(arg: string) -> boolean` +// Returns true if given argument is a tag of self +static int checkTag(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "type.is: expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + std::string arg = luaL_checkstring(L, 2); + + lua_pushboolean(L, getTag(L, self) == arg); + return 1; +} + +TypeFunctionTypeId deepClone(NotNull runtime, TypeFunctionTypeId ty); // Forward declaration + +// Luau: `types.copy(arg: string) -> type` +// Returns a deep copy of the argument +static int deepCopy(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "types.copy: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId arg = getTypeUserData(L, 1); + + TypeFunctionTypeId copy = deepClone(getTypeFunctionContext(L)->typeFunctionRuntime, arg); + allocTypeUserData(L, copy->type); + return 1; +} + +// Luau: `self == arg -> boolean` +// Used to set the __eq metamethod +static int isEqualToType(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + TypeFunctionTypeId arg = getTypeUserData(L, 2); + + lua_pushboolean(L, *self == *arg); + return 1; +} + +// Register the type userdata +void registerTypeUserData(lua_State* L) +{ + // List of fields for type userdata + luaL_Reg typeUserdataFields[] = { + {"unknown", createUnknown}, + {"never", createNever}, + {"any", createAny}, + {"boolean", createBoolean}, + {"number", createNumber}, + {"string", createString}, + {nullptr, nullptr} + }; + + // List of methods for type userdata + luaL_Reg typeUserdataMethods[] = { + {"singleton", createSingleton}, + {"negationof", createNegation}, + {"unionof", createUnion}, + {"intersectionof", createIntersection}, + {"newtable", createTable}, + {"newfunction", createFunction}, + {"copy", deepCopy}, + + // Common methods + {"is", checkTag}, + + // Negation type methods + {"inner", getNegatedValue}, + + // Singleton type methods + {"value", getSingletonValue}, + + // Table type methods + {"setproperty", setTableProp}, + {"setreadproperty", setReadTableProp}, + {"setwriteproperty", setWriteTableProp}, + {"readproperty", readTableProp}, + {"writeproperty", writeTableProp}, + {"properties", getProps}, + {"setindexer", setTableIndexer}, + {"setreadindexer", setTableReadIndexer}, + {"setwriteindexer", setTableWriteIndexer}, + {"indexer", getIndexer}, + {"readindexer", getReadIndexer}, + {"writeindexer", getWriteIndexer}, + {"setmetatable", setTableMetatable}, + {"metatable", getMetatable}, + + // Function type methods + {"setparameters", setFunctionParameters}, + {"parameters", getFunctionParameters}, + {"setreturns", setFunctionReturns}, + {"returns", getFunctionReturns}, + + // Union and Intersection type methods + {"components", getComponents}, + + // Class type methods + {"parent", getClassParent}, + {"indexer", getIndexer}, + {nullptr, nullptr} + }; + + // Create and register metatable for type userdata + luaL_newmetatable(L, "type"); + + // Protect metatable from being fetched. + lua_pushstring(L, "The metatable is locked"); + lua_setfield(L, -2, "__metatable"); + + // Set type userdata metatable's __eq to type_equals() + lua_pushcfunction(L, isEqualToType, "__eq"); + lua_setfield(L, -2, "__eq"); + + // Set type userdata metatable's __index to itself + lua_pushvalue(L, -1); // Push a copy of type userdata metatable + lua_setfield(L, -2, "__index"); + + luaL_register(L, nullptr, typeUserdataMethods); + + // Set fields for type userdata + for (luaL_Reg* l = typeUserdataFields; l->name; l++) + { + l->func(L); + lua_setfield(L, -2, l->name); + } + + // Set types library as a global name "types" + lua_setglobal(L, "types"); + + // Sets up a destructor for the type userdata. + lua_setuserdatadtor(L, kTypeUserdataTag, deallocTypeUserData); +} + +// Used to redirect all the removed global functions to say "this function is unsupported" +int unsupportedFunction(lua_State* L) +{ + luaL_errorL(L, "this function is not supported in type functions"); + return 0; +} + +// Add libraries / globals for type function environment +void setTypeFunctionEnvironment(lua_State* L) +{ + // Register math library + luaopen_math(L); + lua_pop(L, 1); + + // Register table library + luaopen_table(L); + lua_pop(L, 1); + + // Register string library + luaopen_string(L); + lua_pop(L, 1); + + // Register bit32 library + luaopen_bit32(L); + lua_pop(L, 1); + + // Register utf8 library + luaopen_utf8(L); + lua_pop(L, 1); + + // Register buffer library + luaopen_buffer(L); + lua_pop(L, 1); + + // Register base library + luaopen_base(L); + lua_pop(L, 1); + + // Remove certain global functions from the base library + static const std::string unavailableGlobals[] = {"gcinfo", "getfenv", "newproxy", "setfenv", "pcall", "xpcall"}; + for (auto& name : unavailableGlobals) + { + lua_pushcfunction(L, unsupportedFunction, "Removing global function from type function environment"); + lua_setglobal(L, name.c_str()); + } +} + +/* + * Below are helper methods for __eq + * Same as one from Type.cpp + */ +using SeenSet = std::set>; +bool areEqual(SeenSet& seen, const TypeFunctionType& lhs, const TypeFunctionType& rhs); +bool areEqual(SeenSet& seen, const TypeFunctionTypePackVar& lhs, const TypeFunctionTypePackVar& rhs); + +bool seenSetContains(SeenSet& seen, const void* lhs, const void* rhs) +{ + if (lhs == rhs) + return true; + + auto p = std::make_pair(lhs, rhs); + if (seen.find(p) != seen.end()) + return true; + + seen.insert(p); + return false; +} + +bool areEqual(SeenSet& seen, const TypeFunctionSingletonType& lhs, const TypeFunctionSingletonType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + { + const TypeFunctionBooleanSingleton* lp = get(&lhs); + const TypeFunctionBooleanSingleton* rp = get(&lhs); + if (lp && rp) + return lp->value == rp->value; + } + + { + const TypeFunctionStringSingleton* lp = get(&lhs); + const TypeFunctionStringSingleton* rp = get(&lhs); + if (lp && rp) + return lp->value == rp->value; + } + + return false; +} + +bool areEqual(SeenSet& seen, const TypeFunctionUnionType& lhs, const TypeFunctionUnionType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + if (lhs.components.size() != rhs.components.size()) + return false; + + auto l = lhs.components.begin(); + auto r = rhs.components.begin(); + + while (l != lhs.components.end()) + { + if (!areEqual(seen, **l, **r)) + return false; + ++l; + ++r; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionIntersectionType& lhs, const TypeFunctionIntersectionType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + if (lhs.components.size() != rhs.components.size()) + return false; + + auto l = lhs.components.begin(); + auto r = rhs.components.begin(); + + while (l != lhs.components.end()) + { + if (!areEqual(seen, **l, **r)) + return false; + ++l; + ++r; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionNegationType& lhs, const TypeFunctionNegationType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + return areEqual(seen, *lhs.type, *rhs.type); +} + +bool areEqual(SeenSet& seen, const TypeFunctionTableType& lhs, const TypeFunctionTableType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + if (lhs.props.size() != rhs.props.size()) + return false; + + if (bool(lhs.indexer) != bool(rhs.indexer)) + return false; + + if (lhs.indexer && rhs.indexer) + { + if (!areEqual(seen, *lhs.indexer->keyType, *rhs.indexer->keyType)) + return false; + + if (!areEqual(seen, *lhs.indexer->valueType, *rhs.indexer->valueType)) + return false; + } + + auto l = lhs.props.begin(); + auto r = rhs.props.begin(); + + while (l != lhs.props.end()) + { + if ((l->second.readTy && !r->second.readTy) || (!l->second.readTy && r->second.readTy)) + return false; + + if (l->second.readTy && r->second.readTy && !areEqual(seen, **(l->second.readTy), **(r->second.readTy))) + return false; + + if ((l->second.writeTy && !r->second.writeTy) || (!l->second.writeTy && r->second.writeTy)) + return false; + + if (l->second.writeTy && r->second.writeTy && !areEqual(seen, **(l->second.writeTy), **(r->second.writeTy))) + return false; + + ++l; + ++r; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionFunctionType& lhs, const TypeFunctionFunctionType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + if (bool(lhs.argTypes) != bool(rhs.argTypes)) + return false; + + if (lhs.argTypes && rhs.argTypes) + { + if (!areEqual(seen, *lhs.argTypes, *rhs.argTypes)) + return false; + } + + if (bool(lhs.retTypes) != bool(rhs.retTypes)) + return false; + + if (lhs.retTypes && rhs.retTypes) + { + if (!areEqual(seen, *lhs.retTypes, *rhs.retTypes)) + return false; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionClassType& lhs, const TypeFunctionClassType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + return lhs.name == rhs.name; +} + +bool areEqual(SeenSet& seen, const TypeFunctionType& lhs, const TypeFunctionType& rhs) +{ + + if (lhs.type.index() != rhs.type.index()) + return false; + + { + const TypeFunctionPrimitiveType* lp = get(&lhs); + const TypeFunctionPrimitiveType* rp = get(&rhs); + if (lp && rp) + return lp->type == rp->type; + } + + if (get(&lhs) && get(&rhs)) + return true; + + if (get(&lhs) && get(&rhs)) + return true; + + if (get(&lhs) && get(&rhs)) + return true; + + { + const TypeFunctionSingletonType* lf = get(&lhs); + const TypeFunctionSingletonType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionUnionType* lf = get(&lhs); + const TypeFunctionUnionType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionIntersectionType* lf = get(&lhs); + const TypeFunctionIntersectionType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionNegationType* lf = get(&lhs); + const TypeFunctionNegationType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionTableType* lt = get(&lhs); + const TypeFunctionTableType* rt = get(&rhs); + if (lt && rt) + return areEqual(seen, *lt, *rt); + } + + { + const TypeFunctionFunctionType* lf = get(&lhs); + const TypeFunctionFunctionType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionClassType* lf = get(&lhs); + const TypeFunctionClassType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + return false; +} + +bool areEqual(SeenSet& seen, const TypeFunctionTypePack& lhs, const TypeFunctionTypePack& rhs) +{ + if (lhs.head.size() != rhs.head.size()) + return false; + + auto l = lhs.head.begin(); + auto r = rhs.head.begin(); + + while (l != lhs.head.end()) + { + if (!areEqual(seen, **l, **r)) + return false; + ++l; + ++r; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionVariadicTypePack& lhs, const TypeFunctionVariadicTypePack& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + return areEqual(seen, *lhs.type, *rhs.type); +} + +bool areEqual(SeenSet& seen, const TypeFunctionTypePackVar& lhs, const TypeFunctionTypePackVar& rhs) +{ + { + const TypeFunctionTypePack* lb = get(&lhs); + const TypeFunctionTypePack* rb = get(&rhs); + if (lb && rb) + return areEqual(seen, *lb, *rb); + } + + { + const TypeFunctionVariadicTypePack* lv = get(&lhs); + const TypeFunctionVariadicTypePack* rv = get(&rhs); + if (lv && rv) + return areEqual(seen, *lv, *rv); + } + + return false; +} + +bool TypeFunctionType::operator==(const TypeFunctionType& rhs) const +{ + SeenSet seen; + return areEqual(seen, *this, rhs); +} + +bool TypeFunctionTypePackVar::operator==(const TypeFunctionTypePackVar& rhs) const +{ + SeenSet seen; + return areEqual(seen, *this, rhs); +} + + +TypeFunctionProperty TypeFunctionProperty::readonly(TypeFunctionTypeId ty) +{ + TypeFunctionProperty p; + p.readTy = ty; + return p; +} + +TypeFunctionProperty TypeFunctionProperty::writeonly(TypeFunctionTypeId ty) +{ + TypeFunctionProperty p; + p.writeTy = ty; + return p; +} + +TypeFunctionProperty TypeFunctionProperty::rw(TypeFunctionTypeId ty) +{ + return TypeFunctionProperty::rw(ty, ty); +} + +TypeFunctionProperty TypeFunctionProperty::rw(TypeFunctionTypeId read, TypeFunctionTypeId write) +{ + TypeFunctionProperty p; + p.readTy = read; + p.writeTy = write; + return p; +} + +bool TypeFunctionProperty::isReadOnly() const +{ + return readTy && !writeTy; +} + +bool TypeFunctionProperty::isWriteOnly() const +{ + return writeTy && !readTy; +} + +/* + * Below is a helper class for type.copy() + * Forked version of Clone.cpp + */ +using TypeFunctionKind = Variant; + +template +const T* get(const TypeFunctionKind& kind) +{ + return get_if(&kind); +} + +class TypeFunctionCloner +{ + using SeenTypes = DenseHashMap; + using SeenTypePacks = DenseHashMap; + + NotNull typeFunctionRuntime; + + // A queue of TypeFunctionTypeIds that have been cloned, but whose interior types hasn't + // been updated to point to itself. Once all of its interior types + // has been updated, it gets removed from the queue. + + // queue.back() should always return two of same type in their respective sides + // For example `auto [first, second] = queue.back()`: if first is TypeFunctionPrimitiveType, + // second must be TypeFunctionPrimitiveType; `second` is trying to copy `first` + std::vector> queue; + + SeenTypes types{{}}; // Mapping of TypeFunctionTypeIds that have been shallow cloned to TypeFunctionTypeIds + SeenTypePacks packs{{}}; // Mapping of TypeFunctionTypePackIds that have been shallow cloned to TypeFunctionTypePackIds + + int steps = 0; + +public: + explicit TypeFunctionCloner(TypeFunctionRuntime* typeFunctionRuntime) + : typeFunctionRuntime(typeFunctionRuntime) + { + } + + TypeFunctionTypeId clone(TypeFunctionTypeId ty) + { + shallowClone(ty); + run(); + + if (hasExceededIterationLimit()) + return nullptr; + + return find(ty).value_or(nullptr); + } + + TypeFunctionTypePackId clone(TypeFunctionTypePackId tp) + { + shallowClone(tp); + run(); + + if (hasExceededIterationLimit()) + return nullptr; + + return find(tp).value_or(nullptr); + } + +private: + bool hasExceededIterationLimit() const + { + return steps + queue.size() >= (size_t)DFInt::LuauTypeFunctionSerdeIterationLimit; + } + + void run() + { + while (!queue.empty()) + { + ++steps; + + if (hasExceededIterationLimit()) + break; + + auto [ty, tfti] = queue.back(); + queue.pop_back(); + + cloneChildren(ty, tfti); + } + } + + std::optional find(TypeFunctionTypeId ty) const + { + if (auto result = types.find(ty)) + return *result; + + return std::nullopt; + } + + std::optional find(TypeFunctionTypePackId tp) const + { + if (auto result = packs.find(tp)) + return *result; + + return std::nullopt; + } + + std::optional find(TypeFunctionKind kind) const + { + if (auto ty = get(kind)) + return find(*ty); + else if (auto tp = get(kind)) + return find(*tp); + else + { + LUAU_ASSERT(!"Unknown kind?"); + return std::nullopt; + } + } + + TypeFunctionTypeId shallowClone(TypeFunctionTypeId ty) + { + if (auto it = find(ty)) + return *it; + + // Create a shallow serialization + TypeFunctionTypeId target = {}; + if (auto p = get(ty)) + { + switch (p->type) + { + case TypeFunctionPrimitiveType::Type::NilType: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::NilType)); + break; + case TypeFunctionPrimitiveType::Type::Boolean: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Boolean)); + break; + case TypeFunctionPrimitiveType::Number: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Number)); + break; + case TypeFunctionPrimitiveType::String: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String)); + break; + default: + break; + } + } + else if (auto u = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnknownType{}); + else if (auto a = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNeverType{}); + else if (auto a = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionAnyType{}); + else if (auto s = get(ty)) + { + if (auto bs = get(s)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionBooleanSingleton{bs->value}}); + else if (auto ss = get(s)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionStringSingleton{ss->value}}); + } + else if (auto u = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnionType{{}}); + else if (auto i = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionIntersectionType{{}}); + else if (auto n = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNegationType{{}}); + else if (auto t = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{{}, std::nullopt, std::nullopt}); + else if (auto f = get(ty)) + { + TypeFunctionTypePackId emptyTypePack = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{}); + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionFunctionType{emptyTypePack, emptyTypePack}); + } + else if (auto c = get(ty)) + target = ty; // Don't copy a class since they are immutable + + types[ty] = target; + queue.emplace_back(ty, target); + return target; + } + + TypeFunctionTypePackId shallowClone(TypeFunctionTypePackId tp) + { + if (auto it = find(tp)) + return *it; + + // Create a shallow serialization + TypeFunctionTypePackId target = {}; + if (auto tPack = get(tp)) + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}}); + else if (auto vPack = get(tp)) + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{}); + + packs[tp] = target; + queue.emplace_back(tp, target); + return target; + } + + void cloneChildren(TypeFunctionTypeId ty, TypeFunctionTypeId tfti) + { + if (auto [p1, p2] = std::tuple{getMutable(ty), getMutable(tfti)}; p1 && p2) + cloneChildren(p1, p2); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + cloneChildren(u1, u2); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + cloneChildren(n1, n2); + else if (auto [a1, a2] = std::tuple{getMutable(ty), getMutable(tfti)}; a1 && a2) + cloneChildren(a1, a2); + else if (auto [s1, s2] = std::tuple{getMutable(ty), getMutable(tfti)}; s1 && s2) + cloneChildren(s1, s2); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + cloneChildren(u1, u2); + else if (auto [i1, i2] = std::tuple{getMutable(ty), getMutable(tfti)}; i1 && i2) + cloneChildren(i1, i2); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + cloneChildren(n1, n2); + else if (auto [t1, t2] = std::tuple{getMutable(ty), getMutable(tfti)}; t1 && t2) + cloneChildren(t1, t2); + else if (auto [f1, f2] = std::tuple{getMutable(ty), getMutable(tfti)}; f1 && f2) + cloneChildren(f1, f2); + else if (auto [c1, c2] = std::tuple{getMutable(ty), getMutable(tfti)}; c1 && c2) + cloneChildren(c1, c2); + else + LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types + } + + void cloneChildren(TypeFunctionTypePackId tp, TypeFunctionTypePackId tftp) + { + if (auto [tPack1, tPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; tPack1 && tPack2) + cloneChildren(tPack1, tPack2); + else if (auto [vPack1, vPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; + vPack1 && vPack2) + cloneChildren(vPack1, vPack2); + else + LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types + } + + void cloneChildren(TypeFunctionKind kind, TypeFunctionKind tfkind) + { + if (auto [ty, tfty] = std::tuple{get(kind), get(tfkind)}; ty && tfty) + cloneChildren(*ty, *tfty); + else if (auto [tp, tftp] = std::tuple{get(kind), get(tfkind)}; tp && tftp) + cloneChildren(*tp, *tftp); + else + LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types + } + + void cloneChildren(TypeFunctionPrimitiveType* p1, TypeFunctionPrimitiveType* p2) + { + // noop. + } + + void cloneChildren(TypeFunctionUnknownType* u1, TypeFunctionUnknownType* u2) + { + // noop. + } + + void cloneChildren(TypeFunctionNeverType* n1, TypeFunctionNeverType* n2) + { + // noop. + } + + void cloneChildren(TypeFunctionAnyType* a1, TypeFunctionAnyType* a2) + { + // noop. + } + + void cloneChildren(TypeFunctionSingletonType* s1, TypeFunctionSingletonType* s2) + { + // noop. + } + + void cloneChildren(TypeFunctionUnionType* u1, TypeFunctionUnionType* u2) + { + for (TypeFunctionTypeId& ty : u1->components) + u2->components.push_back(shallowClone(ty)); + } + + void cloneChildren(TypeFunctionIntersectionType* i1, TypeFunctionIntersectionType* i2) + { + for (TypeFunctionTypeId& ty : i1->components) + i2->components.push_back(shallowClone(ty)); + } + + void cloneChildren(TypeFunctionNegationType* n1, TypeFunctionNegationType* n2) + { + n2->type = shallowClone(n1->type); + } + + void cloneChildren(TypeFunctionTableType* t1, TypeFunctionTableType* t2) + { + for (auto& [k, p] : t1->props) + { + std::optional readTy; + if (p.readTy) + readTy = shallowClone(*p.readTy); + + std::optional writeTy; + if (p.writeTy) + writeTy = shallowClone(*p.writeTy); + + t2->props[k] = TypeFunctionProperty{readTy, writeTy}; + } + + if (t1->indexer.has_value()) + t2->indexer = TypeFunctionTableIndexer(shallowClone(t1->indexer->keyType), shallowClone(t1->indexer->valueType)); + + if (t1->metatable.has_value()) + t2->metatable = shallowClone(*t1->metatable); + } + + void cloneChildren(TypeFunctionFunctionType* f1, TypeFunctionFunctionType* f2) + { + f2->argTypes = shallowClone(f1->argTypes); + f2->retTypes = shallowClone(f1->retTypes); + } + + void cloneChildren(TypeFunctionClassType* c1, TypeFunctionClassType* c2) + { + // noop. + } + + void cloneChildren(TypeFunctionTypePack* t1, TypeFunctionTypePack* t2) + { + for (TypeFunctionTypeId& ty : t1->head) + t2->head.push_back(shallowClone(ty)); + } + + void cloneChildren(TypeFunctionVariadicTypePack* v1, TypeFunctionVariadicTypePack* v2) + { + v2->type = shallowClone(v1->type); + } +}; + +TypeFunctionTypeId deepClone(NotNull runtime, TypeFunctionTypeId ty) +{ + return TypeFunctionCloner(runtime).clone(ty); +} + +} // namespace Luau diff --git a/Analysis/src/TypeFunctionRuntimeBuilder.cpp b/Analysis/src/TypeFunctionRuntimeBuilder.cpp new file mode 100644 index 00000000..e14c3773 --- /dev/null +++ b/Analysis/src/TypeFunctionRuntimeBuilder.cpp @@ -0,0 +1,788 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypeFunctionRuntimeBuilder.h" + +#include "Luau/Ast.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" +#include "Luau/DenseHash.h" +#include "Luau/StringUtils.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypeFunctionRuntime.h" +#include "Luau/TypePack.h" +#include "Luau/ToString.h" + +#include + +// used to control the recursion limit of any operations done by user-defined type functions +// currently, controls serialization, deserialization, and `type.copy` +LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFunctionSerdeIterationLimit, 100'000); + +namespace Luau +{ + +// Forked version of Clone.cpp +class TypeFunctionSerializer +{ + using SeenTypes = DenseHashMap; + using SeenTypePacks = DenseHashMap; + + TypeFunctionRuntimeBuilderState* state = nullptr; + NotNull typeFunctionRuntime; + + // A queue of TypeFunctionTypeIds that have been serialized, but whose interior types hasn't + // been updated to point to itself. Once all of its interior types + // has been updated, it gets removed from the queue. + + // queue.back() should always return two of same type in their respective sides + // For example `auto [first, second] = queue.back()`: if first is PrimitiveType, + // second must be TypeFunctionPrimitiveType; else there should be an error + std::vector> queue; + + SeenTypes types; // Mapping of TypeIds that have been shallow serialized to TypeFunctionTypeIds + SeenTypePacks packs; // Mapping of TypePackIds that have been shallow serialized to TypeFunctionTypePackIds + + int steps = 0; + +public: + explicit TypeFunctionSerializer(TypeFunctionRuntimeBuilderState* state) + : state(state) + , typeFunctionRuntime(state->ctx->typeFunctionRuntime) + , queue({}) + , types({}) + , packs({}) + { + } + + TypeFunctionTypeId serialize(TypeId ty) + { + shallowSerialize(ty); + run(); + + if (hasExceededIterationLimit() || state->errors.size() != 0) + return nullptr; + + return find(ty).value_or(nullptr); + } + + TypeFunctionTypePackId serialize(TypePackId tp) + { + shallowSerialize(tp); + run(); + + if (hasExceededIterationLimit() || state->errors.size() != 0) + return nullptr; + + return find(tp).value_or(nullptr); + } + +private: + bool hasExceededIterationLimit() const + { + if (DFInt::LuauTypeFunctionSerdeIterationLimit == 0) + return false; + + return steps + queue.size() >= size_t(DFInt::LuauTypeFunctionSerdeIterationLimit); + } + + void run() + { + while (!queue.empty()) + { + ++steps; + + if (hasExceededIterationLimit() || state->errors.size() != 0) + break; + + auto [ty, tfti] = queue.back(); + queue.pop_back(); + + serializeChildren(ty, tfti); + } + } + + std::optional find(TypeId ty) const + { + if (auto result = types.find(ty)) + return *result; + + return std::nullopt; + } + + std::optional find(TypePackId tp) const + { + if (auto result = packs.find(tp)) + return *result; + + return std::nullopt; + } + + std::optional find(Kind kind) const + { + if (auto ty = get(kind)) + return find(*ty); + else if (auto tp = get(kind)) + return find(*tp); + else + { + LUAU_ASSERT(!"Unknown kind found at TypeFunctionRuntimeSerializer"); + return std::nullopt; + } + } + + TypeFunctionTypeId shallowSerialize(TypeId ty) + { + ty = follow(ty); + + if (auto it = find(ty)) + return *it; + + // Create a shallow serialization + TypeFunctionTypeId target = {}; + if (auto p = get(ty)) + { + switch (p->type) + { + case PrimitiveType::Type::NilType: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::NilType)); + break; + case PrimitiveType::Type::Boolean: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Boolean)); + break; + case PrimitiveType::Number: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Number)); + break; + case PrimitiveType::String: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String)); + break; + case PrimitiveType::Thread: + case PrimitiveType::Function: + case PrimitiveType::Table: + case PrimitiveType::Buffer: + default: + { + std::string error = format("Argument of primitive type %s is not currently serializable by type functions", toString(ty).c_str()); + state->errors.push_back(error); + } + } + } + else if (auto u = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnknownType{}); + else if (auto a = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNeverType{}); + else if (auto a = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionAnyType{}); + else if (auto s = get(ty)) + { + if (auto bs = get(s)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionBooleanSingleton{bs->value}}); + else if (auto ss = get(s)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionStringSingleton{ss->value}}); + else + { + std::string error = format("Argument of singleton type %s is not currently serializable by type functions", toString(ty).c_str()); + state->errors.push_back(error); + } + } + else if (auto u = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnionType{{}}); + else if (auto i = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionIntersectionType{{}}); + else if (auto n = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNegationType{{}}); + else if (auto t = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{{}, std::nullopt, std::nullopt}); + else if (auto m = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{{}, std::nullopt, std::nullopt}); + else if (auto f = get(ty)) + { + TypeFunctionTypePackId emptyTypePack = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{}); + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionFunctionType{emptyTypePack, emptyTypePack}); + } + else if (auto c = get(ty)) + { + state->classesSerialized[c->name] = ty; + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionClassType{{}, std::nullopt, std::nullopt, std::nullopt, c->name}); + } + else + { + std::string error = format("Argument of type %s is not currently serializable by type functions", toString(ty).c_str()); + state->errors.push_back(error); + } + + types[ty] = target; + queue.emplace_back(ty, target); + return target; + } + + TypeFunctionTypePackId shallowSerialize(TypePackId tp) + { + tp = follow(tp); + + if (auto it = find(tp)) + return *it; + + // Create a shallow serialization + TypeFunctionTypePackId target = {}; + if (auto tPack = get(tp)) + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}}); + else if (auto vPack = get(tp)) + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{}); + else + { + std::string error = format("Argument of type pack %s is not currently serializable by type functions", toString(tp).c_str()); + state->errors.push_back(error); + } + + packs[tp] = target; + queue.emplace_back(tp, target); + return target; + } + + void serializeChildren(TypeId ty, TypeFunctionTypeId tfti) + { + if (auto [p1, p2] = std::tuple{getMutable(ty), getMutable(tfti)}; p1 && p2) + serializeChildren(p1, p2); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + serializeChildren(u1, u2); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + serializeChildren(n1, n2); + else if (auto [a1, a2] = std::tuple{getMutable(ty), getMutable(tfti)}; a1 && a2) + serializeChildren(a1, a2); + else if (auto [s1, s2] = std::tuple{getMutable(ty), getMutable(tfti)}; s1 && s2) + serializeChildren(s1, s2); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + serializeChildren(u1, u2); + else if (auto [i1, i2] = std::tuple{getMutable(ty), getMutable(tfti)}; i1 && i2) + serializeChildren(i1, i2); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + serializeChildren(n1, n2); + else if (auto [t1, t2] = std::tuple{getMutable(ty), getMutable(tfti)}; t1 && t2) + serializeChildren(t1, t2); + else if (auto [m1, m2] = std::tuple{getMutable(ty), getMutable(tfti)}; m1 && m2) + serializeChildren(m1, m2); + else if (auto [f1, f2] = std::tuple{getMutable(ty), getMutable(tfti)}; f1 && f2) + serializeChildren(f1, f2); + else if (auto [c1, c2] = std::tuple{getMutable(ty), getMutable(tfti)}; c1 && c2) + serializeChildren(c1, c2); + else + { // Either this or ty and tfti do not represent the same type + std::string error = format("Argument of type %s is not currently serializable by type functions", toString(ty).c_str()); + state->errors.push_back(error); + } + } + + void serializeChildren(TypePackId tp, TypeFunctionTypePackId tftp) + { + if (auto [tPack1, tPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; tPack1 && tPack2) + serializeChildren(tPack1, tPack2); + else if (auto [vPack1, vPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; + vPack1 && vPack2) + serializeChildren(vPack1, vPack2); + else + { // Either this or ty and tfti do not represent the same type + std::string error = format("Argument of type pack %s is not currently serializable by type functions", toString(tp).c_str()); + state->errors.push_back(error); + } + } + + void serializeChildren(Kind kind, TypeFunctionKind tfkind) + { + if (auto [ty, tfty] = std::tuple{get(kind), get(tfkind)}; ty && tfty) + serializeChildren(*ty, *tfty); + else if (auto [tp, tftp] = std::tuple{get(kind), get(tfkind)}; tp && tftp) + serializeChildren(*tp, *tftp); + else + state->ctx->ice->ice("Serializing user defined type function arguments: kind and tfkind do not represent the same type"); + } + + void serializeChildren(PrimitiveType* p1, TypeFunctionPrimitiveType* p2) + { + // noop. + } + + void serializeChildren(UnknownType* u1, TypeFunctionUnknownType* u2) + { + // noop. + } + + void serializeChildren(NeverType* n1, TypeFunctionNeverType* n2) + { + // noop. + } + + void serializeChildren(AnyType* a1, TypeFunctionAnyType* a2) + { + // noop. + } + + void serializeChildren(SingletonType* s1, TypeFunctionSingletonType* s2) + { + // noop. + } + + void serializeChildren(UnionType* u1, TypeFunctionUnionType* u2) + { + for (TypeId& ty : u1->options) + u2->components.push_back(shallowSerialize(ty)); + } + + void serializeChildren(IntersectionType* i1, TypeFunctionIntersectionType* i2) + { + for (TypeId& ty : i1->parts) + i2->components.push_back(shallowSerialize(ty)); + } + + void serializeChildren(NegationType* n1, TypeFunctionNegationType* n2) + { + n2->type = shallowSerialize(n1->ty); + } + + void serializeChildren(TableType* t1, TypeFunctionTableType* t2) + { + for (const auto& [k, p] : t1->props) + { + std::optional readTy = std::nullopt; + if (p.readTy) + readTy = shallowSerialize(*p.readTy); + + std::optional writeTy = std::nullopt; + if (p.writeTy) + writeTy = shallowSerialize(*p.writeTy); + + t2->props[k] = TypeFunctionProperty{readTy, writeTy}; + } + + if (t1->indexer) + t2->indexer = TypeFunctionTableIndexer(shallowSerialize(t1->indexer->indexType), shallowSerialize(t1->indexer->indexResultType)); + } + + void serializeChildren(MetatableType* m1, TypeFunctionTableType* m2) + { + auto tmpTable = get(shallowSerialize(m1->table)); + if (!tmpTable) + state->ctx->ice->ice("Serializing user defined type function arguments: metatable's table is not a TableType"); + + m2->props = tmpTable->props; + m2->indexer = tmpTable->indexer; + + m2->metatable = shallowSerialize(m1->metatable); + } + + void serializeChildren(FunctionType* f1, TypeFunctionFunctionType* f2) + { + f2->argTypes = shallowSerialize(f1->argTypes); + f2->retTypes = shallowSerialize(f1->retTypes); + } + + void serializeChildren(ClassType* c1, TypeFunctionClassType* c2) + { + for (const auto& [k, p] : c1->props) + { + std::optional readTy = std::nullopt; + if (p.readTy) + readTy = shallowSerialize(*p.readTy); + + std::optional writeTy = std::nullopt; + if (p.writeTy) + writeTy = shallowSerialize(*p.writeTy); + + c2->props[k] = TypeFunctionProperty{readTy, writeTy}; + } + + if (c1->indexer) + c2->indexer = TypeFunctionTableIndexer(shallowSerialize(c1->indexer->indexType), shallowSerialize(c1->indexer->indexResultType)); + + if (c1->metatable) + c2->metatable = shallowSerialize(*c1->metatable); + + if (c1->parent) + c2->parent = shallowSerialize(*c1->parent); + } + + void serializeChildren(TypePack* t1, TypeFunctionTypePack* t2) + { + for (TypeId& ty : t1->head) + t2->head.push_back(shallowSerialize(ty)); + + if (t1->tail.has_value()) + t2->tail = shallowSerialize(*t1->tail); + } + + void serializeChildren(VariadicTypePack* v1, TypeFunctionVariadicTypePack* v2) + { + v2->type = shallowSerialize(v1->ty); + } +}; + +// Complete inverse of TypeFunctionSerializer +class TypeFunctionDeserializer +{ + using SeenTypes = DenseHashMap; + using SeenTypePacks = DenseHashMap; + + TypeFunctionRuntimeBuilderState* state = nullptr; + NotNull typeFunctionRuntime; + + // A queue of TypeIds that have been deserialized, but whose interior types hasn't + // been updated to point to itself. Once all of its interior types + // has been updated, it gets removed from the queue. + + // queue.back() should always return two of same type in their respective sides + // For example `auto [first, second] = queue.back()`: if first is TypeFunctionPrimitiveType, + // second must be PrimitiveType; else there should be an error + std::vector> queue; + + SeenTypes types; // Mapping of TypeFunctionTypeIds that have been shallow deserialized to TypeIds + SeenTypePacks packs; // Mapping of TypeFunctionTypePackIds that have been shallow deserialized to TypePackIds + + int steps = 0; + +public: + explicit TypeFunctionDeserializer(TypeFunctionRuntimeBuilderState* state) + : state(state) + , typeFunctionRuntime(state->ctx->typeFunctionRuntime) + , queue({}) + , types({}) + , packs({}){}; + + TypeId deserialize(TypeFunctionTypeId ty) + { + shallowDeserialize(ty); + run(); + + if (hasExceededIterationLimit() || state->errors.size() != 0) + { + TypeId error = state->ctx->builtins->errorRecoveryType(); + types[ty] = error; + return error; + } + + return find(ty).value_or(state->ctx->builtins->errorRecoveryType()); + } + + TypePackId deserialize(TypeFunctionTypePackId tp) + { + shallowDeserialize(tp); + run(); + + if (hasExceededIterationLimit() || state->errors.size() != 0) + { + TypePackId error = state->ctx->builtins->errorRecoveryTypePack(); + packs[tp] = error; + return error; + } + + return find(tp).value_or(state->ctx->builtins->errorRecoveryTypePack()); + } + +private: + bool hasExceededIterationLimit() const + { + if (DFInt::LuauTypeFunctionSerdeIterationLimit == 0) + return false; + + return steps + queue.size() >= size_t(DFInt::LuauTypeFunctionSerdeIterationLimit); + } + + void run() + { + while (!queue.empty()) + { + ++steps; + + if (hasExceededIterationLimit() || state->errors.size() != 0) + break; + + auto [tfti, ty] = queue.back(); + queue.pop_back(); + + deserializeChildren(tfti, ty); + } + } + + std::optional find(TypeFunctionTypeId ty) const + { + if (auto result = types.find(ty)) + return *result; + + return std::nullopt; + } + + std::optional find(TypeFunctionTypePackId tp) const + { + if (auto result = packs.find(tp)) + return *result; + + return std::nullopt; + } + + std::optional find(TypeFunctionKind kind) const + { + if (auto ty = get(kind)) + return find(*ty); + else if (auto tp = get(kind)) + return find(*tp); + else + { + LUAU_ASSERT(!"Unknown kind found at TypeFunctionDeserializer"); + return std::nullopt; + } + } + + TypeId shallowDeserialize(TypeFunctionTypeId ty) + { + if (auto it = find(ty)) + return *it; + + // Create a shallow deserialization + TypeId target = {}; + if (auto p = get(ty)) + { + switch (p->type) + { + case TypeFunctionPrimitiveType::Type::NilType: + target = state->ctx->builtins->nilType; + break; + case TypeFunctionPrimitiveType::Type::Boolean: + target = state->ctx->builtins->booleanType; + break; + case TypeFunctionPrimitiveType::Type::Number: + target = state->ctx->builtins->numberType; + break; + case TypeFunctionPrimitiveType::Type::String: + target = state->ctx->builtins->stringType; + break; + default: + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + } + } + else if (auto u = get(ty)) + target = state->ctx->builtins->unknownType; + else if (auto n = get(ty)) + target = state->ctx->builtins->neverType; + else if (auto a = get(ty)) + target = state->ctx->builtins->anyType; + else if (auto s = get(ty)) + { + if (auto bs = get(s)) + target = state->ctx->arena->addType(SingletonType{BooleanSingleton{bs->value}}); + else if (auto ss = get(s)) + target = state->ctx->arena->addType(SingletonType{StringSingleton{ss->value}}); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + } + else if (auto u = get(ty)) + target = state->ctx->arena->addTV(Type(UnionType{{}})); + else if (auto i = get(ty)) + target = state->ctx->arena->addTV(Type(IntersectionType{{}})); + else if (auto n = get(ty)) + target = state->ctx->arena->addType(NegationType{state->ctx->builtins->unknownType}); + else if (auto t = get(ty); t && !t->metatable.has_value()) + target = state->ctx->arena->addType(TableType{TableType::Props{}, std::nullopt, TypeLevel{}, TableState::Sealed}); + else if (auto m = get(ty); m && m->metatable.has_value()) + { + TypeId emptyTable = state->ctx->arena->addType(TableType{TableType::Props{}, std::nullopt, TypeLevel{}, TableState::Sealed}); + target = state->ctx->arena->addType(MetatableType{emptyTable, emptyTable}); + } + else if (auto f = get(ty)) + { + TypePackId emptyTypePack = state->ctx->arena->addTypePack(TypePack{}); + target = state->ctx->arena->addType(FunctionType{emptyTypePack, emptyTypePack, {}, false}); + } + else if (auto c = get(ty)) + { + if (auto result = state->classesSerialized.find(c->name)) + target = *result; + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious class type is being deserialized"); + } + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + + types[ty] = target; + queue.emplace_back(ty, target); + return target; + } + + TypePackId shallowDeserialize(TypeFunctionTypePackId tp) + { + if (auto it = find(tp)) + return *it; + + // Create a shallow deserialization + TypePackId target = {}; + if (auto tPack = get(tp)) + target = state->ctx->arena->addTypePack(TypePack{}); + else if (auto vPack = get(tp)) + target = state->ctx->arena->addTypePack(VariadicTypePack{}); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + + packs[tp] = target; + queue.emplace_back(tp, target); + return target; + } + + void deserializeChildren(TypeFunctionTypeId tfti, TypeId ty) + { + if (auto [p1, p2] = std::tuple{getMutable(ty), getMutable(tfti)}; p1 && p2) + deserializeChildren(p2, p1); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + deserializeChildren(u2, u1); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + deserializeChildren(n2, n1); + else if (auto [a1, a2] = std::tuple{getMutable(ty), getMutable(tfti)}; a1 && a2) + deserializeChildren(a2, a1); + else if (auto [s1, s2] = std::tuple{getMutable(ty), getMutable(tfti)}; s1 && s2) + deserializeChildren(s2, s1); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + deserializeChildren(u2, u1); + else if (auto [i1, i2] = std::tuple{getMutable(ty), getMutable(tfti)}; i1 && i2) + deserializeChildren(i2, i1); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + deserializeChildren(n2, n1); + else if (auto [t1, t2] = std::tuple{getMutable(ty), getMutable(tfti)}; + t1 && t2 && !t2->metatable.has_value()) + deserializeChildren(t2, t1); + else if (auto [m1, m2] = std::tuple{getMutable(ty), getMutable(tfti)}; + m1 && m2 && m2->metatable.has_value()) + deserializeChildren(m2, m1); + else if (auto [f1, f2] = std::tuple{getMutable(ty), getMutable(tfti)}; f1 && f2) + deserializeChildren(f2, f1); + else if (auto [c1, c2] = std::tuple{getMutable(ty), getMutable(tfti)}; c1 && c2) + deserializeChildren(c2, c1); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + } + + void deserializeChildren(TypeFunctionTypePackId tftp, TypePackId tp) + { + if (auto [tPack1, tPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; tPack1 && tPack2) + deserializeChildren(tPack2, tPack1); + else if (auto [vPack1, vPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; + vPack1 && vPack2) + deserializeChildren(vPack2, vPack1); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + } + + void deserializeChildren(TypeFunctionKind tfkind, Kind kind) + { + if (auto [ty, tfty] = std::tuple{get(kind), get(tfkind)}; ty && tfty) + deserializeChildren(*tfty, *ty); + else if (auto [tp, tftp] = std::tuple{get(kind), get(tfkind)}; tp && tftp) + deserializeChildren(*tftp, *tp); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: tfkind and kind do not represent the same type"); + } + + void deserializeChildren(TypeFunctionPrimitiveType* p2, PrimitiveType* p1) + { + // noop. + } + + void deserializeChildren(TypeFunctionUnknownType* u2, UnknownType* u1) + { + // noop. + } + + void deserializeChildren(TypeFunctionNeverType* n2, NeverType* n1) + { + // noop. + } + + void deserializeChildren(TypeFunctionAnyType* a2, AnyType* a1) + { + // noop. + } + + void deserializeChildren(TypeFunctionSingletonType* s2, SingletonType* s1) + { + // noop. + } + + void deserializeChildren(TypeFunctionUnionType* u2, UnionType* u1) + { + for (TypeFunctionTypeId& ty : u2->components) + u1->options.push_back(shallowDeserialize(ty)); + } + + void deserializeChildren(TypeFunctionIntersectionType* i2, IntersectionType* i1) + { + for (TypeFunctionTypeId& ty : i2->components) + i1->parts.push_back(shallowDeserialize(ty)); + } + + void deserializeChildren(TypeFunctionNegationType* n2, NegationType* n1) + { + n1->ty = shallowDeserialize(n2->type); + } + + void deserializeChildren(TypeFunctionTableType* t2, TableType* t1) + { + for (const auto& [k, p] : t2->props) + { + if (p.readTy && p.writeTy) + t1->props[k] = Property::rw(shallowDeserialize(*p.readTy), shallowDeserialize(*p.writeTy)); + else if (p.readTy) + t1->props[k] = Property::readonly(shallowDeserialize(*p.readTy)); + else if (p.writeTy) + t1->props[k] = Property::writeonly(shallowDeserialize(*p.writeTy)); + } + + if (t2->indexer.has_value()) + t1->indexer = TableIndexer(shallowDeserialize(t2->indexer->keyType), shallowDeserialize(t2->indexer->valueType)); + } + + void deserializeChildren(TypeFunctionTableType* m2, MetatableType* m1) + { + TypeFunctionTypeId temp = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{m2->props, m2->indexer}); + m1->table = shallowDeserialize(temp); + + if (m2->metatable.has_value()) + m1->metatable = shallowDeserialize(*m2->metatable); + } + + void deserializeChildren(TypeFunctionFunctionType* f2, FunctionType* f1) + { + if (f2->argTypes) + f1->argTypes = shallowDeserialize(f2->argTypes); + + if (f2->retTypes) + f1->retTypes = shallowDeserialize(f2->retTypes); + } + + void deserializeChildren(TypeFunctionClassType* c2, ClassType* c1) + { + // noop. + } + + void deserializeChildren(TypeFunctionTypePack* t2, TypePack* t1) + { + for (TypeFunctionTypeId& ty : t2->head) + t1->head.push_back(shallowDeserialize(ty)); + + if (t2->tail.has_value()) + t1->tail = shallowDeserialize(*t2->tail); + } + + void deserializeChildren(TypeFunctionVariadicTypePack* v2, VariadicTypePack* v1) + { + v1->ty = shallowDeserialize(v2->type); + } +}; + +TypeFunctionTypeId serialize(TypeId ty, TypeFunctionRuntimeBuilderState* state) +{ + return TypeFunctionSerializer(state).serialize(ty); +} + +TypeId deserialize(TypeFunctionTypeId ty, TypeFunctionRuntimeBuilderState* state) +{ + return TypeFunctionDeserializer(state).deserialize(ty); +} + +} // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 6b2e861d..7a7be71d 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -33,7 +33,6 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false) -LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false) LUAU_FASTFLAGVARIABLE(LuauAcceptIndexingTableUnionsIntersections, false) namespace Luau @@ -1284,20 +1283,11 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) for (size_t i = 2; i < varTypes.size(); ++i) unify(nilType, varTypes[i], scope, forin.location); } - else if (isNonstrictMode() || FFlag::LuauOkWithIteratingOverTableProperties) + else { for (TypeId var : varTypes) unify(unknownType, var, scope, forin.location); } - else - { - TypeId varTy = errorRecoveryType(loopScope); - - for (TypeId var : varTypes) - unify(varTy, var, scope, forin.location); - - reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"}); - } return check(loopScope, *forin.body); } diff --git a/Ast/include/Luau/ParseOptions.h b/Ast/include/Luau/ParseOptions.h index 01f2a74f..804d16fc 100644 --- a/Ast/include/Luau/ParseOptions.h +++ b/Ast/include/Luau/ParseOptions.h @@ -1,6 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Ast.h" +#include "Luau/DenseHash.h" + +#include + namespace Luau { @@ -12,10 +17,17 @@ enum class Mode Definition, // Type definition module, has special parsing rules }; +struct FragmentParseResumeSettings +{ + DenseHashMap localMap{AstName()}; + std::vector localStack; +}; + struct ParseOptions { bool allowDeclarationSyntax = false; bool captureComments = false; + std::optional parseFragment = std::nullopt; }; } // namespace Luau diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 4e49028a..83d6eefd 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -452,4 +452,4 @@ private: std::string scratchData; }; -} // namespace Luau \ No newline at end of file +} // namespace Luau diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h index bd2ca86b..2259f21c 100644 --- a/Ast/include/Luau/TimeTrace.h +++ b/Ast/include/Luau/TimeTrace.h @@ -7,6 +7,7 @@ #include #include +#include LUAU_FASTFLAG(DebugLuauTimeTracing) diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index a5e1d40e..54540215 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -7,8 +7,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false) - namespace Luau { @@ -434,13 +432,11 @@ Lexeme Lexer::lookahead() lineOffset = currentLineOffset; lexeme = currentLexeme; prevLocation = currentPrevLocation; - if (FFlag::LuauLexerLookaheadRemembersBraceType) - { - if (braceStack.size() < currentBraceStackSize) - braceStack.push_back(currentBraceType); - else if (braceStack.size() > currentBraceStackSize) - braceStack.pop_back(); - } + + if (braceStack.size() < currentBraceStackSize) + braceStack.push_back(currentBraceType); + else if (braceStack.size() > currentBraceStackSize) + braceStack.pop_back(); return result; } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 4b9eddda..44a40abf 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -19,7 +19,8 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauSolverV2, false) LUAU_FASTFLAGVARIABLE(LuauNativeAttribute, false) LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr, false) -LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctions, false) +LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionsSyntax, false) +LUAU_FASTFLAGVARIABLE(LuauAllowFragmentParsing, false) namespace Luau { @@ -211,6 +212,15 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc scratchExpr.reserve(16); scratchLocal.reserve(16); scratchBinding.reserve(16); + + if (FFlag::LuauAllowFragmentParsing) + { + if (options.parseFragment) + { + localMap = options.parseFragment->localMap; + localStack = options.parseFragment->localStack; + } + } } bool Parser::blockFollow(const Lexeme& l) @@ -891,7 +901,7 @@ AstStat* Parser::parseReturn() AstStat* Parser::parseTypeAlias(const Location& start, bool exported) { // parsing a type function - if (FFlag::LuauUserDefinedTypeFunctions) + if (FFlag::LuauUserDefinedTypeFunctionsSyntax) { if (lexer.current().type == Lexeme::ReservedFunction) return parseTypeFunction(start); diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index e8be59eb..8bccffce 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -3,6 +3,7 @@ #include "Luau/StringUtils.h" +#include #include #include diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index cd73bcbb..a63655cc 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -11,8 +11,6 @@ #include "lstate.h" #include "lgc.h" -LUAU_FASTFLAGVARIABLE(LuauCodegenArmNumToVecFix, false) - namespace Luau { namespace CodeGen @@ -1121,7 +1119,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) else { RegisterA64 tempd = tempDouble(inst.a); - RegisterA64 temps = FFlag::LuauCodegenArmNumToVecFix ? regs.allocTemp(KindA64::s) : castReg(KindA64::s, tempd); + RegisterA64 temps = regs.allocTemp(KindA64::s); build.fcvt(temps, tempd); build.dup_4s(inst.regA64, castReg(KindA64::q, temps), 0); diff --git a/Makefile b/Makefile index 3e6b85ad..cb199de8 100644 --- a/Makefile +++ b/Makefile @@ -142,7 +142,7 @@ endif $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -ICommon/include -IAst/include $(CONFIG_OBJECTS): CXXFLAGS+=-std=c++17 -IConfig/include -ICommon/include -IAst/include -$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include +$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include -ICompiler/include -IVM/include $(EQSAT_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IEqSat/include $(CODEGEN_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -ICodeGen/include -IVM/include -IVM/src # Code generation needs VM internals $(VM_OBJECTS): CXXFLAGS+=-std=c++11 -ICommon/include -IVM/include @@ -227,7 +227,7 @@ luau-tests: $(TESTS_TARGET) # executable targets $(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) $(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) -$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(AST_TARGET) $(CONFIG_TARGET) +$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(COMPILER_TARGET) $(VM_TARGET) $(COMPILE_CLI_TARGET): $(COMPILE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(BYTECODE_CLI_TARGET): $(BYTECODE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) diff --git a/Sources.cmake b/Sources.cmake index 80bcd5b2..103ea280 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -182,6 +182,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Documentation.h Analysis/include/Luau/Error.h Analysis/include/Luau/FileResolver.h + Analysis/include/Luau/FragmentAutocomplete.h Analysis/include/Luau/Frontend.h Analysis/include/Luau/Generalization.h Analysis/include/Luau/GlobalTypes.h @@ -223,6 +224,8 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/TypedAllocator.h Analysis/include/Luau/TypeFunction.h Analysis/include/Luau/TypeFunctionReductionGuesser.h + Analysis/include/Luau/TypeFunctionRuntime.h + Analysis/include/Luau/TypeFunctionRuntimeBuilder.h Analysis/include/Luau/TypeFwd.h Analysis/include/Luau/TypeInfer.h Analysis/include/Luau/TypeOrPack.h @@ -253,6 +256,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Differ.cpp Analysis/src/EmbeddedBuiltinDefinitions.cpp Analysis/src/Error.cpp + Analysis/src/FragmentAutocomplete.cpp Analysis/src/Frontend.cpp Analysis/src/Generalization.cpp Analysis/src/GlobalTypes.cpp @@ -287,6 +291,8 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/TypedAllocator.cpp Analysis/src/TypeFunction.cpp Analysis/src/TypeFunctionReductionGuesser.cpp + Analysis/src/TypeFunctionRuntime.cpp + Analysis/src/TypeFunctionRuntimeBuilder.cpp Analysis/src/TypeInfer.cpp Analysis/src/TypeOrPack.cpp Analysis/src/TypePack.cpp @@ -440,6 +446,7 @@ if(TARGET Luau.UnitTest) tests/Error.test.cpp tests/Fixture.cpp tests/Fixture.h + tests/FragmentAutocomplete.test.cpp tests/Frontend.test.cpp tests/Generalization.test.cpp tests/InsertionOrderedMap.test.cpp @@ -474,6 +481,7 @@ if(TARGET Luau.UnitTest) tests/Transpiler.test.cpp tests/TxnLog.test.cpp tests/TypeFunction.test.cpp + tests/TypeFunction.user.test.cpp tests/TypeInfer.aliases.test.cpp tests/TypeInfer.annotations.test.cpp tests/TypeInfer.anyerror.test.cpp diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index 16775f9b..f38ab80b 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -10,8 +10,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauPreserveLudataRenaming, false) - // clang-format off const char* const luaT_typenames[] = { // ORDER TYPE @@ -124,74 +122,40 @@ const TValue* luaT_gettmbyobj(lua_State* L, const TValue* o, TMS event) const TString* luaT_objtypenamestr(lua_State* L, const TValue* o) { - if (FFlag::LuauPreserveLudataRenaming) + // Userdata created by the environment can have a custom type name set in the individual metatable + // If there is no custom name, 'userdata' is returned + if (ttisuserdata(o) && uvalue(o)->tag != UTAG_PROXY && uvalue(o)->metatable) { - // Userdata created by the environment can have a custom type name set in the individual metatable - // If there is no custom name, 'userdata' is returned - if (ttisuserdata(o) && uvalue(o)->tag != UTAG_PROXY && uvalue(o)->metatable) - { - const TValue* type = luaH_getstr(uvalue(o)->metatable, L->global->tmname[TM_TYPE]); + const TValue* type = luaH_getstr(uvalue(o)->metatable, L->global->tmname[TM_TYPE]); - if (ttisstring(type)) - return tsvalue(type); - - return L->global->ttname[ttype(o)]; - } - - // Tagged lightuserdata can be named using lua_setlightuserdataname - if (ttislightuserdata(o)) - { - int tag = lightuserdatatag(o); - - if (unsigned(tag) < LUA_LUTAG_LIMIT) - { - if (const TString* name = L->global->lightuserdataname[tag]) - return name; - } - } - - // For all types except userdata and table, a global metatable can be set with a global name override - if (Table* mt = L->global->mt[ttype(o)]) - { - const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]); - - if (ttisstring(type)) - return tsvalue(type); - } + if (ttisstring(type)) + return tsvalue(type); return L->global->ttname[ttype(o)]; } - else + + // Tagged lightuserdata can be named using lua_setlightuserdataname + if (ttislightuserdata(o)) { - if (ttisuserdata(o) && uvalue(o)->tag != UTAG_PROXY && uvalue(o)->metatable) + int tag = lightuserdatatag(o); + + if (unsigned(tag) < LUA_LUTAG_LIMIT) { - const TValue* type = luaH_getstr(uvalue(o)->metatable, L->global->tmname[TM_TYPE]); - - if (ttisstring(type)) - return tsvalue(type); + if (const TString* name = L->global->lightuserdataname[tag]) + return name; } - else if (ttislightuserdata(o)) - { - int tag = lightuserdatatag(o); - - if (unsigned(tag) < LUA_LUTAG_LIMIT) - { - const TString* name = L->global->lightuserdataname[tag]; - - if (name) - return name; - } - } - else if (Table* mt = L->global->mt[ttype(o)]) - { - const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]); - - if (ttisstring(type)) - return tsvalue(type); - } - - return L->global->ttname[ttype(o)]; } + + // For all types except userdata and table, a global metatable can be set with a global name override + if (Table* mt = L->global->mt[ttype(o)]) + { + const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]); + + if (ttisstring(type)) + return tsvalue(type); + } + + return L->global->ttname[ttype(o)]; } const char* luaT_objtypename(lua_State* L, const TValue* o) diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 376caa44..df6e5332 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -34,8 +34,6 @@ void luaC_validate(lua_State* L); LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTFLAG(LuauNativeAttribute) -LUAU_FASTFLAG(LuauPreserveLudataRenaming) -LUAU_FASTFLAG(LuauCodegenArmNumToVecFix) static lua_CompileOptions defaultOptions() { @@ -825,8 +823,6 @@ TEST_CASE("Pack") TEST_CASE("Vector") { - ScopedFastFlag luauCodegenArmNumToVecFix{FFlag::LuauCodegenArmNumToVecFix, true}; - lua_CompileOptions copts = defaultOptions(); Luau::CodeGen::CompilationOptions nativeOpts = defaultCodegenOptions(); @@ -2251,20 +2247,17 @@ TEST_CASE("LightuserdataApi") lua_pop(L, 1); - if (FFlag::LuauPreserveLudataRenaming) - { - // Still possible to rename the global lightuserdata name using a metatable - lua_pushlightuserdata(L, value); - CHECK(strcmp(luaL_typename(L, -1), "userdata") == 0); + // Still possible to rename the global lightuserdata name using a metatable + lua_pushlightuserdata(L, value); + CHECK(strcmp(luaL_typename(L, -1), "userdata") == 0); - lua_createtable(L, 0, 1); - lua_pushstring(L, "luserdata"); - lua_setfield(L, -2, "__type"); - lua_setmetatable(L, -2); + lua_createtable(L, 0, 1); + lua_pushstring(L, "luserdata"); + lua_setfield(L, -2, "__type"); + lua_setmetatable(L, -2); - CHECK(strcmp(luaL_typename(L, -1), "luserdata") == 0); - lua_pop(L, 1); - } + CHECK(strcmp(luaL_typename(L, -1), "luserdata") == 0); + lua_pop(L, 1); globalState.reset(); } diff --git a/tests/ConstraintGeneratorFixture.cpp b/tests/ConstraintGeneratorFixture.cpp index 7f168465..f595d6ec 100644 --- a/tests/ConstraintGeneratorFixture.cpp +++ b/tests/ConstraintGeneratorFixture.cpp @@ -42,7 +42,9 @@ void ConstraintGeneratorFixture::generateConstraints(const std::string& code) void ConstraintGeneratorFixture::solve(const std::string& code) { generateConstraints(code); - ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, {}}; + ConstraintSolver cs{ + NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, {} + }; cs.run(); } diff --git a/tests/ConstraintGeneratorFixture.h b/tests/ConstraintGeneratorFixture.h index ff362be1..acf616e0 100644 --- a/tests/ConstraintGeneratorFixture.h +++ b/tests/ConstraintGeneratorFixture.h @@ -20,6 +20,7 @@ struct ConstraintGeneratorFixture : Fixture DcrLogger logger; UnifierSharedState sharedState{&ice}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + TypeFunctionRuntime typeFunctionRuntime; std::unique_ptr dfg; std::unique_ptr cg; diff --git a/tests/FragmentAutocomplete.test.cpp b/tests/FragmentAutocomplete.test.cpp new file mode 100644 index 00000000..b8b7829d --- /dev/null +++ b/tests/FragmentAutocomplete.test.cpp @@ -0,0 +1,139 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/FragmentAutocomplete.h" +#include "Fixture.h" +#include "Luau/Ast.h" +#include "Luau/AstQuery.h" + + +using namespace Luau; + +struct FragmentAutocompleteFixture : Fixture +{ + + FragmentAutocompleteAncestryResult runAutocompleteVisitor(const std::string& source, const Position& cursorPos) + { + ParseResult p = tryParse(source); // We don't care about parsing incomplete asts + REQUIRE(p.root); + return findAncestryForFragmentParse(p.root, cursorPos); + } +}; + +TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTest"); + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "just_two_locals") +{ + auto result = runAutocompleteVisitor( + R"( +local x = 4 +local y = 5 +)", + {2, 11} + ); + + CHECK_EQ(3, result.ancestry.size()); + CHECK_EQ(2, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + REQUIRE(result.nearestStatement); + + AstStatLocal* local = result.nearestStatement->as(); + REQUIRE(local); + CHECK(1 == local->vars.size); + CHECK_EQ("y", std::string(local->vars.data[0]->name.value)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "cursor_within_scope_tracks_locals_from_previous_scope") +{ + auto result = runAutocompleteVisitor( + R"( +local x = 4 +local y = 5 +if x == 4 then + local e = y +end +)", + {4, 15} + ); + + CHECK_EQ(5, result.ancestry.size()); + CHECK_EQ(3, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + REQUIRE(result.nearestStatement); + CHECK_EQ("e", std::string(result.localStack.back()->name.value)); + + AstStatLocal* local = result.nearestStatement->as(); + REQUIRE(local); + CHECK(1 == local->vars.size); + CHECK_EQ("e", std::string(local->vars.data[0]->name.value)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "cursor_that_comes_later_shouldnt_capture_locals_in_unavailable_scope") +{ + auto result = runAutocompleteVisitor( + R"( +local x = 4 +local y = 5 +if x == 4 then + local e = y +end +local z = x + x +if y == 5 then + local q = x + y + z +end +)", + {8, 23} + ); + + CHECK_EQ(6, result.ancestry.size()); + CHECK_EQ(4, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + REQUIRE(result.nearestStatement); + CHECK_EQ("q", std::string(result.localStack.back()->name.value)); + + AstStatLocal* local = result.nearestStatement->as(); + REQUIRE(local); + CHECK(1 == local->vars.size); + CHECK_EQ("q", std::string(local->vars.data[0]->name.value)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "nearest_enclosing_statement_can_be_non_local") +{ + auto result = runAutocompleteVisitor( + R"( +local x = 4 +local y = 5 +if x == 4 then +)", + {3, 4} + ); + + CHECK_EQ(4, result.ancestry.size()); + CHECK_EQ(2, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + REQUIRE(result.nearestStatement); + CHECK_EQ("y", std::string(result.localStack.back()->name.value)); + + AstStatIf* ifS = result.nearestStatement->as(); + CHECK(ifS != nullptr); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_funcs_show_up_in_local_stack") +{ + auto result = runAutocompleteVisitor( + R"( +local function foo() return 4 end +local x = foo() +local function bar() return x + foo() end +)", + {3, 32} + ); + + CHECK_EQ(8, result.ancestry.size()); + CHECK_EQ(3, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + CHECK_EQ("bar", std::string(result.localStack.back()->name.value)); + auto returnSt = result.nearestStatement->as(); + CHECK(returnSt != nullptr); +} + +TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index dfcf0ded..74d7a920 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -3,6 +3,7 @@ #include "AstQueryDsl.h" #include "Fixture.h" +#include "Luau/Common.h" #include "ScopedFlags.h" #include "doctest.h" @@ -11,13 +12,12 @@ using namespace Luau; -LUAU_FASTFLAG(LuauLexerLookaheadRemembersBraceType); -LUAU_FASTINT(LuauRecursionLimit); -LUAU_FASTINT(LuauTypeLengthLimit); -LUAU_FASTINT(LuauParseErrorLimit); -LUAU_FASTFLAG(LuauSolverV2); -LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr); -LUAU_FASTFLAG(LuauUserDefinedTypeFunctions); +LUAU_FASTINT(LuauRecursionLimit) +LUAU_FASTINT(LuauTypeLengthLimit) +LUAU_FASTINT(LuauParseErrorLimit) +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax) namespace { @@ -2380,7 +2380,7 @@ TEST_CASE_FIXTURE(Fixture, "invalid_type_forms") TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions") { - ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctions, true}; + ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; AstStat* stat = parse(R"( type function foo() @@ -3138,8 +3138,6 @@ TEST_CASE_FIXTURE(Fixture, "do_block_with_no_end") TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved") { - ScopedFastFlag sff{FFlag::LuauLexerLookaheadRemembersBraceType, true}; - ParseResult result = tryParse(R"( local x = `{ {y} }` )"); @@ -3149,8 +3147,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved") TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved2") { - ScopedFastFlag sff{FFlag::LuauLexerLookaheadRemembersBraceType, true}; - ParseResult result = tryParse(R"( local x = `{ { y{} } }` )"); diff --git a/tests/Subtyping.test.cpp b/tests/Subtyping.test.cpp index a59312ac..05bea2f7 100644 --- a/tests/Subtyping.test.cpp +++ b/tests/Subtyping.test.cpp @@ -66,6 +66,7 @@ struct SubtypeFixture : Fixture InternalErrorReporter iceReporter; UnifierSharedState sharedState{&ice}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + TypeFunctionRuntime typeFunctionRuntime; ScopedFastFlag sff{FFlag::LuauSolverV2, true}; @@ -77,7 +78,7 @@ struct SubtypeFixture : Fixture Subtyping mkSubtyping() { - return Subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&iceReporter}}; + return Subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}}; } TypePackId pack(std::initializer_list tys) diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index edc2bf47..f6208c1b 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -12,7 +12,7 @@ using namespace Luau; -LUAU_FASTFLAG(LuauUserDefinedTypeFunctions); +LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax) TEST_SUITE_BEGIN("TranspilerTests"); @@ -698,7 +698,7 @@ TEST_CASE_FIXTURE(Fixture, "transpile_string_literal_escape") TEST_CASE_FIXTURE(Fixture, "transpile_type_functions") { - ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctions, true}; + ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; std::string code = R"( type function foo(arg1, arg2) if arg1 == arg2 then return arg1 end return arg2 end )"; diff --git a/tests/TypeFunction.test.cpp b/tests/TypeFunction.test.cpp index d3732d60..18d8f17b 100644 --- a/tests/TypeFunction.test.cpp +++ b/tests/TypeFunction.test.cpp @@ -1247,18 +1247,4 @@ TEST_CASE_FIXTURE(ClassFixture, "rawget_type_function_errors_w_classes") CHECK(toString(result.errors[0]) == "Property '\"BaseField\"' does not exist on type 'BaseClass'"); } -TEST_CASE_FIXTURE(Fixture, "user_defined_type_function_errors") -{ - if (!FFlag::LuauUserDefinedTypeFunctions) - return; - - CheckResult result = check(R"( - type function foo() - return nil - end - )"); - LUAU_CHECK_ERROR_COUNT(1, result); - CHECK(toString(result.errors[0]) == "This syntax is not supported"); -} - TEST_SUITE_END(); diff --git a/tests/TypeFunction.user.test.cpp b/tests/TypeFunction.user.test.cpp new file mode 100644 index 00000000..fbce4df2 --- /dev/null +++ b/tests/TypeFunction.user.test.cpp @@ -0,0 +1,1007 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "ClassFixture.h" +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctions) + +TEST_SUITE_BEGIN("UserDefinedTypeFunctionTests"); + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_nil_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_nil(arg) + return arg + end + type type_being_serialized = nil + local function ok(idx: serialize_nil): nil return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_nil_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getnil() + local ty = types.singleton(nil) + if ty:is("nil") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getnil<>): nil return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_unknown_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_unknown(arg) + return arg + end + type type_being_serialized = unknown + local function ok(idx: serialize_unknown): unknown return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_unknown_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getunknown() + local ty = types.unknown + if ty:is("unknown") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getunknown<>): unknown return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_never_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_never(arg) + return arg + end + type type_being_serialized = never + local function ok(idx: serialize_never): never return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_never_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getnever() + local ty = types.never + if ty:is("never") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getnever<>): never return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_any_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_any(arg) + return arg + end + type type_being_serialized = any + local function ok(idx: serialize_any): any return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_any_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getany() + local ty = types.any + if ty:is("any") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getany<>): any return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolean_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_bool(arg) + return arg + end + type type_being_serialized = boolean + local function ok(idx: serialize_bool): boolean return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolean_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getboolean() + local ty = types.boolean + if ty:is("boolean") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getboolean<>): boolean return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_number_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_num(arg) + return arg + end + type type_being_serialized = number + local function ok(idx: serialize_num): number return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_number_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getnumber() + local ty = types.number + if ty:is("number") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getnumber<>): number return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_string_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_str(arg) + return arg + end + type type_being_serialized = string + local function ok(idx: serialize_str): string return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_string_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getstring() + local ty = types.string + if ty:is("string") then + return ty + end + -- this should never be returned + return types.boolean + end + local function ok(idx: getstring<>): string return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolsingleton_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_boolsingleton(arg) + return arg + end + type type_being_serialized = true + local function ok(idx: serialize_boolsingleton): true return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolsingleton_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getboolsingleton() + local ty = types.singleton(true) + if ty:is("singleton") and ty:value() then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getboolsingleton<>): true return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strsingleton_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_strsingleton(arg) + return arg + end + type type_being_serialized = "popcorn and movies!" + local function ok(idx: serialize_strsingleton): "popcorn and movies!" return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strsingleton_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getstrsingleton() + local ty = types.singleton("hungry hippo") + if ty:is("singleton") and ty:value() == "hungry hippo" then + return ty + end + -- this should never be returned + return types.number + end + local function ok(idx: getstrsingleton<>): "hungry hippo" return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_union_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_union(arg) + return arg + end + type type_being_serialized = number | string | boolean + -- forcing an error here to check the exact type of the union + local function ok(idx: serialize_union): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "boolean | number | string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_union_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getunion() + local ty = types.unionof(types.string, types.number, types.boolean) + if ty:is("union") then + -- creating a copy of `ty` + local arr = {} + for _, value in ty:components() do + table.insert(arr, value) + end + return types.unionof(table.unpack(arr)) + end + -- this should never be returned + return types.number + end + -- forcing an error here to check the exact type of the union + local function ok(idx: getunion<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "boolean | number | string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_intersection_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_intersection(arg) + return arg + end + type type_being_serialized = { boolean: boolean, number: number } & { boolean: boolean, string: string } + -- forcing an error here to check the exact type of the intersection + local function ok(idx: serialize_intersection): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ boolean: boolean, number: number } & { boolean: boolean, string: string }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_intersection_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getintersection() + local tbl1 = types.newtable(nil, nil, nil) + tbl1:setproperty(types.singleton("boolean"), types.boolean) -- {boolean: boolean} + tbl1:setproperty(types.singleton("number"), types.number) -- {boolean: boolean, number: number} + local tbl2 = types.newtable(nil, nil, nil) + tbl2:setproperty(types.singleton("boolean"), types.boolean) -- {boolean: boolean} + tbl2:setproperty(types.singleton("string"), types.string) -- {boolean: boolean, string: string} + local ty = types.intersectionof(tbl1, tbl2) + if ty:is("intersection") then + -- creating a copy of `ty` + local arr = {} + for index, value in ty:components() do + table.insert(arr, value) + end + return types.intersectionof(table.unpack(arr)) + end + -- this should never be returned + return types.string + end + -- forcing an error here to check the exact type of the intersection + local function ok(idx: getintersection<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ boolean: boolean, number: number } & { boolean: boolean, string: string }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_negation_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getnegation() + local ty = types.negationof(types.string) + if ty:is("negation") then + return ty + end + -- this should never be returned + return types.number + end + + -- forcing an error here to check the exact type of the negation + local function ok(idx: getnegation<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "~string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_table_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_table(arg) + return arg + end + type type_being_serialized = { boolean: boolean, number: number, [string]: number } + -- forcing an error here to check the exact type of the table + local function ok(idx: serialize_table): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ [string]: number, boolean: boolean, number: number }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_table_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function gettable() + local indexer = { + index = types.number, + readresult = types.boolean, + writeresult = types.boolean, + } + local ty = types.newtable(nil, indexer, nil) -- {[number]: boolean} + ty:setproperty(types.singleton("string"), types.number) -- {string: number, [number] = boolean} + ty:setproperty(types.singleton("number"), types.string) -- {string: number, number: string, [number] = boolean} + ty:setproperty(types.singleton("string"), nil) -- {number: string, [number] = boolean} + local ret = types.newtable(nil, nil, nil) -- {} + -- creating a copy of `ty` + for k, v in ty:properties() do + ret:setreadproperty(k, v.read) + ret:setwriteproperty(k, v.write) + end + if ret:is("table") then + ret:setindexer(types.boolean, types.string) -- {number: string, [boolean] = string} + return ret -- {number: string, [boolean] = string} + end + -- this should never be returned + return types.number + end + -- forcing an error here to check the exact type of the table + local function ok(idx: gettable<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ [boolean]: string, number: string }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_metatable_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getmetatable() + local indexer = { + index = types.number, + readresult = types.boolean, + writeresult = types.boolean, + } + local ty = types.newtable(nil, indexer, nil) -- {[number]: boolean} + ty:setproperty(types.singleton("string"), types.number) -- {string: number, [number]: boolean} + local metatbl = types.newtable(nil, nil, ty) -- { { }, @metatable { [number]: boolean, string: number } } + metatbl:setmetatable(types.newtable(nil, indexer, nil)) -- { { }, @metatable { [number]: boolean } } + local ret = metatbl:metatable() + if metatbl:is("table") and metatbl:metatable() then + return ret -- { @metatable { [number]: boolean } } + end + -- this should never be returned + return types.number + end + -- forcing an error here to check the exact type of the metatable + local function ok(idx: getmetatable<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{boolean}"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_func(arg) + return arg + end + type type_being_serialized = (boolean, number, nil) -> (...string) + local function ok(idx: serialize_func): (boolean, number, nil) -> (...string) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getfunction() + local ty = types.newfunction(nil, nil) -- () -> () + ty:setparameters({types.string, types.number}, nil) -- (string, number) -> () + ty:setreturns(nil, types.boolean) -- (string, number) -> (...boolean) + if ty:is("function") then + -- creating a copy of `ty` parameters + local arr = {} + for index, val in ty:parameters().head do + table.insert(arr, val) + end + return types.newfunction({head = arr}, ty:returns()) -- (string, number) -> (...boolean) + end + -- this should never be returned + return types.number + end + local function ok(idx: getfunction<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "(string, number) -> (...boolean)"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_class_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_class(arg) + return arg + end + local function ok(idx: serialize_class): BaseClass return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_class_methods_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + + CheckResult result = check(R"( + type function getclass(arg) + local props = arg:properties() + local indexer = arg:indexer() + local metatable = arg:metatable() + return types.newtable(props, indexer, metatable) + end + -- forcing an error here to check the exact type of the metatable + local function ok(idx: getclass): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ BaseField: number, read BaseMethod: (BaseClass, number) -> (), read Touched: Connection }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_check_mutability") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function checkmut() + local indexer = { + index = types.number, + readresult = types.boolean, + writeresult = types.boolean, + } + local ty = types.newtable(props, indexer, nil) -- {[number]: boolean} + ty:setproperty(types.singleton("string"), types.number) -- {string: number, [number]: boolean} + local metatbl = types.newtable(nil, nil, ty) -- { { }, @metatable { [number]: boolean, string: number } } + -- mutate the table + ty:setproperty(types.singleton("string"), nil) -- {[number]: boolean} + if metatbl:is("table") and metatbl:metatable() then + return metatbl -- { @metatable { [number]: boolean }, { } } + end + -- this should never be returned + return types.number + end + local function ok(idx: checkmut<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ @metatable {boolean}, { } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_copy_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function getcopy() + local indexer = { + index = types.number, + readresult = types.boolean, + writeresult = types.boolean, + } + local ty = types.newtable(nil, indexer, nil) -- {[number]: boolean} + ty:setproperty(types.singleton("string"), types.number) -- {string: number, [number]: boolean} + local metaty = types.newtable(nil, nil, ty) -- { { }, @metatable { [number]: boolean, string: number } } + local copy = types.copy(metaty) + -- mutate the table + ty:setproperty(types.singleton("string"), nil) -- {[number]: boolean} + if copy:is("table") and copy:metatable() then + return copy -- { { }, @metatable { [number]: boolean, string: number } } + end + -- this should never be returned + return types.number + end + local function ok(idx: getcopy<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ @metatable { [number]: boolean, string: number }, { } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_simple_cyclic_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_cycle(arg) + return arg + end + type basety = { + first: basety2 + } + type basety2 = { + second: basety + } + local function ok(idx: serialize_cycle): basety return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_createtable_bad_metatable") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function badmetatable() + return types.newtable(nil, nil, types.number) + end + local function bad(arg: badmetatable<>) end + )"); + + LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK( + e->message == "'badmetatable' type function errored at runtime: [string \"badmetatable\"]:3: types.newtable: expected to be given a table " + "type as a metatable, but got number instead" + ); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_complex_cyclic_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function serialize_cycle2(arg) + return arg + end + type Employee = { + name: string, + department: Department? + } + type Department = { + name: string, + manager: Employee?, + employees: { Employee }, + company: Company? + } + type Company = { + name: string, + departments: { Department } + } + local function ok(idx: serialize_cycle2): Company return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_user_error_is_reported") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function errors_if_string(arg) + if arg:is("string") then + local a = 1 + error("We are in a math class! not english") + end + return arg + end + local function ok(idx: errors_if_string): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK(e->message == "'errors_if_string' type function errored at runtime: [string \"errors_if_string\"]:5: We are in a math class! not english"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_type_overrides_call_metamethod") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function hello(arg) + error(type(arg)) + end + local function ok(idx: hello): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK(e->message == "'hello' type function errored at runtime: [string \"hello\"]:3: userdata"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_type_overrides_eq_metamethod") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function hello() + local p1 = types.string + local p2 = types.string + local t1 = types.newtable(nil, nil, nil) + t1:setproperty(types.singleton("string"), types.boolean) + t1:setmetatable(t1) + local t2 = types.newtable(nil, nil, nil) + t2:setproperty(types.singleton("string"), types.boolean) + t1:setmetatable(t1) + if p1 == p2 and t1 == t2 then + return types.number + end + end + local function ok(idx: hello<>): number return idx end + )"); + + LUAU_CHECK_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_type_cant_call_get_props") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function hello(arg) + local arr = arg:properties() + end + local function ok(idx: hello<() -> ()>): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK( + e->message == "'hello' type function errored at runtime: [string \"hello\"]:3: type.properties: expected self to be either a table or class, " + "but got function instead" + ); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_cannot_call_other") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function foo() + return "hi" + end + local x = true; + type function cannot_call_others() + return foo() + end + local function ok(idx: cannot_call_others<>): string return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK(e->message == "'cannot_call_others' type function errored at runtime: [string \"cannot_call_others\"]:7: attempt to call a nil value"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_optionify") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function optionify(tbl) + if not tbl:is("table") then + error("Argument is not a table") + end + for k, v in tbl:properties() do + tbl:setproperty(k, types.unionof(v.read, types.singleton(nil))) + end + return tbl + end + type Person = { + name: string, + age: number, + alive: boolean + } + local function ok(idx: optionify): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ age: number?, alive: boolean?, name: string? }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_illegal_global") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions, true}; + + CheckResult result = check(R"( + type function illegal(arg) + gcinfo() -- this should error + + return arg -- this should not be reached + end + + local function ok(idx: illegal): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK(e->message == "'illegal' type function errored at runtime: [string \"illegal\"]:3: this function is not supported in type functions"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_recursion_and_gc") +{ + ScopedFastFlag newSolver{ FFlag::LuauSolverV2, true }; + ScopedFastFlag udtfSyntax{ FFlag::LuauUserDefinedTypeFunctionsSyntax, true }; + ScopedFastFlag udtf{ FFlag::LuauUserDefinedTypeFunctions, true }; + + CheckResult result = check(R"( + type function foo(tbl) + local count = 0 + for k,v in tbl:properties() do count += 1 end + if count < 100 then + tbl:setproperty(types.singleton(`m{count}`), types.string) + foo(tbl) + end + for i = 1,100 do table.create(10000) end + return tbl + end + type Test = {} + local function ok(idx: foo): nil return idx end + )"); + + LUAU_CHECK_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index bb5a2cdd..15eed392 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -9,6 +9,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax) LUAU_FASTFLAG(LuauUserDefinedTypeFunctions) TEST_SUITE_BEGIN("TypeAliases"); @@ -1169,8 +1170,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_adds_reduce_constraint_for_type_f TEST_CASE_FIXTURE(Fixture, "user_defined_type_function_errors") { - if (!FFlag::LuauUserDefinedTypeFunctions) - return; + ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax, true}; + ScopedFastFlag noUDTFimpl{FFlag::LuauUserDefinedTypeFunctions, false}; CheckResult result = check(R"( type function foo() diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 9912cc35..25f3d113 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -1427,4 +1427,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3") CHECK_EQ(toString(requireType("e")), "number?"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "string_find_should_not_crash") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + local function StringSplit(input, separator) + string.find(input, separator) + if not separator then + separator = "%s+" + end + end + )")); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index de79654b..ec36b30e 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -15,7 +15,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauOkWithIteratingOverTableProperties) LUAU_DYNAMIC_FASTFLAG(LuauImproveNonFunctionCallError) @@ -699,8 +698,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "loop_typecheck_crash_on_empty_optional") if (FFlag::LuauSolverV2) return; - ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true}; - CheckResult result = check(R"( local t = {} for _ in t do @@ -784,7 +781,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer_strict") // CLI-116498 Sometimes you can iterate over tables with no indexers. ScopedFastFlag sff[] = { {FFlag::LuauSolverV2, false}, - {FFlag::LuauOkWithIteratingOverTableProperties, true} }; CheckResult result = check(R"( @@ -937,8 +933,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cli_68448_iterators_need_not_accept_nil") TEST_CASE_FIXTURE(Fixture, "iterate_over_free_table") { - ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true}; - CheckResult result = check(R"( function print(x) end @@ -1095,8 +1089,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties") // CLI-116498 - Sometimes you can iterate over tables with no indexer. ScopedFastFlag sff0{FFlag::LuauSolverV2, false}; - ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true}; - CheckResult result = check(R"( local function f() local t = { p = 5, q = "hello" } @@ -1118,8 +1110,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties") TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties_nonstrict") { - ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true}; - CheckResult result = check(R"( --!nonstrict local function f() diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index 42f1229f..ba54aca0 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -530,4 +530,82 @@ return l0 CHECK(mod->scopes[3].second->importedModules["l1"] == "game/A"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "ensure_scope_is_nullptr_after_shallow_copy") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + frontend.options.retainFullTypeGraphs = false; + + fileResolver.source["game/A"] = R"( +-- Roughly taken from ReactTypes.lua +type CoreBinding = {} +type BindingMap = {} +export type Binding = CoreBinding & BindingMap + +return {} + )"; + + LUAU_REQUIRE_NO_ERRORS(check(R"( +local Types = require(game.A) +type Binding = Types.Binding + )")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "ensure_free_variables_are_generialized_across_function_boundaries") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + fileResolver.source["game/A"] = R"( +-- Roughly taken from react-shallow-renderer +function createUpdater(renderer) + local updater = { + _renderer = renderer, + } + + function updater.enqueueForceUpdate(publicInstance, callback, _callerName) + updater._renderer.render( + updater._renderer, + updater._renderer._element, + updater._renderer._context + ) + end + + function updater.enqueueReplaceState( + publicInstance, + completeState, + callback, + _callerName + ) + updater._renderer.render( + updater._renderer, + updater._renderer._element, + updater._renderer._context + ) + end + + function updater.enqueueSetState(publicInstance, partialState, callback, _callerName) + local currentState = updater._renderer._newState or publicInstance.state + updater._renderer.render( + updater._renderer, + updater._renderer._element, + updater._renderer._context + ) + end + + return updater +end + +local ReactShallowRenderer = {} + +function ReactShallowRenderer:_reset() + self._updater = createUpdater(self) +end + +return ReactShallowRenderer + )"; + + LUAU_REQUIRE_NO_ERRORS(check(R"( +local ReactShallowRenderer = require(game.A); + )")); +} + TEST_SUITE_END();