Sync to upstream/release/645 (#1440)

In this update, we continue to improve the overall stability of the new
type solver. We're also shipping some early bits of two new features,
one of the language and one of the analysis API: user-defined type
functions and an incremental typechecking API.

If you use the new solver and want to use all new fixes included in this
release, you have to reference an additional Luau flag:
```c++
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
```
And set its value to `645`:
```c++
DFInt::LuauTypeSolverRelease.value = 645; // Or a higher value for future updates
```

## New Solver

* Fix a crash where scopes are incorrectly accessed cross-module after
they've been deallocated by appropriately zeroing out associated scope
pointers for free types, generic types, table types, etc.
* Fix a crash where we were incorrectly caching results for bound types
in generalization.
* Eliminated some unnecessary intermediate allocations in the constraint
solver and type function infrastructure.
* Built some initial groundwork for an incremental typecheck API for use
by language servers.
* Built an initial technical preview for [user-defined type
functions](https://rfcs.luau-lang.org/user-defined-type-functions.html),
more work still to come (including calling type functions from other
type functions), but adventurous folks wanting to experiment with it can
try it out by enabling `FFlag::LuauUserDefinedTypeFunctionsSyntax` and
`FFlag::LuauUserDefinedTypeFunction` in their local environment. Special
thanks to @joonyoo181 who built up all the initial infrastructure for
this during his internship!

## Miscellaneous changes

* Fix a compilation error on Ubuntu (fixes #1437)

---

Internal Contributors:

Co-authored-by: Aaron Weiss <aaronweiss@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Jeremy Yoo <jyoo@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>

---------

Co-authored-by: Alexander McCord <amccord@roblox.com>
Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Vighnesh <vvijay@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
Co-authored-by: David Cope <dcope@roblox.com>
Co-authored-by: Lily Brown <lbrown@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
Co-authored-by: Junseo Yoo <jyoo@roblox.com>
This commit is contained in:
aaron 2024-09-27 11:58:21 -07:00 committed by GitHub
parent c188715605
commit 02241b6d24
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
52 changed files with 5037 additions and 221 deletions

View file

@ -12,6 +12,7 @@
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/TypeCheckLimits.h" #include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFunction.h"
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
@ -62,6 +63,7 @@ struct ConstraintSolver
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
InternalErrorReporter iceReporter; InternalErrorReporter iceReporter;
NotNull<Normalizer> normalizer; NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
// The entire set of constraints that the solver is trying to resolve. // The entire set of constraints that the solver is trying to resolve.
std::vector<NotNull<Constraint>> constraints; std::vector<NotNull<Constraint>> constraints;
NotNull<Scope> rootScope; NotNull<Scope> rootScope;
@ -111,6 +113,7 @@ struct ConstraintSolver
explicit ConstraintSolver( explicit ConstraintSolver(
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> rootScope, NotNull<Scope> rootScope,
std::vector<NotNull<Constraint>> constraints, std::vector<NotNull<Constraint>> constraints,
ModuleName moduleName, ModuleName moduleName,
@ -278,18 +281,18 @@ public:
/** /**
* @returns true if the TypeId is in a blocked state. * @returns true if the TypeId is in a blocked state.
*/ */
bool isBlocked(TypeId ty); bool isBlocked(TypeId ty) const;
/** /**
* @returns true if the TypePackId is in a blocked state. * @returns true if the TypePackId is in a blocked state.
*/ */
bool isBlocked(TypePackId tp); bool isBlocked(TypePackId tp) const;
/** /**
* Returns whether the constraint is blocked on anything. * Returns whether the constraint is blocked on anything.
* @param constraint the constraint to check. * @param constraint the constraint to check.
*/ */
bool isBlocked(NotNull<const Constraint> constraint); bool isBlocked(NotNull<const Constraint> constraint) const;
/** Pushes a new solver constraint to the solver. /** Pushes a new solver constraint to the solver.
* @param cv the body of the constraint. * @param cv the body of the constraint.
@ -381,8 +384,8 @@ public:
TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp); TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp);
void throwTimeLimitError(); void throwTimeLimitError() const;
void throwUserCancelError(); void throwUserCancelError() const;
ToStringOptions opts; ToStringOptions opts;
}; };

View file

@ -448,6 +448,13 @@ struct UnexpectedTypePackInSubtyping
bool operator==(const UnexpectedTypePackInSubtyping& rhs) const; bool operator==(const UnexpectedTypePackInSubtyping& rhs) const;
}; };
struct UserDefinedTypeFunctionError
{
std::string message;
bool operator==(const UserDefinedTypeFunctionError& rhs) const;
};
using TypeErrorData = Variant< using TypeErrorData = Variant<
TypeMismatch, TypeMismatch,
UnknownSymbol, UnknownSymbol,
@ -496,7 +503,8 @@ using TypeErrorData = Variant<
CheckedFunctionIncorrectArgs, CheckedFunctionIncorrectArgs,
UnexpectedTypeInSubtyping, UnexpectedTypeInSubtyping,
UnexpectedTypePackInSubtyping, UnexpectedTypePackInSubtyping,
ExplicitFunctionAnnotationRecommended>; ExplicitFunctionAnnotationRecommended,
UserDefinedTypeFunctionError>;
struct TypeErrorSummary struct TypeErrorSummary
{ {

View file

@ -0,0 +1,23 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
#include "Luau/Ast.h"
#include <vector>
namespace Luau
{
struct FragmentAutocompleteAncestryResult
{
DenseHashMap<AstName, AstLocal*> localMap{AstName()};
std::vector<AstLocal*> localStack;
std::vector<AstNode*> ancestry;
AstStat* nearestStatement;
};
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos);
} // namespace Luau

View file

@ -35,6 +35,7 @@ struct OverloadResolver
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope, NotNull<Scope> scope,
NotNull<InternalErrorReporter> reporter, NotNull<InternalErrorReporter> reporter,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
@ -44,6 +45,7 @@ struct OverloadResolver
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena; NotNull<TypeArena> arena;
NotNull<Normalizer> normalizer; NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
NotNull<Scope> scope; NotNull<Scope> scope;
NotNull<InternalErrorReporter> ice; NotNull<InternalErrorReporter> ice;
NotNull<TypeCheckLimits> limits; NotNull<TypeCheckLimits> limits;
@ -109,6 +111,7 @@ SolveResult solveFunctionCall(
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter, NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
NotNull<Scope> scope, NotNull<Scope> scope,

View file

@ -135,6 +135,7 @@ struct Subtyping
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena; NotNull<TypeArena> arena;
NotNull<Normalizer> normalizer; NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
NotNull<InternalErrorReporter> iceReporter; NotNull<InternalErrorReporter> iceReporter;
TypeCheckLimits limits; TypeCheckLimits limits;
@ -155,6 +156,7 @@ struct Subtyping
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> typeArena, NotNull<TypeArena> typeArena,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter NotNull<InternalErrorReporter> iceReporter
); );

View file

@ -83,6 +83,7 @@ struct TypeChecker2
DenseHashSet<TypeId> seenTypeFunctionInstances{nullptr}; DenseHashSet<TypeId> seenTypeFunctionInstances{nullptr};
Normalizer normalizer; Normalizer normalizer;
TypeFunctionRuntime typeFunctionRuntime;
Subtyping _subtyping; Subtyping _subtyping;
NotNull<Subtyping> subtyping; NotNull<Subtyping> subtyping;

View file

@ -1,10 +1,11 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/ConstraintSolver.h" #include "Luau/Constraint.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/TypeCheckLimits.h" #include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFunctionRuntime.h"
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
#include <functional> #include <functional>
@ -16,14 +17,23 @@ namespace Luau
struct TypeArena; struct TypeArena;
struct TxnLog; struct TxnLog;
struct ConstraintSolver;
class Normalizer; class Normalizer;
struct TypeFunctionRuntime
{
// For user-defined type functions, we store all generated types and packs for the duration of the typecheck
TypedAllocator<TypeFunctionType> typeArena;
TypedAllocator<TypeFunctionTypePackVar> typePackArena;
};
struct TypeFunctionContext struct TypeFunctionContext
{ {
NotNull<TypeArena> arena; NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtins; NotNull<BuiltinTypes> builtins;
NotNull<Scope> scope; NotNull<Scope> scope;
NotNull<Normalizer> normalizer; NotNull<Normalizer> normalizer;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
NotNull<InternalErrorReporter> ice; NotNull<InternalErrorReporter> ice;
NotNull<TypeCheckLimits> limits; NotNull<TypeCheckLimits> limits;
@ -35,23 +45,14 @@ struct TypeFunctionContext
std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs
std::optional<AstExprFunction*> userFuncBody; // Body of the user-defined type function; only available for UDTFs std::optional<AstExprFunction*> userFuncBody; // Body of the user-defined type function; only available for UDTFs
TypeFunctionContext(NotNull<ConstraintSolver> cs, NotNull<Scope> scope, NotNull<const Constraint> constraint) TypeFunctionContext(NotNull<ConstraintSolver> cs, NotNull<Scope> scope, NotNull<const Constraint> constraint);
: arena(cs->arena)
, builtins(cs->builtinTypes)
, scope(scope)
, normalizer(cs->normalizer)
, ice(NotNull{&cs->iceReporter})
, limits(NotNull{&cs->limits})
, solver(cs.get())
, constraint(constraint.get())
{
}
TypeFunctionContext( TypeFunctionContext(
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtins, NotNull<BuiltinTypes> builtins,
NotNull<Scope> scope, NotNull<Scope> scope,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> ice, NotNull<InternalErrorReporter> ice,
NotNull<TypeCheckLimits> limits NotNull<TypeCheckLimits> limits
) )
@ -59,6 +60,7 @@ struct TypeFunctionContext
, builtins(builtins) , builtins(builtins)
, scope(scope) , scope(scope)
, normalizer(normalizer) , normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime)
, ice(ice) , ice(ice)
, limits(limits) , limits(limits)
, solver(nullptr) , solver(nullptr)
@ -66,7 +68,7 @@ struct TypeFunctionContext
{ {
} }
NotNull<Constraint> pushConstraint(ConstraintV&& c); NotNull<Constraint> pushConstraint(ConstraintV&& c) const;
}; };
/// Represents a reduction result, which may have successfully reduced the type, /// Represents a reduction result, which may have successfully reduced the type,
@ -88,6 +90,8 @@ struct TypeFunctionReductionResult
/// Any type packs that need to be progressed or mutated before the /// Any type packs that need to be progressed or mutated before the
/// reduction may proceed. /// reduction may proceed.
std::vector<TypePackId> blockedPacks; std::vector<TypePackId> blockedPacks;
/// A runtime error message from user-defined type functions
std::optional<std::string> error;
}; };
template<typename T> template<typename T>

View file

@ -0,0 +1,267 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Common.h"
#include "Luau/Variant.h"
#include <optional>
#include <string>
#include <unordered_map>
#include <vector>
using lua_State = struct lua_State;
namespace Luau
{
void* typeFunctionAlloc(void* ud, void* ptr, size_t osize, size_t nsize);
// Replica of types from Type.h
struct TypeFunctionType;
using TypeFunctionTypeId = const TypeFunctionType*;
struct TypeFunctionTypePackVar;
using TypeFunctionTypePackId = const TypeFunctionTypePackVar*;
struct TypeFunctionPrimitiveType
{
enum Type
{
NilType,
Boolean,
Number,
String,
};
Type type;
TypeFunctionPrimitiveType(Type type)
: type(type)
{
}
};
struct TypeFunctionBooleanSingleton
{
bool value = false;
};
struct TypeFunctionStringSingleton
{
std::string value;
};
using TypeFunctionSingletonVariant = Variant<TypeFunctionBooleanSingleton, TypeFunctionStringSingleton>;
struct TypeFunctionSingletonType
{
TypeFunctionSingletonVariant variant;
explicit TypeFunctionSingletonType(TypeFunctionSingletonVariant variant)
: variant(std::move(variant))
{
}
};
template<typename T>
const T* get(const TypeFunctionSingletonType* tv)
{
LUAU_ASSERT(tv);
return tv ? get_if<T>(&tv->variant) : nullptr;
}
template<typename T>
T* getMutable(const TypeFunctionSingletonType* tv)
{
LUAU_ASSERT(tv);
return tv ? get_if<T>(&const_cast<TypeFunctionSingletonType*>(tv)->variant) : nullptr;
}
struct TypeFunctionUnionType
{
std::vector<TypeFunctionTypeId> components;
};
struct TypeFunctionIntersectionType
{
std::vector<TypeFunctionTypeId> components;
};
struct TypeFunctionAnyType
{
};
struct TypeFunctionUnknownType
{
};
struct TypeFunctionNeverType
{
};
struct TypeFunctionNegationType
{
TypeFunctionTypeId type;
};
struct TypeFunctionTypePack
{
std::vector<TypeFunctionTypeId> head;
std::optional<TypeFunctionTypePackId> tail;
};
struct TypeFunctionVariadicTypePack
{
TypeFunctionTypeId type;
};
using TypeFunctionTypePackVariant = Variant<TypeFunctionTypePack, TypeFunctionVariadicTypePack>;
struct TypeFunctionTypePackVar
{
TypeFunctionTypePackVariant type;
TypeFunctionTypePackVar(TypeFunctionTypePackVariant type)
: type(std::move(type))
{
}
bool operator==(const TypeFunctionTypePackVar& rhs) const;
};
struct TypeFunctionFunctionType
{
TypeFunctionTypePackId argTypes;
TypeFunctionTypePackId retTypes;
};
template<typename T>
const T* get(TypeFunctionTypePackId tv)
{
LUAU_ASSERT(tv);
return tv ? get_if<T>(&tv->type) : nullptr;
}
template<typename T>
T* getMutable(TypeFunctionTypePackId tv)
{
LUAU_ASSERT(tv);
return tv ? get_if<T>(&const_cast<TypeFunctionTypePackVar*>(tv)->type) : nullptr;
}
struct TypeFunctionTableIndexer
{
TypeFunctionTableIndexer(TypeFunctionTypeId keyType, TypeFunctionTypeId valueType)
: keyType(keyType)
, valueType(valueType)
{
}
TypeFunctionTypeId keyType;
TypeFunctionTypeId valueType;
};
struct TypeFunctionProperty
{
static TypeFunctionProperty readonly(TypeFunctionTypeId ty);
static TypeFunctionProperty writeonly(TypeFunctionTypeId ty);
static TypeFunctionProperty rw(TypeFunctionTypeId ty); // Shared read-write type.
static TypeFunctionProperty rw(TypeFunctionTypeId read, TypeFunctionTypeId write); // Separate read-write type.
bool isReadOnly() const;
bool isWriteOnly() const;
std::optional<TypeFunctionTypeId> readTy;
std::optional<TypeFunctionTypeId> writeTy;
};
struct TypeFunctionTableType
{
using Name = std::string;
using Props = std::unordered_map<Name, TypeFunctionProperty>;
Props props;
std::optional<TypeFunctionTableIndexer> indexer;
// Should always be a TypeFunctionTableType
std::optional<TypeFunctionTypeId> metatable;
};
struct TypeFunctionClassType
{
using Name = std::string;
using Props = std::unordered_map<Name, TypeFunctionProperty>;
Props props;
std::optional<TypeFunctionTableIndexer> indexer;
std::optional<TypeFunctionTypeId> metatable; // metaclass?
std::optional<TypeFunctionTypeId> parent;
std::string name;
};
using TypeFunctionTypeVariant = Luau::Variant<
TypeFunctionPrimitiveType,
TypeFunctionAnyType,
TypeFunctionUnknownType,
TypeFunctionNeverType,
TypeFunctionSingletonType,
TypeFunctionUnionType,
TypeFunctionIntersectionType,
TypeFunctionNegationType,
TypeFunctionFunctionType,
TypeFunctionTableType,
TypeFunctionClassType>;
struct TypeFunctionType
{
TypeFunctionTypeVariant type;
TypeFunctionType(TypeFunctionTypeVariant type)
: type(std::move(type))
{
}
bool operator==(const TypeFunctionType& rhs) const;
};
template<typename T>
const T* get(TypeFunctionTypeId tv)
{
LUAU_ASSERT(tv);
return tv ? Luau::get_if<T>(&tv->type) : nullptr;
}
template<typename T>
T* getMutable(TypeFunctionTypeId tv)
{
LUAU_ASSERT(tv);
return tv ? Luau::get_if<T>(&const_cast<TypeFunctionType*>(tv)->type) : nullptr;
}
std::optional<std::string> checkResultForError(lua_State* L, const char* typeFunctionName, int luaResult);
TypeFunctionType* allocateTypeFunctionType(lua_State* L, TypeFunctionTypeVariant type);
TypeFunctionTypePackVar* allocateTypeFunctionTypePack(lua_State* L, TypeFunctionTypePackVariant type);
void allocTypeUserData(lua_State* L, TypeFunctionTypeVariant type);
bool isTypeUserData(lua_State* L, int idx);
TypeFunctionTypeId getTypeUserData(lua_State* L, int idx);
std::optional<TypeFunctionTypeId> optionalTypeUserData(lua_State* L, int idx);
void registerTypeUserData(lua_State* L);
void setTypeFunctionEnvironment(lua_State* L);
} // namespace Luau

View file

@ -0,0 +1,52 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Type.h"
#include "Luau/TypeFunction.h"
#include "Luau/TypeFunctionRuntime.h"
namespace Luau
{
using Kind = Variant<TypeId, TypePackId>;
template<typename T>
const T* get(const Kind& kind)
{
return get_if<T>(&kind);
}
using TypeFunctionKind = Variant<TypeFunctionTypeId, TypeFunctionTypePackId>;
template<typename T>
const T* get(const TypeFunctionKind& tfkind)
{
return get_if<T>(&tfkind);
}
struct TypeFunctionRuntimeBuilderState
{
NotNull<TypeFunctionContext> ctx;
// Mapping of class name to ClassType
// Invariant: users can not create a new class types -> any class types that get deserialized must have been an argument to the type function
// Using this invariant, whenever a ClassType is serialized, we can put it into this map
// whenever a ClassType is deserialized, we can use this map to return the corresponding value
DenseHashMap<std::string, TypeId> classesSerialized{{}};
// List of errors that occur during serialization/deserialization
// At every iteration of serialization/deserialzation, if this list.size() != 0, we halt the process
std::vector<std::string> errors{};
TypeFunctionRuntimeBuilderState(NotNull<TypeFunctionContext> ctx)
: ctx(ctx)
, classesSerialized({})
, errors({})
{
}
};
TypeFunctionTypeId serialize(TypeId ty, TypeFunctionRuntimeBuilderState* state);
TypeId deserialize(TypeFunctionTypeId ty, TypeFunctionRuntimeBuilderState* state);
} // namespace Luau

View file

@ -149,13 +149,15 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull<Scope> scope, T
if (FFlag::LuauSolverV2) if (FFlag::LuauSolverV2)
{ {
TypeFunctionRuntime typeFunctionRuntime; // TODO: maybe subtyping checks should not invoke user-defined type function runtime
if (FFlag::LuauAutocompleteNewSolverLimit) if (FFlag::LuauAutocompleteNewSolverLimit)
{ {
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit;
} }
Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&iceReporter}}; Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}};
return subtyping.isSubtype(subTy, superTy, scope).isSubtype; return subtyping.isSubtype(subTy, superTy, scope).isSubtype;
} }

View file

@ -321,6 +321,7 @@ struct InstantiationQueuer : TypeOnceVisitor
ConstraintSolver::ConstraintSolver( ConstraintSolver::ConstraintSolver(
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> rootScope, NotNull<Scope> rootScope,
std::vector<NotNull<Constraint>> constraints, std::vector<NotNull<Constraint>> constraints,
ModuleName moduleName, ModuleName moduleName,
@ -332,11 +333,12 @@ ConstraintSolver::ConstraintSolver(
: arena(normalizer->arena) : arena(normalizer->arena)
, builtinTypes(normalizer->builtinTypes) , builtinTypes(normalizer->builtinTypes)
, normalizer(normalizer) , normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime)
, constraints(std::move(constraints)) , constraints(std::move(constraints))
, rootScope(rootScope) , rootScope(rootScope)
, currentModuleName(std::move(moduleName)) , currentModuleName(std::move(moduleName))
, moduleResolver(moduleResolver) , moduleResolver(moduleResolver)
, requireCycles(requireCycles) , requireCycles(std::move(requireCycles))
, logger(logger) , logger(logger)
, limits(std::move(limits)) , limits(std::move(limits))
{ {
@ -344,7 +346,7 @@ ConstraintSolver::ConstraintSolver(
for (NotNull<Constraint> c : this->constraints) for (NotNull<Constraint> c : this->constraints)
{ {
unsolvedConstraints.push_back(c); unsolvedConstraints.emplace_back(c);
// initialize the reference counts for the free types in this constraint. // initialize the reference counts for the free types in this constraint.
for (auto ty : c->getMaybeMutatedFreeTypes()) for (auto ty : c->getMaybeMutatedFreeTypes())
@ -1240,7 +1242,14 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
} }
OverloadResolver resolver{ OverloadResolver resolver{
builtinTypes, NotNull{arena}, normalizer, constraint->scope, NotNull{&iceReporter}, NotNull{&limits}, constraint->location builtinTypes,
NotNull{arena},
normalizer,
typeFunctionRuntime,
constraint->scope,
NotNull{&iceReporter},
NotNull{&limits},
constraint->location
}; };
auto [status, overload] = resolver.selectOverload(fn, argsPack); auto [status, overload] = resolver.selectOverload(fn, argsPack);
TypeId overloadToUse = fn; TypeId overloadToUse = fn;
@ -1270,7 +1279,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
for (const auto& [expanded, additions] : u2.expandedFreeTypes) for (const auto& [expanded, additions] : u2.expandedFreeTypes)
{ {
for (TypeId addition : additions) for (TypeId addition : additions)
upperBoundContributors[expanded].push_back(std::make_pair(constraint->location, addition)); upperBoundContributors[expanded].emplace_back(constraint->location, addition);
} }
if (occursCheckPassed && c.callSite) if (occursCheckPassed && c.callSite)
@ -1437,8 +1446,17 @@ bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNull<con
else if (expectedType && maybeSingleton(*expectedType)) else if (expectedType && maybeSingleton(*expectedType))
bindTo = freeType->lowerBound; bindTo = freeType->lowerBound;
shiftReferences(c.freeType, bindTo); if (DFInt::LuauTypeSolverRelease >= 645)
bind(constraint, c.freeType, bindTo); {
auto ty = follow(c.freeType);
shiftReferences(ty, bindTo);
bind(constraint, ty, bindTo);
}
else
{
shiftReferences(c.freeType, bindTo);
bind(constraint, c.freeType, bindTo);
}
return true; return true;
} }
@ -2603,7 +2621,7 @@ bool ConstraintSolver::unify(NotNull<const Constraint> constraint, TID subTy, TI
for (const auto& [expanded, additions] : u2.expandedFreeTypes) for (const auto& [expanded, additions] : u2.expandedFreeTypes)
{ {
for (TypeId addition : additions) for (TypeId addition : additions)
upperBoundContributors[expanded].push_back(std::make_pair(constraint->location, addition)); upperBoundContributors[expanded].emplace_back(constraint->location, addition);
} }
} }
else else
@ -2820,7 +2838,7 @@ void ConstraintSolver::reproduceConstraints(NotNull<Scope> scope, const Location
} }
} }
bool ConstraintSolver::isBlocked(TypeId ty) bool ConstraintSolver::isBlocked(TypeId ty) const
{ {
ty = follow(ty); ty = follow(ty);
@ -2830,7 +2848,7 @@ bool ConstraintSolver::isBlocked(TypeId ty)
return nullptr != get<BlockedType>(ty) || nullptr != get<PendingExpansionType>(ty); return nullptr != get<BlockedType>(ty) || nullptr != get<PendingExpansionType>(ty);
} }
bool ConstraintSolver::isBlocked(TypePackId tp) bool ConstraintSolver::isBlocked(TypePackId tp) const
{ {
tp = follow(tp); tp = follow(tp);
@ -2840,7 +2858,7 @@ bool ConstraintSolver::isBlocked(TypePackId tp)
return nullptr != get<BlockedTypePack>(tp); return nullptr != get<BlockedTypePack>(tp);
} }
bool ConstraintSolver::isBlocked(NotNull<const Constraint> constraint) bool ConstraintSolver::isBlocked(NotNull<const Constraint> constraint) const
{ {
auto blockedIt = blockedConstraints.find(constraint); auto blockedIt = blockedConstraints.find(constraint);
return blockedIt != blockedConstraints.end() && blockedIt->second > 0; return blockedIt != blockedConstraints.end() && blockedIt->second > 0;
@ -2851,7 +2869,7 @@ NotNull<Constraint> ConstraintSolver::pushConstraint(NotNull<Scope> scope, const
std::unique_ptr<Constraint> c = std::make_unique<Constraint>(scope, location, std::move(cv)); std::unique_ptr<Constraint> c = std::make_unique<Constraint>(scope, location, std::move(cv));
NotNull<Constraint> borrow = NotNull(c.get()); NotNull<Constraint> borrow = NotNull(c.get());
solverConstraints.push_back(std::move(c)); solverConstraints.push_back(std::move(c));
unsolvedConstraints.push_back(borrow); unsolvedConstraints.emplace_back(borrow);
return borrow; return borrow;
} }
@ -2997,12 +3015,12 @@ TypePackId ConstraintSolver::anyifyModuleReturnTypePackGenerics(TypePackId tp)
return arena->addTypePack(resultTypes, resultTail); return arena->addTypePack(resultTypes, resultTail);
} }
LUAU_NOINLINE void ConstraintSolver::throwTimeLimitError() LUAU_NOINLINE void ConstraintSolver::throwTimeLimitError() const
{ {
throw TimeLimitError(currentModuleName); throw TimeLimitError(currentModuleName);
} }
LUAU_NOINLINE void ConstraintSolver::throwUserCancelError() LUAU_NOINLINE void ConstraintSolver::throwUserCancelError() const
{ {
throw UserCancelError(currentModuleName); throw UserCancelError(currentModuleName);
} }

View file

@ -793,6 +793,11 @@ struct ErrorConverter
return "Encountered an unexpected type pack in subtyping: " + toString(e.tp); return "Encountered an unexpected type pack in subtyping: " + toString(e.tp);
} }
std::string operator()(const UserDefinedTypeFunctionError& e) const
{
return e.message;
}
std::string operator()(const CannotAssignToNever& e) const std::string operator()(const CannotAssignToNever& e) const
{ {
std::string result = "Cannot assign a value of type " + toString(e.rhsType) + " to a field of type never"; std::string result = "Cannot assign a value of type " + toString(e.rhsType) + " to a field of type never";
@ -1175,6 +1180,11 @@ bool UnexpectedTypePackInSubtyping::operator==(const UnexpectedTypePackInSubtypi
return tp == rhs.tp; return tp == rhs.tp;
} }
bool UserDefinedTypeFunctionError::operator==(const UserDefinedTypeFunctionError& rhs) const
{
return message == rhs.message;
}
bool CannotAssignToNever::operator==(const CannotAssignToNever& rhs) const bool CannotAssignToNever::operator==(const CannotAssignToNever& rhs) const
{ {
if (cause.size() != rhs.cause.size()) if (cause.size() != rhs.cause.size())
@ -1384,6 +1394,9 @@ void copyError(T& e, TypeArena& destArena, CloneState& cloneState)
e.ty = clone(e.ty); e.ty = clone(e.ty);
else if constexpr (std::is_same_v<T, UnexpectedTypePackInSubtyping>) else if constexpr (std::is_same_v<T, UnexpectedTypePackInSubtyping>)
e.tp = clone(e.tp); e.tp = clone(e.tp);
else if constexpr (std::is_same_v<T, UserDefinedTypeFunctionError>)
{
}
else if constexpr (std::is_same_v<T, CannotAssignToNever>) else if constexpr (std::is_same_v<T, CannotAssignToNever>)
{ {
e.rhsType = clone(e.rhsType); e.rhsType = clone(e.rhsType);

View file

@ -0,0 +1,48 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/FragmentAutocomplete.h"
#include "Luau/Ast.h"
#include "Luau/AstQuery.h"
namespace Luau
{
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos)
{
std::vector<AstNode*> ancestry = findAncestryAtPositionForAutocomplete(root, cursorPos);
DenseHashMap<AstName, AstLocal*> localMap{AstName()};
std::vector<AstLocal*> localStack;
AstStat* nearestStatement = nullptr;
for (AstNode* node : ancestry)
{
if (auto block = node->as<AstStatBlock>())
{
for (auto stat : block->body)
{
if (stat->location.begin <= cursorPos)
nearestStatement = stat;
if (stat->location.begin <= cursorPos)
{
// This statement precedes the current one
if (auto loc = stat->as<AstStatLocal>())
{
for (auto v : loc->vars)
{
localStack.push_back(v);
localMap[v->name] = v;
}
}
else if (auto locFun = stat->as<AstStatLocalFunction>())
{
localStack.push_back(locFun->name);
localMap[locFun->name->name] = locFun->name;
}
}
}
}
}
return {std::move(localMap), std::move(localStack), std::move(ancestry), std::move(nearestStatement)};
}
} // namespace Luau

View file

@ -1383,6 +1383,7 @@ ModulePtr check(
unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit); unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit);
Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}}; Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}};
TypeFunctionRuntime typeFunctionRuntime;
ConstraintGenerator cg{ ConstraintGenerator cg{
result, result,
@ -1402,6 +1403,7 @@ ModulePtr check(
ConstraintSolver cs{ ConstraintSolver cs{
NotNull{&normalizer}, NotNull{&normalizer},
NotNull{&typeFunctionRuntime},
NotNull(cg.rootScope), NotNull(cg.rootScope),
borrowConstraints(cg.constraints), borrowConstraints(cg.constraints),
result->name, result->name,

View file

@ -9,6 +9,8 @@
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/VisitType.h" #include "Luau/VisitType.h"
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
namespace Luau namespace Luau
{ {
@ -871,6 +873,17 @@ struct TypeCacher : TypeOnceVisitor
markUncacheable(tp); markUncacheable(tp);
return false; return false;
} }
bool visit(TypePackId tp, const BoundTypePack& btp) override {
if (DFInt::LuauTypeSolverRelease >= 645) {
traverse(btp.boundTo);
if (isUncacheable(btp.boundTo))
markUncacheable(tp);
return false;
}
return true;
}
}; };
std::optional<TypeId> generalize( std::optional<TypeId> generalize(

View file

@ -227,6 +227,8 @@ static void errorToString(std::ostream& stream, const T& err)
stream << "UnexpectedTypeInSubtyping { ty = '" + toString(err.ty) + "' }"; stream << "UnexpectedTypeInSubtyping { ty = '" + toString(err.ty) + "' }";
else if constexpr (std::is_same_v<T, UnexpectedTypePackInSubtyping>) else if constexpr (std::is_same_v<T, UnexpectedTypePackInSubtyping>)
stream << "UnexpectedTypePackInSubtyping { tp = '" + toString(err.tp) + "' }"; stream << "UnexpectedTypePackInSubtyping { tp = '" + toString(err.tp) + "' }";
else if constexpr (std::is_same_v<T, UserDefinedTypeFunctionError>)
stream << "UserDefinedTypeFunctionError { " << err.message << " }";
else if constexpr (std::is_same_v<T, CannotAssignToNever>) else if constexpr (std::is_same_v<T, CannotAssignToNever>)
{ {
stream << "CannotAssignToNever { rvalueType = '" << toString(err.rhsType) << "', reason = '" << err.reason << "', cause = { "; stream << "CannotAssignToNever { rvalueType = '" << toString(err.rhsType) << "', reason = '" << err.reason << "', cause = { ";

View file

@ -15,6 +15,7 @@
#include <algorithm> #include <algorithm>
LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(LuauSolverV2);
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
namespace Luau namespace Luau
{ {
@ -131,10 +132,26 @@ struct ClonePublicInterface : Substitution
} }
ftv->level = TypeLevel{0, 0}; ftv->level = TypeLevel{0, 0};
if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645)
ftv->scope = nullptr;
} }
else if (TableType* ttv = getMutable<TableType>(result)) else if (TableType* ttv = getMutable<TableType>(result))
{ {
ttv->level = TypeLevel{0, 0}; ttv->level = TypeLevel{0, 0};
if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645)
ttv->scope = nullptr;
}
if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645)
{
if (auto freety = getMutable<FreeType>(result))
{
freety->scope = nullptr;
}
else if (auto genericty = getMutable<GenericType>(result))
{
genericty->scope = nullptr;
}
} }
return result; return result;

View file

@ -160,6 +160,7 @@ struct NonStrictTypeChecker
NotNull<TypeArena> arena; NotNull<TypeArena> arena;
Module* module; Module* module;
Normalizer normalizer; Normalizer normalizer;
TypeFunctionRuntime typeFunctionRuntime;
Subtyping subtyping; Subtyping subtyping;
NotNull<const DataFlowGraph> dfg; NotNull<const DataFlowGraph> dfg;
DenseHashSet<TypeId> noTypeFunctionErrors{nullptr}; DenseHashSet<TypeId> noTypeFunctionErrors{nullptr};
@ -182,7 +183,7 @@ struct NonStrictTypeChecker
, arena(arena) , arena(arena)
, module(module) , module(module)
, normalizer{arena, builtinTypes, unifierState, /* cache inhabitance */ true} , normalizer{arena, builtinTypes, unifierState, /* cache inhabitance */ true}
, subtyping{builtinTypes, arena, NotNull(&normalizer), ice} , subtyping{builtinTypes, arena, NotNull(&normalizer), NotNull(&typeFunctionRuntime), ice}
, dfg(dfg) , dfg(dfg)
, limits(limits) , limits(limits)
{ {
@ -228,7 +229,12 @@ struct NonStrictTypeChecker
return instance; return instance;
ErrorVec errors = ErrorVec errors =
reduceTypeFunctions(instance, location, TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true) reduceTypeFunctions(
instance,
location,
TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, ice, limits},
true
)
.errors; .errors;
if (errors.empty()) if (errors.empty())

View file

@ -3434,11 +3434,12 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<Built
UnifierSharedState sharedState{&ice}; UnifierSharedState sharedState{&ice};
TypeArena arena; TypeArena arena;
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
TypeFunctionRuntime typeFunctionRuntime; // TODO: maybe subtyping checks should not invoke user-defined type function runtime
// Subtyping under DCR is not implemented using unification! // Subtyping under DCR is not implemented using unification!
if (FFlag::LuauSolverV2) if (FFlag::LuauSolverV2)
{ {
Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&ice}}; Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}};
return subtyping.isSubtype(subTy, superTy, scope).isSubtype; return subtyping.isSubtype(subTy, superTy, scope).isSubtype;
} }
@ -3456,11 +3457,12 @@ bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, N
UnifierSharedState sharedState{&ice}; UnifierSharedState sharedState{&ice};
TypeArena arena; TypeArena arena;
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
TypeFunctionRuntime typeFunctionRuntime; // TODO: maybe subtyping checks should not invoke user-defined type function runtime
// Subtyping under DCR is not implemented using unification! // Subtyping under DCR is not implemented using unification!
if (FFlag::LuauSolverV2) if (FFlag::LuauSolverV2)
{ {
Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&ice}}; Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}};
return subtyping.isSubtype(subPack, superPack, scope).isSubtype; return subtyping.isSubtype(subPack, superPack, scope).isSubtype;
} }

