Sync to upstream/release/645 (#1440)

In this update, we continue to improve the overall stability of the new
type solver. We're also shipping some early bits of two new features,
one of the language and one of the analysis API: user-defined type
functions and an incremental typechecking API.

If you use the new solver and want to use all new fixes included in this
release, you have to reference an additional Luau flag:
```c++
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
```
And set its value to `645`:
```c++
DFInt::LuauTypeSolverRelease.value = 645; // Or a higher value for future updates
```

## New Solver

* Fix a crash where scopes are incorrectly accessed cross-module after
they've been deallocated by appropriately zeroing out associated scope
pointers for free types, generic types, table types, etc.
* Fix a crash where we were incorrectly caching results for bound types
in generalization.
* Eliminated some unnecessary intermediate allocations in the constraint
solver and type function infrastructure.
* Built some initial groundwork for an incremental typecheck API for use
by language servers.
* Built an initial technical preview for [user-defined type
functions](https://rfcs.luau-lang.org/user-defined-type-functions.html),
more work still to come (including calling type functions from other
type functions), but adventurous folks wanting to experiment with it can
try it out by enabling `FFlag::LuauUserDefinedTypeFunctionsSyntax` and
`FFlag::LuauUserDefinedTypeFunction` in their local environment. Special
thanks to @joonyoo181 who built up all the initial infrastructure for
this during his internship!

## Miscellaneous changes

* Fix a compilation error on Ubuntu (fixes #1437)

---

Internal Contributors:

Co-authored-by: Aaron Weiss <aaronweiss@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Jeremy Yoo <jyoo@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>

---------

Co-authored-by: Alexander McCord <amccord@roblox.com>
Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Vighnesh <vvijay@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
Co-authored-by: David Cope <dcope@roblox.com>
Co-authored-by: Lily Brown <lbrown@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
Co-authored-by: Junseo Yoo <jyoo@roblox.com>
This commit is contained in:
aaron 2024-09-27 11:58:21 -07:00 committed by GitHub
parent c188715605
commit 02241b6d24
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
52 changed files with 5037 additions and 221 deletions

View file

@ -12,6 +12,7 @@
#include "Luau/ToString.h"
#include "Luau/Type.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFunction.h"
#include "Luau/TypeFwd.h"
#include "Luau/Variant.h"
@ -62,6 +63,7 @@ struct ConstraintSolver
NotNull<BuiltinTypes> builtinTypes;
InternalErrorReporter iceReporter;
NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
// The entire set of constraints that the solver is trying to resolve.
std::vector<NotNull<Constraint>> constraints;
NotNull<Scope> rootScope;
@ -111,6 +113,7 @@ struct ConstraintSolver
explicit ConstraintSolver(
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> rootScope,
std::vector<NotNull<Constraint>> constraints,
ModuleName moduleName,
@ -278,18 +281,18 @@ public:
/**
* @returns true if the TypeId is in a blocked state.
*/
bool isBlocked(TypeId ty);
bool isBlocked(TypeId ty) const;
/**
* @returns true if the TypePackId is in a blocked state.
*/
bool isBlocked(TypePackId tp);
bool isBlocked(TypePackId tp) const;
/**
* Returns whether the constraint is blocked on anything.
* @param constraint the constraint to check.
*/
bool isBlocked(NotNull<const Constraint> constraint);
bool isBlocked(NotNull<const Constraint> constraint) const;
/** Pushes a new solver constraint to the solver.
* @param cv the body of the constraint.
@ -381,8 +384,8 @@ public:
TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp);
void throwTimeLimitError();
void throwUserCancelError();
void throwTimeLimitError() const;
void throwUserCancelError() const;
ToStringOptions opts;
};

View file

@ -448,6 +448,13 @@ struct UnexpectedTypePackInSubtyping
bool operator==(const UnexpectedTypePackInSubtyping& rhs) const;
};
struct UserDefinedTypeFunctionError
{
std::string message;
bool operator==(const UserDefinedTypeFunctionError& rhs) const;
};
using TypeErrorData = Variant<
TypeMismatch,
UnknownSymbol,
@ -496,7 +503,8 @@ using TypeErrorData = Variant<
CheckedFunctionIncorrectArgs,
UnexpectedTypeInSubtyping,
UnexpectedTypePackInSubtyping,
ExplicitFunctionAnnotationRecommended>;
ExplicitFunctionAnnotationRecommended,
UserDefinedTypeFunctionError>;
struct TypeErrorSummary
{

View file

@ -0,0 +1,23 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
#include "Luau/Ast.h"
#include <vector>
namespace Luau
{
struct FragmentAutocompleteAncestryResult
{
DenseHashMap<AstName, AstLocal*> localMap{AstName()};
std::vector<AstLocal*> localStack;
std::vector<AstNode*> ancestry;
AstStat* nearestStatement;
};
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos);
} // namespace Luau

View file

@ -35,6 +35,7 @@ struct OverloadResolver
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope,
NotNull<InternalErrorReporter> reporter,
NotNull<TypeCheckLimits> limits,
@ -44,6 +45,7 @@ struct OverloadResolver
NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena;
NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
NotNull<Scope> scope;
NotNull<InternalErrorReporter> ice;
NotNull<TypeCheckLimits> limits;
@ -109,6 +111,7 @@ SolveResult solveFunctionCall(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits,
NotNull<Scope> scope,

View file

@ -135,6 +135,7 @@ struct Subtyping
NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena;
NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
NotNull<InternalErrorReporter> iceReporter;
TypeCheckLimits limits;
@ -155,6 +156,7 @@ struct Subtyping
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> typeArena,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter
);

View file

@ -83,6 +83,7 @@ struct TypeChecker2
DenseHashSet<TypeId> seenTypeFunctionInstances{nullptr};
Normalizer normalizer;
TypeFunctionRuntime typeFunctionRuntime;
Subtyping _subtyping;
NotNull<Subtyping> subtyping;

View file

@ -1,10 +1,11 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/ConstraintSolver.h"
#include "Luau/Constraint.h"
#include "Luau/Error.h"
#include "Luau/NotNull.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFunctionRuntime.h"
#include "Luau/TypeFwd.h"
#include <functional>
@ -16,14 +17,23 @@ namespace Luau
struct TypeArena;
struct TxnLog;
struct ConstraintSolver;
class Normalizer;
struct TypeFunctionRuntime
{
// For user-defined type functions, we store all generated types and packs for the duration of the typecheck
TypedAllocator<TypeFunctionType> typeArena;
TypedAllocator<TypeFunctionTypePackVar> typePackArena;
};
struct TypeFunctionContext
{
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtins;
NotNull<Scope> scope;
NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
NotNull<InternalErrorReporter> ice;
NotNull<TypeCheckLimits> limits;
@ -35,23 +45,14 @@ struct TypeFunctionContext
std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs
std::optional<AstExprFunction*> userFuncBody; // Body of the user-defined type function; only available for UDTFs
TypeFunctionContext(NotNull<ConstraintSolver> cs, NotNull<Scope> scope, NotNull<const Constraint> constraint)
: arena(cs->arena)
, builtins(cs->builtinTypes)
, scope(scope)
, normalizer(cs->normalizer)
, ice(NotNull{&cs->iceReporter})
, limits(NotNull{&cs->limits})
, solver(cs.get())
, constraint(constraint.get())
{
}
TypeFunctionContext(NotNull<ConstraintSolver> cs, NotNull<Scope> scope, NotNull<const Constraint> constraint);
TypeFunctionContext(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtins,
NotNull<Scope> scope,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> ice,
NotNull<TypeCheckLimits> limits
)
@ -59,6 +60,7 @@ struct TypeFunctionContext
, builtins(builtins)
, scope(scope)
, normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime)
, ice(ice)
, limits(limits)
, solver(nullptr)
@ -66,7 +68,7 @@ struct TypeFunctionContext
{
}
NotNull<Constraint> pushConstraint(ConstraintV&& c);
NotNull<Constraint> pushConstraint(ConstraintV&& c) const;
};
/// Represents a reduction result, which may have successfully reduced the type,
@ -88,6 +90,8 @@ struct TypeFunctionReductionResult
/// Any type packs that need to be progressed or mutated before the
/// reduction may proceed.
std::vector<TypePackId> blockedPacks;
/// A runtime error message from user-defined type functions
std::optional<std::string> error;
};
template<typename T>

View file

@ -0,0 +1,267 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Common.h"
#include "Luau/Variant.h"
#include <optional>
#include <string>
#include <unordered_map>
#include <vector>
using lua_State = struct lua_State;
namespace Luau
{
void* typeFunctionAlloc(void* ud, void* ptr, size_t osize, size_t nsize);
// Replica of types from Type.h
struct TypeFunctionType;
using TypeFunctionTypeId = const TypeFunctionType*;
struct TypeFunctionTypePackVar;
using TypeFunctionTypePackId = const TypeFunctionTypePackVar*;
struct TypeFunctionPrimitiveType
{
enum Type
{
NilType,
Boolean,
Number,
String,
};
Type type;
TypeFunctionPrimitiveType(Type type)
: type(type)
{
}
};
struct TypeFunctionBooleanSingleton
{
bool value = false;
};
struct TypeFunctionStringSingleton
{
std::string value;
};
using TypeFunctionSingletonVariant = Variant<TypeFunctionBooleanSingleton, TypeFunctionStringSingleton>;
struct TypeFunctionSingletonType
{
TypeFunctionSingletonVariant variant;
explicit TypeFunctionSingletonType(TypeFunctionSingletonVariant variant)
: variant(std::move(variant))
{
}
};
template<typename T>
const T* get(const TypeFunctionSingletonType* tv)
{
LUAU_ASSERT(tv);
return tv ? get_if<T>(&tv->variant) : nullptr;
}
template<typename T>
T* getMutable(const TypeFunctionSingletonType* tv)
{
LUAU_ASSERT(tv);
return tv ? get_if<T>(&const_cast<TypeFunctionSingletonType*>(tv)->variant) : nullptr;
}
struct TypeFunctionUnionType
{
std::vector<TypeFunctionTypeId> components;
};
struct TypeFunctionIntersectionType
{
std::vector<TypeFunctionTypeId> components;
};
struct TypeFunctionAnyType
{
};
struct TypeFunctionUnknownType
{
};
struct TypeFunctionNeverType
{
};
struct TypeFunctionNegationType
{
TypeFunctionTypeId type;
};
struct TypeFunctionTypePack
{
std::vector<TypeFunctionTypeId> head;
std::optional<TypeFunctionTypePackId> tail;
};
struct TypeFunctionVariadicTypePack
{
TypeFunctionTypeId type;
};
using TypeFunctionTypePackVariant = Variant<TypeFunctionTypePack, TypeFunctionVariadicTypePack>;
struct TypeFunctionTypePackVar
{
TypeFunctionTypePackVariant type;
TypeFunctionTypePackVar(TypeFunctionTypePackVariant type)
: type(std::move(type))
{
}
bool operator==(const TypeFunctionTypePackVar& rhs) const;
};
struct TypeFunctionFunctionType
{
TypeFunctionTypePackId argTypes;
TypeFunctionTypePackId retTypes;
};
template<typename T>
const T* get(TypeFunctionTypePackId tv)
{
LUAU_ASSERT(tv);
return tv ? get_if<T>(&tv->type) : nullptr;
}
template<typename T>
T* getMutable(TypeFunctionTypePackId tv)
{
LUAU_ASSERT(tv);
return tv ? get_if<T>(&const_cast<TypeFunctionTypePackVar*>(tv)->type) : nullptr;
}
struct TypeFunctionTableIndexer
{
TypeFunctionTableIndexer(TypeFunctionTypeId keyType, TypeFunctionTypeId valueType)
: keyType(keyType)
, valueType(valueType)
{
}
TypeFunctionTypeId keyType;
TypeFunctionTypeId valueType;
};
struct TypeFunctionProperty
{
static TypeFunctionProperty readonly(TypeFunctionTypeId ty);
static TypeFunctionProperty writeonly(TypeFunctionTypeId ty);
static TypeFunctionProperty rw(TypeFunctionTypeId ty); // Shared read-write type.
static TypeFunctionProperty rw(TypeFunctionTypeId read, TypeFunctionTypeId write); // Separate read-write type.
bool isReadOnly() const;
bool isWriteOnly() const;
std::optional<TypeFunctionTypeId> readTy;
std::optional<TypeFunctionTypeId> writeTy;
};
struct TypeFunctionTableType
{
using Name = std::string;
using Props = std::unordered_map<Name, TypeFunctionProperty>;
Props props;
std::optional<TypeFunctionTableIndexer> indexer;
// Should always be a TypeFunctionTableType
std::optional<TypeFunctionTypeId> metatable;
};
struct TypeFunctionClassType
{
using Name = std::string;
using Props = std::unordered_map<Name, TypeFunctionProperty>;
Props props;
std::optional<TypeFunctionTableIndexer> indexer;
std::optional<TypeFunctionTypeId> metatable; // metaclass?
std::optional<TypeFunctionTypeId> parent;
std::string name;
};
using TypeFunctionTypeVariant = Luau::Variant<
TypeFunctionPrimitiveType,
TypeFunctionAnyType,
TypeFunctionUnknownType,
TypeFunctionNeverType,
TypeFunctionSingletonType,
TypeFunctionUnionType,
TypeFunctionIntersectionType,
TypeFunctionNegationType,
TypeFunctionFunctionType,
TypeFunctionTableType,
TypeFunctionClassType>;
struct TypeFunctionType
{
TypeFunctionTypeVariant type;
TypeFunctionType(TypeFunctionTypeVariant type)
: type(std::move(type))
{
}
bool operator==(const TypeFunctionType& rhs) const;
};
template<typename T>
const T* get(TypeFunctionTypeId tv)
{
LUAU_ASSERT(tv);
return tv ? Luau::get_if<T>(&tv->type) : nullptr;
}
template<typename T>
T* getMutable(TypeFunctionTypeId tv)
{
LUAU_ASSERT(tv);
return tv ? Luau::get_if<T>(&const_cast<TypeFunctionType*>(tv)->type) : nullptr;
}
std::optional<std::string> checkResultForError(lua_State* L, const char* typeFunctionName, int luaResult);
TypeFunctionType* allocateTypeFunctionType(lua_State* L, TypeFunctionTypeVariant type);
TypeFunctionTypePackVar* allocateTypeFunctionTypePack(lua_State* L, TypeFunctionTypePackVariant type);
void allocTypeUserData(lua_State* L, TypeFunctionTypeVariant type);
bool isTypeUserData(lua_State* L, int idx);
TypeFunctionTypeId getTypeUserData(lua_State* L, int idx);
std::optional<TypeFunctionTypeId> optionalTypeUserData(lua_State* L, int idx);
void registerTypeUserData(lua_State* L);
void setTypeFunctionEnvironment(lua_State* L);
} // namespace Luau

View file

@ -0,0 +1,52 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Type.h"
#include "Luau/TypeFunction.h"
#include "Luau/TypeFunctionRuntime.h"
namespace Luau
{
using Kind = Variant<TypeId, TypePackId>;
template<typename T>
const T* get(const Kind& kind)
{
return get_if<T>(&kind);
}
using TypeFunctionKind = Variant<TypeFunctionTypeId, TypeFunctionTypePackId>;
template<typename T>
const T* get(const TypeFunctionKind& tfkind)
{
return get_if<T>(&tfkind);
}
struct TypeFunctionRuntimeBuilderState
{
NotNull<TypeFunctionContext> ctx;
// Mapping of class name to ClassType
// Invariant: users can not create a new class types -> any class types that get deserialized must have been an argument to the type function
// Using this invariant, whenever a ClassType is serialized, we can put it into this map
// whenever a ClassType is deserialized, we can use this map to return the corresponding value
DenseHashMap<std::string, TypeId> classesSerialized{{}};
// List of errors that occur during serialization/deserialization
// At every iteration of serialization/deserialzation, if this list.size() != 0, we halt the process
std::vector<std::string> errors{};
TypeFunctionRuntimeBuilderState(NotNull<TypeFunctionContext> ctx)
: ctx(ctx)
, classesSerialized({})
, errors({})
{
}
};
TypeFunctionTypeId serialize(TypeId ty, TypeFunctionRuntimeBuilderState* state);
TypeId deserialize(TypeFunctionTypeId ty, TypeFunctionRuntimeBuilderState* state);
} // namespace Luau

View file

@ -149,13 +149,15 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull<Scope> scope, T
if (FFlag::LuauSolverV2)
{
TypeFunctionRuntime typeFunctionRuntime; // TODO: maybe subtyping checks should not invoke user-defined type function runtime
if (FFlag::LuauAutocompleteNewSolverLimit)
{
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit;
}
Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&iceReporter}};
Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}};
return subtyping.isSubtype(subTy, superTy, scope).isSubtype;
}

View file

@ -321,6 +321,7 @@ struct InstantiationQueuer : TypeOnceVisitor
ConstraintSolver::ConstraintSolver(
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> rootScope,
std::vector<NotNull<Constraint>> constraints,
ModuleName moduleName,
@ -332,11 +333,12 @@ ConstraintSolver::ConstraintSolver(
: arena(normalizer->arena)
, builtinTypes(normalizer->builtinTypes)
, normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime)
, constraints(std::move(constraints))
, rootScope(rootScope)
, currentModuleName(std::move(moduleName))
, moduleResolver(moduleResolver)
, requireCycles(requireCycles)
, requireCycles(std::move(requireCycles))
, logger(logger)
, limits(std::move(limits))
{
@ -344,7 +346,7 @@ ConstraintSolver::ConstraintSolver(
for (NotNull<Constraint> c : this->constraints)
{
unsolvedConstraints.push_back(c);
unsolvedConstraints.emplace_back(c);
// initialize the reference counts for the free types in this constraint.
for (auto ty : c->getMaybeMutatedFreeTypes())
@ -1240,7 +1242,14 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
}
OverloadResolver resolver{
builtinTypes, NotNull{arena}, normalizer, constraint->scope, NotNull{&iceReporter}, NotNull{&limits}, constraint->location
builtinTypes,
NotNull{arena},
normalizer,
typeFunctionRuntime,
constraint->scope,
NotNull{&iceReporter},
NotNull{&limits},
constraint->location
};
auto [status, overload] = resolver.selectOverload(fn, argsPack);
TypeId overloadToUse = fn;
@ -1270,7 +1279,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
for (const auto& [expanded, additions] : u2.expandedFreeTypes)
{
for (TypeId addition : additions)
upperBoundContributors[expanded].push_back(std::make_pair(constraint->location, addition));
upperBoundContributors[expanded].emplace_back(constraint->location, addition);
}
if (occursCheckPassed && c.callSite)
@ -1437,8 +1446,17 @@ bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNull<con
else if (expectedType && maybeSingleton(*expectedType))
bindTo = freeType->lowerBound;
shiftReferences(c.freeType, bindTo);
bind(constraint, c.freeType, bindTo);
if (DFInt::LuauTypeSolverRelease >= 645)
{
auto ty = follow(c.freeType);
shiftReferences(ty, bindTo);
bind(constraint, ty, bindTo);
}
else
{
shiftReferences(c.freeType, bindTo);
bind(constraint, c.freeType, bindTo);
}
return true;
}
@ -2603,7 +2621,7 @@ bool ConstraintSolver::unify(NotNull<const Constraint> constraint, TID subTy, TI
for (const auto& [expanded, additions] : u2.expandedFreeTypes)
{
for (TypeId addition : additions)
upperBoundContributors[expanded].push_back(std::make_pair(constraint->location, addition));
upperBoundContributors[expanded].emplace_back(constraint->location, addition);
}
}
else
@ -2820,7 +2838,7 @@ void ConstraintSolver::reproduceConstraints(NotNull<Scope> scope, const Location
}
}
bool ConstraintSolver::isBlocked(TypeId ty)
bool ConstraintSolver::isBlocked(TypeId ty) const
{
ty = follow(ty);
@ -2830,7 +2848,7 @@ bool ConstraintSolver::isBlocked(TypeId ty)
return nullptr != get<BlockedType>(ty) || nullptr != get<PendingExpansionType>(ty);
}
bool ConstraintSolver::isBlocked(TypePackId tp)
bool ConstraintSolver::isBlocked(TypePackId tp) const
{
tp = follow(tp);
@ -2840,7 +2858,7 @@ bool ConstraintSolver::isBlocked(TypePackId tp)
return nullptr != get<BlockedTypePack>(tp);
}
bool ConstraintSolver::isBlocked(NotNull<const Constraint> constraint)
bool ConstraintSolver::isBlocked(NotNull<const Constraint> constraint) const
{
auto blockedIt = blockedConstraints.find(constraint);
return blockedIt != blockedConstraints.end() && blockedIt->second > 0;
@ -2851,7 +2869,7 @@ NotNull<Constraint> ConstraintSolver::pushConstraint(NotNull<Scope> scope, const
std::unique_ptr<Constraint> c = std::make_unique<Constraint>(scope, location, std::move(cv));
NotNull<Constraint> borrow = NotNull(c.get());
solverConstraints.push_back(std::move(c));
unsolvedConstraints.push_back(borrow);
unsolvedConstraints.emplace_back(borrow);
return borrow;
}
@ -2997,12 +3015,12 @@ TypePackId ConstraintSolver::anyifyModuleReturnTypePackGenerics(TypePackId tp)
return arena->addTypePack(resultTypes, resultTail);
}
LUAU_NOINLINE void ConstraintSolver::throwTimeLimitError()
LUAU_NOINLINE void ConstraintSolver::throwTimeLimitError() const
{
throw TimeLimitError(currentModuleName);
}
LUAU_NOINLINE void ConstraintSolver::throwUserCancelError()
LUAU_NOINLINE void ConstraintSolver::throwUserCancelError() const
{
throw UserCancelError(currentModuleName);
}

View file

@ -793,6 +793,11 @@ struct ErrorConverter
return "Encountered an unexpected type pack in subtyping: " + toString(e.tp);
}
std::string operator()(const UserDefinedTypeFunctionError& e) const
{
return e.message;
}
std::string operator()(const CannotAssignToNever& e) const
{
std::string result = "Cannot assign a value of type " + toString(e.rhsType) + " to a field of type never";
@ -1175,6 +1180,11 @@ bool UnexpectedTypePackInSubtyping::operator==(const UnexpectedTypePackInSubtypi
return tp == rhs.tp;
}
bool UserDefinedTypeFunctionError::operator==(const UserDefinedTypeFunctionError& rhs) const
{
return message == rhs.message;
}
bool CannotAssignToNever::operator==(const CannotAssignToNever& rhs) const
{
if (cause.size() != rhs.cause.size())
@ -1384,6 +1394,9 @@ void copyError(T& e, TypeArena& destArena, CloneState& cloneState)
e.ty = clone(e.ty);
else if constexpr (std::is_same_v<T, UnexpectedTypePackInSubtyping>)
e.tp = clone(e.tp);
else if constexpr (std::is_same_v<T, UserDefinedTypeFunctionError>)
{
}
else if constexpr (std::is_same_v<T, CannotAssignToNever>)
{
e.rhsType = clone(e.rhsType);

View file

@ -0,0 +1,48 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/FragmentAutocomplete.h"
#include "Luau/Ast.h"
#include "Luau/AstQuery.h"
namespace Luau
{
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos)
{
std::vector<AstNode*> ancestry = findAncestryAtPositionForAutocomplete(root, cursorPos);
DenseHashMap<AstName, AstLocal*> localMap{AstName()};
std::vector<AstLocal*> localStack;
AstStat* nearestStatement = nullptr;
for (AstNode* node : ancestry)
{
if (auto block = node->as<AstStatBlock>())
{
for (auto stat : block->body)
{
if (stat->location.begin <= cursorPos)
nearestStatement = stat;
if (stat->location.begin <= cursorPos)
{
// This statement precedes the current one
if (auto loc = stat->as<AstStatLocal>())
{
for (auto v : loc->vars)
{
localStack.push_back(v);
localMap[v->name] = v;
}
}
else if (auto locFun = stat->as<AstStatLocalFunction>())
{
localStack.push_back(locFun->name);
localMap[locFun->name->name] = locFun->name;
}
}
}
}
}
return {std::move(localMap), std::move(localStack), std::move(ancestry), std::move(nearestStatement)};
}
} // namespace Luau

View file

@ -1383,6 +1383,7 @@ ModulePtr check(
unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit);
Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}};
TypeFunctionRuntime typeFunctionRuntime;
ConstraintGenerator cg{
result,
@ -1402,6 +1403,7 @@ ModulePtr check(
ConstraintSolver cs{
NotNull{&normalizer},
NotNull{&typeFunctionRuntime},
NotNull(cg.rootScope),
borrowConstraints(cg.constraints),
result->name,

View file

@ -9,6 +9,8 @@
#include "Luau/TypePack.h"
#include "Luau/VisitType.h"
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
namespace Luau
{
@ -871,6 +873,17 @@ struct TypeCacher : TypeOnceVisitor
markUncacheable(tp);
return false;
}
bool visit(TypePackId tp, const BoundTypePack& btp) override {
if (DFInt::LuauTypeSolverRelease >= 645) {
traverse(btp.boundTo);
if (isUncacheable(btp.boundTo))
markUncacheable(tp);
return false;
}
return true;
}
};
std::optional<TypeId> generalize(

View file

@ -227,6 +227,8 @@ static void errorToString(std::ostream& stream, const T& err)
stream << "UnexpectedTypeInSubtyping { ty = '" + toString(err.ty) + "' }";
else if constexpr (std::is_same_v<T, UnexpectedTypePackInSubtyping>)
stream << "UnexpectedTypePackInSubtyping { tp = '" + toString(err.tp) + "' }";
else if constexpr (std::is_same_v<T, UserDefinedTypeFunctionError>)
stream << "UserDefinedTypeFunctionError { " << err.message << " }";
else if constexpr (std::is_same_v<T, CannotAssignToNever>)
{
stream << "CannotAssignToNever { rvalueType = '" << toString(err.rhsType) << "', reason = '" << err.reason << "', cause = { ";

View file

@ -15,6 +15,7 @@
#include <algorithm>
LUAU_FASTFLAG(LuauSolverV2);
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
namespace Luau
{
@ -131,10 +132,26 @@ struct ClonePublicInterface : Substitution
}
ftv->level = TypeLevel{0, 0};
if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645)
ftv->scope = nullptr;
}
else if (TableType* ttv = getMutable<TableType>(result))
{
ttv->level = TypeLevel{0, 0};
if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645)
ttv->scope = nullptr;
}
if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645)
{
if (auto freety = getMutable<FreeType>(result))
{
freety->scope = nullptr;
}
else if (auto genericty = getMutable<GenericType>(result))
{
genericty->scope = nullptr;
}
}
return result;

View file

@ -160,6 +160,7 @@ struct NonStrictTypeChecker
NotNull<TypeArena> arena;
Module* module;
Normalizer normalizer;
TypeFunctionRuntime typeFunctionRuntime;
Subtyping subtyping;
NotNull<const DataFlowGraph> dfg;
DenseHashSet<TypeId> noTypeFunctionErrors{nullptr};
@ -182,7 +183,7 @@ struct NonStrictTypeChecker
, arena(arena)
, module(module)
, normalizer{arena, builtinTypes, unifierState, /* cache inhabitance */ true}
, subtyping{builtinTypes, arena, NotNull(&normalizer), ice}
, subtyping{builtinTypes, arena, NotNull(&normalizer), NotNull(&typeFunctionRuntime), ice}
, dfg(dfg)
, limits(limits)
{
@ -228,7 +229,12 @@ struct NonStrictTypeChecker
return instance;
ErrorVec errors =
reduceTypeFunctions(instance, location, TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true)
reduceTypeFunctions(
instance,
location,
TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, ice, limits},
true
)
.errors;
if (errors.empty())

View file

@ -3434,11 +3434,12 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<Built
UnifierSharedState sharedState{&ice};
TypeArena arena;
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
TypeFunctionRuntime typeFunctionRuntime; // TODO: maybe subtyping checks should not invoke user-defined type function runtime
// Subtyping under DCR is not implemented using unification!
if (FFlag::LuauSolverV2)
{
Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&ice}};
Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}};
return subtyping.isSubtype(subTy, superTy, scope).isSubtype;
}
@ -3456,11 +3457,12 @@ bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, N
UnifierSharedState sharedState{&ice};
TypeArena arena;
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
TypeFunctionRuntime typeFunctionRuntime; // TODO: maybe subtyping checks should not invoke user-defined type function runtime
// Subtyping under DCR is not implemented using unification!
if (FFlag::LuauSolverV2)
{
Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&ice}};
Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}};
return subtyping.isSubtype(subPack, superPack, scope).isSubtype;
}

View file

@ -17,6 +17,7 @@ OverloadResolver::OverloadResolver(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope,
NotNull<InternalErrorReporter> reporter,
NotNull<TypeCheckLimits> limits,
@ -25,10 +26,11 @@ OverloadResolver::OverloadResolver(
: builtinTypes(builtinTypes)
, arena(arena)
, normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime)
, scope(scope)
, ice(reporter)
, limits(limits)
, subtyping({builtinTypes, arena, normalizer, ice})
, subtyping({builtinTypes, arena, normalizer, typeFunctionRuntime, ice})
, callLoc(callLocation)
{
}
@ -199,8 +201,9 @@ std::pair<OverloadResolver::Analysis, ErrorVec> OverloadResolver::checkOverload_
const std::vector<AstExpr*>* argExprs
)
{
FunctionGraphReductionResult result =
reduceTypeFunctions(fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, ice, limits}, /*force=*/true);
FunctionGraphReductionResult result = reduceTypeFunctions(
fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, ice, limits}, /*force=*/true
);
if (!result.errors.empty())
return {OverloadIsNonviable, result.errors};
@ -405,6 +408,7 @@ std::optional<TypeId> selectOverload(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope,
NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits,
@ -413,7 +417,7 @@ std::optional<TypeId> selectOverload(
TypePackId argsPack
)
{
OverloadResolver resolver{builtinTypes, arena, normalizer, scope, iceReporter, limits, location};
OverloadResolver resolver{builtinTypes, arena, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location};
auto [status, overload] = resolver.selectOverload(fn, argsPack);
if (status == OverloadResolver::Analysis::Ok)
@ -429,6 +433,7 @@ SolveResult solveFunctionCall(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits,
NotNull<Scope> scope,
@ -437,7 +442,8 @@ SolveResult solveFunctionCall(
TypePackId argsPack
)
{
std::optional<TypeId> overloadToUse = selectOverload(builtinTypes, arena, normalizer, scope, iceReporter, limits, location, fn, argsPack);
std::optional<TypeId> overloadToUse =
selectOverload(builtinTypes, arena, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location, fn, argsPack);
if (!overloadToUse)
return {SolveResult::NoMatchingOverload};

View file

@ -440,11 +440,13 @@ Subtyping::Subtyping(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> typeArena,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter
)
: builtinTypes(builtinTypes)
, arena(typeArena)
, normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime)
, iceReporter(iceReporter)
{
}
@ -1911,7 +1913,7 @@ TypeId Subtyping::makeAggregateType(const Container& container, TypeId orElse)
std::pair<TypeId, ErrorVec> Subtyping::handleTypeFunctionReductionResult(const TypeFunctionInstanceType* functionInstance, NotNull<Scope> scope)
{
TypeFunctionContext context{arena, builtinTypes, scope, normalizer, iceReporter, NotNull{&limits}};
TypeFunctionContext context{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, iceReporter, NotNull{&limits}};
TypeId function = arena->addType(*functionInstance);
FunctionGraphReductionResult result = reduceTypeFunctions(function, {}, context, true);
ErrorVec errors;

View file

@ -1040,6 +1040,7 @@ struct TypeStringifier
state.emit(tfitv.userFuncName->value);
else
state.emit(tfitv.function->name);
state.emit("<");
bool comma = false;

View file

@ -31,6 +31,7 @@
#include <ostream>
LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctions)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
namespace Luau
@ -306,7 +307,7 @@ TypeChecker2::TypeChecker2(
, sourceModule(sourceModule)
, module(module)
, normalizer{&module->internalTypes, builtinTypes, unifierState, /* cacheInhabitance */ true}
, _subtyping{builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, NotNull{unifierState->iceHandler}}
, _subtyping{builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{unifierState->iceHandler}}
, subtyping(&_subtyping)
{
}
@ -484,13 +485,16 @@ TypeId TypeChecker2::checkForTypeFunctionInhabitance(TypeId instance, Location l
return instance;
seenTypeFunctionInstances.insert(instance);
ErrorVec errors = reduceTypeFunctions(
instance,
location,
TypeFunctionContext{NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits},
true
)
.errors;
ErrorVec errors =
reduceTypeFunctions(
instance,
location,
TypeFunctionContext{
NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, ice, limits
},
true
)
.errors;
if (!isErrorSuppressing(location, instance))
reportErrors(std::move(errors));
return instance;
@ -1194,8 +1198,8 @@ void TypeChecker2::visit(AstStatTypeAlias* stat)
void TypeChecker2::visit(AstStatTypeFunction* stat)
{
// TODO: add type checking for user-defined type functions
reportError(TypeError{stat->location, GenericError{"This syntax is not supported"}});
if (!FFlag::LuauUserDefinedTypeFunctions)
reportError(TypeError{stat->location, GenericError{"This syntax is not supported"}});
}
void TypeChecker2::visit(AstTypeList types)
@ -1446,6 +1450,7 @@ void TypeChecker2::visitCall(AstExprCall* call)
builtinTypes,
NotNull{&module->internalTypes},
NotNull{&normalizer},
NotNull{&typeFunctionRuntime},
NotNull{stack.back()},
ice,
limits,

View file

@ -2,7 +2,9 @@
#include "Luau/TypeFunction.h"
#include "Luau/BytecodeBuilder.h"
#include "Luau/Common.h"
#include "Luau/Compiler.h"
#include "Luau/ConstraintSolver.h"
#include "Luau/DenseHash.h"
#include "Luau/Instantiation.h"
@ -12,17 +14,25 @@
#include "Luau/Set.h"
#include "Luau/Simplify.h"
#include "Luau/Subtyping.h"
#include "Luau/TimeTrace.h"
#include "Luau/ToString.h"
#include "Luau/TxnLog.h"
#include "Luau/Type.h"
#include "Luau/TypeFunctionReductionGuesser.h"
#include "Luau/TypeFunctionRuntime.h"
#include "Luau/TypeFunctionRuntimeBuilder.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypeUtils.h"
#include "Luau/Unifier2.h"
#include "Luau/VecDeque.h"
#include "Luau/VisitType.h"
#include "lua.h"
#include "lualib.h"
#include <iterator>
#include <memory>
#include <unordered_map>
// used to control emitting CodeTooComplex warnings on type function reduction
LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyGraphReductionMaximumSteps, 1'000'000);
@ -35,7 +45,8 @@ LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyApplicationCartesianProductLimit, 5'0
// when this value is set to a negative value, guessing will be totally disabled.
LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyUseGuesserDepth, -1);
LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies, false);
LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies, false)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctions, false)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
@ -166,7 +177,7 @@ struct TypeFunctionReducer
return SkipTestResult::Okay;
}
SkipTestResult testForSkippability(TypePackId ty)
SkipTestResult testForSkippability(TypePackId ty) const
{
ty = follow(ty);
@ -214,15 +225,18 @@ struct TypeFunctionReducer
{
irreducible.insert(subject);
if (reduction.error.has_value())
result.errors.emplace_back(location, UserDefinedTypeFunctionError{*reduction.error});
if (reduction.uninhabited || force)
{
if (FFlag::DebugLuauLogTypeFamilies)
printf("%s is uninhabited\n", toString(subject, {true}).c_str());
if constexpr (std::is_same_v<T, TypeId>)
result.errors.push_back(TypeError{location, UninhabitedTypeFunction{subject}});
result.errors.emplace_back(location, UninhabitedTypeFunction{subject});
else if constexpr (std::is_same_v<T, TypePackId>)
result.errors.push_back(TypeError{location, UninhabitedTypePackFunction{subject}});
result.errors.emplace_back(location, UninhabitedTypePackFunction{subject});
}
else if (!reduction.uninhabited && !force)
{
@ -243,7 +257,7 @@ struct TypeFunctionReducer
}
}
bool done()
bool done() const
{
return queuedTys.empty() && queuedTps.empty();
}
@ -422,7 +436,7 @@ static FunctionGraphReductionResult reduceFunctionsInternal(
++iterationCount;
if (iterationCount > DFInt::LuauTypeFamilyGraphReductionMaximumSteps)
{
reducer.result.errors.push_back(TypeError{location, CodeTooComplex{}});
reducer.result.errors.emplace_back(location, CodeTooComplex{});
break;
}
}
@ -506,7 +520,7 @@ static std::optional<TypeFunctionReductionResult<TypeId>> tryDistributeTypeFunct
size_t cartesianProductSize = 1;
const UnionType* firstUnion = nullptr;
size_t unionIndex;
size_t unionIndex = 0;
std::vector<TypeId> arguments = typeParams;
for (size_t i = 0; i < arguments.size(); ++i)
@ -572,6 +586,8 @@ static std::optional<TypeFunctionReductionResult<TypeId>> tryDistributeTypeFunct
return std::nullopt;
}
using StateRef = std::unique_ptr<lua_State, void (*)(lua_State*)>;
TypeFunctionReductionResult<TypeId> userDefinedTypeFunction(
TypeId instance,
const std::vector<TypeId>& typeParams,
@ -585,9 +601,122 @@ TypeFunctionReductionResult<TypeId> userDefinedTypeFunction(
return {std::nullopt, true, {}, {}};
}
// TODO: implementation of user-defined type functions goes here
for (auto typeParam : typeParams)
{
TypeId ty = follow(typeParam);
return {std::nullopt, true, {}, {}};
// block if we need to
if (isPending(ty, ctx->solver))
return {std::nullopt, false, {ty}, {}};
}
AstName name = *ctx->userFuncName;
AstExprFunction* function = *ctx->userFuncBody;
// Construct ParseResult containing the type function
Allocator allocator;
AstNameTable names(allocator);
AstExprGlobal globalName{Location{}, name};
AstStatFunction typeFunction{Location{}, &globalName, function};
AstStat* stmtArray[] = {&typeFunction};
AstArray<AstStat*> stmts{stmtArray, 1};
AstStatBlock exec{Location{}, stmts};
ParseResult parseResult{&exec, 1};
BytecodeBuilder builder;
try
{
compileOrThrow(builder, parseResult, names);
}
catch (CompileError& e)
{
std::string errMsg = format("'%s' type function failed to compile with error message: %s", name.value, e.what());
return {std::nullopt, true, {}, {}, errMsg};
}
std::string bytecode = builder.getBytecode();
// Initialize Lua state
StateRef globalState(lua_newstate(typeFunctionAlloc, nullptr), lua_close);
lua_State* L = globalState.get();
lua_setthreaddata(L, ctx.get());
setTypeFunctionEnvironment(L);
// Register type userdata
registerTypeUserData(L);
luaL_sandbox(L);
luaL_sandboxthread(L);
// Load bytecode into Luau state
if (auto error = checkResultForError(L, name.value, luau_load(L, name.value, bytecode.data(), bytecode.size(), 0)))
return {std::nullopt, true, {}, {}, error};
// Execute the loaded chunk to register the function in the global environment
if (auto error = checkResultForError(L, name.value, lua_pcall(L, 0, 0, 0)))
return {std::nullopt, true, {}, {}, error};
// Get type function from the global environment
lua_getglobal(L, name.value);
if (!lua_isfunction(L, -1))
{
std::string errMsg = format("Could not find '%s' type function in the global scope", name.value);
return {std::nullopt, true, {}, {}, errMsg};
}
// Push serialized arguments onto the stack
// Since there aren't any new class types being created in type functions, there isn't a deserialization function
// class types. Instead, we can keep this map and return the mapping as the "deserialized value"
std::unique_ptr<TypeFunctionRuntimeBuilderState> runtimeBuilder = std::make_unique<TypeFunctionRuntimeBuilderState>(ctx);
for (auto typeParam : typeParams)
{
TypeId ty = follow(typeParam);
// This is checked at the top of the function, and should still be true.
LUAU_ASSERT(!isPending(ty, ctx->solver));
TypeFunctionTypeId serializedTy = serialize(ty, runtimeBuilder.get());
// Check if there were any errors while serializing
if (runtimeBuilder->errors.size() != 0)
return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()};
allocTypeUserData(L, serializedTy->type);
}
// Set up an interrupt handler for type functions to respect type checking limits and LSP cancellation requests.
lua_callbacks(L)->interrupt = [](lua_State* L, int gc)
{
auto ctx = static_cast<const TypeFunctionContext*>(lua_getthreaddata(lua_mainthread(L)));
if (ctx->limits->finishTime && TimeTrace::getClock() > *ctx->limits->finishTime)
ctx->solver->throwTimeLimitError();
if (ctx->limits->cancellationToken && ctx->limits->cancellationToken->requested())
ctx->solver->throwUserCancelError();
};
if (auto error = checkResultForError(L, name.value, lua_resume(L, nullptr, int(typeParams.size()))))
return {std::nullopt, true, {}, {}, error};
// If the return value is not a type userdata, return with error message
if (!isTypeUserData(L, 1))
return {std::nullopt, true, {}, {}, format("'%s' type function: returned a non-type value", name.value)};
TypeFunctionTypeId retTypeFunctionTypeId = getTypeUserData(L, 1);
// No errors should be present here since we should've returned already if any were raised during serialization.
LUAU_ASSERT(runtimeBuilder->errors.size() == 0);
TypeId retTypeId = deserialize(retTypeFunctionTypeId, runtimeBuilder.get());
// At least 1 error occured while deserializing
if (runtimeBuilder->errors.size() > 0)
return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()};
return {retTypeId, false, {}, {}};
}
TypeFunctionReductionResult<TypeId> notTypeFunction(
@ -711,7 +840,7 @@ TypeFunctionReductionResult<TypeId> lenTypeFunction(
if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes))
return {std::nullopt, true, {}, {}}; // occurs check failed
Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice};
Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice};
if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance?
return {std::nullopt, true, {}, {}};
@ -808,7 +937,7 @@ TypeFunctionReductionResult<TypeId> unmTypeFunction(
if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes))
return {std::nullopt, true, {}, {}}; // occurs check failed
Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice};
Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice};
if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance?
return {std::nullopt, true, {}, {}};
@ -818,7 +947,20 @@ TypeFunctionReductionResult<TypeId> unmTypeFunction(
return {std::nullopt, true, {}, {}};
}
NotNull<Constraint> TypeFunctionContext::pushConstraint(ConstraintV&& c)
TypeFunctionContext::TypeFunctionContext(NotNull<ConstraintSolver> cs, NotNull<Scope> scope, NotNull<const Constraint> constraint)
: arena(cs->arena)
, builtins(cs->builtinTypes)
, scope(scope)
, normalizer(cs->normalizer)
, typeFunctionRuntime(cs->typeFunctionRuntime)
, ice(NotNull{&cs->iceReporter})
, limits(NotNull{&cs->limits})
, solver(cs.get())
, constraint(constraint.get())
{
}
NotNull<Constraint> TypeFunctionContext::pushConstraint(ConstraintV&& c) const
{
LUAU_ASSERT(solver);
NotNull<Constraint> newConstraint = solver->pushConstraint(scope, constraint ? constraint->location : Location{}, std::move(c));
@ -921,12 +1063,16 @@ TypeFunctionReductionResult<TypeId> numericBinopTypeFunction(
SolveResult solveResult;
if (!reversed)
solveResult = solveFunctionCall(ctx->arena, ctx->builtins, ctx->normalizer, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack);
solveResult = solveFunctionCall(
ctx->arena, ctx->builtins, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack
);
else
{
TypePack* p = getMutable<TypePack>(argPack);
std::swap(p->head.front(), p->head.back());
solveResult = solveFunctionCall(ctx->arena, ctx->builtins, ctx->normalizer, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack);
solveResult = solveFunctionCall(
ctx->arena, ctx->builtins, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack
);
}
if (!solveResult.typePackId.has_value())
@ -1156,7 +1302,7 @@ TypeFunctionReductionResult<TypeId> concatTypeFunction(
if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes))
return {std::nullopt, true, {}, {}}; // occurs check failed
Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice};
Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice};
if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance?
return {std::nullopt, true, {}, {}};
@ -1410,7 +1556,7 @@ static TypeFunctionReductionResult<TypeId> comparisonTypeFunction(
if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes))
return {std::nullopt, true, {}, {}}; // occurs check failed
Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice};
Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice};
if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance?
return {std::nullopt, true, {}, {}};
@ -1554,7 +1700,7 @@ TypeFunctionReductionResult<TypeId> eqTypeFunction(
if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes))
return {std::nullopt, true, {}, {}}; // occurs check failed
Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice};
Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice};
if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance?
return {std::nullopt, true, {}, {}};
@ -2004,7 +2150,7 @@ TypeFunctionReductionResult<TypeId> keyofFunctionImpl(
if (!computeKeysOf(*classesIter, localKeys, seen, isRaw, ctx))
continue;
for (auto key : keys)
for (auto& key : keys)
{
// remove any keys that are not present in each class
if (!localKeys.contains(key))
@ -2039,7 +2185,7 @@ TypeFunctionReductionResult<TypeId> keyofFunctionImpl(
if (!computeKeysOf(*tablesIter, localKeys, seen, isRaw, ctx))
continue;
for (auto key : keys)
for (auto& key : keys)
{
// remove any keys that are not present in each table
if (!localKeys.contains(key))
@ -2239,7 +2385,7 @@ TypeFunctionReductionResult<TypeId> indexFunctionImpl(
return {std::nullopt, true, {}, {}};
// indexer can be a union —> break them down into a vector
const std::vector<TypeId>* typesToFind;
const std::vector<TypeId>* typesToFind = nullptr;
const std::vector<TypeId> singleType{indexerTy};
if (auto unionTy = get<UnionType>(indexerTy))
typesToFind = &unionTy->options;

View file

@ -3,6 +3,7 @@
#include "Luau/DenseHash.h"
#include "Luau/Normalize.h"
#include "Luau/ToString.h"
#include "Luau/TypeFunction.h"
#include "Luau/Type.h"
#include "Luau/TypePack.h"

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,788 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypeFunctionRuntimeBuilder.h"
#include "Luau/Ast.h"
#include "Luau/BuiltinDefinitions.h"
#include "Luau/Common.h"
#include "Luau/DenseHash.h"
#include "Luau/StringUtils.h"
#include "Luau/Type.h"
#include "Luau/TypeArena.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypeFunctionRuntime.h"
#include "Luau/TypePack.h"
#include "Luau/ToString.h"
#include <optional>
// used to control the recursion limit of any operations done by user-defined type functions
// currently, controls serialization, deserialization, and `type.copy`
LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFunctionSerdeIterationLimit, 100'000);
namespace Luau
{
// Forked version of Clone.cpp
class TypeFunctionSerializer
{
using SeenTypes = DenseHashMap<TypeId, TypeFunctionTypeId>;
using SeenTypePacks = DenseHashMap<TypePackId, TypeFunctionTypePackId>;
TypeFunctionRuntimeBuilderState* state = nullptr;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
// A queue of TypeFunctionTypeIds that have been serialized, but whose interior types hasn't
// been updated to point to itself. Once all of its interior types
// has been updated, it gets removed from the queue.
// queue.back() should always return two of same type in their respective sides
// For example `auto [first, second] = queue.back()`: if first is PrimitiveType,
// second must be TypeFunctionPrimitiveType; else there should be an error
std::vector<std::tuple<Kind, TypeFunctionKind>> queue;
SeenTypes types; // Mapping of TypeIds that have been shallow serialized to TypeFunctionTypeIds
SeenTypePacks packs; // Mapping of TypePackIds that have been shallow serialized to TypeFunctionTypePackIds
int steps = 0;
public:
explicit TypeFunctionSerializer(TypeFunctionRuntimeBuilderState* state)
: state(state)
, typeFunctionRuntime(state->ctx->typeFunctionRuntime)
, queue({})
, types({})
, packs({})
{
}
TypeFunctionTypeId serialize(TypeId ty)
{
shallowSerialize(ty);
run();
if (hasExceededIterationLimit() || state->errors.size() != 0)
return nullptr;
return find(ty).value_or(nullptr);
}
TypeFunctionTypePackId serialize(TypePackId tp)
{
shallowSerialize(tp);
run();
if (hasExceededIterationLimit() || state->errors.size() != 0)
return nullptr;
return find(tp).value_or(nullptr);
}
private:
bool hasExceededIterationLimit() const
{
if (DFInt::LuauTypeFunctionSerdeIterationLimit == 0)
return false;
return steps + queue.size() >= size_t(DFInt::LuauTypeFunctionSerdeIterationLimit);
}
void run()
{
while (!queue.empty())
{
++steps;
if (hasExceededIterationLimit() || state->errors.size() != 0)
break;
auto [ty, tfti] = queue.back();
queue.pop_back();
serializeChildren(ty, tfti);
}
}
std::optional<TypeFunctionTypeId> find(TypeId ty) const
{
if (auto result = types.find(ty))
return *result;
return std::nullopt;
}
std::optional<TypeFunctionTypePackId> find(TypePackId tp) const
{
if (auto result = packs.find(tp))
return *result;
return std::nullopt;
}
std::optional<TypeFunctionKind> find(Kind kind) const
{
if (auto ty = get<TypeId>(kind))
return find(*ty);
else if (auto tp = get<TypePackId>(kind))
return find(*tp);
else
{
LUAU_ASSERT(!"Unknown kind found at TypeFunctionRuntimeSerializer");
return std::nullopt;
}
}
TypeFunctionTypeId shallowSerialize(TypeId ty)
{
ty = follow(ty);
if (auto it = find(ty))
return *it;
// Create a shallow serialization
TypeFunctionTypeId target = {};
if (auto p = get<PrimitiveType>(ty))
{
switch (p->type)
{
case PrimitiveType::Type::NilType:
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::NilType));
break;
case PrimitiveType::Type::Boolean:
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Boolean));
break;
case PrimitiveType::Number:
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Number));
break;
case PrimitiveType::String:
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String));
break;
case PrimitiveType::Thread:
case PrimitiveType::Function:
case PrimitiveType::Table:
case PrimitiveType::Buffer:
default:
{
std::string error = format("Argument of primitive type %s is not currently serializable by type functions", toString(ty).c_str());
state->errors.push_back(error);
}
}
}
else if (auto u = get<UnknownType>(ty))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnknownType{});
else if (auto a = get<NeverType>(ty))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNeverType{});
else if (auto a = get<AnyType>(ty))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionAnyType{});
else if (auto s = get<SingletonType>(ty))
{
if (auto bs = get<BooleanSingleton>(s))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionBooleanSingleton{bs->value}});
else if (auto ss = get<StringSingleton>(s))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionStringSingleton{ss->value}});
else
{
std::string error = format("Argument of singleton type %s is not currently serializable by type functions", toString(ty).c_str());
state->errors.push_back(error);
}
}
else if (auto u = get<UnionType>(ty))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnionType{{}});
else if (auto i = get<IntersectionType>(ty))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionIntersectionType{{}});
else if (auto n = get<NegationType>(ty))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNegationType{{}});
else if (auto t = get<TableType>(ty))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{{}, std::nullopt, std::nullopt});
else if (auto m = get<MetatableType>(ty))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{{}, std::nullopt, std::nullopt});
else if (auto f = get<FunctionType>(ty))
{
TypeFunctionTypePackId emptyTypePack = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{});
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionFunctionType{emptyTypePack, emptyTypePack});
}
else if (auto c = get<ClassType>(ty))
{
state->classesSerialized[c->name] = ty;
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionClassType{{}, std::nullopt, std::nullopt, std::nullopt, c->name});
}
else
{
std::string error = format("Argument of type %s is not currently serializable by type functions", toString(ty).c_str());
state->errors.push_back(error);
}
types[ty] = target;
queue.emplace_back(ty, target);
return target;
}
TypeFunctionTypePackId shallowSerialize(TypePackId tp)
{
tp = follow(tp);
if (auto it = find(tp))
return *it;
// Create a shallow serialization
TypeFunctionTypePackId target = {};
if (auto tPack = get<TypePack>(tp))
target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}});
else if (auto vPack = get<VariadicTypePack>(tp))
target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{});
else
{
std::string error = format("Argument of type pack %s is not currently serializable by type functions", toString(tp).c_str());
state->errors.push_back(error);
}
packs[tp] = target;
queue.emplace_back(tp, target);
return target;
}
void serializeChildren(TypeId ty, TypeFunctionTypeId tfti)
{
if (auto [p1, p2] = std::tuple{getMutable<PrimitiveType>(ty), getMutable<TypeFunctionPrimitiveType>(tfti)}; p1 && p2)
serializeChildren(p1, p2);
else if (auto [u1, u2] = std::tuple{getMutable<UnknownType>(ty), getMutable<TypeFunctionUnknownType>(tfti)}; u1 && u2)
serializeChildren(u1, u2);
else if (auto [n1, n2] = std::tuple{getMutable<NeverType>(ty), getMutable<TypeFunctionNeverType>(tfti)}; n1 && n2)
serializeChildren(n1, n2);
else if (auto [a1, a2] = std::tuple{getMutable<AnyType>(ty), getMutable<TypeFunctionAnyType>(tfti)}; a1 && a2)
serializeChildren(a1, a2);
else if (auto [s1, s2] = std::tuple{getMutable<SingletonType>(ty), getMutable<TypeFunctionSingletonType>(tfti)}; s1 && s2)
serializeChildren(s1, s2);
else if (auto [u1, u2] = std::tuple{getMutable<UnionType>(ty), getMutable<TypeFunctionUnionType>(tfti)}; u1 && u2)
serializeChildren(u1, u2);
else if (auto [i1, i2] = std::tuple{getMutable<IntersectionType>(ty), getMutable<TypeFunctionIntersectionType>(tfti)}; i1 && i2)
serializeChildren(i1, i2);
else if (auto [n1, n2] = std::tuple{getMutable<NegationType>(ty), getMutable<TypeFunctionNegationType>(tfti)}; n1 && n2)
serializeChildren(n1, n2);
else if (auto [t1, t2] = std::tuple{getMutable<TableType>(ty), getMutable<TypeFunctionTableType>(tfti)}; t1 && t2)
serializeChildren(t1, t2);
else if (auto [m1, m2] = std::tuple{getMutable<MetatableType>(ty), getMutable<TypeFunctionTableType>(tfti)}; m1 && m2)
serializeChildren(m1, m2);
else if (auto [f1, f2] = std::tuple{getMutable<FunctionType>(ty), getMutable<TypeFunctionFunctionType>(tfti)}; f1 && f2)
serializeChildren(f1, f2);
else if (auto [c1, c2] = std::tuple{getMutable<ClassType>(ty), getMutable<TypeFunctionClassType>(tfti)}; c1 && c2)
serializeChildren(c1, c2);
else
{ // Either this or ty and tfti do not represent the same type
std::string error = format("Argument of type %s is not currently serializable by type functions", toString(ty).c_str());
state->errors.push_back(error);
}
}
void serializeChildren(TypePackId tp, TypeFunctionTypePackId tftp)
{
if (auto [tPack1, tPack2] = std::tuple{getMutable<TypePack>(tp), getMutable<TypeFunctionTypePack>(tftp)}; tPack1 && tPack2)
serializeChildren(tPack1, tPack2);
else if (auto [vPack1, vPack2] = std::tuple{getMutable<VariadicTypePack>(tp), getMutable<TypeFunctionVariadicTypePack>(tftp)};
vPack1 && vPack2)
serializeChildren(vPack1, vPack2);
else
{ // Either this or ty and tfti do not represent the same type
std::string error = format("Argument of type pack %s is not currently serializable by type functions", toString(tp).c_str());
state->errors.push_back(error);
}
}
void serializeChildren(Kind kind, TypeFunctionKind tfkind)
{
if (auto [ty, tfty] = std::tuple{get<TypeId>(kind), get<TypeFunctionTypeId>(tfkind)}; ty && tfty)
serializeChildren(*ty, *tfty);
else if (auto [tp, tftp] = std::tuple{get<TypePackId>(kind), get<TypeFunctionTypePackId>(tfkind)}; tp && tftp)
serializeChildren(*tp, *tftp);
else
state->ctx->ice->ice("Serializing user defined type function arguments: kind and tfkind do not represent the same type");
}
void serializeChildren(PrimitiveType* p1, TypeFunctionPrimitiveType* p2)
{
// noop.
}
void serializeChildren(UnknownType* u1, TypeFunctionUnknownType* u2)
{
// noop.
}
void serializeChildren(NeverType* n1, TypeFunctionNeverType* n2)
{
// noop.
}
void serializeChildren(AnyType* a1, TypeFunctionAnyType* a2)
{
// noop.
}
void serializeChildren(SingletonType* s1, TypeFunctionSingletonType* s2)
{
// noop.
}
void serializeChildren(UnionType* u1, TypeFunctionUnionType* u2)
{
for (TypeId& ty : u1->options)
u2->components.push_back(shallowSerialize(ty));
}
void serializeChildren(IntersectionType* i1, TypeFunctionIntersectionType* i2)
{
for (TypeId& ty : i1->parts)
i2->components.push_back(shallowSerialize(ty));
}
void serializeChildren(NegationType* n1, TypeFunctionNegationType* n2)
{
n2->type = shallowSerialize(n1->ty);
}
void serializeChildren(TableType* t1, TypeFunctionTableType* t2)
{
for (const auto& [k, p] : t1->props)
{
std::optional<TypeFunctionTypeId> readTy = std::nullopt;
if (p.readTy)
readTy = shallowSerialize(*p.readTy);
std::optional<TypeFunctionTypeId> writeTy = std::nullopt;
if (p.writeTy)
writeTy = shallowSerialize(*p.writeTy);
t2->props[k] = TypeFunctionProperty{readTy, writeTy};
}
if (t1->indexer)
t2->indexer = TypeFunctionTableIndexer(shallowSerialize(t1->indexer->indexType), shallowSerialize(t1->indexer->indexResultType));
}
void serializeChildren(MetatableType* m1, TypeFunctionTableType* m2)
{
auto tmpTable = get<TypeFunctionTableType>(shallowSerialize(m1->table));
if (!tmpTable)
state->ctx->ice->ice("Serializing user defined type function arguments: metatable's table is not a TableType");
m2->props = tmpTable->props;
m2->indexer = tmpTable->indexer;
m2->metatable = shallowSerialize(m1->metatable);
}
void serializeChildren(FunctionType* f1, TypeFunctionFunctionType* f2)
{
f2->argTypes = shallowSerialize(f1->argTypes);
f2->retTypes = shallowSerialize(f1->retTypes);
}
void serializeChildren(ClassType* c1, TypeFunctionClassType* c2)
{
for (const auto& [k, p] : c1->props)
{
std::optional<TypeFunctionTypeId> readTy = std::nullopt;
if (p.readTy)
readTy = shallowSerialize(*p.readTy);
std::optional<TypeFunctionTypeId> writeTy = std::nullopt;
if (p.writeTy)
writeTy = shallowSerialize(*p.writeTy);
c2->props[k] = TypeFunctionProperty{readTy, writeTy};
}
if (c1->indexer)
c2->indexer = TypeFunctionTableIndexer(shallowSerialize(c1->indexer->indexType), shallowSerialize(c1->indexer->indexResultType));
if (c1->metatable)
c2->metatable = shallowSerialize(*c1->metatable);
if (c1->parent)
c2->parent = shallowSerialize(*c1->parent);
}
void serializeChildren(TypePack* t1, TypeFunctionTypePack* t2)
{
for (TypeId& ty : t1->head)
t2->head.push_back(shallowSerialize(ty));
if (t1->tail.has_value())
t2->tail = shallowSerialize(*t1->tail);
}
void serializeChildren(VariadicTypePack* v1, TypeFunctionVariadicTypePack* v2)
{
v2->type = shallowSerialize(v1->ty);
}
};
// Complete inverse of TypeFunctionSerializer
class TypeFunctionDeserializer
{
using SeenTypes = DenseHashMap<TypeFunctionTypeId, TypeId>;
using SeenTypePacks = DenseHashMap<TypeFunctionTypePackId, TypePackId>;
TypeFunctionRuntimeBuilderState* state = nullptr;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
// A queue of TypeIds that have been deserialized, but whose interior types hasn't
// been updated to point to itself. Once all of its interior types
// has been updated, it gets removed from the queue.
// queue.back() should always return two of same type in their respective sides
// For example `auto [first, second] = queue.back()`: if first is TypeFunctionPrimitiveType,
// second must be PrimitiveType; else there should be an error
std::vector<std::tuple<TypeFunctionKind, Kind>> queue;
SeenTypes types; // Mapping of TypeFunctionTypeIds that have been shallow deserialized to TypeIds
SeenTypePacks packs; // Mapping of TypeFunctionTypePackIds that have been shallow deserialized to TypePackIds
int steps = 0;
public:
explicit TypeFunctionDeserializer(TypeFunctionRuntimeBuilderState* state)
: state(state)
, typeFunctionRuntime(state->ctx->typeFunctionRuntime)
, queue({})
, types({})
, packs({}){};
TypeId deserialize(TypeFunctionTypeId ty)
{
shallowDeserialize(ty);
run();
if (hasExceededIterationLimit() || state->errors.size() != 0)
{
TypeId error = state->ctx->builtins->errorRecoveryType();
types[ty] = error;
return error;
}
return find(ty).value_or(state->ctx->builtins->errorRecoveryType());
}
TypePackId deserialize(TypeFunctionTypePackId tp)
{
shallowDeserialize(tp);
run();
if (hasExceededIterationLimit() || state->errors.size() != 0)
{
TypePackId error = state->ctx->builtins->errorRecoveryTypePack();
packs[tp] = error;
return error;
}
return find(tp).value_or(state->ctx->builtins->errorRecoveryTypePack());
}
private:
bool hasExceededIterationLimit() const
{
if (DFInt::LuauTypeFunctionSerdeIterationLimit == 0)
return false;
return steps + queue.size() >= size_t(DFInt::LuauTypeFunctionSerdeIterationLimit);
}
void run()
{
while (!queue.empty())
{
++steps;
if (hasExceededIterationLimit() || state->errors.size() != 0)
break;
auto [tfti, ty] = queue.back();
queue.pop_back();
deserializeChildren(tfti, ty);
}
}
std::optional<TypeId> find(TypeFunctionTypeId ty) const
{
if (auto result = types.find(ty))
return *result;
return std::nullopt;
}
std::optional<TypePackId> find(TypeFunctionTypePackId tp) const
{
if (auto result = packs.find(tp))
return *result;
return std::nullopt;
}
std::optional<Kind> find(TypeFunctionKind kind) const
{
if (auto ty = get<TypeFunctionTypeId>(kind))
return find(*ty);
else if (auto tp = get<TypeFunctionTypePackId>(kind))
return find(*tp);
else
{
LUAU_ASSERT(!"Unknown kind found at TypeFunctionDeserializer");
return std::nullopt;
}
}
TypeId shallowDeserialize(TypeFunctionTypeId ty)
{
if (auto it = find(ty))
return *it;
// Create a shallow deserialization
TypeId target = {};
if (auto p = get<TypeFunctionPrimitiveType>(ty))
{
switch (p->type)
{
case TypeFunctionPrimitiveType::Type::NilType:
target = state->ctx->builtins->nilType;
break;
case TypeFunctionPrimitiveType::Type::Boolean:
target = state->ctx->builtins->booleanType;
break;
case TypeFunctionPrimitiveType::Type::Number:
target = state->ctx->builtins->numberType;
break;
case TypeFunctionPrimitiveType::Type::String:
target = state->ctx->builtins->stringType;
break;
default:
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
}
}
else if (auto u = get<TypeFunctionUnknownType>(ty))
target = state->ctx->builtins->unknownType;
else if (auto n = get<TypeFunctionNeverType>(ty))
target = state->ctx->builtins->neverType;
else if (auto a = get<TypeFunctionAnyType>(ty))
target = state->ctx->builtins->anyType;
else if (auto s = get<TypeFunctionSingletonType>(ty))
{
if (auto bs = get<TypeFunctionBooleanSingleton>(s))
target = state->ctx->arena->addType(SingletonType{BooleanSingleton{bs->value}});
else if (auto ss = get<TypeFunctionStringSingleton>(s))
target = state->ctx->arena->addType(SingletonType{StringSingleton{ss->value}});
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
}
else if (auto u = get<TypeFunctionUnionType>(ty))
target = state->ctx->arena->addTV(Type(UnionType{{}}));
else if (auto i = get<TypeFunctionIntersectionType>(ty))
target = state->ctx->arena->addTV(Type(IntersectionType{{}}));
else if (auto n = get<TypeFunctionNegationType>(ty))
target = state->ctx->arena->addType(NegationType{state->ctx->builtins->unknownType});
else if (auto t = get<TypeFunctionTableType>(ty); t && !t->metatable.has_value())
target = state->ctx->arena->addType(TableType{TableType::Props{}, std::nullopt, TypeLevel{}, TableState::Sealed});
else if (auto m = get<TypeFunctionTableType>(ty); m && m->metatable.has_value())
{
TypeId emptyTable = state->ctx->arena->addType(TableType{TableType::Props{}, std::nullopt, TypeLevel{}, TableState::Sealed});
target = state->ctx->arena->addType(MetatableType{emptyTable, emptyTable});
}
else if (auto f = get<TypeFunctionFunctionType>(ty))
{
TypePackId emptyTypePack = state->ctx->arena->addTypePack(TypePack{});
target = state->ctx->arena->addType(FunctionType{emptyTypePack, emptyTypePack, {}, false});
}
else if (auto c = get<TypeFunctionClassType>(ty))
{
if (auto result = state->classesSerialized.find(c->name))
target = *result;
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious class type is being deserialized");
}
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
types[ty] = target;
queue.emplace_back(ty, target);
return target;
}
TypePackId shallowDeserialize(TypeFunctionTypePackId tp)
{
if (auto it = find(tp))
return *it;
// Create a shallow deserialization
TypePackId target = {};
if (auto tPack = get<TypeFunctionTypePack>(tp))
target = state->ctx->arena->addTypePack(TypePack{});
else if (auto vPack = get<TypeFunctionVariadicTypePack>(tp))
target = state->ctx->arena->addTypePack(VariadicTypePack{});
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
packs[tp] = target;
queue.emplace_back(tp, target);
return target;
}
void deserializeChildren(TypeFunctionTypeId tfti, TypeId ty)
{
if (auto [p1, p2] = std::tuple{getMutable<PrimitiveType>(ty), getMutable<TypeFunctionPrimitiveType>(tfti)}; p1 && p2)
deserializeChildren(p2, p1);
else if (auto [u1, u2] = std::tuple{getMutable<UnknownType>(ty), getMutable<TypeFunctionUnknownType>(tfti)}; u1 && u2)
deserializeChildren(u2, u1);
else if (auto [n1, n2] = std::tuple{getMutable<NeverType>(ty), getMutable<TypeFunctionNeverType>(tfti)}; n1 && n2)
deserializeChildren(n2, n1);
else if (auto [a1, a2] = std::tuple{getMutable<AnyType>(ty), getMutable<TypeFunctionAnyType>(tfti)}; a1 && a2)
deserializeChildren(a2, a1);
else if (auto [s1, s2] = std::tuple{getMutable<SingletonType>(ty), getMutable<TypeFunctionSingletonType>(tfti)}; s1 && s2)
deserializeChildren(s2, s1);
else if (auto [u1, u2] = std::tuple{getMutable<UnionType>(ty), getMutable<TypeFunctionUnionType>(tfti)}; u1 && u2)
deserializeChildren(u2, u1);
else if (auto [i1, i2] = std::tuple{getMutable<IntersectionType>(ty), getMutable<TypeFunctionIntersectionType>(tfti)}; i1 && i2)
deserializeChildren(i2, i1);
else if (auto [n1, n2] = std::tuple{getMutable<NegationType>(ty), getMutable<TypeFunctionNegationType>(tfti)}; n1 && n2)
deserializeChildren(n2, n1);
else if (auto [t1, t2] = std::tuple{getMutable<TableType>(ty), getMutable<TypeFunctionTableType>(tfti)};
t1 && t2 && !t2->metatable.has_value())
deserializeChildren(t2, t1);
else if (auto [m1, m2] = std::tuple{getMutable<MetatableType>(ty), getMutable<TypeFunctionTableType>(tfti)};
m1 && m2 && m2->metatable.has_value())
deserializeChildren(m2, m1);
else if (auto [f1, f2] = std::tuple{getMutable<FunctionType>(ty), getMutable<TypeFunctionFunctionType>(tfti)}; f1 && f2)
deserializeChildren(f2, f1);
else if (auto [c1, c2] = std::tuple{getMutable<ClassType>(ty), getMutable<TypeFunctionClassType>(tfti)}; c1 && c2)
deserializeChildren(c2, c1);
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
}
void deserializeChildren(TypeFunctionTypePackId tftp, TypePackId tp)
{
if (auto [tPack1, tPack2] = std::tuple{getMutable<TypePack>(tp), getMutable<TypeFunctionTypePack>(tftp)}; tPack1 && tPack2)
deserializeChildren(tPack2, tPack1);
else if (auto [vPack1, vPack2] = std::tuple{getMutable<VariadicTypePack>(tp), getMutable<TypeFunctionVariadicTypePack>(tftp)};
vPack1 && vPack2)
deserializeChildren(vPack2, vPack1);
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
}
void deserializeChildren(TypeFunctionKind tfkind, Kind kind)
{
if (auto [ty, tfty] = std::tuple{get<TypeId>(kind), get<TypeFunctionTypeId>(tfkind)}; ty && tfty)
deserializeChildren(*tfty, *ty);
else if (auto [tp, tftp] = std::tuple{get<TypePackId>(kind), get<TypeFunctionTypePackId>(tfkind)}; tp && tftp)
deserializeChildren(*tftp, *tp);
else
state->ctx->ice->ice("Deserializing user defined type function arguments: tfkind and kind do not represent the same type");
}
void deserializeChildren(TypeFunctionPrimitiveType* p2, PrimitiveType* p1)
{
// noop.
}
void deserializeChildren(TypeFunctionUnknownType* u2, UnknownType* u1)
{
// noop.
}
void deserializeChildren(TypeFunctionNeverType* n2, NeverType* n1)
{
// noop.
}
void deserializeChildren(TypeFunctionAnyType* a2, AnyType* a1)
{
// noop.
}
void deserializeChildren(TypeFunctionSingletonType* s2, SingletonType* s1)
{
// noop.
}
void deserializeChildren(TypeFunctionUnionType* u2, UnionType* u1)
{
for (TypeFunctionTypeId& ty : u2->components)
u1->options.push_back(shallowDeserialize(ty));
}
void deserializeChildren(TypeFunctionIntersectionType* i2, IntersectionType* i1)
{
for (TypeFunctionTypeId& ty : i2->components)
i1->parts.push_back(shallowDeserialize(ty));
}
void deserializeChildren(TypeFunctionNegationType* n2, NegationType* n1)
{
n1->ty = shallowDeserialize(n2->type);
}
void deserializeChildren(TypeFunctionTableType* t2, TableType* t1)
{
for (const auto& [k, p] : t2->props)
{
if (p.readTy && p.writeTy)
t1->props[k] = Property::rw(shallowDeserialize(*p.readTy), shallowDeserialize(*p.writeTy));
else if (p.readTy)
t1->props[k] = Property::readonly(shallowDeserialize(*p.readTy));
else if (p.writeTy)
t1->props[k] = Property::writeonly(shallowDeserialize(*p.writeTy));
}
if (t2->indexer.has_value())
t1->indexer = TableIndexer(shallowDeserialize(t2->indexer->keyType), shallowDeserialize(t2->indexer->valueType));
}
void deserializeChildren(TypeFunctionTableType* m2, MetatableType* m1)
{
TypeFunctionTypeId temp = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{m2->props, m2->indexer});
m1->table = shallowDeserialize(temp);
if (m2->metatable.has_value())
m1->metatable = shallowDeserialize(*m2->metatable);
}
void deserializeChildren(TypeFunctionFunctionType* f2, FunctionType* f1)
{
if (f2->argTypes)
f1->argTypes = shallowDeserialize(f2->argTypes);
if (f2->retTypes)
f1->retTypes = shallowDeserialize(f2->retTypes);
}
void deserializeChildren(TypeFunctionClassType* c2, ClassType* c1)
{
// noop.
}
void deserializeChildren(TypeFunctionTypePack* t2, TypePack* t1)
{
for (TypeFunctionTypeId& ty : t2->head)
t1->head.push_back(shallowDeserialize(ty));
if (t2->tail.has_value())
t1->tail = shallowDeserialize(*t2->tail);
}
void deserializeChildren(TypeFunctionVariadicTypePack* v2, VariadicTypePack* v1)
{
v1->ty = shallowDeserialize(v2->type);
}
};
TypeFunctionTypeId serialize(TypeId ty, TypeFunctionRuntimeBuilderState* state)
{
return TypeFunctionSerializer(state).serialize(ty);
}
TypeId deserialize(TypeFunctionTypeId ty, TypeFunctionRuntimeBuilderState* state)
{
return TypeFunctionDeserializer(state).deserialize(ty);
}
} // namespace Luau