View file

@ -17,6 +17,7 @@ OverloadResolver::OverloadResolver(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope, NotNull<Scope> scope,
NotNull<InternalErrorReporter> reporter, NotNull<InternalErrorReporter> reporter,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
@ -25,10 +26,11 @@ OverloadResolver::OverloadResolver(
: builtinTypes(builtinTypes) : builtinTypes(builtinTypes)
, arena(arena) , arena(arena)
, normalizer(normalizer) , normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime)
, scope(scope) , scope(scope)
, ice(reporter) , ice(reporter)
, limits(limits) , limits(limits)
, subtyping({builtinTypes, arena, normalizer, ice}) , subtyping({builtinTypes, arena, normalizer, typeFunctionRuntime, ice})
, callLoc(callLocation) , callLoc(callLocation)
{ {
} }
@ -199,8 +201,9 @@ std::pair<OverloadResolver::Analysis, ErrorVec> OverloadResolver::checkOverload_
const std::vector<AstExpr*>* argExprs const std::vector<AstExpr*>* argExprs
) )
{ {
FunctionGraphReductionResult result = FunctionGraphReductionResult result = reduceTypeFunctions(
reduceTypeFunctions(fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, ice, limits}, /*force=*/true); fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, ice, limits}, /*force=*/true
);
if (!result.errors.empty()) if (!result.errors.empty())
return {OverloadIsNonviable, result.errors}; return {OverloadIsNonviable, result.errors};
@ -405,6 +408,7 @@ std::optional<TypeId> selectOverload(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope, NotNull<Scope> scope,
NotNull<InternalErrorReporter> iceReporter, NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
@ -413,7 +417,7 @@ std::optional<TypeId> selectOverload(
TypePackId argsPack TypePackId argsPack
) )
{ {
OverloadResolver resolver{builtinTypes, arena, normalizer, scope, iceReporter, limits, location}; OverloadResolver resolver{builtinTypes, arena, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location};
auto [status, overload] = resolver.selectOverload(fn, argsPack); auto [status, overload] = resolver.selectOverload(fn, argsPack);
if (status == OverloadResolver::Analysis::Ok) if (status == OverloadResolver::Analysis::Ok)
@ -429,6 +433,7 @@ SolveResult solveFunctionCall(
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter, NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
NotNull<Scope> scope, NotNull<Scope> scope,
@ -437,7 +442,8 @@ SolveResult solveFunctionCall(
TypePackId argsPack TypePackId argsPack
) )
{ {
std::optional<TypeId> overloadToUse = selectOverload(builtinTypes, arena, normalizer, scope, iceReporter, limits, location, fn, argsPack); std::optional<TypeId> overloadToUse =
selectOverload(builtinTypes, arena, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location, fn, argsPack);
if (!overloadToUse) if (!overloadToUse)
return {SolveResult::NoMatchingOverload}; return {SolveResult::NoMatchingOverload};

View file

@ -440,11 +440,13 @@ Subtyping::Subtyping(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> typeArena, NotNull<TypeArena> typeArena,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<InternalErrorReporter> iceReporter NotNull<InternalErrorReporter> iceReporter
) )
: builtinTypes(builtinTypes) : builtinTypes(builtinTypes)
, arena(typeArena) , arena(typeArena)
, normalizer(normalizer) , normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime)
, iceReporter(iceReporter) , iceReporter(iceReporter)
{ {
} }
@ -1911,7 +1913,7 @@ TypeId Subtyping::makeAggregateType(const Container& container, TypeId orElse)
std::pair<TypeId, ErrorVec> Subtyping::handleTypeFunctionReductionResult(const TypeFunctionInstanceType* functionInstance, NotNull<Scope> scope) std::pair<TypeId, ErrorVec> Subtyping::handleTypeFunctionReductionResult(const TypeFunctionInstanceType* functionInstance, NotNull<Scope> scope)
{ {
TypeFunctionContext context{arena, builtinTypes, scope, normalizer, iceReporter, NotNull{&limits}}; TypeFunctionContext context{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, iceReporter, NotNull{&limits}};
TypeId function = arena->addType(*functionInstance); TypeId function = arena->addType(*functionInstance);
FunctionGraphReductionResult result = reduceTypeFunctions(function, {}, context, true); FunctionGraphReductionResult result = reduceTypeFunctions(function, {}, context, true);
ErrorVec errors; ErrorVec errors;

View file

@ -1040,6 +1040,7 @@ struct TypeStringifier
state.emit(tfitv.userFuncName->value); state.emit(tfitv.userFuncName->value);
else else
state.emit(tfitv.function->name); state.emit(tfitv.function->name);
state.emit("<"); state.emit("<");
bool comma = false; bool comma = false;

View file

@ -31,6 +31,7 @@
#include <ostream> #include <ostream>
LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctions)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
namespace Luau namespace Luau
@ -306,7 +307,7 @@ TypeChecker2::TypeChecker2(
, sourceModule(sourceModule) , sourceModule(sourceModule)
, module(module) , module(module)
, normalizer{&module->internalTypes, builtinTypes, unifierState, /* cacheInhabitance */ true} , normalizer{&module->internalTypes, builtinTypes, unifierState, /* cacheInhabitance */ true}
, _subtyping{builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, NotNull{unifierState->iceHandler}} , _subtyping{builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{unifierState->iceHandler}}
, subtyping(&_subtyping) , subtyping(&_subtyping)
{ {
} }
@ -484,13 +485,16 @@ TypeId TypeChecker2::checkForTypeFunctionInhabitance(TypeId instance, Location l
return instance; return instance;
seenTypeFunctionInstances.insert(instance); seenTypeFunctionInstances.insert(instance);
ErrorVec errors = reduceTypeFunctions( ErrorVec errors =
instance, reduceTypeFunctions(
location, instance,
TypeFunctionContext{NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, location,
true TypeFunctionContext{
) NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, ice, limits
.errors; },
true
)
.errors;
if (!isErrorSuppressing(location, instance)) if (!isErrorSuppressing(location, instance))
reportErrors(std::move(errors)); reportErrors(std::move(errors));
return instance; return instance;
@ -1194,8 +1198,8 @@ void TypeChecker2::visit(AstStatTypeAlias* stat)
void TypeChecker2::visit(AstStatTypeFunction* stat) void TypeChecker2::visit(AstStatTypeFunction* stat)
{ {
// TODO: add type checking for user-defined type functions // TODO: add type checking for user-defined type functions
if (!FFlag::LuauUserDefinedTypeFunctions)
reportError(TypeError{stat->location, GenericError{"This syntax is not supported"}}); reportError(TypeError{stat->location, GenericError{"This syntax is not supported"}});
} }
void TypeChecker2::visit(AstTypeList types) void TypeChecker2::visit(AstTypeList types)
@ -1446,6 +1450,7 @@ void TypeChecker2::visitCall(AstExprCall* call)
builtinTypes, builtinTypes,
NotNull{&module->internalTypes}, NotNull{&module->internalTypes},
NotNull{&normalizer}, NotNull{&normalizer},
NotNull{&typeFunctionRuntime},
NotNull{stack.back()}, NotNull{stack.back()},
ice, ice,
limits, limits,

View file

@ -2,7 +2,9 @@
#include "Luau/TypeFunction.h" #include "Luau/TypeFunction.h"
#include "Luau/BytecodeBuilder.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Compiler.h"
#include "Luau/ConstraintSolver.h" #include "Luau/ConstraintSolver.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/Instantiation.h" #include "Luau/Instantiation.h"
@ -12,17 +14,25 @@
#include "Luau/Set.h" #include "Luau/Set.h"
#include "Luau/Simplify.h" #include "Luau/Simplify.h"
#include "Luau/Subtyping.h" #include "Luau/Subtyping.h"
#include "Luau/TimeTrace.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/TypeFunctionReductionGuesser.h" #include "Luau/TypeFunctionReductionGuesser.h"
#include "Luau/TypeFunctionRuntime.h"
#include "Luau/TypeFunctionRuntimeBuilder.h"
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
#include "Luau/TypeUtils.h" #include "Luau/TypeUtils.h"
#include "Luau/Unifier2.h" #include "Luau/Unifier2.h"
#include "Luau/VecDeque.h" #include "Luau/VecDeque.h"
#include "Luau/VisitType.h" #include "Luau/VisitType.h"
#include "lua.h"
#include "lualib.h"
#include <iterator> #include <iterator>
#include <memory>
#include <unordered_map>
// used to control emitting CodeTooComplex warnings on type function reduction // used to control emitting CodeTooComplex warnings on type function reduction
LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyGraphReductionMaximumSteps, 1'000'000); LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyGraphReductionMaximumSteps, 1'000'000);
@ -35,7 +45,8 @@ LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyApplicationCartesianProductLimit, 5'0
// when this value is set to a negative value, guessing will be totally disabled. // when this value is set to a negative value, guessing will be totally disabled.
LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyUseGuesserDepth, -1); LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyUseGuesserDepth, -1);
LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies, false); LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies, false)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctions, false)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
@ -166,7 +177,7 @@ struct TypeFunctionReducer
return SkipTestResult::Okay; return SkipTestResult::Okay;
} }
SkipTestResult testForSkippability(TypePackId ty) SkipTestResult testForSkippability(TypePackId ty) const
{ {
ty = follow(ty); ty = follow(ty);
@ -214,15 +225,18 @@ struct TypeFunctionReducer
{ {
irreducible.insert(subject); irreducible.insert(subject);
if (reduction.error.has_value())
result.errors.emplace_back(location, UserDefinedTypeFunctionError{*reduction.error});
if (reduction.uninhabited || force) if (reduction.uninhabited || force)
{ {
if (FFlag::DebugLuauLogTypeFamilies) if (FFlag::DebugLuauLogTypeFamilies)
printf("%s is uninhabited\n", toString(subject, {true}).c_str()); printf("%s is uninhabited\n", toString(subject, {true}).c_str());
if constexpr (std::is_same_v<T, TypeId>) if constexpr (std::is_same_v<T, TypeId>)
result.errors.push_back(TypeError{location, UninhabitedTypeFunction{subject}}); result.errors.emplace_back(location, UninhabitedTypeFunction{subject});
else if constexpr (std::is_same_v<T, TypePackId>) else if constexpr (std::is_same_v<T, TypePackId>)
result.errors.push_back(TypeError{location, UninhabitedTypePackFunction{subject}}); result.errors.emplace_back(location, UninhabitedTypePackFunction{subject});
} }
else if (!reduction.uninhabited && !force) else if (!reduction.uninhabited && !force)
{ {
@ -243,7 +257,7 @@ struct TypeFunctionReducer
} }
} }
bool done() bool done() const
{ {
return queuedTys.empty() && queuedTps.empty(); return queuedTys.empty() && queuedTps.empty();
} }
@ -422,7 +436,7 @@ static FunctionGraphReductionResult reduceFunctionsInternal(
++iterationCount; ++iterationCount;
if (iterationCount > DFInt::LuauTypeFamilyGraphReductionMaximumSteps) if (iterationCount > DFInt::LuauTypeFamilyGraphReductionMaximumSteps)
{ {
reducer.result.errors.push_back(TypeError{location, CodeTooComplex{}}); reducer.result.errors.emplace_back(location, CodeTooComplex{});
break; break;
} }
} }
@ -506,7 +520,7 @@ static std::optional<TypeFunctionReductionResult<TypeId>> tryDistributeTypeFunct
size_t cartesianProductSize = 1; size_t cartesianProductSize = 1;
const UnionType* firstUnion = nullptr; const UnionType* firstUnion = nullptr;
size_t unionIndex; size_t unionIndex = 0;
std::vector<TypeId> arguments = typeParams; std::vector<TypeId> arguments = typeParams;
for (size_t i = 0; i < arguments.size(); ++i) for (size_t i = 0; i < arguments.size(); ++i)
@ -572,6 +586,8 @@ static std::optional<TypeFunctionReductionResult<TypeId>> tryDistributeTypeFunct
return std::nullopt; return std::nullopt;
} }
using StateRef = std::unique_ptr<lua_State, void (*)(lua_State*)>;
TypeFunctionReductionResult<TypeId> userDefinedTypeFunction( TypeFunctionReductionResult<TypeId> userDefinedTypeFunction(
TypeId instance, TypeId instance,
const std::vector<TypeId>& typeParams, const std::vector<TypeId>& typeParams,
@ -585,9 +601,122 @@ TypeFunctionReductionResult<TypeId> userDefinedTypeFunction(
return {std::nullopt, true, {}, {}}; return {std::nullopt, true, {}, {}};
} }
// TODO: implementation of user-defined type functions goes here for (auto typeParam : typeParams)
{
TypeId ty = follow(typeParam);
return {std::nullopt, true, {}, {}}; // block if we need to
if (isPending(ty, ctx->solver))
return {std::nullopt, false, {ty}, {}};
}
AstName name = *ctx->userFuncName;
AstExprFunction* function = *ctx->userFuncBody;
// Construct ParseResult containing the type function
Allocator allocator;
AstNameTable names(allocator);
AstExprGlobal globalName{Location{}, name};
AstStatFunction typeFunction{Location{}, &globalName, function};
AstStat* stmtArray[] = {&typeFunction};
AstArray<AstStat*> stmts{stmtArray, 1};
AstStatBlock exec{Location{}, stmts};
ParseResult parseResult{&exec, 1};
BytecodeBuilder builder;
try
{
compileOrThrow(builder, parseResult, names);
}
catch (CompileError& e)
{
std::string errMsg = format("'%s' type function failed to compile with error message: %s", name.value, e.what());
return {std::nullopt, true, {}, {}, errMsg};
}
std::string bytecode = builder.getBytecode();
// Initialize Lua state
StateRef globalState(lua_newstate(typeFunctionAlloc, nullptr), lua_close);
lua_State* L = globalState.get();
lua_setthreaddata(L, ctx.get());
setTypeFunctionEnvironment(L);
// Register type userdata
registerTypeUserData(L);
luaL_sandbox(L);
luaL_sandboxthread(L);
// Load bytecode into Luau state
if (auto error = checkResultForError(L, name.value, luau_load(L, name.value, bytecode.data(), bytecode.size(), 0)))
return {std::nullopt, true, {}, {}, error};
// Execute the loaded chunk to register the function in the global environment
if (auto error = checkResultForError(L, name.value, lua_pcall(L, 0, 0, 0)))
return {std::nullopt, true, {}, {}, error};
// Get type function from the global environment
lua_getglobal(L, name.value);
if (!lua_isfunction(L, -1))
{
std::string errMsg = format("Could not find '%s' type function in the global scope", name.value);
return {std::nullopt, true, {}, {}, errMsg};
}
// Push serialized arguments onto the stack
// Since there aren't any new class types being created in type functions, there isn't a deserialization function
// class types. Instead, we can keep this map and return the mapping as the "deserialized value"
std::unique_ptr<TypeFunctionRuntimeBuilderState> runtimeBuilder = std::make_unique<TypeFunctionRuntimeBuilderState>(ctx);
for (auto typeParam : typeParams)
{
TypeId ty = follow(typeParam);
// This is checked at the top of the function, and should still be true.
LUAU_ASSERT(!isPending(ty, ctx->solver));
TypeFunctionTypeId serializedTy = serialize(ty, runtimeBuilder.get());
// Check if there were any errors while serializing
if (runtimeBuilder->errors.size() != 0)
return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()};
allocTypeUserData(L, serializedTy->type);
}
// Set up an interrupt handler for type functions to respect type checking limits and LSP cancellation requests.
lua_callbacks(L)->interrupt = [](lua_State* L, int gc)
{
auto ctx = static_cast<const TypeFunctionContext*>(lua_getthreaddata(lua_mainthread(L)));
if (ctx->limits->finishTime && TimeTrace::getClock() > *ctx->limits->finishTime)
ctx->solver->throwTimeLimitError();
if (ctx->limits->cancellationToken && ctx->limits->cancellationToken->requested())
ctx->solver->throwUserCancelError();
};
if (auto error = checkResultForError(L, name.value, lua_resume(L, nullptr, int(typeParams.size()))))
return {std::nullopt, true, {}, {}, error};
// If the return value is not a type userdata, return with error message
if (!isTypeUserData(L, 1))
return {std::nullopt, true, {}, {}, format("'%s' type function: returned a non-type value", name.value)};
TypeFunctionTypeId retTypeFunctionTypeId = getTypeUserData(L, 1);
// No errors should be present here since we should've returned already if any were raised during serialization.
LUAU_ASSERT(runtimeBuilder->errors.size() == 0);
TypeId retTypeId = deserialize(retTypeFunctionTypeId, runtimeBuilder.get());
// At least 1 error occured while deserializing
if (runtimeBuilder->errors.size() > 0)
return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()};
return {retTypeId, false, {}, {}};
} }
TypeFunctionReductionResult<TypeId> notTypeFunction( TypeFunctionReductionResult<TypeId> notTypeFunction(
@ -711,7 +840,7 @@ TypeFunctionReductionResult<TypeId> lenTypeFunction(
if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes))
return {std::nullopt, true, {}, {}}; // occurs check failed return {std::nullopt, true, {}, {}}; // occurs check failed
Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice};
if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance?
return {std::nullopt, true, {}, {}}; return {std::nullopt, true, {}, {}};
@ -808,7 +937,7 @@ TypeFunctionReductionResult<TypeId> unmTypeFunction(
if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes))
return {std::nullopt, true, {}, {}}; // occurs check failed return {std::nullopt, true, {}, {}}; // occurs check failed
Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice};
if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance?
return {std::nullopt, true, {}, {}}; return {std::nullopt, true, {}, {}};
@ -818,7 +947,20 @@ TypeFunctionReductionResult<TypeId> unmTypeFunction(
return {std::nullopt, true, {}, {}}; return {std::nullopt, true, {}, {}};
} }
NotNull<Constraint> TypeFunctionContext::pushConstraint(ConstraintV&& c) TypeFunctionContext::TypeFunctionContext(NotNull<ConstraintSolver> cs, NotNull<Scope> scope, NotNull<const Constraint> constraint)
: arena(cs->arena)
, builtins(cs->builtinTypes)
, scope(scope)
, normalizer(cs->normalizer)
, typeFunctionRuntime(cs->typeFunctionRuntime)
, ice(NotNull{&cs->iceReporter})
, limits(NotNull{&cs->limits})
, solver(cs.get())
, constraint(constraint.get())
{
}
NotNull<Constraint> TypeFunctionContext::pushConstraint(ConstraintV&& c) const
{ {
LUAU_ASSERT(solver); LUAU_ASSERT(solver);
NotNull<Constraint> newConstraint = solver->pushConstraint(scope, constraint ? constraint->location : Location{}, std::move(c)); NotNull<Constraint> newConstraint = solver->pushConstraint(scope, constraint ? constraint->location : Location{}, std::move(c));
@ -921,12 +1063,16 @@ TypeFunctionReductionResult<TypeId> numericBinopTypeFunction(
SolveResult solveResult; SolveResult solveResult;
if (!reversed) if (!reversed)
solveResult = solveFunctionCall(ctx->arena, ctx->builtins, ctx->normalizer, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack); solveResult = solveFunctionCall(
ctx->arena, ctx->builtins, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack
);
else else
{ {
TypePack* p = getMutable<TypePack>(argPack); TypePack* p = getMutable<TypePack>(argPack);
std::swap(p->head.front(), p->head.back()); std::swap(p->head.front(), p->head.back());
solveResult = solveFunctionCall(ctx->arena, ctx->builtins, ctx->normalizer, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack); solveResult = solveFunctionCall(
ctx->arena, ctx->builtins, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack
);
} }
if (!solveResult.typePackId.has_value()) if (!solveResult.typePackId.has_value())
@ -1156,7 +1302,7 @@ TypeFunctionReductionResult<TypeId> concatTypeFunction(
if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes))
return {std::nullopt, true, {}, {}}; // occurs check failed return {std::nullopt, true, {}, {}}; // occurs check failed
Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice};
if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance?
return {std::nullopt, true, {}, {}}; return {std::nullopt, true, {}, {}};
@ -1410,7 +1556,7 @@ static TypeFunctionReductionResult<TypeId> comparisonTypeFunction(
if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes))
return {std::nullopt, true, {}, {}}; // occurs check failed return {std::nullopt, true, {}, {}}; // occurs check failed
Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice};
if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance?
return {std::nullopt, true, {}, {}}; return {std::nullopt, true, {}, {}};
@ -1554,7 +1700,7 @@ TypeFunctionReductionResult<TypeId> eqTypeFunction(
if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes))
return {std::nullopt, true, {}, {}}; // occurs check failed return {std::nullopt, true, {}, {}}; // occurs check failed
Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice};
if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance?
return {std::nullopt, true, {}, {}}; return {std::nullopt, true, {}, {}};
@ -2004,7 +2150,7 @@ TypeFunctionReductionResult<TypeId> keyofFunctionImpl(
if (!computeKeysOf(*classesIter, localKeys, seen, isRaw, ctx)) if (!computeKeysOf(*classesIter, localKeys, seen, isRaw, ctx))
continue; continue;
for (auto key : keys) for (auto& key : keys)
{ {
// remove any keys that are not present in each class // remove any keys that are not present in each class
if (!localKeys.contains(key)) if (!localKeys.contains(key))
@ -2039,7 +2185,7 @@ TypeFunctionReductionResult<TypeId> keyofFunctionImpl(
if (!computeKeysOf(*tablesIter, localKeys, seen, isRaw, ctx)) if (!computeKeysOf(*tablesIter, localKeys, seen, isRaw, ctx))
continue; continue;
for (auto key : keys) for (auto& key : keys)
{ {
// remove any keys that are not present in each table // remove any keys that are not present in each table
if (!localKeys.contains(key)) if (!localKeys.contains(key))
@ -2239,7 +2385,7 @@ TypeFunctionReductionResult<TypeId> indexFunctionImpl(
return {std::nullopt, true, {}, {}}; return {std::nullopt, true, {}, {}};
// indexer can be a union —> break them down into a vector // indexer can be a union —> break them down into a vector
const std::vector<TypeId>* typesToFind; const std::vector<TypeId>* typesToFind = nullptr;
const std::vector<TypeId> singleType{indexerTy}; const std::vector<TypeId> singleType{indexerTy};
if (auto unionTy = get<UnionType>(indexerTy)) if (auto unionTy = get<UnionType>(indexerTy))
typesToFind = &unionTy->options; typesToFind = &unionTy->options;

View file

@ -3,6 +3,7 @@
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/Normalize.h" #include "Luau/Normalize.h"
#include "Luau/ToString.h"
#include "Luau/TypeFunction.h" #include "Luau/TypeFunction.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,788 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypeFunctionRuntimeBuilder.h"
#include "Luau/Ast.h"
#include "Luau/BuiltinDefinitions.h"
#include "Luau/Common.h"
#include "Luau/DenseHash.h"
#include "Luau/StringUtils.h"
#include "Luau/Type.h"
#include "Luau/TypeArena.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypeFunctionRuntime.h"
#include "Luau/TypePack.h"
#include "Luau/ToString.h"
#include <optional>
// used to control the recursion limit of any operations done by user-defined type functions
// currently, controls serialization, deserialization, and `type.copy`
LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFunctionSerdeIterationLimit, 100'000);
namespace Luau
{
// Forked version of Clone.cpp
class TypeFunctionSerializer
{
using SeenTypes = DenseHashMap<TypeId, TypeFunctionTypeId>;
using SeenTypePacks = DenseHashMap<TypePackId, TypeFunctionTypePackId>;
TypeFunctionRuntimeBuilderState* state = nullptr;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
// A queue of TypeFunctionTypeIds that have been serialized, but whose interior types hasn't
// been updated to point to itself. Once all of its interior types
// has been updated, it gets removed from the queue.
// queue.back() should always return two of same type in their respective sides
// For example `auto [first, second] = queue.back()`: if first is PrimitiveType,
// second must be TypeFunctionPrimitiveType; else there should be an error
std::vector<std::tuple<Kind, TypeFunctionKind>> queue;
SeenTypes types; // Mapping of TypeIds that have been shallow serialized to TypeFunctionTypeIds
SeenTypePacks packs; // Mapping of TypePackIds that have been shallow serialized to TypeFunctionTypePackIds
int steps = 0;
public:
explicit TypeFunctionSerializer(TypeFunctionRuntimeBuilderState* state)
: state(state)
, typeFunctionRuntime(state->ctx->typeFunctionRuntime)
, queue({})
, types({})
, packs({})
{
}
TypeFunctionTypeId serialize(TypeId ty)
{
shallowSerialize(ty);
run();
if (hasExceededIterationLimit() || state->errors.size() != 0)
return nullptr;
return find(ty).value_or(nullptr);
}
TypeFunctionTypePackId serialize(TypePackId tp)
{
shallowSerialize(tp);
run();
if (hasExceededIterationLimit() || state->errors.size() != 0)
return nullptr;
return find(tp).value_or(nullptr);
}
private:
bool hasExceededIterationLimit() const
{
if (DFInt::LuauTypeFunctionSerdeIterationLimit == 0)
return false;
return steps + queue.size() >= size_t(DFInt::LuauTypeFunctionSerdeIterationLimit);
}
void run()
{
while (!queue.empty())
{
++steps;
if (hasExceededIterationLimit() || state->errors.size() != 0)
break;
auto [ty, tfti] = queue.back();
queue.pop_back();
serializeChildren(ty, tfti);
}
}
std::optional<TypeFunctionTypeId> find(TypeId ty) const
{
if (auto result = types.find(ty))
return *result;
return std::nullopt;
}
std::optional<TypeFunctionTypePackId> find(TypePackId tp) const
{
if (auto result = packs.find(tp))
return *result;
return std::nullopt;
}
std::optional<TypeFunctionKind> find(Kind kind) const
{
if (auto ty = get<TypeId>(kind))
return find(*ty);
else if (auto tp = get<TypePackId>(kind))
return find(*tp);
else
{
LUAU_ASSERT(!"Unknown kind found at TypeFunctionRuntimeSerializer");
return std::nullopt;
}
}
TypeFunctionTypeId shallowSerialize(TypeId ty)
{
ty = follow(ty);
if (auto it = find(ty))
return *it;
// Create a shallow serialization
TypeFunctionTypeId target = {};
if (auto p = get<PrimitiveType>(ty))
{
switch (p->type)
{
case PrimitiveType::Type::NilType:
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::NilType));
break;
case PrimitiveType::Type::Boolean:
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Boolean));
break;
case PrimitiveType::Number:
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Number));
break;
case PrimitiveType::String:
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String));
break;
case PrimitiveType::Thread:
case PrimitiveType::Function:
case PrimitiveType::Table:
case PrimitiveType::Buffer:
default:
{
std::string error = format("Argument of primitive type %s is not currently serializable by type functions", toString(ty).c_str());
state->errors.push_back(error);
}
}
}
else if (auto u = get<UnknownType>(ty))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnknownType{});
else if (auto a = get<NeverType>(ty))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNeverType{});
else if (auto a = get<AnyType>(ty))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionAnyType{});
else if (auto s = get<SingletonType>(ty))
{
if (auto bs = get<BooleanSingleton>(s))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionBooleanSingleton{bs->value}});
else if (auto ss = get<StringSingleton>(s))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionStringSingleton{ss->value}});
else
{
std::string error = format("Argument of singleton type %s is not currently serializable by type functions", toString(ty).c_str());
state->errors.push_back(error);
}
}
else if (auto u = get<UnionType>(ty))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnionType{{}});
else if (auto i = get<IntersectionType>(ty))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionIntersectionType{{}});
else if (auto n = get<NegationType>(ty))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNegationType{{}});
else if (auto t = get<TableType>(ty))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{{}, std::nullopt, std::nullopt});
else if (auto m = get<MetatableType>(ty))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{{}, std::nullopt, std::nullopt});
else if (auto f = get<FunctionType>(ty))
{
TypeFunctionTypePackId emptyTypePack = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{});
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionFunctionType{emptyTypePack, emptyTypePack});
}
else if (auto c = get<ClassType>(ty))
{
state->classesSerialized[c->name] = ty;
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionClassType{{}, std::nullopt, std::nullopt, std::nullopt, c->name});
}
else
{
std::string error = format("Argument of type %s is not currently serializable by type functions", toString(ty).c_str());
state->errors.push_back(error);
}
types[ty] = target;
queue.emplace_back(ty, target);
return target;
}
TypeFunctionTypePackId shallowSerialize(TypePackId tp)
{
tp = follow(tp);
if (auto it = find(tp))
return *it;
// Create a shallow serialization
TypeFunctionTypePackId target = {};
if (auto tPack = get<TypePack>(tp))
target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}});
else if (auto vPack = get<VariadicTypePack>(tp))
target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{});
else
{
std::string error = format("Argument of type pack %s is not currently serializable by type functions", toString(tp).c_str());
state->errors.push_back(error);
}
packs[tp] = target;
queue.emplace_back(tp, target);
return target;
}
void serializeChildren(TypeId ty, TypeFunctionTypeId tfti)
{
if (auto [p1, p2] = std::tuple{getMutable<PrimitiveType>(ty), getMutable<TypeFunctionPrimitiveType>(tfti)}; p1 && p2)
serializeChildren(p1, p2);
else if (auto [u1, u2] = std::tuple{getMutable<UnknownType>(ty), getMutable<TypeFunctionUnknownType>(tfti)}; u1 && u2)
serializeChildren(u1, u2);
else if (auto [n1, n2] = std::tuple{getMutable<NeverType>(ty), getMutable<TypeFunctionNeverType>(tfti)}; n1 && n2)
serializeChildren(n1, n2);
else if (auto [a1, a2] = std::tuple{getMutable<AnyType>(ty), getMutable<TypeFunctionAnyType>(tfti)}; a1 && a2)
serializeChildren(a1, a2);
else if (auto [s1, s2] = std::tuple{getMutable<SingletonType>(ty), getMutable<TypeFunctionSingletonType>(tfti)}; s1 && s2)
serializeChildren(s1, s2);
else if (auto [u1, u2] = std::tuple{getMutable<UnionType>(ty), getMutable<TypeFunctionUnionType>(tfti)}; u1 && u2)
serializeChildren(u1, u2);
else if (auto [i1, i2] = std::tuple{getMutable<IntersectionType>(ty), getMutable<TypeFunctionIntersectionType>(tfti)}; i1 && i2)
serializeChildren(i1, i2);
else if (auto [n1, n2] = std::tuple{getMutable<NegationType>(ty), getMutable<TypeFunctionNegationType>(tfti)}; n1 && n2)
serializeChildren(n1, n2);
else if (auto [t1, t2] = std::tuple{getMutable<TableType>(ty), getMutable<TypeFunctionTableType>(tfti)}; t1 && t2)
serializeChildren(t1, t2);
else if (auto [m1, m2] = std::tuple{getMutable<MetatableType>(ty), getMutable<TypeFunctionTableType>(tfti)}; m1 && m2)
serializeChildren(m1, m2);
else if (auto [f1, f2] = std::tuple{getMutable<FunctionType>(ty), getMutable<TypeFunctionFunctionType>(tfti)}; f1 && f2)
serializeChildren(f1, f2);
else if (auto [c1, c2] = std::tuple{getMutable<ClassType>(ty), getMutable<TypeFunctionClassType>(tfti)}; c1 && c2)
serializeChildren(c1, c2);
else
{ // Either this or ty and tfti do not represent the same type
std::string error = format("Argument of type %s is not currently serializable by type functions", toString(ty).c_str());
state->errors.push_back(error);
}
}
void serializeChildren(TypePackId tp, TypeFunctionTypePackId tftp)
{
if (auto [tPack1, tPack2] = std::tuple{getMutable<TypePack>(tp), getMutable<TypeFunctionTypePack>(tftp)}; tPack1 && tPack2)
serializeChildren(tPack1, tPack2);
else if (auto [vPack1, vPack2] = std::tuple{getMutable<VariadicTypePack>(tp), getMutable<TypeFunctionVariadicTypePack>(tftp)};
vPack1 && vPack2)
serializeChildren(vPack1, vPack2);
else
{ // Either this or ty and tfti do not represent the same type
std::string error = format("Argument of type pack %s is not currently serializable by type functions", toString(tp).c_str());
state->errors.push_back(error);
}
}
void serializeChildren(Kind kind, TypeFunctionKind tfkind)
{
if (auto [ty, tfty] = std::tuple{get<TypeId>(kind), get<TypeFunctionTypeId>(tfkind)}; ty && tfty)
serializeChildren(*ty, *tfty);
else if (auto [tp, tftp] = std::tuple{get<TypePackId>(kind), get<TypeFunctionTypePackId>(tfkind)}; tp && tftp)
serializeChildren(*tp, *tftp);
else
state->ctx->ice->ice("Serializing user defined type function arguments: kind and tfkind do not represent the same type");
}
void serializeChildren(PrimitiveType* p1, TypeFunctionPrimitiveType* p2)
{
// noop.
}
void serializeChildren(UnknownType* u1, TypeFunctionUnknownType* u2)
{
// noop.
}
void serializeChildren(NeverType* n1, TypeFunctionNeverType* n2)
{
// noop.
}
void serializeChildren(AnyType* a1, TypeFunctionAnyType* a2)
{
// noop.
}
void serializeChildren(SingletonType* s1, TypeFunctionSingletonType* s2)
{
// noop.
}
void serializeChildren(UnionType* u1, TypeFunctionUnionType* u2)
{
for (TypeId& ty : u1->options)
u2->components.push_back(shallowSerialize(ty));
}
void serializeChildren(IntersectionType* i1, TypeFunctionIntersectionType* i2)
{
for (TypeId& ty : i1->parts)
i2->components.push_back(shallowSerialize(ty));
}
void serializeChildren(NegationType* n1, TypeFunctionNegationType* n2)
{
n2->type = shallowSerialize(n1->ty);
}
void serializeChildren(TableType* t1, TypeFunctionTableType* t2)
{
for (const auto& [k, p] : t1->props)
{
std::optional<TypeFunctionTypeId> readTy = std::nullopt;
if (p.readTy)
readTy = shallowSerialize(*p.readTy);
std::optional<TypeFunctionTypeId> writeTy = std::nullopt;
if (p.writeTy)
writeTy = shallowSerialize(*p.writeTy);
t2->props[k] = TypeFunctionProperty{readTy, writeTy};
}
if (t1->indexer)
t2->indexer = TypeFunctionTableIndexer(shallowSerialize(t1->indexer->indexType), shallowSerialize(t1->indexer->indexResultType));
}
void serializeChildren(MetatableType* m1, TypeFunctionTableType* m2)
{
auto tmpTable = get<TypeFunctionTableType>(shallowSerialize(m1->table));
if (!tmpTable)
state->ctx->ice->ice("Serializing user defined type function arguments: metatable's table is not a TableType");
m2->props = tmpTable->props;
m2->indexer = tmpTable->indexer;
m2->metatable = shallowSerialize(m1->metatable);
}
void serializeChildren(FunctionType* f1, TypeFunctionFunctionType* f2)
{
f2->argTypes = shallowSerialize(f1->argTypes);
f2->retTypes = shallowSerialize(f1->retTypes);
}
void serializeChildren(ClassType* c1, TypeFunctionClassType* c2)
{
for (const auto& [k, p] : c1->props)
{
std::optional<TypeFunctionTypeId> readTy = std::nullopt;
if (p.readTy)
readTy = shallowSerialize(*p.readTy);
std::optional<TypeFunctionTypeId> writeTy = std::nullopt;
if (p.writeTy)
writeTy = shallowSerialize(*p.writeTy);
c2->props[k] = TypeFunctionProperty{readTy, writeTy};
}
if (c1->indexer)
c2->indexer = TypeFunctionTableIndexer(shallowSerialize(c1->indexer->indexType), shallowSerialize(c1->indexer->indexResultType));
if (c1->metatable)
c2->metatable = shallowSerialize(*c1->metatable);
if (c1->parent)
c2->parent = shallowSerialize(*c1->parent);
}
void serializeChildren(TypePack* t1, TypeFunctionTypePack* t2)
{
for (TypeId& ty : t1->head)
t2->head.push_back(shallowSerialize(ty));
if (t1->tail.has_value())
t2->tail = shallowSerialize(*t1->tail);
}
void serializeChildren(VariadicTypePack* v1, TypeFunctionVariadicTypePack* v2)
{
v2->type = shallowSerialize(v1->ty);
}
};
// Complete inverse of TypeFunctionSerializer
class TypeFunctionDeserializer
{
using SeenTypes = DenseHashMap<TypeFunctionTypeId, TypeId>;
using SeenTypePacks = DenseHashMap<TypeFunctionTypePackId, TypePackId>;
TypeFunctionRuntimeBuilderState* state = nullptr;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
// A queue of TypeIds that have been deserialized, but whose interior types hasn't
// been updated to point to itself. Once all of its interior types
// has been updated, it gets removed from the queue.
// queue.back() should always return two of same type in their respective sides
// For example `auto [first, second] = queue.back()`: if first is TypeFunctionPrimitiveType,
// second must be PrimitiveType; else there should be an error
std::vector<std::tuple<TypeFunctionKind, Kind>> queue;
SeenTypes types; // Mapping of TypeFunctionTypeIds that have been shallow deserialized to TypeIds
SeenTypePacks packs; // Mapping of TypeFunctionTypePackIds that have been shallow deserialized to TypePackIds
int steps = 0;
public:
explicit TypeFunctionDeserializer(TypeFunctionRuntimeBuilderState* state)
: state(state)
, typeFunctionRuntime(state->ctx->typeFunctionRuntime)
, queue({})
, types({})
, packs({}){};
TypeId deserialize(TypeFunctionTypeId ty)
{
shallowDeserialize(ty);
run();
if (hasExceededIterationLimit() || state->errors.size() != 0)
{
TypeId error = state->ctx->builtins->errorRecoveryType();
types[ty] = error;
return error;
}
return find(ty).value_or(state->ctx->builtins->errorRecoveryType());
}
TypePackId deserialize(TypeFunctionTypePackId tp)
{
shallowDeserialize(tp);
run();
if (hasExceededIterationLimit() || state->errors.size() != 0)
{
TypePackId error = state->ctx->builtins->errorRecoveryTypePack();
packs[tp] = error;
return error;
}
return find(tp).value_or(state->ctx->builtins->errorRecoveryTypePack());
}
private:
bool hasExceededIterationLimit() const
{
if (DFInt::LuauTypeFunctionSerdeIterationLimit == 0)
return false;
return steps + queue.size() >= size_t(DFInt::LuauTypeFunctionSerdeIterationLimit);
}
void run()
{
while (!queue.empty())
{
++steps;
if (hasExceededIterationLimit() || state->errors.size() != 0)
break;
auto [tfti, ty] = queue.back();
queue.pop_back();
deserializeChildren(tfti, ty);
}
}
std::optional<TypeId> find(TypeFunctionTypeId ty) const
{
if (auto result = types.find(ty))
return *result;
return std::nullopt;
}
std::optional<TypePackId> find(TypeFunctionTypePackId tp) const
{
if (auto result = packs.find(tp))
return *result;
return std::nullopt;
}
std::optional<Kind> find(TypeFunctionKind kind) const
{
if (auto ty = get<TypeFunctionTypeId>(kind))
return find(*ty);
else if (auto tp = get<TypeFunctionTypePackId>(kind))
return find(*tp);
else
{
LUAU_ASSERT(!"Unknown kind found at TypeFunctionDeserializer");
return std::nullopt;
}
}
TypeId shallowDeserialize(TypeFunctionTypeId ty)
{
if (auto it = find(ty))
return *it;
// Create a shallow deserialization
TypeId target = {};
if (auto p = get<TypeFunctionPrimitiveType>(ty))
{
switch (p->type)
{
case TypeFunctionPrimitiveType::Type::NilType:
target = state->ctx->builtins->nilType;
break;
case TypeFunctionPrimitiveType::Type::Boolean:
target = state->ctx->builtins->booleanType;
break;
case TypeFunctionPrimitiveType::Type::Number:
target = state->ctx->builtins->numberType;
break;
case TypeFunctionPrimitiveType::Type::String:
target = state->ctx->builtins->stringType;
break;
default:
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
}
}
else if (auto u = get<TypeFunctionUnknownType>(ty))
target = state->ctx->builtins->unknownType;
else if (auto n = get<TypeFunctionNeverType>(ty))
target = state->ctx->builtins->neverType;
else if (auto a = get<TypeFunctionAnyType>(ty))
target = state->ctx->builtins->anyType;
else if (auto s = get<TypeFunctionSingletonType>(ty))
{
if (auto bs = get<TypeFunctionBooleanSingleton>(s))
target = state->ctx->arena->addType(SingletonType{BooleanSingleton{bs->value}});
else if (auto ss = get<TypeFunctionStringSingleton>(s))
target = state->ctx->arena->addType(SingletonType{StringSingleton{ss->value}});
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
}
else if (auto u = get<TypeFunctionUnionType>(ty))
target = state->ctx->arena->addTV(Type(UnionType{{}}));
else if (auto i = get<TypeFunctionIntersectionType>(ty))
target = state->ctx->arena->addTV(Type(IntersectionType{{}}));
else if (auto n = get<TypeFunctionNegationType>(ty))
target = state->ctx->arena->addType(NegationType{state->ctx->builtins->unknownType});
else if (auto t = get<TypeFunctionTableType>(ty); t && !t->metatable.has_value())
target = state->ctx->arena->addType(TableType{TableType::Props{}, std::nullopt, TypeLevel{}, TableState::Sealed});
else if (auto m = get<TypeFunctionTableType>(ty); m && m->metatable.has_value())
{
TypeId emptyTable = state->ctx->arena->addType(TableType{TableType::Props{}, std::nullopt, TypeLevel{}, TableState::Sealed});
target = state->ctx->arena->addType(MetatableType{emptyTable, emptyTable});
}
else if (auto f = get<TypeFunctionFunctionType>(ty))
{
TypePackId emptyTypePack = state->ctx->arena->addTypePack(TypePack{});
target = state->ctx->arena->addType(FunctionType{emptyTypePack, emptyTypePack, {}, false});
}
else if (auto c = get<TypeFunctionClassType>(ty))
{
if (auto result = state->classesSerialized.find(c->name))
target = *result;
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious class type is being deserialized");
}
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
types[ty] = target;
queue.emplace_back(ty, target);
return target;
}
TypePackId shallowDeserialize(TypeFunctionTypePackId tp)
{
if (auto it = find(tp))
return *it;
// Create a shallow deserialization
TypePackId target = {};
if (auto tPack = get<TypeFunctionTypePack>(tp))
target = state->ctx->arena->addTypePack(TypePack{});
else if (auto vPack = get<TypeFunctionVariadicTypePack>(tp))
target = state->ctx->arena->addTypePack(VariadicTypePack{});
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
packs[tp] = target;
queue.emplace_back(tp, target);
return target;
}
void deserializeChildren(TypeFunctionTypeId tfti, TypeId ty)
{
if (auto [p1, p2] = std::tuple{getMutable<PrimitiveType>(ty), getMutable<TypeFunctionPrimitiveType>(tfti)}; p1 && p2)
deserializeChildren(p2, p1);
else if (auto [u1, u2] = std::tuple{getMutable<UnknownType>(ty), getMutable<TypeFunctionUnknownType>(tfti)}; u1 && u2)
deserializeChildren(u2, u1);
else if (auto [n1, n2] = std::tuple{getMutable<NeverType>(ty), getMutable<TypeFunctionNeverType>(tfti)}; n1 && n2)
deserializeChildren(n2, n1);
else if (auto [a1, a2] = std::tuple{getMutable<AnyType>(ty), getMutable<TypeFunctionAnyType>(tfti)}; a1 && a2)
deserializeChildren(a2, a1);
else if (auto [s1, s2] = std::tuple{getMutable<SingletonType>(ty), getMutable<TypeFunctionSingletonType>(tfti)}; s1 && s2)
deserializeChildren(s2, s1);
else if (auto [u1, u2] = std::tuple{getMutable<UnionType>(ty), getMutable<TypeFunctionUnionType>(tfti)}; u1 && u2)
deserializeChildren(u2, u1);
else if (auto [i1, i2] = std::tuple{getMutable<IntersectionType>(ty), getMutable<TypeFunctionIntersectionType>(tfti)}; i1 && i2)
deserializeChildren(i2, i1);
else if (auto [n1, n2] = std::tuple{getMutable<NegationType>(ty), getMutable<TypeFunctionNegationType>(tfti)}; n1 && n2)
deserializeChildren(n2, n1);
else if (auto [t1, t2] = std::tuple{getMutable<TableType>(ty), getMutable<TypeFunctionTableType>(tfti)};
t1 && t2 && !t2->metatable.has_value())
deserializeChildren(t2, t1);
else if (auto [m1, m2] = std::tuple{getMutable<MetatableType>(ty), getMutable<TypeFunctionTableType>(tfti)};
m1 && m2 && m2->metatable.has_value())
deserializeChildren(m2, m1);
else if (auto [f1, f2] = std::tuple{getMutable<FunctionType>(ty), getMutable<TypeFunctionFunctionType>(tfti)}; f1 && f2)
deserializeChildren(f2, f1);
else if (auto [c1, c2] = std::tuple{getMutable<ClassType>(ty), getMutable<TypeFunctionClassType>(tfti)}; c1 && c2)
deserializeChildren(c2, c1);
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
}
void deserializeChildren(TypeFunctionTypePackId tftp, TypePackId tp)
{
if (auto [tPack1, tPack2] = std::tuple{getMutable<TypePack>(tp), getMutable<TypeFunctionTypePack>(tftp)}; tPack1 && tPack2)
deserializeChildren(tPack2, tPack1);
else if (auto [vPack1, vPack2] = std::tuple{getMutable<VariadicTypePack>(tp), getMutable<TypeFunctionVariadicTypePack>(tftp)};
vPack1 && vPack2)
deserializeChildren(vPack2, vPack1);
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
}
void deserializeChildren(TypeFunctionKind tfkind, Kind kind)
{
if (auto [ty, tfty] = std::tuple{get<TypeId>(kind), get<TypeFunctionTypeId>(tfkind)}; ty && tfty)
deserializeChildren(*tfty, *ty);
else if (auto [tp, tftp] = std::tuple{get<TypePackId>(kind), get<TypeFunctionTypePackId>(tfkind)}; tp && tftp)
deserializeChildren(*tftp, *tp);
else
state->ctx->ice->ice("Deserializing user defined type function arguments: tfkind and kind do not represent the same type");
}
void deserializeChildren(TypeFunctionPrimitiveType* p2, PrimitiveType* p1)
{
// noop.
}
void deserializeChildren(TypeFunctionUnknownType* u2, UnknownType* u1)
{
// noop.
}
void deserializeChildren(TypeFunctionNeverType* n2, NeverType* n1)
{
// noop.
}
void deserializeChildren(TypeFunctionAnyType* a2, AnyType* a1)
{
// noop.
}
void deserializeChildren(TypeFunctionSingletonType* s2, SingletonType* s1)
{
// noop.
}
void deserializeChildren(TypeFunctionUnionType* u2, UnionType* u1)
{
for (TypeFunctionTypeId& ty : u2->components)
u1->options.push_back(shallowDeserialize(ty));
}
void deserializeChildren(TypeFunctionIntersectionType* i2, IntersectionType* i1)
{
for (TypeFunctionTypeId& ty : i2->components)
i1->parts.push_back(shallowDeserialize(ty));
}
void deserializeChildren(TypeFunctionNegationType* n2, NegationType* n1)
{
n1->ty = shallowDeserialize(n2->type);
}
void deserializeChildren(TypeFunctionTableType* t2, TableType* t1)
{
for (const auto& [k, p] : t2->props)
{
if (p.readTy && p.writeTy)
t1->props[k] = Property::rw(shallowDeserialize(*p.readTy), shallowDeserialize(*p.writeTy));
else if (p.readTy)
t1->props[k] = Property::readonly(shallowDeserialize(*p.readTy));
else if (p.writeTy)
t1->props[k] = Property::writeonly(shallowDeserialize(*p.writeTy));
}
if (t2->indexer.has_value())
t1->indexer = TableIndexer(shallowDeserialize(t2->indexer->keyType), shallowDeserialize(t2->indexer->valueType));
}
void deserializeChildren(TypeFunctionTableType* m2, MetatableType* m1)
{
TypeFunctionTypeId temp = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{m2->props, m2->indexer});
m1->table = shallowDeserialize(temp);
if (m2->metatable.has_value())
m1->metatable = shallowDeserialize(*m2->metatable);
}
void deserializeChildren(TypeFunctionFunctionType* f2, FunctionType* f1)
{
if (f2->argTypes)
f1->argTypes = shallowDeserialize(f2->argTypes);
if (f2->retTypes)
f1->retTypes = shallowDeserialize(f2->retTypes);
}
void deserializeChildren(TypeFunctionClassType* c2, ClassType* c1)
{
// noop.
}
void deserializeChildren(TypeFunctionTypePack* t2, TypePack* t1)
{
for (TypeFunctionTypeId& ty : t2->head)
t1->head.push_back(shallowDeserialize(ty));
if (t2->tail.has_value())
t1->tail = shallowDeserialize(*t2->tail);
}
void deserializeChildren(TypeFunctionVariadicTypePack* v2, VariadicTypePack* v1)
{
v1->ty = shallowDeserialize(v2->type);
}
};
TypeFunctionTypeId serialize(TypeId ty, TypeFunctionRuntimeBuilderState* state)
{
return TypeFunctionSerializer(state).serialize(ty);
}
TypeId deserialize(TypeFunctionTypeId ty, TypeFunctionRuntimeBuilderState* state)
{
return TypeFunctionDeserializer(state).deserialize(ty);
}
} // namespace Luau

View file

@ -33,7 +33,6 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false)
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false) LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false)
LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false)
LUAU_FASTFLAGVARIABLE(LuauAcceptIndexingTableUnionsIntersections, false) LUAU_FASTFLAGVARIABLE(LuauAcceptIndexingTableUnionsIntersections, false)
namespace Luau namespace Luau
@ -1284,20 +1283,11 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin)
for (size_t i = 2; i < varTypes.size(); ++i) for (size_t i = 2; i < varTypes.size(); ++i)
unify(nilType, varTypes[i], scope, forin.location); unify(nilType, varTypes[i], scope, forin.location);
} }
else if (isNonstrictMode() || FFlag::LuauOkWithIteratingOverTableProperties) else
{ {
for (TypeId var : varTypes) for (TypeId var : varTypes)
unify(unknownType, var, scope, forin.location); unify(unknownType, var, scope, forin.location);
} }
else
{
TypeId varTy = errorRecoveryType(loopScope);
for (TypeId var : varTypes)
unify(varTy, var, scope, forin.location);
reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"});
}
return check(loopScope, *forin.body); return check(loopScope, *forin.body);
} }

View file

@ -1,6 +1,11 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/Ast.h"
#include "Luau/DenseHash.h"
#include <vector>
namespace Luau namespace Luau
{ {
@ -12,10 +17,17 @@ enum class Mode
Definition, // Type definition module, has special parsing rules Definition, // Type definition module, has special parsing rules
}; };
struct FragmentParseResumeSettings
{
DenseHashMap<AstName, AstLocal*> localMap{AstName()};
std::vector<AstLocal*> localStack;
};
struct ParseOptions struct ParseOptions
{ {
bool allowDeclarationSyntax = false; bool allowDeclarationSyntax = false;
bool captureComments = false; bool captureComments = false;
std::optional<FragmentParseResumeSettings> parseFragment = std::nullopt;
}; };
} // namespace Luau } // namespace Luau