View file

@ -33,7 +33,6 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false)
LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false)
LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false)
LUAU_FASTFLAGVARIABLE(LuauAcceptIndexingTableUnionsIntersections, false)
namespace Luau
@ -1284,20 +1283,11 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin)
for (size_t i = 2; i < varTypes.size(); ++i)
unify(nilType, varTypes[i], scope, forin.location);
}
else if (isNonstrictMode() || FFlag::LuauOkWithIteratingOverTableProperties)
else
{
for (TypeId var : varTypes)
unify(unknownType, var, scope, forin.location);
}
else
{
TypeId varTy = errorRecoveryType(loopScope);
for (TypeId var : varTypes)
unify(varTy, var, scope, forin.location);
reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"});
}
return check(loopScope, *forin.body);
}

View file

@ -1,6 +1,11 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/DenseHash.h"
#include <vector>
namespace Luau
{
@ -12,10 +17,17 @@ enum class Mode
Definition, // Type definition module, has special parsing rules
};
struct FragmentParseResumeSettings
{
DenseHashMap<AstName, AstLocal*> localMap{AstName()};
std::vector<AstLocal*> localStack;
};
struct ParseOptions
{
bool allowDeclarationSyntax = false;
bool captureComments = false;
std::optional<FragmentParseResumeSettings> parseFragment = std::nullopt;
};
} // namespace Luau

View file

@ -452,4 +452,4 @@ private:
std::string scratchData;
};
} // namespace Luau
} // namespace Luau

View file

@ -7,6 +7,7 @@
#include <memory>
#include <stdint.h>
#include <string.h>
LUAU_FASTFLAG(DebugLuauTimeTracing)

View file

@ -7,8 +7,6 @@
#include <limits.h>
LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false)
namespace Luau
{
@ -434,13 +432,11 @@ Lexeme Lexer::lookahead()
lineOffset = currentLineOffset;
lexeme = currentLexeme;
prevLocation = currentPrevLocation;
if (FFlag::LuauLexerLookaheadRemembersBraceType)
{
if (braceStack.size() < currentBraceStackSize)
braceStack.push_back(currentBraceType);
else if (braceStack.size() > currentBraceStackSize)
braceStack.pop_back();
}
if (braceStack.size() < currentBraceStackSize)
braceStack.push_back(currentBraceType);
else if (braceStack.size() > currentBraceStackSize)
braceStack.pop_back();
return result;
}

View file

@ -19,7 +19,8 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
LUAU_FASTFLAGVARIABLE(LuauSolverV2, false)
LUAU_FASTFLAGVARIABLE(LuauNativeAttribute, false)
LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr, false)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctions, false)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionsSyntax, false)
LUAU_FASTFLAGVARIABLE(LuauAllowFragmentParsing, false)
namespace Luau
{
@ -211,6 +212,15 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc
scratchExpr.reserve(16);
scratchLocal.reserve(16);
scratchBinding.reserve(16);
if (FFlag::LuauAllowFragmentParsing)
{
if (options.parseFragment)
{
localMap = options.parseFragment->localMap;
localStack = options.parseFragment->localStack;
}
}
}
bool Parser::blockFollow(const Lexeme& l)
@ -891,7 +901,7 @@ AstStat* Parser::parseReturn()
AstStat* Parser::parseTypeAlias(const Location& start, bool exported)
{
// parsing a type function
if (FFlag::LuauUserDefinedTypeFunctions)
if (FFlag::LuauUserDefinedTypeFunctionsSyntax)
{
if (lexer.current().type == Lexeme::ReservedFunction)
return parseTypeFunction(start);

View file

@ -3,6 +3,7 @@
#include "Luau/StringUtils.h"
#include <algorithm>
#include <mutex>
#include <string>

View file

@ -85,6 +85,7 @@ target_link_libraries(Luau.Config PUBLIC Luau.Ast)
target_compile_features(Luau.Analysis PUBLIC cxx_std_17)
target_include_directories(Luau.Analysis PUBLIC Analysis/include)
target_link_libraries(Luau.Analysis PUBLIC Luau.Ast Luau.EqSat Luau.Config)
target_link_libraries(Luau.Analysis PRIVATE Luau.Compiler Luau.VM)
target_compile_features(Luau.EqSat PUBLIC cxx_std_17)
target_include_directories(Luau.EqSat PUBLIC EqSat/include)
@ -276,7 +277,7 @@ foreach(LIB Luau.Ast Luau.Compiler Luau.Config Luau.Analysis Luau.EqSat Luau.Cod
if(LIB MATCHES "CodeGen|VM" AND DEPENDS MATCHES "Ast|Analysis|Config|Compiler")
message(FATAL_ERROR ${LIB} " is a runtime component but it depends on one of the offline components")
endif()
if(LIB MATCHES "Ast|Analysis|EqSat|Compiler" AND DEPENDS MATCHES "CodeGen|VM")
if(LIB MATCHES "Ast|EqSat|Compiler" AND DEPENDS MATCHES "CodeGen|VM")
message(FATAL_ERROR ${LIB} " is an offline component but it depends on one of the runtime components")
endif()
if(LIB MATCHES "Ast|Compiler" AND DEPENDS MATCHES "Analysis|Config")

View file

@ -11,8 +11,6 @@
#include "lstate.h"
#include "lgc.h"
LUAU_FASTFLAGVARIABLE(LuauCodegenArmNumToVecFix, false)
namespace Luau
{
namespace CodeGen
@ -1121,7 +1119,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
else
{
RegisterA64 tempd = tempDouble(inst.a);
RegisterA64 temps = FFlag::LuauCodegenArmNumToVecFix ? regs.allocTemp(KindA64::s) : castReg(KindA64::s, tempd);
RegisterA64 temps = regs.allocTemp(KindA64::s);
build.fcvt(temps, tempd);
build.dup_4s(inst.regA64, castReg(KindA64::q, temps), 0);

View file

@ -142,7 +142,7 @@ endif
$(AST_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include
$(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -ICommon/include -IAst/include
$(CONFIG_OBJECTS): CXXFLAGS+=-std=c++17 -IConfig/include -ICommon/include -IAst/include
$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include
$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include -ICompiler/include -IVM/include
$(EQSAT_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IEqSat/include
$(CODEGEN_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -ICodeGen/include -IVM/include -IVM/src # Code generation needs VM internals
$(VM_OBJECTS): CXXFLAGS+=-std=c++11 -ICommon/include -IVM/include
@ -227,7 +227,7 @@ luau-tests: $(TESTS_TARGET)
# executable targets
$(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET)
$(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET)
$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(AST_TARGET) $(CONFIG_TARGET)
$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(COMPILER_TARGET) $(VM_TARGET)
$(COMPILE_CLI_TARGET): $(COMPILE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET)
$(BYTECODE_CLI_TARGET): $(BYTECODE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET)

View file

@ -182,6 +182,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/Documentation.h
Analysis/include/Luau/Error.h
Analysis/include/Luau/FileResolver.h
Analysis/include/Luau/FragmentAutocomplete.h
Analysis/include/Luau/Frontend.h
Analysis/include/Luau/Generalization.h
Analysis/include/Luau/GlobalTypes.h
@ -223,6 +224,8 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/TypedAllocator.h
Analysis/include/Luau/TypeFunction.h
Analysis/include/Luau/TypeFunctionReductionGuesser.h
Analysis/include/Luau/TypeFunctionRuntime.h
Analysis/include/Luau/TypeFunctionRuntimeBuilder.h
Analysis/include/Luau/TypeFwd.h
Analysis/include/Luau/TypeInfer.h
Analysis/include/Luau/TypeOrPack.h
@ -253,6 +256,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/Differ.cpp
Analysis/src/EmbeddedBuiltinDefinitions.cpp
Analysis/src/Error.cpp
Analysis/src/FragmentAutocomplete.cpp
Analysis/src/Frontend.cpp
Analysis/src/Generalization.cpp
Analysis/src/GlobalTypes.cpp
@ -287,6 +291,8 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/TypedAllocator.cpp
Analysis/src/TypeFunction.cpp
Analysis/src/TypeFunctionReductionGuesser.cpp
Analysis/src/TypeFunctionRuntime.cpp
Analysis/src/TypeFunctionRuntimeBuilder.cpp
Analysis/src/TypeInfer.cpp
Analysis/src/TypeOrPack.cpp
Analysis/src/TypePack.cpp
@ -440,6 +446,7 @@ if(TARGET Luau.UnitTest)
tests/Error.test.cpp
tests/Fixture.cpp
tests/Fixture.h
tests/FragmentAutocomplete.test.cpp
tests/Frontend.test.cpp
tests/Generalization.test.cpp
tests/InsertionOrderedMap.test.cpp
@ -474,6 +481,7 @@ if(TARGET Luau.UnitTest)
tests/Transpiler.test.cpp
tests/TxnLog.test.cpp
tests/TypeFunction.test.cpp
tests/TypeFunction.user.test.cpp
tests/TypeInfer.aliases.test.cpp
tests/TypeInfer.annotations.test.cpp
tests/TypeInfer.anyerror.test.cpp

View file

@ -10,8 +10,6 @@
#include <string.h>
LUAU_FASTFLAGVARIABLE(LuauPreserveLudataRenaming, false)
// clang-format off
const char* const luaT_typenames[] = {
// ORDER TYPE
@ -124,74 +122,40 @@ const TValue* luaT_gettmbyobj(lua_State* L, const TValue* o, TMS event)
const TString* luaT_objtypenamestr(lua_State* L, const TValue* o)
{
if (FFlag::LuauPreserveLudataRenaming)
// Userdata created by the environment can have a custom type name set in the individual metatable
// If there is no custom name, 'userdata' is returned
if (ttisuserdata(o) && uvalue(o)->tag != UTAG_PROXY && uvalue(o)->metatable)
{
// Userdata created by the environment can have a custom type name set in the individual metatable
// If there is no custom name, 'userdata' is returned
if (ttisuserdata(o) && uvalue(o)->tag != UTAG_PROXY && uvalue(o)->metatable)
{
const TValue* type = luaH_getstr(uvalue(o)->metatable, L->global->tmname[TM_TYPE]);
const TValue* type = luaH_getstr(uvalue(o)->metatable, L->global->tmname[TM_TYPE]);
if (ttisstring(type))
return tsvalue(type);
return L->global->ttname[ttype(o)];
}
// Tagged lightuserdata can be named using lua_setlightuserdataname
if (ttislightuserdata(o))
{
int tag = lightuserdatatag(o);
if (unsigned(tag) < LUA_LUTAG_LIMIT)
{
if (const TString* name = L->global->lightuserdataname[tag])
return name;
}
}
// For all types except userdata and table, a global metatable can be set with a global name override
if (Table* mt = L->global->mt[ttype(o)])
{
const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]);
if (ttisstring(type))
return tsvalue(type);
}
if (ttisstring(type))
return tsvalue(type);
return L->global->ttname[ttype(o)];
}
else
// Tagged lightuserdata can be named using lua_setlightuserdataname
if (ttislightuserdata(o))
{
if (ttisuserdata(o) && uvalue(o)->tag != UTAG_PROXY && uvalue(o)->metatable)
int tag = lightuserdatatag(o);
if (unsigned(tag) < LUA_LUTAG_LIMIT)
{
const TValue* type = luaH_getstr(uvalue(o)->metatable, L->global->tmname[TM_TYPE]);
if (ttisstring(type))
return tsvalue(type);
if (const TString* name = L->global->lightuserdataname[tag])
return name;
}
else if (ttislightuserdata(o))
{
int tag = lightuserdatatag(o);
if (unsigned(tag) < LUA_LUTAG_LIMIT)
{
const TString* name = L->global->lightuserdataname[tag];
if (name)
return name;
}
}
else if (Table* mt = L->global->mt[ttype(o)])
{
const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]);
if (ttisstring(type))
return tsvalue(type);
}
return L->global->ttname[ttype(o)];
}
// For all types except userdata and table, a global metatable can be set with a global name override
if (Table* mt = L->global->mt[ttype(o)])
{
const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]);
if (ttisstring(type))
return tsvalue(type);
}
return L->global->ttname[ttype(o)];
}
const char* luaT_objtypename(lua_State* L, const TValue* o)

View file

@ -34,8 +34,6 @@ void luaC_validate(lua_State* L);
LUAU_FASTFLAG(DebugLuauAbortingChecks)
LUAU_FASTINT(CodegenHeuristicsInstructionLimit)
LUAU_FASTFLAG(LuauNativeAttribute)
LUAU_FASTFLAG(LuauPreserveLudataRenaming)
LUAU_FASTFLAG(LuauCodegenArmNumToVecFix)
static lua_CompileOptions defaultOptions()
{
@ -825,8 +823,6 @@ TEST_CASE("Pack")
TEST_CASE("Vector")
{
ScopedFastFlag luauCodegenArmNumToVecFix{FFlag::LuauCodegenArmNumToVecFix, true};
lua_CompileOptions copts = defaultOptions();
Luau::CodeGen::CompilationOptions nativeOpts = defaultCodegenOptions();
@ -2251,20 +2247,17 @@ TEST_CASE("LightuserdataApi")
lua_pop(L, 1);
if (FFlag::LuauPreserveLudataRenaming)
{
// Still possible to rename the global lightuserdata name using a metatable
lua_pushlightuserdata(L, value);
CHECK(strcmp(luaL_typename(L, -1), "userdata") == 0);
// Still possible to rename the global lightuserdata name using a metatable
lua_pushlightuserdata(L, value);
CHECK(strcmp(luaL_typename(L, -1), "userdata") == 0);
lua_createtable(L, 0, 1);
lua_pushstring(L, "luserdata");
lua_setfield(L, -2, "__type");
lua_setmetatable(L, -2);
lua_createtable(L, 0, 1);
lua_pushstring(L, "luserdata");
lua_setfield(L, -2, "__type");
lua_setmetatable(L, -2);
CHECK(strcmp(luaL_typename(L, -1), "luserdata") == 0);
lua_pop(L, 1);
}
CHECK(strcmp(luaL_typename(L, -1), "luserdata") == 0);
lua_pop(L, 1);
globalState.reset();
}

View file

@ -42,7 +42,9 @@ void ConstraintGeneratorFixture::generateConstraints(const std::string& code)
void ConstraintGeneratorFixture::solve(const std::string& code)
{
generateConstraints(code);
ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, {}};
ConstraintSolver cs{
NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, {}
};
cs.run();
}

View file

@ -20,6 +20,7 @@ struct ConstraintGeneratorFixture : Fixture
DcrLogger logger;
UnifierSharedState sharedState{&ice};
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
TypeFunctionRuntime typeFunctionRuntime;
std::unique_ptr<DataFlowGraph> dfg;
std::unique_ptr<ConstraintGenerator> cg;

View file

@ -0,0 +1,139 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/FragmentAutocomplete.h"
#include "Fixture.h"
#include "Luau/Ast.h"
#include "Luau/AstQuery.h"
using namespace Luau;
struct FragmentAutocompleteFixture : Fixture
{
FragmentAutocompleteAncestryResult runAutocompleteVisitor(const std::string& source, const Position& cursorPos)
{
ParseResult p = tryParse(source); // We don't care about parsing incomplete asts
REQUIRE(p.root);
return findAncestryForFragmentParse(p.root, cursorPos);
}
};
TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTest");
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "just_two_locals")
{
auto result = runAutocompleteVisitor(
R"(
local x = 4
local y = 5
)",
{2, 11}
);
CHECK_EQ(3, result.ancestry.size());
CHECK_EQ(2, result.localStack.size());
CHECK_EQ(result.localMap.size(), result.localStack.size());
REQUIRE(result.nearestStatement);
AstStatLocal* local = result.nearestStatement->as<AstStatLocal>();
REQUIRE(local);
CHECK(1 == local->vars.size);
CHECK_EQ("y", std::string(local->vars.data[0]->name.value));
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "cursor_within_scope_tracks_locals_from_previous_scope")
{
auto result = runAutocompleteVisitor(
R"(
local x = 4
local y = 5
if x == 4 then
local e = y
end
)",
{4, 15}
);
CHECK_EQ(5, result.ancestry.size());
CHECK_EQ(3, result.localStack.size());
CHECK_EQ(result.localMap.size(), result.localStack.size());
REQUIRE(result.nearestStatement);
CHECK_EQ("e", std::string(result.localStack.back()->name.value));
AstStatLocal* local = result.nearestStatement->as<AstStatLocal>();
REQUIRE(local);
CHECK(1 == local->vars.size);
CHECK_EQ("e", std::string(local->vars.data[0]->name.value));
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "cursor_that_comes_later_shouldnt_capture_locals_in_unavailable_scope")
{
auto result = runAutocompleteVisitor(
R"(
local x = 4
local y = 5
if x == 4 then
local e = y
end
local z = x + x
if y == 5 then
local q = x + y + z
end
)",
{8, 23}
);
CHECK_EQ(6, result.ancestry.size());
CHECK_EQ(4, result.localStack.size());
CHECK_EQ(result.localMap.size(), result.localStack.size());
REQUIRE(result.nearestStatement);
CHECK_EQ("q", std::string(result.localStack.back()->name.value));
AstStatLocal* local = result.nearestStatement->as<AstStatLocal>();
REQUIRE(local);
CHECK(1 == local->vars.size);
CHECK_EQ("q", std::string(local->vars.data[0]->name.value));
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "nearest_enclosing_statement_can_be_non_local")
{
auto result = runAutocompleteVisitor(
R"(
local x = 4
local y = 5
if x == 4 then
)",
{3, 4}
);
CHECK_EQ(4, result.ancestry.size());
CHECK_EQ(2, result.localStack.size());
CHECK_EQ(result.localMap.size(), result.localStack.size());
REQUIRE(result.nearestStatement);
CHECK_EQ("y", std::string(result.localStack.back()->name.value));
AstStatIf* ifS = result.nearestStatement->as<AstStatIf>();
CHECK(ifS != nullptr);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_funcs_show_up_in_local_stack")
{
auto result = runAutocompleteVisitor(
R"(
local function foo() return 4 end
local x = foo()
local function bar() return x + foo() end
)",
{3, 32}
);
CHECK_EQ(8, result.ancestry.size());
CHECK_EQ(3, result.localStack.size());
CHECK_EQ(result.localMap.size(), result.localStack.size());
CHECK_EQ("bar", std::string(result.localStack.back()->name.value));
auto returnSt = result.nearestStatement->as<AstStatReturn>();
CHECK(returnSt != nullptr);
}
TEST_SUITE_END();