View file

@ -7,6 +7,7 @@
#include <memory> #include <memory>
#include <stdint.h> #include <stdint.h>
#include <string.h>
LUAU_FASTFLAG(DebugLuauTimeTracing) LUAU_FASTFLAG(DebugLuauTimeTracing)

View file

@ -7,8 +7,6 @@
#include <limits.h> #include <limits.h>
LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false)
namespace Luau namespace Luau
{ {
@ -434,13 +432,11 @@ Lexeme Lexer::lookahead()
lineOffset = currentLineOffset; lineOffset = currentLineOffset;
lexeme = currentLexeme; lexeme = currentLexeme;
prevLocation = currentPrevLocation; prevLocation = currentPrevLocation;
if (FFlag::LuauLexerLookaheadRemembersBraceType)
{ if (braceStack.size() < currentBraceStackSize)
if (braceStack.size() < currentBraceStackSize) braceStack.push_back(currentBraceType);
braceStack.push_back(currentBraceType); else if (braceStack.size() > currentBraceStackSize)
else if (braceStack.size() > currentBraceStackSize) braceStack.pop_back();
braceStack.pop_back();
}
return result; return result;
} }

View file

@ -19,7 +19,8 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
LUAU_FASTFLAGVARIABLE(LuauSolverV2, false) LUAU_FASTFLAGVARIABLE(LuauSolverV2, false)
LUAU_FASTFLAGVARIABLE(LuauNativeAttribute, false) LUAU_FASTFLAGVARIABLE(LuauNativeAttribute, false)
LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr, false) LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr, false)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctions, false) LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionsSyntax, false)
LUAU_FASTFLAGVARIABLE(LuauAllowFragmentParsing, false)
namespace Luau namespace Luau
{ {
@ -211,6 +212,15 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc
scratchExpr.reserve(16); scratchExpr.reserve(16);
scratchLocal.reserve(16); scratchLocal.reserve(16);
scratchBinding.reserve(16); scratchBinding.reserve(16);
if (FFlag::LuauAllowFragmentParsing)
{
if (options.parseFragment)
{
localMap = options.parseFragment->localMap;
localStack = options.parseFragment->localStack;
}
}
} }
bool Parser::blockFollow(const Lexeme& l) bool Parser::blockFollow(const Lexeme& l)
@ -891,7 +901,7 @@ AstStat* Parser::parseReturn()
AstStat* Parser::parseTypeAlias(const Location& start, bool exported) AstStat* Parser::parseTypeAlias(const Location& start, bool exported)
{ {
// parsing a type function // parsing a type function
if (FFlag::LuauUserDefinedTypeFunctions) if (FFlag::LuauUserDefinedTypeFunctionsSyntax)
{ {
if (lexer.current().type == Lexeme::ReservedFunction) if (lexer.current().type == Lexeme::ReservedFunction)
return parseTypeFunction(start); return parseTypeFunction(start);

View file

@ -3,6 +3,7 @@
#include "Luau/StringUtils.h" #include "Luau/StringUtils.h"
#include <algorithm>
#include <mutex> #include <mutex>
#include <string> #include <string>

View file

@ -85,6 +85,7 @@ target_link_libraries(Luau.Config PUBLIC Luau.Ast)
target_compile_features(Luau.Analysis PUBLIC cxx_std_17) target_compile_features(Luau.Analysis PUBLIC cxx_std_17)
target_include_directories(Luau.Analysis PUBLIC Analysis/include) target_include_directories(Luau.Analysis PUBLIC Analysis/include)
target_link_libraries(Luau.Analysis PUBLIC Luau.Ast Luau.EqSat Luau.Config) target_link_libraries(Luau.Analysis PUBLIC Luau.Ast Luau.EqSat Luau.Config)
target_link_libraries(Luau.Analysis PRIVATE Luau.Compiler Luau.VM)
target_compile_features(Luau.EqSat PUBLIC cxx_std_17) target_compile_features(Luau.EqSat PUBLIC cxx_std_17)
target_include_directories(Luau.EqSat PUBLIC EqSat/include) target_include_directories(Luau.EqSat PUBLIC EqSat/include)
@ -276,7 +277,7 @@ foreach(LIB Luau.Ast Luau.Compiler Luau.Config Luau.Analysis Luau.EqSat Luau.Cod
if(LIB MATCHES "CodeGen|VM" AND DEPENDS MATCHES "Ast|Analysis|Config|Compiler") if(LIB MATCHES "CodeGen|VM" AND DEPENDS MATCHES "Ast|Analysis|Config|Compiler")
message(FATAL_ERROR ${LIB} " is a runtime component but it depends on one of the offline components") message(FATAL_ERROR ${LIB} " is a runtime component but it depends on one of the offline components")
endif() endif()
if(LIB MATCHES "Ast|Analysis|EqSat|Compiler" AND DEPENDS MATCHES "CodeGen|VM") if(LIB MATCHES "Ast|EqSat|Compiler" AND DEPENDS MATCHES "CodeGen|VM")
message(FATAL_ERROR ${LIB} " is an offline component but it depends on one of the runtime components") message(FATAL_ERROR ${LIB} " is an offline component but it depends on one of the runtime components")
endif() endif()
if(LIB MATCHES "Ast|Compiler" AND DEPENDS MATCHES "Analysis|Config") if(LIB MATCHES "Ast|Compiler" AND DEPENDS MATCHES "Analysis|Config")

View file

@ -11,8 +11,6 @@
#include "lstate.h" #include "lstate.h"
#include "lgc.h" #include "lgc.h"
LUAU_FASTFLAGVARIABLE(LuauCodegenArmNumToVecFix, false)
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
@ -1121,7 +1119,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
else else
{ {
RegisterA64 tempd = tempDouble(inst.a); RegisterA64 tempd = tempDouble(inst.a);
RegisterA64 temps = FFlag::LuauCodegenArmNumToVecFix ? regs.allocTemp(KindA64::s) : castReg(KindA64::s, tempd); RegisterA64 temps = regs.allocTemp(KindA64::s);
build.fcvt(temps, tempd); build.fcvt(temps, tempd);
build.dup_4s(inst.regA64, castReg(KindA64::q, temps), 0); build.dup_4s(inst.regA64, castReg(KindA64::q, temps), 0);

View file

@ -142,7 +142,7 @@ endif
$(AST_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include
$(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -ICommon/include -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -ICommon/include -IAst/include
$(CONFIG_OBJECTS): CXXFLAGS+=-std=c++17 -IConfig/include -ICommon/include -IAst/include $(CONFIG_OBJECTS): CXXFLAGS+=-std=c++17 -IConfig/include -ICommon/include -IAst/include
$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include $(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include -ICompiler/include -IVM/include
$(EQSAT_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IEqSat/include $(EQSAT_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IEqSat/include
$(CODEGEN_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -ICodeGen/include -IVM/include -IVM/src # Code generation needs VM internals $(CODEGEN_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -ICodeGen/include -IVM/include -IVM/src # Code generation needs VM internals
$(VM_OBJECTS): CXXFLAGS+=-std=c++11 -ICommon/include -IVM/include $(VM_OBJECTS): CXXFLAGS+=-std=c++11 -ICommon/include -IVM/include
@ -227,7 +227,7 @@ luau-tests: $(TESTS_TARGET)
# executable targets # executable targets
$(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) $(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET)
$(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) $(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET)
$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(COMPILER_TARGET) $(VM_TARGET)
$(COMPILE_CLI_TARGET): $(COMPILE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(COMPILE_CLI_TARGET): $(COMPILE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET)
$(BYTECODE_CLI_TARGET): $(BYTECODE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(BYTECODE_CLI_TARGET): $(BYTECODE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET)

View file

@ -182,6 +182,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/Documentation.h Analysis/include/Luau/Documentation.h
Analysis/include/Luau/Error.h Analysis/include/Luau/Error.h
Analysis/include/Luau/FileResolver.h Analysis/include/Luau/FileResolver.h
Analysis/include/Luau/FragmentAutocomplete.h
Analysis/include/Luau/Frontend.h Analysis/include/Luau/Frontend.h
Analysis/include/Luau/Generalization.h Analysis/include/Luau/Generalization.h
Analysis/include/Luau/GlobalTypes.h Analysis/include/Luau/GlobalTypes.h
@ -223,6 +224,8 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/TypedAllocator.h Analysis/include/Luau/TypedAllocator.h
Analysis/include/Luau/TypeFunction.h Analysis/include/Luau/TypeFunction.h
Analysis/include/Luau/TypeFunctionReductionGuesser.h Analysis/include/Luau/TypeFunctionReductionGuesser.h
Analysis/include/Luau/TypeFunctionRuntime.h
Analysis/include/Luau/TypeFunctionRuntimeBuilder.h
Analysis/include/Luau/TypeFwd.h Analysis/include/Luau/TypeFwd.h
Analysis/include/Luau/TypeInfer.h Analysis/include/Luau/TypeInfer.h
Analysis/include/Luau/TypeOrPack.h Analysis/include/Luau/TypeOrPack.h
@ -253,6 +256,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/Differ.cpp Analysis/src/Differ.cpp
Analysis/src/EmbeddedBuiltinDefinitions.cpp Analysis/src/EmbeddedBuiltinDefinitions.cpp
Analysis/src/Error.cpp Analysis/src/Error.cpp
Analysis/src/FragmentAutocomplete.cpp
Analysis/src/Frontend.cpp Analysis/src/Frontend.cpp
Analysis/src/Generalization.cpp Analysis/src/Generalization.cpp
Analysis/src/GlobalTypes.cpp Analysis/src/GlobalTypes.cpp
@ -287,6 +291,8 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/TypedAllocator.cpp Analysis/src/TypedAllocator.cpp
Analysis/src/TypeFunction.cpp Analysis/src/TypeFunction.cpp
Analysis/src/TypeFunctionReductionGuesser.cpp Analysis/src/TypeFunctionReductionGuesser.cpp
Analysis/src/TypeFunctionRuntime.cpp
Analysis/src/TypeFunctionRuntimeBuilder.cpp
Analysis/src/TypeInfer.cpp Analysis/src/TypeInfer.cpp
Analysis/src/TypeOrPack.cpp Analysis/src/TypeOrPack.cpp
Analysis/src/TypePack.cpp Analysis/src/TypePack.cpp
@ -440,6 +446,7 @@ if(TARGET Luau.UnitTest)
tests/Error.test.cpp tests/Error.test.cpp
tests/Fixture.cpp tests/Fixture.cpp
tests/Fixture.h tests/Fixture.h
tests/FragmentAutocomplete.test.cpp
tests/Frontend.test.cpp tests/Frontend.test.cpp
tests/Generalization.test.cpp tests/Generalization.test.cpp
tests/InsertionOrderedMap.test.cpp tests/InsertionOrderedMap.test.cpp
@ -474,6 +481,7 @@ if(TARGET Luau.UnitTest)
tests/Transpiler.test.cpp tests/Transpiler.test.cpp
tests/TxnLog.test.cpp tests/TxnLog.test.cpp
tests/TypeFunction.test.cpp tests/TypeFunction.test.cpp
tests/TypeFunction.user.test.cpp
tests/TypeInfer.aliases.test.cpp tests/TypeInfer.aliases.test.cpp
tests/TypeInfer.annotations.test.cpp tests/TypeInfer.annotations.test.cpp
tests/TypeInfer.anyerror.test.cpp tests/TypeInfer.anyerror.test.cpp

View file

@ -10,8 +10,6 @@
#include <string.h> #include <string.h>
LUAU_FASTFLAGVARIABLE(LuauPreserveLudataRenaming, false)
// clang-format off // clang-format off
const char* const luaT_typenames[] = { const char* const luaT_typenames[] = {
// ORDER TYPE // ORDER TYPE
@ -124,74 +122,40 @@ const TValue* luaT_gettmbyobj(lua_State* L, const TValue* o, TMS event)
const TString* luaT_objtypenamestr(lua_State* L, const TValue* o) const TString* luaT_objtypenamestr(lua_State* L, const TValue* o)
{ {
if (FFlag::LuauPreserveLudataRenaming) // Userdata created by the environment can have a custom type name set in the individual metatable
// If there is no custom name, 'userdata' is returned
if (ttisuserdata(o) && uvalue(o)->tag != UTAG_PROXY && uvalue(o)->metatable)
{ {
// Userdata created by the environment can have a custom type name set in the individual metatable const TValue* type = luaH_getstr(uvalue(o)->metatable, L->global->tmname[TM_TYPE]);
// If there is no custom name, 'userdata' is returned
if (ttisuserdata(o) && uvalue(o)->tag != UTAG_PROXY && uvalue(o)->metatable)
{
const TValue* type = luaH_getstr(uvalue(o)->metatable, L->global->tmname[TM_TYPE]);
if (ttisstring(type)) if (ttisstring(type))
return tsvalue(type); return tsvalue(type);
return L->global->ttname[ttype(o)];
}
// Tagged lightuserdata can be named using lua_setlightuserdataname
if (ttislightuserdata(o))
{
int tag = lightuserdatatag(o);
if (unsigned(tag) < LUA_LUTAG_LIMIT)
{
if (const TString* name = L->global->lightuserdataname[tag])
return name;
}
}
// For all types except userdata and table, a global metatable can be set with a global name override
if (Table* mt = L->global->mt[ttype(o)])
{
const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]);
if (ttisstring(type))
return tsvalue(type);
}
return L->global->ttname[ttype(o)]; return L->global->ttname[ttype(o)];
} }
else
// Tagged lightuserdata can be named using lua_setlightuserdataname
if (ttislightuserdata(o))
{ {
if (ttisuserdata(o) && uvalue(o)->tag != UTAG_PROXY && uvalue(o)->metatable) int tag = lightuserdatatag(o);
if (unsigned(tag) < LUA_LUTAG_LIMIT)
{ {
const TValue* type = luaH_getstr(uvalue(o)->metatable, L->global->tmname[TM_TYPE]); if (const TString* name = L->global->lightuserdataname[tag])
return name;
if (ttisstring(type))
return tsvalue(type);
} }
else if (ttislightuserdata(o))
{
int tag = lightuserdatatag(o);
if (unsigned(tag) < LUA_LUTAG_LIMIT)
{
const TString* name = L->global->lightuserdataname[tag];
if (name)
return name;
}
}
else if (Table* mt = L->global->mt[ttype(o)])
{
const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]);
if (ttisstring(type))
return tsvalue(type);
}
return L->global->ttname[ttype(o)];
} }
// For all types except userdata and table, a global metatable can be set with a global name override
if (Table* mt = L->global->mt[ttype(o)])
{
const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]);
if (ttisstring(type))
return tsvalue(type);
}
return L->global->ttname[ttype(o)];
} }
const char* luaT_objtypename(lua_State* L, const TValue* o) const char* luaT_objtypename(lua_State* L, const TValue* o)

View file

@ -34,8 +34,6 @@ void luaC_validate(lua_State* L);
LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTFLAG(DebugLuauAbortingChecks)
LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTINT(CodegenHeuristicsInstructionLimit)
LUAU_FASTFLAG(LuauNativeAttribute) LUAU_FASTFLAG(LuauNativeAttribute)
LUAU_FASTFLAG(LuauPreserveLudataRenaming)
LUAU_FASTFLAG(LuauCodegenArmNumToVecFix)
static lua_CompileOptions defaultOptions() static lua_CompileOptions defaultOptions()
{ {
@ -825,8 +823,6 @@ TEST_CASE("Pack")
TEST_CASE("Vector") TEST_CASE("Vector")
{ {
ScopedFastFlag luauCodegenArmNumToVecFix{FFlag::LuauCodegenArmNumToVecFix, true};
lua_CompileOptions copts = defaultOptions(); lua_CompileOptions copts = defaultOptions();
Luau::CodeGen::CompilationOptions nativeOpts = defaultCodegenOptions(); Luau::CodeGen::CompilationOptions nativeOpts = defaultCodegenOptions();
@ -2251,20 +2247,17 @@ TEST_CASE("LightuserdataApi")
lua_pop(L, 1); lua_pop(L, 1);
if (FFlag::LuauPreserveLudataRenaming) // Still possible to rename the global lightuserdata name using a metatable
{ lua_pushlightuserdata(L, value);
// Still possible to rename the global lightuserdata name using a metatable CHECK(strcmp(luaL_typename(L, -1), "userdata") == 0);
lua_pushlightuserdata(L, value);
CHECK(strcmp(luaL_typename(L, -1), "userdata") == 0);
lua_createtable(L, 0, 1); lua_createtable(L, 0, 1);
lua_pushstring(L, "luserdata"); lua_pushstring(L, "luserdata");
lua_setfield(L, -2, "__type"); lua_setfield(L, -2, "__type");
lua_setmetatable(L, -2); lua_setmetatable(L, -2);
CHECK(strcmp(luaL_typename(L, -1), "luserdata") == 0); CHECK(strcmp(luaL_typename(L, -1), "luserdata") == 0);
lua_pop(L, 1); lua_pop(L, 1);
}
globalState.reset(); globalState.reset();
} }

View file

@ -42,7 +42,9 @@ void ConstraintGeneratorFixture::generateConstraints(const std::string& code)
void ConstraintGeneratorFixture::solve(const std::string& code) void ConstraintGeneratorFixture::solve(const std::string& code)
{ {
generateConstraints(code); generateConstraints(code);
ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, {}}; ConstraintSolver cs{
NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, {}
};
cs.run(); cs.run();
} }

View file

@ -20,6 +20,7 @@ struct ConstraintGeneratorFixture : Fixture
DcrLogger logger; DcrLogger logger;
UnifierSharedState sharedState{&ice}; UnifierSharedState sharedState{&ice};
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
TypeFunctionRuntime typeFunctionRuntime;
std::unique_ptr<DataFlowGraph> dfg; std::unique_ptr<DataFlowGraph> dfg;
std::unique_ptr<ConstraintGenerator> cg; std::unique_ptr<ConstraintGenerator> cg;

View file

@ -0,0 +1,139 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/FragmentAutocomplete.h"
#include "Fixture.h"
#include "Luau/Ast.h"
#include "Luau/AstQuery.h"
using namespace Luau;
struct FragmentAutocompleteFixture : Fixture
{
FragmentAutocompleteAncestryResult runAutocompleteVisitor(const std::string& source, const Position& cursorPos)
{
ParseResult p = tryParse(source); // We don't care about parsing incomplete asts
REQUIRE(p.root);
return findAncestryForFragmentParse(p.root, cursorPos);
}
};
TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTest");
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "just_two_locals")
{
auto result = runAutocompleteVisitor(
R"(
local x = 4
local y = 5
)",
{2, 11}
);
CHECK_EQ(3, result.ancestry.size());
CHECK_EQ(2, result.localStack.size());
CHECK_EQ(result.localMap.size(), result.localStack.size());
REQUIRE(result.nearestStatement);
AstStatLocal* local = result.nearestStatement->as<AstStatLocal>();
REQUIRE(local);
CHECK(1 == local->vars.size);
CHECK_EQ("y", std::string(local->vars.data[0]->name.value));
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "cursor_within_scope_tracks_locals_from_previous_scope")
{
auto result = runAutocompleteVisitor(
R"(
local x = 4
local y = 5
if x == 4 then
local e = y
end
)",
{4, 15}
);
CHECK_EQ(5, result.ancestry.size());
CHECK_EQ(3, result.localStack.size());
CHECK_EQ(result.localMap.size(), result.localStack.size());
REQUIRE(result.nearestStatement);
CHECK_EQ("e", std::string(result.localStack.back()->name.value));
AstStatLocal* local = result.nearestStatement->as<AstStatLocal>();
REQUIRE(local);
CHECK(1 == local->vars.size);
CHECK_EQ("e", std::string(local->vars.data[0]->name.value));
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "cursor_that_comes_later_shouldnt_capture_locals_in_unavailable_scope")
{
auto result = runAutocompleteVisitor(
R"(
local x = 4
local y = 5
if x == 4 then
local e = y
end
local z = x + x
if y == 5 then
local q = x + y + z
end
)",
{8, 23}
);
CHECK_EQ(6, result.ancestry.size());
CHECK_EQ(4, result.localStack.size());
CHECK_EQ(result.localMap.size(), result.localStack.size());
REQUIRE(result.nearestStatement);
CHECK_EQ("q", std::string(result.localStack.back()->name.value));
AstStatLocal* local = result.nearestStatement->as<AstStatLocal>();
REQUIRE(local);
CHECK(1 == local->vars.size);
CHECK_EQ("q", std::string(local->vars.data[0]->name.value));
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "nearest_enclosing_statement_can_be_non_local")
{
auto result = runAutocompleteVisitor(
R"(
local x = 4
local y = 5
if x == 4 then
)",
{3, 4}
);
CHECK_EQ(4, result.ancestry.size());
CHECK_EQ(2, result.localStack.size());
CHECK_EQ(result.localMap.size(), result.localStack.size());
REQUIRE(result.nearestStatement);
CHECK_EQ("y", std::string(result.localStack.back()->name.value));
AstStatIf* ifS = result.nearestStatement->as<AstStatIf>();
CHECK(ifS != nullptr);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_funcs_show_up_in_local_stack")
{
auto result = runAutocompleteVisitor(
R"(
local function foo() return 4 end
local x = foo()
local function bar() return x + foo() end
)",
{3, 32}
);
CHECK_EQ(8, result.ancestry.size());
CHECK_EQ(3, result.localStack.size());
CHECK_EQ(result.localMap.size(), result.localStack.size());
CHECK_EQ("bar", std::string(result.localStack.back()->name.value));
auto returnSt = result.nearestStatement->as<AstStatReturn>();
CHECK(returnSt != nullptr);
}
TEST_SUITE_END();

View file

@ -3,6 +3,7 @@
#include "AstQueryDsl.h" #include "AstQueryDsl.h"
#include "Fixture.h" #include "Fixture.h"
#include "Luau/Common.h"
#include "ScopedFlags.h" #include "ScopedFlags.h"
#include "doctest.h" #include "doctest.h"
@ -11,13 +12,12 @@
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauLexerLookaheadRemembersBraceType); LUAU_FASTINT(LuauRecursionLimit)
LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauTypeLengthLimit)
LUAU_FASTINT(LuauTypeLengthLimit); LUAU_FASTINT(LuauParseErrorLimit)
LUAU_FASTINT(LuauParseErrorLimit); LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr)
LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr); LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctions);
namespace namespace
{ {
@ -2380,7 +2380,7 @@ TEST_CASE_FIXTURE(Fixture, "invalid_type_forms")
TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions") TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions")
{ {
ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctions, true}; ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax, true};
AstStat* stat = parse(R"( AstStat* stat = parse(R"(
type function foo() type function foo()
@ -3138,8 +3138,6 @@ TEST_CASE_FIXTURE(Fixture, "do_block_with_no_end")
TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved") TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved")
{ {
ScopedFastFlag sff{FFlag::LuauLexerLookaheadRemembersBraceType, true};
ParseResult result = tryParse(R"( ParseResult result = tryParse(R"(
local x = `{ {y} }` local x = `{ {y} }`
)"); )");
@ -3149,8 +3147,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved")
TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved2") TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved2")
{ {
ScopedFastFlag sff{FFlag::LuauLexerLookaheadRemembersBraceType, true};
ParseResult result = tryParse(R"( ParseResult result = tryParse(R"(
local x = `{ { y{} } }` local x = `{ { y{} } }`
)"); )");

View file

@ -66,6 +66,7 @@ struct SubtypeFixture : Fixture
InternalErrorReporter iceReporter; InternalErrorReporter iceReporter;
UnifierSharedState sharedState{&ice}; UnifierSharedState sharedState{&ice};
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
TypeFunctionRuntime typeFunctionRuntime;
ScopedFastFlag sff{FFlag::LuauSolverV2, true}; ScopedFastFlag sff{FFlag::LuauSolverV2, true};
@ -77,7 +78,7 @@ struct SubtypeFixture : Fixture
Subtyping mkSubtyping() Subtyping mkSubtyping()
{ {
return Subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&iceReporter}}; return Subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}};
} }
TypePackId pack(std::initializer_list<TypeId> tys) TypePackId pack(std::initializer_list<TypeId> tys)

View file

@ -12,7 +12,7 @@
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauUserDefinedTypeFunctions); LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax)
TEST_SUITE_BEGIN("TranspilerTests"); TEST_SUITE_BEGIN("TranspilerTests");
@ -698,7 +698,7 @@ TEST_CASE_FIXTURE(Fixture, "transpile_string_literal_escape")
TEST_CASE_FIXTURE(Fixture, "transpile_type_functions") TEST_CASE_FIXTURE(Fixture, "transpile_type_functions")
{ {
ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctions, true}; ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax, true};
std::string code = R"( type function foo(arg1, arg2) if arg1 == arg2 then return arg1 end return arg2 end )"; std::string code = R"( type function foo(arg1, arg2) if arg1 == arg2 then return arg1 end return arg2 end )";

View file

@ -1247,18 +1247,4 @@ TEST_CASE_FIXTURE(ClassFixture, "rawget_type_function_errors_w_classes")
CHECK(toString(result.errors[0]) == "Property '\"BaseField\"' does not exist on type 'BaseClass'"); CHECK(toString(result.errors[0]) == "Property '\"BaseField\"' does not exist on type 'BaseClass'");
} }
TEST_CASE_FIXTURE(Fixture, "user_defined_type_function_errors")
{
if (!FFlag::LuauUserDefinedTypeFunctions)
return;
CheckResult result = check(R"(
type function foo()
return nil
end
)");
LUAU_CHECK_ERROR_COUNT(1, result);
CHECK(toString(result.errors[0]) == "This syntax is not supported");
}
TEST_SUITE_END(); TEST_SUITE_END();

File diff suppressed because it is too large Load diff

View file

@ -9,6 +9,7 @@
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctions) LUAU_FASTFLAG(LuauUserDefinedTypeFunctions)
TEST_SUITE_BEGIN("TypeAliases"); TEST_SUITE_BEGIN("TypeAliases");
@ -1169,8 +1170,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_adds_reduce_constraint_for_type_f
TEST_CASE_FIXTURE(Fixture, "user_defined_type_function_errors") TEST_CASE_FIXTURE(Fixture, "user_defined_type_function_errors")
{ {
if (!FFlag::LuauUserDefinedTypeFunctions) ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax, true};
return; ScopedFastFlag noUDTFimpl{FFlag::LuauUserDefinedTypeFunctions, false};
CheckResult result = check(R"( CheckResult result = check(R"(
type function foo() type function foo()

View file

@ -1427,4 +1427,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3")
CHECK_EQ(toString(requireType("e")), "number?"); CHECK_EQ(toString(requireType("e")), "number?");
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "string_find_should_not_crash")
{
ScopedFastFlag _{FFlag::LuauSolverV2, true};
LUAU_REQUIRE_NO_ERRORS(check(R"(
local function StringSplit(input, separator)
string.find(input, separator)
if not separator then
separator = "%s+"
end
end
)"));
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -15,7 +15,6 @@
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauOkWithIteratingOverTableProperties)
LUAU_DYNAMIC_FASTFLAG(LuauImproveNonFunctionCallError) LUAU_DYNAMIC_FASTFLAG(LuauImproveNonFunctionCallError)
@ -699,8 +698,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "loop_typecheck_crash_on_empty_optional")
if (FFlag::LuauSolverV2) if (FFlag::LuauSolverV2)
return; return;
ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true};
CheckResult result = check(R"( CheckResult result = check(R"(
local t = {} local t = {}
for _ in t do for _ in t do
@ -784,7 +781,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer_strict")
// CLI-116498 Sometimes you can iterate over tables with no indexers. // CLI-116498 Sometimes you can iterate over tables with no indexers.
ScopedFastFlag sff[] = { ScopedFastFlag sff[] = {
{FFlag::LuauSolverV2, false}, {FFlag::LuauSolverV2, false},
{FFlag::LuauOkWithIteratingOverTableProperties, true}
}; };
CheckResult result = check(R"( CheckResult result = check(R"(
@ -937,8 +933,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cli_68448_iterators_need_not_accept_nil")
TEST_CASE_FIXTURE(Fixture, "iterate_over_free_table") TEST_CASE_FIXTURE(Fixture, "iterate_over_free_table")
{ {
ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true};
CheckResult result = check(R"( CheckResult result = check(R"(
function print(x) end function print(x) end
@ -1095,8 +1089,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties")
// CLI-116498 - Sometimes you can iterate over tables with no indexer. // CLI-116498 - Sometimes you can iterate over tables with no indexer.
ScopedFastFlag sff0{FFlag::LuauSolverV2, false}; ScopedFastFlag sff0{FFlag::LuauSolverV2, false};
ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true};
CheckResult result = check(R"( CheckResult result = check(R"(
local function f() local function f()
local t = { p = 5, q = "hello" } local t = { p = 5, q = "hello" }
@ -1118,8 +1110,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties")
TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties_nonstrict") TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties_nonstrict")
{ {
ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true};
CheckResult result = check(R"( CheckResult result = check(R"(
--!nonstrict --!nonstrict
local function f() local function f()

View file

@ -530,4 +530,82 @@ return l0
CHECK(mod->scopes[3].second->importedModules["l1"] == "game/A"); CHECK(mod->scopes[3].second->importedModules["l1"] == "game/A");
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "ensure_scope_is_nullptr_after_shallow_copy")
{
ScopedFastFlag _{FFlag::LuauSolverV2, true};
frontend.options.retainFullTypeGraphs = false;
fileResolver.source["game/A"] = R"(
-- Roughly taken from ReactTypes.lua
type CoreBinding<T> = {}
type BindingMap = {}
export type Binding<T> = CoreBinding<T> & BindingMap
return {}
)";
LUAU_REQUIRE_NO_ERRORS(check(R"(
local Types = require(game.A)
type Binding<T> = Types.Binding<T>
)"));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "ensure_free_variables_are_generialized_across_function_boundaries")
{
ScopedFastFlag _{FFlag::LuauSolverV2, true};
fileResolver.source["game/A"] = R"(
-- Roughly taken from react-shallow-renderer
function createUpdater(renderer)
local updater = {
_renderer = renderer,
}
function updater.enqueueForceUpdate(publicInstance, callback, _callerName)
updater._renderer.render(
updater._renderer,
updater._renderer._element,
updater._renderer._context
)
end
function updater.enqueueReplaceState(
publicInstance,
completeState,
callback,
_callerName
)
updater._renderer.render(
updater._renderer,
updater._renderer._element,
updater._renderer._context
)
end
function updater.enqueueSetState(publicInstance, partialState, callback, _callerName)
local currentState = updater._renderer._newState or publicInstance.state
updater._renderer.render(
updater._renderer,
updater._renderer._element,
updater._renderer._context
)
end
return updater
end
local ReactShallowRenderer = {}
function ReactShallowRenderer:_reset()
self._updater = createUpdater(self)
end
return ReactShallowRenderer
)";
LUAU_REQUIRE_NO_ERRORS(check(R"(
local ReactShallowRenderer = require(game.A);
)"));
}
TEST_SUITE_END(); TEST_SUITE_END();