View file

@ -3,6 +3,7 @@
#include "AstQueryDsl.h"
#include "Fixture.h"
#include "Luau/Common.h"
#include "ScopedFlags.h"
#include "doctest.h"
@ -11,13 +12,12 @@
using namespace Luau;
LUAU_FASTFLAG(LuauLexerLookaheadRemembersBraceType);
LUAU_FASTINT(LuauRecursionLimit);
LUAU_FASTINT(LuauTypeLengthLimit);
LUAU_FASTINT(LuauParseErrorLimit);
LUAU_FASTFLAG(LuauSolverV2);
LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr);
LUAU_FASTFLAG(LuauUserDefinedTypeFunctions);
LUAU_FASTINT(LuauRecursionLimit)
LUAU_FASTINT(LuauTypeLengthLimit)
LUAU_FASTINT(LuauParseErrorLimit)
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax)
namespace
{
@ -2380,7 +2380,7 @@ TEST_CASE_FIXTURE(Fixture, "invalid_type_forms")
TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions")
{
ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctions, true};
ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax, true};
AstStat* stat = parse(R"(
type function foo()
@ -3138,8 +3138,6 @@ TEST_CASE_FIXTURE(Fixture, "do_block_with_no_end")
TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved")
{
ScopedFastFlag sff{FFlag::LuauLexerLookaheadRemembersBraceType, true};
ParseResult result = tryParse(R"(
local x = `{ {y} }`
)");
@ -3149,8 +3147,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved")
TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved2")
{
ScopedFastFlag sff{FFlag::LuauLexerLookaheadRemembersBraceType, true};
ParseResult result = tryParse(R"(
local x = `{ { y{} } }`
)");

View file

@ -66,6 +66,7 @@ struct SubtypeFixture : Fixture
InternalErrorReporter iceReporter;
UnifierSharedState sharedState{&ice};
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
TypeFunctionRuntime typeFunctionRuntime;
ScopedFastFlag sff{FFlag::LuauSolverV2, true};
@ -77,7 +78,7 @@ struct SubtypeFixture : Fixture
Subtyping mkSubtyping()
{
return Subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&iceReporter}};
return Subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}};
}
TypePackId pack(std::initializer_list<TypeId> tys)

View file

@ -12,7 +12,7 @@
using namespace Luau;
LUAU_FASTFLAG(LuauUserDefinedTypeFunctions);
LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax)
TEST_SUITE_BEGIN("TranspilerTests");
@ -698,7 +698,7 @@ TEST_CASE_FIXTURE(Fixture, "transpile_string_literal_escape")
TEST_CASE_FIXTURE(Fixture, "transpile_type_functions")
{
ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctions, true};
ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax, true};
std::string code = R"( type function foo(arg1, arg2) if arg1 == arg2 then return arg1 end return arg2 end )";

View file

@ -1247,18 +1247,4 @@ TEST_CASE_FIXTURE(ClassFixture, "rawget_type_function_errors_w_classes")
CHECK(toString(result.errors[0]) == "Property '\"BaseField\"' does not exist on type 'BaseClass'");
}
TEST_CASE_FIXTURE(Fixture, "user_defined_type_function_errors")
{
if (!FFlag::LuauUserDefinedTypeFunctions)
return;
CheckResult result = check(R"(
type function foo()
return nil
end
)");
LUAU_CHECK_ERROR_COUNT(1, result);
CHECK(toString(result.errors[0]) == "This syntax is not supported");
}
TEST_SUITE_END();

File diff suppressed because it is too large Load diff

View file

@ -9,6 +9,7 @@
using namespace Luau;
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctions)
TEST_SUITE_BEGIN("TypeAliases");
@ -1169,8 +1170,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_adds_reduce_constraint_for_type_f
TEST_CASE_FIXTURE(Fixture, "user_defined_type_function_errors")
{
if (!FFlag::LuauUserDefinedTypeFunctions)
return;
ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax, true};
ScopedFastFlag noUDTFimpl{FFlag::LuauUserDefinedTypeFunctions, false};
CheckResult result = check(R"(
type function foo()

View file

@ -1427,4 +1427,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3")
CHECK_EQ(toString(requireType("e")), "number?");
}
TEST_CASE_FIXTURE(BuiltinsFixture, "string_find_should_not_crash")
{
ScopedFastFlag _{FFlag::LuauSolverV2, true};
LUAU_REQUIRE_NO_ERRORS(check(R"(
local function StringSplit(input, separator)
string.find(input, separator)
if not separator then
separator = "%s+"
end
end
)"));
}
TEST_SUITE_END();

View file

@ -15,7 +15,6 @@
using namespace Luau;
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauOkWithIteratingOverTableProperties)
LUAU_DYNAMIC_FASTFLAG(LuauImproveNonFunctionCallError)
@ -699,8 +698,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "loop_typecheck_crash_on_empty_optional")
if (FFlag::LuauSolverV2)
return;
ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true};
CheckResult result = check(R"(
local t = {}
for _ in t do
@ -784,7 +781,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer_strict")
// CLI-116498 Sometimes you can iterate over tables with no indexers.
ScopedFastFlag sff[] = {
{FFlag::LuauSolverV2, false},
{FFlag::LuauOkWithIteratingOverTableProperties, true}
};
CheckResult result = check(R"(
@ -937,8 +933,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cli_68448_iterators_need_not_accept_nil")
TEST_CASE_FIXTURE(Fixture, "iterate_over_free_table")
{
ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true};
CheckResult result = check(R"(
function print(x) end
@ -1095,8 +1089,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties")
// CLI-116498 - Sometimes you can iterate over tables with no indexer.
ScopedFastFlag sff0{FFlag::LuauSolverV2, false};
ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true};
CheckResult result = check(R"(
local function f()
local t = { p = 5, q = "hello" }
@ -1118,8 +1110,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties")
TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties_nonstrict")
{
ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true};
CheckResult result = check(R"(
--!nonstrict
local function f()

View file

@ -530,4 +530,82 @@ return l0
CHECK(mod->scopes[3].second->importedModules["l1"] == "game/A");
}
TEST_CASE_FIXTURE(BuiltinsFixture, "ensure_scope_is_nullptr_after_shallow_copy")
{
ScopedFastFlag _{FFlag::LuauSolverV2, true};
frontend.options.retainFullTypeGraphs = false;
fileResolver.source["game/A"] = R"(
-- Roughly taken from ReactTypes.lua
type CoreBinding<T> = {}
type BindingMap = {}
export type Binding<T> = CoreBinding<T> & BindingMap
return {}
)";
LUAU_REQUIRE_NO_ERRORS(check(R"(
local Types = require(game.A)
type Binding<T> = Types.Binding<T>
)"));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "ensure_free_variables_are_generialized_across_function_boundaries")
{
ScopedFastFlag _{FFlag::LuauSolverV2, true};
fileResolver.source["game/A"] = R"(
-- Roughly taken from react-shallow-renderer
function createUpdater(renderer)
local updater = {
_renderer = renderer,
}
function updater.enqueueForceUpdate(publicInstance, callback, _callerName)
updater._renderer.render(
updater._renderer,
updater._renderer._element,
updater._renderer._context
)
end
function updater.enqueueReplaceState(
publicInstance,
completeState,
callback,
_callerName
)
updater._renderer.render(
updater._renderer,
updater._renderer._element,
updater._renderer._context
)
end
function updater.enqueueSetState(publicInstance, partialState, callback, _callerName)
local currentState = updater._renderer._newState or publicInstance.state
updater._renderer.render(
updater._renderer,
updater._renderer._element,
updater._renderer._context
)
end
return updater
end
local ReactShallowRenderer = {}
function ReactShallowRenderer:_reset()
self._updater = createUpdater(self)
end
return ReactShallowRenderer
)";
LUAU_REQUIRE_NO_ERRORS(check(R"(
local ReactShallowRenderer = require(game.A);
)"));
}
TEST_SUITE_END();